eulaurarien/database.py
2021-08-14 23:27:28 +02:00

800 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Utility functions to interact with the database.
"""
import typing
import time
import logging
import coloredlogs
import pickle
import numpy
import math
import os
TLD_LIST: typing.Set[str] = set()
coloredlogs.install(level="DEBUG", fmt="%(asctime)s %(name)s %(levelname)s %(message)s")
Asn = int
Timestamp = int
Level = int
class Path:
pass
class RulePath(Path):
def __str__(self) -> str:
return "(rule)"
class RuleFirstPath(RulePath):
def __str__(self) -> str:
return "(first-party rule)"
class RuleMultiPath(RulePath):
def __str__(self) -> str:
return "(multi-party rule)"
class DomainPath(Path):
def __init__(self, parts: typing.List[str]):
self.parts = parts
def __str__(self) -> str:
return "?." + Database.unpack_domain(self)
class HostnamePath(DomainPath):
def __str__(self) -> str:
return Database.unpack_domain(self)
class ZonePath(DomainPath):
def __str__(self) -> str:
return "*." + Database.unpack_domain(self)
class AsnPath(Path):
def __init__(self, asn: Asn):
self.asn = asn
def __str__(self) -> str:
return Database.unpack_asn(self)
class Ip4Path(Path):
def __init__(self, value: int, prefixlen: int):
self.value = value
self.prefixlen = prefixlen
def __str__(self) -> str:
return Database.unpack_ip4network(self)
class Match:
def __init__(self) -> None:
self.source: typing.Optional[Path] = None
self.updated: int = 0
self.dupplicate: bool = False
# Cache
self.level: int = 0
self.first_party: bool = False
self.references: int = 0
def active(self, first_party: bool = None) -> bool:
if self.updated == 0 or (first_party and not self.first_party):
return False
return True
def disable(self) -> None:
self.updated = 0
class AsnNode(Match):
def __init__(self) -> None:
Match.__init__(self)
self.name = ""
class DomainTreeNode:
def __init__(self) -> None:
self.children: typing.Dict[str, DomainTreeNode] = dict()
self.match_zone = Match()
self.match_hostname = Match()
class IpTreeNode(Match):
def __init__(self) -> None:
Match.__init__(self)
self.zero: typing.Optional[IpTreeNode] = None
self.one: typing.Optional[IpTreeNode] = None
Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode]
MatchCallable = typing.Callable[[Path, Match], typing.Any]
class Profiler:
def __init__(self) -> None:
do_profile = int(os.environ.get("PROFILE", "0"))
if do_profile:
self.log = logging.getLogger("profiler")
self.time_last = time.perf_counter()
self.time_step = "init"
self.time_dict: typing.Dict[str, float] = dict()
self.step_dict: typing.Dict[str, int] = dict()
self.enter_step = self.enter_step_real
self.profile = self.profile_real
else:
self.enter_step = self.enter_step_dummy
self.profile = self.profile_dummy
def enter_step_dummy(self, name: str) -> None:
return
def enter_step_real(self, name: str) -> None:
now = time.perf_counter()
try:
self.time_dict[self.time_step] += now - self.time_last
self.step_dict[self.time_step] += int(name != self.time_step)
except KeyError:
self.time_dict[self.time_step] = now - self.time_last
self.step_dict[self.time_step] = 1
self.time_step = name
self.time_last = time.perf_counter()
def profile_dummy(self) -> None:
return
def profile_real(self) -> None:
self.enter_step("profile")
total = sum(self.time_dict.values())
for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]):
times = self.step_dict[key]
self.log.debug(
f"{key:<20}: {times:9d} × {secs/times:5.3e} "
f"= {secs:9.2f} s ({secs/total:7.2%}) "
)
self.log.debug(
f"{'total':<20}: " f"{total:9.2f} s ({1:7.2%})"
)
class Database(Profiler):
VERSION = 18
PATH = "blocking.p"
def initialize(self) -> None:
self.log.warning("Creating database version: %d ", Database.VERSION)
# Dummy match objects that everything refer to
self.rules: typing.List[Match] = list()
for first_party in (False, True):
m = Match()
m.updated = 1
m.level = 0
m.first_party = first_party
self.rules.append(m)
self.domtree = DomainTreeNode()
self.asns: typing.Dict[Asn, AsnNode] = dict()
self.ip4tree = IpTreeNode()
def load(self) -> None:
self.enter_step("load")
try:
with open(self.PATH, "rb") as db_fdsec:
version, data = pickle.load(db_fdsec)
if version == Database.VERSION:
self.rules, self.domtree, self.asns, self.ip4tree = data
return
self.log.warning(
"Outdated database version found: %d, " "it will be rebuilt.",
version,
)
except (TypeError, AttributeError, EOFError):
self.log.error(
"Corrupt (or heavily outdated) database found, " "it will be rebuilt."
)
except FileNotFoundError:
pass
self.initialize()
def save(self) -> None:
self.enter_step("save")
with open(self.PATH, "wb") as db_fdsec:
data = self.rules, self.domtree, self.asns, self.ip4tree
pickle.dump((self.VERSION, data), db_fdsec)
self.profile()
def __init__(self) -> None:
Profiler.__init__(self)
self.log = logging.getLogger("db")
self.load()
self.ip4cache_shift: int = 32
self.ip4cache = numpy.ones(1)
def _set_ip4cache(self, path: Path, _: Match) -> None:
assert isinstance(path, Ip4Path)
self.enter_step("set_ip4cache")
mini = path.value >> self.ip4cache_shift
maxi = (path.value + 2 ** (32 - path.prefixlen)) >> self.ip4cache_shift
if mini == maxi:
self.ip4cache[mini] = True
else:
self.ip4cache[mini:maxi] = True
def fill_ip4cache(self, max_size: int = 512 * 1024 ** 2) -> None:
"""
Size in bytes
"""
if max_size > 2 ** 32 / 8:
self.log.warning(
"Allocating more than 512 MiB of RAM for "
"the Ip4 cache is not necessary."
)
max_cache_width = int(math.log2(max(1, max_size * 8)))
allocated = False
cache_width = min(32, max_cache_width)
while not allocated:
cache_size = 2 ** cache_width
try:
self.ip4cache = numpy.zeros(cache_size, dtype=bool)
except MemoryError:
self.log.exception("Could not allocate cache. Retrying a smaller one.")
cache_width -= 1
continue
allocated = True
self.ip4cache_shift = 32 - cache_width
for _ in self.exec_each_ip4(self._set_ip4cache):
pass
@staticmethod
def populate_tld_list() -> None:
with open("temp/all_tld.list", "r") as tld_fdesc:
for tld in tld_fdesc:
tld = tld.strip()
TLD_LIST.add(tld)
@staticmethod
def validate_domain(path: str) -> bool:
if len(path) > 255:
return False
splits = path.split(".")
if not TLD_LIST:
Database.populate_tld_list()
if splits[-1] not in TLD_LIST:
return False
for split in splits:
if not 1 <= len(split) <= 63:
return False
return True
@staticmethod
def pack_domain(domain: str) -> DomainPath:
return DomainPath(domain.split(".")[::-1])
@staticmethod
def unpack_domain(domain: DomainPath) -> str:
return ".".join(domain.parts[::-1])
@staticmethod
def pack_asn(asn: str) -> AsnPath:
asn = asn.upper()
if asn.startswith("AS"):
asn = asn[2:]
return AsnPath(int(asn))
@staticmethod
def unpack_asn(asn: AsnPath) -> str:
return f"AS{asn.asn}"
@staticmethod
def validate_ip4address(path: str) -> bool:
splits = path.split(".")
if len(splits) != 4:
return False
for split in splits:
try:
if not 0 <= int(split) <= 255:
return False
except ValueError:
return False
return True
@staticmethod
def pack_ip4address_low(address: str) -> int:
addr = 0
for split in address.split("."):
octet = int(split)
addr = (addr << 8) + octet
return addr
@staticmethod
def pack_ip4address(address: str) -> Ip4Path:
return Ip4Path(Database.pack_ip4address_low(address), 32)
@staticmethod
def unpack_ip4address(address: Ip4Path) -> str:
addr = address.value
assert address.prefixlen == 32
octets: typing.List[int] = list()
octets = [0] * 4
for o in reversed(range(4)):
octets[o] = addr & 0xFF
addr >>= 8
return ".".join(map(str, octets))
@staticmethod
def validate_ip4network(path: str) -> bool:
# A bit generous but ok for our usage
splits = path.split("/")
if len(splits) != 2:
return False
if not Database.validate_ip4address(splits[0]):
return False
try:
if not 0 <= int(splits[1]) <= 32:
return False
except ValueError:
return False
return True
@staticmethod
def pack_ip4network(network: str) -> Ip4Path:
address, prefixlen_str = network.split("/")
prefixlen = int(prefixlen_str)
addr = Database.pack_ip4address(address)
addr.prefixlen = prefixlen
return addr
@staticmethod
def unpack_ip4network(network: Ip4Path) -> str:
addr = network.value
octets: typing.List[int] = list()
octets = [0] * 4
for o in reversed(range(4)):
octets[o] = addr & 0xFF
addr >>= 8
return ".".join(map(str, octets)) + "/" + str(network.prefixlen)
def get_match(self, path: Path) -> Match:
if isinstance(path, RuleMultiPath):
return self.rules[0]
elif isinstance(path, RuleFirstPath):
return self.rules[1]
elif isinstance(path, AsnPath):
return self.asns[path.asn]
elif isinstance(path, DomainPath):
dicd = self.domtree
for part in path.parts:
dicd = dicd.children[part]
if isinstance(path, HostnamePath):
return dicd.match_hostname
elif isinstance(path, ZonePath):
return dicd.match_zone
else:
raise ValueError
elif isinstance(path, Ip4Path):
dici = self.ip4tree
for i in range(31, 31 - path.prefixlen, -1):
bit = (path.value >> i) & 0b1
dici_next = dici.one if bit else dici.zero
if not dici_next:
raise IndexError
dici = dici_next
return dici
else:
raise ValueError
def exec_each_asn(
self,
callback: MatchCallable,
) -> typing.Any:
for asn in self.asns:
match = self.asns[asn]
if match.active():
c = callback(
AsnPath(asn),
match,
)
try:
yield from c
except TypeError: # not iterable
pass
def exec_each_domain(
self,
callback: MatchCallable,
_dic: DomainTreeNode = None,
_par: DomainPath = None,
) -> typing.Any:
_dic = _dic or self.domtree
_par = _par or DomainPath([])
if _dic.match_hostname.active():
c = callback(
HostnamePath(_par.parts),
_dic.match_hostname,
)
try:
yield from c
except TypeError: # not iterable
pass
if _dic.match_zone.active():
c = callback(
ZonePath(_par.parts),
_dic.match_zone,
)
try:
yield from c
except TypeError: # not iterable
pass
for part in _dic.children:
dic = _dic.children[part]
yield from self.exec_each_domain(
callback, _dic=dic, _par=DomainPath(_par.parts + [part])
)
def exec_each_ip4(
self,
callback: MatchCallable,
_dic: IpTreeNode = None,
_par: Ip4Path = None,
) -> typing.Any:
_dic = _dic or self.ip4tree
_par = _par or Ip4Path(0, 0)
if _dic.active():
c = callback(
_par,
_dic,
)
try:
yield from c
except TypeError: # not iterable
pass
# 0
pref = _par.prefixlen + 1
dic = _dic.zero
if dic:
# addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-pref)))
# assert addr0 == _par.value
addr0 = _par.value
yield from self.exec_each_ip4(callback, _dic=dic, _par=Ip4Path(addr0, pref))
# 1
dic = _dic.one
if dic:
addr1 = _par.value | (1 << (32 - pref))
# assert addr1 != _par.value
yield from self.exec_each_ip4(callback, _dic=dic, _par=Ip4Path(addr1, pref))
def exec_each(
self,
callback: MatchCallable,
) -> typing.Any:
yield from self.exec_each_domain(callback)
yield from self.exec_each_ip4(callback)
yield from self.exec_each_asn(callback)
def update_references(self) -> None:
# Should be correctly calculated normally,
# keeping this just in case
def reset_references_cb(path: Path, match: Match) -> None:
match.references = 0
for _ in self.exec_each(reset_references_cb):
pass
def increment_references_cb(path: Path, match: Match) -> None:
if match.source:
source = self.get_match(match.source)
source.references += 1
for _ in self.exec_each(increment_references_cb):
pass
def _clean_deps(self) -> None:
# Disable the matches that depends on the targeted
# matches until all disabled matches reference count = 0
did_something = True
def clean_deps_cb(path: Path, match: Match) -> None:
nonlocal did_something
if not match.source:
return
source = self.get_match(match.source)
if not source.active():
self._unset_match(match)
elif match.first_party > source.first_party:
match.first_party = source.first_party
else:
return
did_something = True
while did_something:
did_something = False
self.enter_step("pass_clean_deps")
for _ in self.exec_each(clean_deps_cb):
pass
def prune(self, before: int, base_only: bool = False) -> None:
# Disable the matches targeted
def prune_cb(path: Path, match: Match) -> None:
if base_only and match.level > 1:
return
if match.updated > before:
return
self._unset_match(match)
self.log.debug("Print: disabled %s", path)
self.enter_step("pass_prune")
for _ in self.exec_each(prune_cb):
pass
self._clean_deps()
# Remove branches with no match
# TODO
def explain(self, path: Path) -> str:
match = self.get_match(path)
string = str(path)
if isinstance(match, AsnNode):
string += f" ({match.name})"
party_char = "F" if match.first_party else "M"
dup_char = "D" if match.dupplicate else "_"
string += f" {match.level}{party_char}{dup_char}{match.references}"
if match.source:
string += f"{self.explain(match.source)}"
return string
def list_records(
self,
first_party_only: bool = False,
end_chain_only: bool = False,
no_dupplicates: bool = False,
rules_only: bool = False,
hostnames_only: bool = False,
explain: bool = False,
) -> typing.Iterable[str]:
def export_cb(path: Path, match: Match) -> typing.Iterable[str]:
if first_party_only and not match.first_party:
return
if end_chain_only and match.references > 0:
return
if no_dupplicates and match.dupplicate:
return
if rules_only and match.level > 1:
return
if hostnames_only and not isinstance(path, HostnamePath):
return
if explain:
yield self.explain(path)
else:
yield str(path)
yield from self.exec_each(export_cb)
def count_records(
self,
first_party_only: bool = False,
end_chain_only: bool = False,
no_dupplicates: bool = False,
rules_only: bool = False,
hostnames_only: bool = False,
) -> str:
memo: typing.Dict[str, int] = dict()
def count_records_cb(path: Path, match: Match) -> None:
if first_party_only and not match.first_party:
return
if end_chain_only and match.references > 0:
return
if no_dupplicates and match.dupplicate:
return
if rules_only and match.level > 1:
return
if hostnames_only and not isinstance(path, HostnamePath):
return
try:
memo[path.__class__.__name__] += 1
except KeyError:
memo[path.__class__.__name__] = 1
for _ in self.exec_each(count_records_cb):
pass
split: typing.List[str] = list()
for key, value in sorted(memo.items(), key=lambda s: s[0]):
split.append(f"{key[:-4].lower()}s: {value}")
return ", ".join(split)
def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]:
self.enter_step("get_domain_pack")
domain = self.pack_domain(domain_str)
self.enter_step("get_domain_brws")
dic = self.domtree
depth = 0
for part in domain.parts:
if dic.match_zone.active():
self.enter_step("get_domain_yield")
yield ZonePath(domain.parts[:depth])
self.enter_step("get_domain_brws")
if part not in dic.children:
return
dic = dic.children[part]
depth += 1
if dic.match_zone.active():
self.enter_step("get_domain_yield")
yield ZonePath(domain.parts)
if dic.match_hostname.active():
self.enter_step("get_domain_yield")
yield HostnamePath(domain.parts)
def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]:
self.enter_step("get_ip4_pack")
ip4val = self.pack_ip4address_low(ip4_str)
self.enter_step("get_ip4_cache")
if not self.ip4cache[ip4val >> self.ip4cache_shift]:
return
self.enter_step("get_ip4_brws")
dic = self.ip4tree
for i in range(31, -1, -1):
bit = (ip4val >> i) & 0b1
if dic.active():
self.enter_step("get_ip4_yield")
yield Ip4Path(ip4val >> (i + 1) << (i + 1), 31 - i)
self.enter_step("get_ip4_brws")
next_dic = dic.one if bit else dic.zero
if next_dic is None:
return
dic = next_dic
if dic.active():
self.enter_step("get_ip4_yield")
yield Ip4Path(ip4val, 32)
def _unset_match(
self,
match: Match,
) -> None:
match.disable()
if match.source:
source_match = self.get_match(match.source)
source_match.references -= 1
def _set_match(
self,
match: Match,
updated: int,
source: Path,
source_match: Match = None,
dupplicate: bool = False,
) -> None:
# source_match is in parameters because most of the time
# its parent function needs it too,
# so it can pass it to save a traversal
source_match = source_match or self.get_match(source)
new_level = source_match.level + 1
if (
updated > match.updated
or new_level < match.level
or source_match.first_party > match.first_party
):
# NOTE FP and level of matches referencing this one
# won't be updated until run or prune
if match.source:
old_source = self.get_match(match.source)
old_source.references -= 1
match.updated = updated
match.level = new_level
match.first_party = source_match.first_party
match.source = source
source_match.references += 1
match.dupplicate = dupplicate
def _set_domain(
self, hostname: bool, domain_str: str, updated: int, source: Path
) -> None:
self.enter_step("set_domain_val")
if not Database.validate_domain(domain_str):
raise ValueError(f"Invalid domain: {domain_str}")
self.enter_step("set_domain_pack")
domain = self.pack_domain(domain_str)
self.enter_step("set_domain_fp")
source_match = self.get_match(source)
is_first_party = source_match.first_party
self.enter_step("set_domain_brws")
dic = self.domtree
dupplicate = False
for part in domain.parts:
if part not in dic.children:
dic.children[part] = DomainTreeNode()
dic = dic.children[part]
if dic.match_zone.active(is_first_party):
dupplicate = True
if hostname:
match = dic.match_hostname
else:
match = dic.match_zone
self._set_match(
match,
updated,
source,
source_match=source_match,
dupplicate=dupplicate,
)
def set_hostname(self, *args: typing.Any, **kwargs: typing.Any) -> None:
self._set_domain(True, *args, **kwargs)
def set_zone(self, *args: typing.Any, **kwargs: typing.Any) -> None:
self._set_domain(False, *args, **kwargs)
def set_asn(self, asn_str: str, updated: int, source: Path) -> None:
self.enter_step("set_asn")
path = self.pack_asn(asn_str)
if path.asn in self.asns:
match = self.asns[path.asn]
else:
match = AsnNode()
self.asns[path.asn] = match
self._set_match(
match,
updated,
source,
)
def _set_ip4(self, ip4: Ip4Path, updated: int, source: Path) -> None:
self.enter_step("set_ip4_fp")
source_match = self.get_match(source)
is_first_party = source_match.first_party
self.enter_step("set_ip4_brws")
dic = self.ip4tree
dupplicate = False
for i in range(31, 31 - ip4.prefixlen, -1):
bit = (ip4.value >> i) & 0b1
next_dic = dic.one if bit else dic.zero
if next_dic is None:
next_dic = IpTreeNode()
if bit:
dic.one = next_dic
else:
dic.zero = next_dic
dic = next_dic
if dic.active(is_first_party):
dupplicate = True
self._set_match(
dic,
updated,
source,
source_match=source_match,
dupplicate=dupplicate,
)
self._set_ip4cache(ip4, dic)
def set_ip4address(
self, ip4address_str: str, *args: typing.Any, **kwargs: typing.Any
) -> None:
self.enter_step("set_ip4add_val")
if not Database.validate_ip4address(ip4address_str):
raise ValueError(f"Invalid ip4address: {ip4address_str}")
self.enter_step("set_ip4add_pack")
ip4 = self.pack_ip4address(ip4address_str)
self._set_ip4(ip4, *args, **kwargs)
def set_ip4network(
self, ip4network_str: str, *args: typing.Any, **kwargs: typing.Any
) -> None:
self.enter_step("set_ip4net_val")
if not Database.validate_ip4network(ip4network_str):
raise ValueError(f"Invalid ip4network: {ip4network_str}")
self.enter_step("set_ip4net_pack")
ip4 = self.pack_ip4network(ip4network_str)
self._set_ip4(ip4, *args, **kwargs)