#!/usr/bin/env python3 """ Utility functions to interact with the database. """ import typing import time import logging import coloredlogs import pickle import numpy import math import os TLD_LIST: typing.Set[str] = set() coloredlogs.install( level='DEBUG', fmt='%(asctime)s %(name)s %(levelname)s %(message)s' ) Asn = int Timestamp = int Level = int class Path(): pass class RulePath(Path): def __str__(self) -> str: return '(rule)' class RuleFirstPath(RulePath): def __str__(self) -> str: return '(first-party rule)' class RuleMultiPath(RulePath): def __str__(self) -> str: return '(multi-party rule)' class DomainPath(Path): def __init__(self, parts: typing.List[str]): self.parts = parts def __str__(self) -> str: return '?.' + Database.unpack_domain(self) class HostnamePath(DomainPath): def __str__(self) -> str: return Database.unpack_domain(self) class ZonePath(DomainPath): def __str__(self) -> str: return '*.' + Database.unpack_domain(self) class AsnPath(Path): def __init__(self, asn: Asn): self.asn = asn def __str__(self) -> str: return Database.unpack_asn(self) class Ip4Path(Path): def __init__(self, value: int, prefixlen: int): self.value = value self.prefixlen = prefixlen def __str__(self) -> str: return Database.unpack_ip4network(self) class Match(): def __init__(self) -> None: self.source: typing.Optional[Path] = None self.updated: int = 0 self.dupplicate: bool = False # Cache self.level: int = 0 self.first_party: bool = False self.references: int = 0 def active(self, first_party: bool = None) -> bool: if self.updated == 0 or (first_party and not self.first_party): return False return True def disable(self) -> None: self.updated = 0 class AsnNode(Match): def __init__(self) -> None: Match.__init__(self) self.name = '' class DomainTreeNode(): def __init__(self) -> None: self.children: typing.Dict[str, DomainTreeNode] = dict() self.match_zone = Match() self.match_hostname = Match() class IpTreeNode(Match): def __init__(self) -> None: Match.__init__(self) self.zero: typing.Optional[IpTreeNode] = None self.one: typing.Optional[IpTreeNode] = None Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode] MatchCallable = typing.Callable[[Path, Match], typing.Any] class Profiler(): def __init__(self) -> None: do_profile = int(os.environ.get('PROFILE', '0')) if do_profile: self.log = logging.getLogger('profiler') self.time_last = time.perf_counter() self.time_step = 'init' self.time_dict: typing.Dict[str, float] = dict() self.step_dict: typing.Dict[str, int] = dict() self.enter_step = self.enter_step_real self.profile = self.profile_real else: self.enter_step = self.enter_step_dummy self.profile = self.profile_dummy def enter_step_dummy(self, name: str) -> None: return def enter_step_real(self, name: str) -> None: now = time.perf_counter() try: self.time_dict[self.time_step] += now - self.time_last self.step_dict[self.time_step] += int(name != self.time_step) except KeyError: self.time_dict[self.time_step] = now - self.time_last self.step_dict[self.time_step] = 1 self.time_step = name self.time_last = time.perf_counter() def profile_dummy(self) -> None: return def profile_real(self) -> None: self.enter_step('profile') total = sum(self.time_dict.values()) for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]): times = self.step_dict[key] self.log.debug(f"{key:<20}: {times:9d} × {secs/times:5.3e} " f"= {secs:9.2f} s ({secs/total:7.2%}) ") self.log.debug(f"{'total':<20}: " f"{total:9.2f} s ({1:7.2%})") class Database(Profiler): VERSION = 18 PATH = "blocking.p" def initialize(self) -> None: self.log.warning( "Creating database version: %d ", Database.VERSION) # Dummy match objects that everything refer to self.rules: typing.List[Match] = list() for first_party in (False, True): m = Match() m.updated = 1 m.level = 0 m.first_party = first_party self.rules.append(m) self.domtree = DomainTreeNode() self.asns: typing.Dict[Asn, AsnNode] = dict() self.ip4tree = IpTreeNode() def load(self) -> None: self.enter_step('load') try: with open(self.PATH, 'rb') as db_fdsec: version, data = pickle.load(db_fdsec) if version == Database.VERSION: self.rules, self.domtree, self.asns, self.ip4tree = data return self.log.warning( "Outdated database version found: %d, " "it will be rebuilt.", version) except (TypeError, AttributeError, EOFError): self.log.error( "Corrupt (or heavily outdated) database found, " "it will be rebuilt.") except FileNotFoundError: pass self.initialize() def save(self) -> None: self.enter_step('save') with open(self.PATH, 'wb') as db_fdsec: data = self.rules, self.domtree, self.asns, self.ip4tree pickle.dump((self.VERSION, data), db_fdsec) self.profile() def __init__(self) -> None: Profiler.__init__(self) self.log = logging.getLogger('db') self.load() self.ip4cache_shift: int = 32 self.ip4cache = numpy.ones(1) def _set_ip4cache(self, path: Path, _: Match) -> None: assert isinstance(path, Ip4Path) self.enter_step('set_ip4cache') mini = path.value >> self.ip4cache_shift maxi = (path.value + 2**(32-path.prefixlen)) >> self.ip4cache_shift if mini == maxi: self.ip4cache[mini] = True else: self.ip4cache[mini:maxi] = True def fill_ip4cache(self, max_size: int = 512*1024**2) -> None: """ Size in bytes """ if max_size > 2**32/8: self.log.warning("Allocating more than 512 MiB of RAM for " "the Ip4 cache is not necessary.") max_cache_width = int(math.log2(max(1, max_size*8))) allocated = False cache_width = min(2**32, max_cache_width) while not allocated: cache_size = 2**cache_width try: self.ip4cache = numpy.zeros(cache_size, dtype=numpy.bool) except MemoryError: self.log.exception( "Could not allocate cache. Retrying a smaller one.") cache_width -= 1 continue allocated = True self.ip4cache_shift = 32-cache_width for _ in self.exec_each_ip4(self._set_ip4cache): pass @staticmethod def populate_tld_list() -> None: with open('temp/all_tld.list', 'r') as tld_fdesc: for tld in tld_fdesc: tld = tld.strip() TLD_LIST.add(tld) @staticmethod def validate_domain(path: str) -> bool: if len(path) > 255: return False splits = path.split('.') if not TLD_LIST: Database.populate_tld_list() if splits[-1] not in TLD_LIST: return False for split in splits: if not 1 <= len(split) <= 63: return False return True @staticmethod def pack_domain(domain: str) -> DomainPath: return DomainPath(domain.split('.')[::-1]) @staticmethod def unpack_domain(domain: DomainPath) -> str: return '.'.join(domain.parts[::-1]) @staticmethod def pack_asn(asn: str) -> AsnPath: asn = asn.upper() if asn.startswith('AS'): asn = asn[2:] return AsnPath(int(asn)) @staticmethod def unpack_asn(asn: AsnPath) -> str: return f'AS{asn.asn}' @staticmethod def validate_ip4address(path: str) -> bool: splits = path.split('.') if len(splits) != 4: return False for split in splits: try: if not 0 <= int(split) <= 255: return False except ValueError: return False return True @staticmethod def pack_ip4address_low(address: str) -> int: addr = 0 for split in address.split('.'): octet = int(split) addr = (addr << 8) + octet return addr @staticmethod def pack_ip4address(address: str) -> Ip4Path: return Ip4Path(Database.pack_ip4address_low(address), 32) @staticmethod def unpack_ip4address(address: Ip4Path) -> str: addr = address.value assert address.prefixlen == 32 octets: typing.List[int] = list() octets = [0] * 4 for o in reversed(range(4)): octets[o] = addr & 0xFF addr >>= 8 return '.'.join(map(str, octets)) @staticmethod def validate_ip4network(path: str) -> bool: # A bit generous but ok for our usage splits = path.split('/') if len(splits) != 2: return False if not Database.validate_ip4address(splits[0]): return False try: if not 0 <= int(splits[1]) <= 32: return False except ValueError: return False return True @staticmethod def pack_ip4network(network: str) -> Ip4Path: address, prefixlen_str = network.split('/') prefixlen = int(prefixlen_str) addr = Database.pack_ip4address(address) addr.prefixlen = prefixlen return addr @staticmethod def unpack_ip4network(network: Ip4Path) -> str: addr = network.value octets: typing.List[int] = list() octets = [0] * 4 for o in reversed(range(4)): octets[o] = addr & 0xFF addr >>= 8 return '.'.join(map(str, octets)) + '/' + str(network.prefixlen) def get_match(self, path: Path) -> Match: if isinstance(path, RuleMultiPath): return self.rules[0] elif isinstance(path, RuleFirstPath): return self.rules[1] elif isinstance(path, AsnPath): return self.asns[path.asn] elif isinstance(path, DomainPath): dicd = self.domtree for part in path.parts: dicd = dicd.children[part] if isinstance(path, HostnamePath): return dicd.match_hostname elif isinstance(path, ZonePath): return dicd.match_zone else: raise ValueError elif isinstance(path, Ip4Path): dici = self.ip4tree for i in range(31, 31-path.prefixlen, -1): bit = (path.value >> i) & 0b1 dici_next = dici.one if bit else dici.zero if not dici_next: raise IndexError dici = dici_next return dici else: raise ValueError def exec_each_asn(self, callback: MatchCallable, ) -> typing.Any: for asn in self.asns: match = self.asns[asn] if match.active(): c = callback( AsnPath(asn), match, ) try: yield from c except TypeError: # not iterable pass def exec_each_domain(self, callback: MatchCallable, _dic: DomainTreeNode = None, _par: DomainPath = None, ) -> typing.Any: _dic = _dic or self.domtree _par = _par or DomainPath([]) if _dic.match_hostname.active(): c = callback( HostnamePath(_par.parts), _dic.match_hostname, ) try: yield from c except TypeError: # not iterable pass if _dic.match_zone.active(): c = callback( ZonePath(_par.parts), _dic.match_zone, ) try: yield from c except TypeError: # not iterable pass for part in _dic.children: dic = _dic.children[part] yield from self.exec_each_domain( callback, _dic=dic, _par=DomainPath(_par.parts + [part]) ) def exec_each_ip4(self, callback: MatchCallable, _dic: IpTreeNode = None, _par: Ip4Path = None, ) -> typing.Any: _dic = _dic or self.ip4tree _par = _par or Ip4Path(0, 0) if _dic.active(): c = callback( _par, _dic, ) try: yield from c except TypeError: # not iterable pass # 0 pref = _par.prefixlen + 1 dic = _dic.zero if dic: # addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-pref))) # assert addr0 == _par.value addr0 = _par.value yield from self.exec_each_ip4( callback, _dic=dic, _par=Ip4Path(addr0, pref) ) # 1 dic = _dic.one if dic: addr1 = _par.value | (1 << (32-pref)) # assert addr1 != _par.value yield from self.exec_each_ip4( callback, _dic=dic, _par=Ip4Path(addr1, pref) ) def exec_each(self, callback: MatchCallable, ) -> typing.Any: yield from self.exec_each_domain(callback) yield from self.exec_each_ip4(callback) yield from self.exec_each_asn(callback) def update_references(self) -> None: # Should be correctly calculated normally, # keeping this just in case def reset_references_cb(path: Path, match: Match ) -> None: match.references = 0 for _ in self.exec_each(reset_references_cb): pass def increment_references_cb(path: Path, match: Match ) -> None: if match.source: source = self.get_match(match.source) source.references += 1 for _ in self.exec_each(increment_references_cb): pass def _clean_deps(self) -> None: # Disable the matches that depends on the targeted # matches until all disabled matches reference count = 0 did_something = True def clean_deps_cb(path: Path, match: Match ) -> None: nonlocal did_something if not match.source: return source = self.get_match(match.source) if not source.active(): self._unset_match(match) elif match.first_party > source.first_party: match.first_party = source.first_party else: return did_something = True while did_something: did_something = False self.enter_step('pass_clean_deps') for _ in self.exec_each(clean_deps_cb): pass def prune(self, before: int, base_only: bool = False) -> None: # Disable the matches targeted def prune_cb(path: Path, match: Match ) -> None: if base_only and match.level > 1: return if match.updated > before: return self._unset_match(match) self.log.debug("Print: disabled %s", path) self.enter_step('pass_prune') for _ in self.exec_each(prune_cb): pass self._clean_deps() # Remove branches with no match # TODO def explain(self, path: Path) -> str: match = self.get_match(path) string = str(path) if isinstance(match, AsnNode): string += f' ({match.name})' party_char = 'F' if match.first_party else 'M' dup_char = 'D' if match.dupplicate else '_' string += f' {match.level}{party_char}{dup_char}{match.references}' if match.source: string += f' ← {self.explain(match.source)}' return string def list_records(self, first_party_only: bool = False, end_chain_only: bool = False, no_dupplicates: bool = False, rules_only: bool = False, hostnames_only: bool = False, explain: bool = False, ) -> typing.Iterable[str]: def export_cb(path: Path, match: Match ) -> typing.Iterable[str]: if first_party_only and not match.first_party: return if end_chain_only and match.references > 0: return if no_dupplicates and match.dupplicate: return if rules_only and match.level > 1: return if hostnames_only and not isinstance(path, HostnamePath): return if explain: yield self.explain(path) else: yield str(path) yield from self.exec_each(export_cb) def count_records(self, first_party_only: bool = False, end_chain_only: bool = False, no_dupplicates: bool = False, rules_only: bool = False, hostnames_only: bool = False, ) -> str: memo: typing.Dict[str, int] = dict() def count_records_cb(path: Path, match: Match) -> None: if first_party_only and not match.first_party: return if end_chain_only and match.references > 0: return if no_dupplicates and match.dupplicate: return if rules_only and match.level > 1: return if hostnames_only and not isinstance(path, HostnamePath): return try: memo[path.__class__.__name__] += 1 except KeyError: memo[path.__class__.__name__] = 1 for _ in self.exec_each(count_records_cb): pass split: typing.List[str] = list() for key, value in sorted(memo.items(), key=lambda s: s[0]): split.append(f'{key[:-4].lower()}s: {value}') return ', '.join(split) def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]: self.enter_step('get_domain_pack') domain = self.pack_domain(domain_str) self.enter_step('get_domain_brws') dic = self.domtree depth = 0 for part in domain.parts: if dic.match_zone.active(): self.enter_step('get_domain_yield') yield ZonePath(domain.parts[:depth]) self.enter_step('get_domain_brws') if part not in dic.children: return dic = dic.children[part] depth += 1 if dic.match_zone.active(): self.enter_step('get_domain_yield') yield ZonePath(domain.parts) if dic.match_hostname.active(): self.enter_step('get_domain_yield') yield HostnamePath(domain.parts) def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]: self.enter_step('get_ip4_pack') ip4val = self.pack_ip4address_low(ip4_str) self.enter_step('get_ip4_cache') if not self.ip4cache[ip4val >> self.ip4cache_shift]: return self.enter_step('get_ip4_brws') dic = self.ip4tree for i in range(31, -1, -1): bit = (ip4val >> i) & 0b1 if dic.active(): self.enter_step('get_ip4_yield') yield Ip4Path(ip4val >> (i+1) << (i+1), 31-i) self.enter_step('get_ip4_brws') next_dic = dic.one if bit else dic.zero if next_dic is None: return dic = next_dic if dic.active(): self.enter_step('get_ip4_yield') yield Ip4Path(ip4val, 32) def _unset_match(self, match: Match, ) -> None: match.disable() if match.source: source_match = self.get_match(match.source) source_match.references -= 1 def _set_match(self, match: Match, updated: int, source: Path, source_match: Match = None, dupplicate: bool = False, ) -> None: # source_match is in parameters because most of the time # its parent function needs it too, # so it can pass it to save a traversal source_match = source_match or self.get_match(source) new_level = source_match.level + 1 if updated > match.updated or new_level < match.level \ or source_match.first_party > match.first_party: # NOTE FP and level of matches referencing this one # won't be updated until run or prune if match.source: old_source = self.get_match(match.source) old_source.references -= 1 match.updated = updated match.level = new_level match.first_party = source_match.first_party match.source = source source_match.references += 1 match.dupplicate = dupplicate def _set_domain(self, hostname: bool, domain_str: str, updated: int, source: Path) -> None: self.enter_step('set_domain_val') if not Database.validate_domain(domain_str): raise ValueError(f"Invalid domain: {domain_str}") self.enter_step('set_domain_pack') domain = self.pack_domain(domain_str) self.enter_step('set_domain_fp') source_match = self.get_match(source) is_first_party = source_match.first_party self.enter_step('set_domain_brws') dic = self.domtree dupplicate = False for part in domain.parts: if part not in dic.children: dic.children[part] = DomainTreeNode() dic = dic.children[part] if dic.match_zone.active(is_first_party): dupplicate = True if hostname: match = dic.match_hostname else: match = dic.match_zone self._set_match( match, updated, source, source_match=source_match, dupplicate=dupplicate, ) def set_hostname(self, *args: typing.Any, **kwargs: typing.Any ) -> None: self._set_domain(True, *args, **kwargs) def set_zone(self, *args: typing.Any, **kwargs: typing.Any ) -> None: self._set_domain(False, *args, **kwargs) def set_asn(self, asn_str: str, updated: int, source: Path) -> None: self.enter_step('set_asn') path = self.pack_asn(asn_str) if path.asn in self.asns: match = self.asns[path.asn] else: match = AsnNode() self.asns[path.asn] = match self._set_match( match, updated, source, ) def _set_ip4(self, ip4: Ip4Path, updated: int, source: Path) -> None: self.enter_step('set_ip4_fp') source_match = self.get_match(source) is_first_party = source_match.first_party self.enter_step('set_ip4_brws') dic = self.ip4tree dupplicate = False for i in range(31, 31-ip4.prefixlen, -1): bit = (ip4.value >> i) & 0b1 next_dic = dic.one if bit else dic.zero if next_dic is None: next_dic = IpTreeNode() if bit: dic.one = next_dic else: dic.zero = next_dic dic = next_dic if dic.active(is_first_party): dupplicate = True self._set_match( dic, updated, source, source_match=source_match, dupplicate=dupplicate, ) self._set_ip4cache(ip4, dic) def set_ip4address(self, ip4address_str: str, *args: typing.Any, **kwargs: typing.Any ) -> None: self.enter_step('set_ip4add_val') if not Database.validate_ip4address(ip4address_str): raise ValueError(f"Invalid ip4address: {ip4address_str}") self.enter_step('set_ip4add_pack') ip4 = self.pack_ip4address(ip4address_str) self._set_ip4(ip4, *args, **kwargs) def set_ip4network(self, ip4network_str: str, *args: typing.Any, **kwargs: typing.Any ) -> None: self.enter_step('set_ip4net_val') if not Database.validate_ip4network(ip4network_str): raise ValueError(f"Invalid ip4network: {ip4network_str}") self.enter_step('set_ip4net_pack') ip4 = self.pack_ip4network(ip4network_str) self._set_ip4(ip4, *args, **kwargs)