#!/usr/bin/env python3 """ Utility functions to interact with the database. """ import typing import time import logging import coloredlogs import pickle import enum coloredlogs.install( level='DEBUG', fmt='%(asctime)s %(name)s %(levelname)s %(message)s' ) PathType = enum.Enum('PathType', 'Rule Hostname Zone Asn Ip4 Ip6') RulePath = typing.Union[None] Asn = int DomainPath = typing.List[str] Ip4Path = typing.Tuple[int, int] # value, prefixlen Ip6Path = typing.List[int] Path = typing.Union[RulePath, DomainPath, Asn, Ip4Path, Ip6Path] TypedPath = typing.Tuple[PathType, Path] Timestamp = int Level = int Match = typing.Tuple[Timestamp, TypedPath, Level] DebugPath = (PathType.Rule, None) class DomainTreeNode(): def __init__(self) -> None: self.children: typing.Dict[str, DomainTreeNode] = dict() self.match_zone: typing.Optional[Match] = None self.match_hostname: typing.Optional[Match] = None class IpTreeNode(): def __init__(self) -> None: self.children: typing.List[typing.Optional[IpTreeNode]] = [None, None] self.match: typing.Optional[Match] = None class Profiler(): def __init__(self) -> None: self.log = logging.getLogger('profiler') self.time_last = time.perf_counter() self.time_step = 'init' self.time_dict: typing.Dict[str, float] = dict() self.step_dict: typing.Dict[str, int] = dict() def enter_step(self, name: str) -> None: now = time.perf_counter() try: self.time_dict[self.time_step] += now - self.time_last self.step_dict[self.time_step] += int(name != self.time_step) except KeyError: self.time_dict[self.time_step] = now - self.time_last self.step_dict[self.time_step] = 1 self.time_step = name self.time_last = time.perf_counter() def profile(self) -> None: self.enter_step('profile') total = sum(self.time_dict.values()) for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]): times = self.step_dict[key] self.log.debug(f"{key:<20}: {times:9d} × {secs/times:5.3e} " f"= {secs:9.2f} s ({secs/total:7.2%}) ") self.log.debug(f"{'total':<20}: " f"{total:9.2f} s ({1:7.2%})") class Database(Profiler): VERSION = 8 PATH = "blocking.p" def initialize(self) -> None: self.log.warning( "Creating database version: %d ", Database.VERSION) self.domtree = DomainTreeNode() self.asns: typing.Set[Asn] = set() self.ip4tree = IpTreeNode() def load(self) -> None: self.enter_step('load') try: with open(self.PATH, 'rb') as db_fdsec: version, data = pickle.load(db_fdsec) if version == Database.VERSION: self.domtree, self.asns, self.ip4tree = data return self.log.warning( "Outdated database version found: %d, " "will be rebuilt.", version) except (TypeError, AttributeError, EOFError): self.log.error( "Corrupt database found, " "will be rebuilt.") except FileNotFoundError: pass self.initialize() def save(self) -> None: self.enter_step('save') with open(self.PATH, 'wb') as db_fdsec: data = self.domtree, self.asns, self.ip4tree pickle.dump((self.VERSION, data), db_fdsec) self.profile() def __init__(self) -> None: Profiler.__init__(self) self.log = logging.getLogger('db') self.load() @staticmethod def pack_domain(domain: str) -> DomainPath: return domain.split('.')[::-1] @staticmethod def unpack_domain(domain: DomainPath) -> str: return '.'.join(domain[::-1]) @staticmethod def pack_asn(asn: str) -> int: asn = asn.upper() if asn.startswith('AS'): asn = asn[2:] return int(asn) @staticmethod def unpack_asn(asn: int) -> str: return f'AS{asn}' @staticmethod def pack_ip4address(address: str) -> Ip4Path: addr = 0 for split in address.split('.'): addr = (addr << 8) + int(split) return (addr, 32) @staticmethod def unpack_ip4address(address: Ip4Path) -> str: addr, prefixlen = address assert prefixlen == 32 octets: typing.List[int] = list() octets = [0] * 4 for o in reversed(range(4)): octets[o] = addr & 0xFF addr >>= 8 return '.'.join(map(str, octets)) @staticmethod def pack_ip4network(network: str) -> Ip4Path: address, prefixlen_str = network.split('/') prefixlen = int(prefixlen_str) addr, _ = Database.pack_ip4address(address) return (addr, prefixlen) @staticmethod def unpack_ip4network(network: Ip4Path) -> str: address, prefixlen = network addr = Database.unpack_ip4address((address, 32)) return f'{addr}/{prefixlen}' def update_references(self) -> None: raise NotImplementedError def prune(self, before: int, base_only: bool = False) -> None: raise NotImplementedError def explain(self, entry: int) -> str: raise NotImplementedError def export(self, first_party_only: bool = False, end_chain_only: bool = False, explain: bool = False, _dic: DomainTreeNode = None, _par: DomainPath = None, ) -> typing.Iterable[str]: if first_party_only or end_chain_only or explain: raise NotImplementedError _dic = _dic or self.domtree _par = _par or list() if _dic.match_hostname: yield self.unpack_domain(_par) for part in _dic.children: dic = _dic.children[part] yield from self.export(_dic=dic, _par=_par + [part]) def count_rules(self, first_party_only: bool = False, ) -> str: raise NotImplementedError def get_domain(self, domain_str: str) -> typing.Iterable[TypedPath]: self.enter_step('get_domain_pack') domain = self.pack_domain(domain_str) self.enter_step('get_domain_brws') dic = self.domtree depth = 0 for part in domain: if dic.match_zone: self.enter_step('get_domain_yield') yield (PathType.Zone, domain[:depth]) self.enter_step('get_domain_brws') if part not in dic.children: return dic = dic.children[part] depth += 1 if dic.match_zone: self.enter_step('get_domain_yield') yield (PathType.Zone, domain) if dic.match_hostname: self.enter_step('get_domain_yield') yield (PathType.Hostname, domain) def get_ip4(self, ip4_str: str) -> typing.Iterable[TypedPath]: self.enter_step('get_ip4_pack') ip4, prefixlen = self.pack_ip4address(ip4_str) self.enter_step('get_ip4_brws') dic = self.ip4tree for i in reversed(range(prefixlen)): part = (ip4 >> i) & 0b1 if dic.match: self.enter_step('get_ip4_yield') yield (PathType.Ip4, (ip4, 32-i)) self.enter_step('get_ip4_brws') next_dic = dic.children[part] if next_dic is None: return dic = next_dic if dic.match: self.enter_step('get_ip4_yield') yield (PathType.Ip4, ip4) def list_asn(self) -> typing.Iterable[TypedPath]: for asn in self.asns: yield (PathType.Asn, asn) def set_hostname(self, hostname_str: str, updated: int, is_first_party: bool = None, source: TypedPath = None) -> None: self.enter_step('set_hostname_pack') if is_first_party or source: raise NotImplementedError self.enter_step('set_hostname_brws') hostname = self.pack_domain(hostname_str) dic = self.domtree for part in hostname: if dic.match_zone: # Refuse to add hostname whose zone is already matching return if part not in dic.children: dic.children[part] = DomainTreeNode() dic = dic.children[part] dic.match_hostname = (updated, DebugPath, 0) def set_zone(self, zone_str: str, updated: int, is_first_party: bool = None, source: TypedPath = None) -> None: self.enter_step('set_zone_pack') if is_first_party or source: raise NotImplementedError zone = self.pack_domain(zone_str) self.enter_step('set_zone_brws') dic = self.domtree for part in zone: if dic.match_zone: # Refuse to add zone whose parent zone is already matching return if part not in dic.children: dic.children[part] = DomainTreeNode() dic = dic.children[part] dic.match_zone = (updated, DebugPath, 0) def set_asn(self, asn_str: str, updated: int, is_first_party: bool = None, source: TypedPath = None) -> None: self.enter_step('set_asn_pack') if is_first_party or source: # TODO updated raise NotImplementedError asn = self.pack_asn(asn_str) self.enter_step('set_asn_brws') self.asns.add(asn) def set_ip4address(self, ip4address_str: str, updated: int, is_first_party: bool = None, source: TypedPath = None) -> None: self.enter_step('set_ip4add_pack') if is_first_party or source: raise NotImplementedError ip4, prefixlen = self.pack_ip4address(ip4address_str) self.enter_step('set_ip4add_brws') dic = self.ip4tree for i in reversed(range(prefixlen)): part = (ip4 >> i) & 0b1 if dic.match: # Refuse to add ip4address whose network is already matching return next_dic = dic.children[part] if next_dic is None: next_dic = IpTreeNode() dic.children[part] = next_dic dic = next_dic dic.match = (updated, DebugPath, 0) def set_ip4network(self, ip4network_str: str, updated: int, is_first_party: bool = None, source: TypedPath = None) -> None: self.enter_step('set_ip4net_pack') if is_first_party or source: raise NotImplementedError self.enter_step('set_ip4net_brws') ip4, prefixlen = self.pack_ip4network(ip4network_str) dic = self.ip4tree for i in reversed(range(prefixlen)): part = (ip4 >> i) & 0b1 if dic.match: # Refuse to add ip4network whose parent network # is already matching return next_dic = dic.children[part] if next_dic is None: next_dic = IpTreeNode() dic.children[part] = next_dic dic = next_dic dic.match = (updated, DebugPath, 0)