From 03a4042238af71fe75d94105ba8f5f3210dd8ba1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20=E2=80=9CFrogeye=E2=80=9D=20Preud=27homme?= Date: Mon, 16 Dec 2019 09:31:29 +0100 Subject: [PATCH] Added level Also fixed IP logic because this was real messed up --- database.py | 223 +++++++++++++++++++++++++++++++++++++--------------- export.py | 8 +- feed_dns.py | 3 +- 3 files changed, 167 insertions(+), 67 deletions(-) diff --git a/database.py b/database.py index 99ea3ad..13f8876 100644 --- a/database.py +++ b/database.py @@ -26,38 +26,50 @@ class Path(): class RulePath(Path): - pass + def __str__(self) -> str: + return '(rules)' class DomainPath(Path): - def __init__(self, path: typing.List[str]): - self.path = path + def __init__(self, parts: typing.List[str]): + self.parts = parts + + def __str__(self) -> str: + return '?.' + Database.unpack_domain(self) class HostnamePath(DomainPath): - pass + def __str__(self) -> str: + return Database.unpack_domain(self) class ZonePath(DomainPath): - pass + def __str__(self) -> str: + return '*.' + Database.unpack_domain(self) class AsnPath(Path): def __init__(self, asn: Asn): self.asn = asn + def __str__(self) -> str: + return Database.unpack_asn(self) + class Ip4Path(Path): def __init__(self, value: int, prefixlen: int): self.value = value self.prefixlen = prefixlen + def __str__(self) -> str: + return Database.unpack_ip4network(self) + class Match(): def __init__(self) -> None: self.updated: int = 0 self.level: int = 0 - self.source: Path = RulePath() + self.source: typing.Optional[Path] = None # FP dupplicate args def set(self, @@ -86,18 +98,18 @@ class DomainTreeNode(): self.match_hostname = Match() -class IpTreeNode(): +class IpTreeNode(Match): def __init__(self) -> None: + Match.__init__(self) self.zero: typing.Optional[IpTreeNode] = None self.one: typing.Optional[IpTreeNode] = None - self.match = Match() Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode] -NodeCallable = typing.Callable[[Path, - Node, - typing.Optional[typing.Any]], - typing.Any] +MatchCallable = typing.Callable[[Path, + Match, + typing.Optional[typing.Any]], + typing.Any] class Profiler(): @@ -109,7 +121,6 @@ class Profiler(): self.step_dict: typing.Dict[str, int] = dict() def enter_step(self, name: str) -> None: - return now = time.perf_counter() try: self.time_dict[self.time_step] += now - self.time_last @@ -132,7 +143,7 @@ class Profiler(): class Database(Profiler): - VERSION = 11 + VERSION = 13 PATH = "blocking.p" def initialize(self) -> None: @@ -181,7 +192,7 @@ class Database(Profiler): @staticmethod def unpack_domain(domain: DomainPath) -> str: - return '.'.join(domain.path[::-1]) + return '.'.join(domain.parts[::-1]) @staticmethod def pack_asn(asn: str) -> AsnPath: @@ -230,62 +241,107 @@ class Database(Profiler): addr >>= 8 return '.'.join(map(str, octets)) + '/' + str(network.prefixlen) + def get_match(self, path: Path) -> Match: + if isinstance(path, RulePath): + return Match() + elif isinstance(path, AsnPath): + return self.asns[path.asn] + elif isinstance(path, DomainPath): + dicd = self.domtree + for part in path.parts: + dicd = dicd.children[part] + if isinstance(path, HostnamePath): + return dicd.match_hostname + elif isinstance(path, ZonePath): + return dicd.match_zone + else: + raise ValueError + elif isinstance(path, Ip4Path): + dici = self.ip4tree + for i in range(31, 31-path.prefixlen, -1): + bit = (path.value >> i) & 0b1 + dici_next = dici.one if bit else dici.zero + if not dici_next: + raise IndexError + dici = dici_next + return dici + else: + raise ValueError + def exec_each_domain(self, - callback: NodeCallable, + callback: MatchCallable, arg: typing.Any = None, _dic: DomainTreeNode = None, _par: DomainPath = None, ) -> typing.Any: _dic = _dic or self.domtree _par = _par or DomainPath([]) - yield from callback(_par, _dic, arg) + if _dic.match_hostname.active(): + yield from callback( + HostnamePath(_par.parts), + _dic.match_hostname, + arg + ) + if _dic.match_zone.active(): + yield from callback( + ZonePath(_par.parts), + _dic.match_zone, + arg + ) for part in _dic.children: dic = _dic.children[part] yield from self.exec_each_domain( callback, arg, _dic=dic, - _par=DomainPath(_par.path + [part]) + _par=DomainPath(_par.parts + [part]) ) def exec_each_ip4(self, - callback: NodeCallable, + callback: MatchCallable, arg: typing.Any = None, _dic: IpTreeNode = None, _par: Ip4Path = None, ) -> typing.Any: _dic = _dic or self.ip4tree _par = _par or Ip4Path(0, 0) - callback(_par, _dic, arg) + if _dic.active(): + yield from callback( + _par, + _dic, + arg + ) # 0 + pref = _par.prefixlen + 1 dic = _dic.zero if dic: - addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-_par.prefixlen))) + addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-pref))) assert addr0 == _par.value yield from self.exec_each_ip4( callback, arg, _dic=dic, - _par=Ip4Path(addr0, _par.prefixlen+1) + _par=Ip4Path(addr0, pref) ) # 1 dic = _dic.one if dic: - addr1 = _par.value | (1 << (32-_par.prefixlen)) + addr1 = _par.value | (1 << (32-pref)) yield from self.exec_each_ip4( callback, arg, _dic=dic, - _par=Ip4Path(addr1, _par.prefixlen+1) + _par=Ip4Path(addr1, pref) ) def exec_each(self, - callback: NodeCallable, + callback: MatchCallable, arg: typing.Any = None, ) -> typing.Any: yield from self.exec_each_domain(callback) yield from self.exec_each_ip4(callback) + # TODO ASN def update_references(self) -> None: raise NotImplementedError @@ -293,27 +349,47 @@ class Database(Profiler): def prune(self, before: int, base_only: bool = False) -> None: raise NotImplementedError - def explain(self, entry: int) -> str: - raise NotImplementedError + def explain(self, path: Path) -> str: + string = str(path) + match = self.get_match(path) + if match.source: + string += f' ← {self.explain(match.source)}' + return string def export(self, first_party_only: bool = False, end_chain_only: bool = False, explain: bool = False, ) -> typing.Iterable[str]: - if first_party_only or end_chain_only or explain: + if first_party_only or end_chain_only: raise NotImplementedError - def export_cb(path: Path, node: Node, _: typing.Any + def export_cb(path: Path, match: Match, _: typing.Any ) -> typing.Iterable[str]: assert isinstance(path, DomainPath) - assert isinstance(node, DomainTreeNode) - if node.match_hostname: - a = self.unpack_domain(path) - yield a + if isinstance(path, HostnamePath): + if explain: + yield self.explain(path) + else: + yield self.unpack_domain(path) yield from self.exec_each_domain(export_cb, None) + def list_rules(self, + first_party_only: bool = False, + ) -> typing.Iterable[str]: + if first_party_only: + raise NotImplementedError + + def list_rules_cb(path: Path, match: Match, _: typing.Any + ) -> typing.Iterable[str]: + if isinstance(path, ZonePath) \ + or (isinstance(path, Ip4Path) and path.prefixlen < 32): + # if match.level == 0: + yield self.explain(path) + + yield from self.exec_each(list_rules_cb, None) + def count_rules(self, first_party_only: bool = False, ) -> str: @@ -325,10 +401,10 @@ class Database(Profiler): self.enter_step('get_domain_brws') dic = self.domtree depth = 0 - for part in domain.path: + for part in domain.parts: if dic.match_zone.active(): self.enter_step('get_domain_yield') - yield ZonePath(domain.path[:depth]) + yield ZonePath(domain.parts[:depth]) self.enter_step('get_domain_brws') if part not in dic.children: return @@ -336,27 +412,28 @@ class Database(Profiler): depth += 1 if dic.match_zone.active(): self.enter_step('get_domain_yield') - yield ZonePath(domain.path) + yield ZonePath(domain.parts) if dic.match_hostname.active(): self.enter_step('get_domain_yield') - yield HostnamePath(domain.path) + yield HostnamePath(domain.parts) def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]: self.enter_step('get_ip4_pack') ip4 = self.pack_ip4address(ip4_str) self.enter_step('get_ip4_brws') dic = self.ip4tree - for i in reversed(range(ip4.prefixlen)): - part = (ip4.value >> i) & 0b1 - if dic.match.active(): + for i in range(31, 31-ip4.prefixlen, -1): + bit = (ip4.value >> i) & 0b1 + if dic.active(): self.enter_step('get_ip4_yield') - yield Ip4Path(ip4.value, 32-i) - self.enter_step('get_ip4_brws') - next_dic = dic.one if part else dic.zero + a = Ip4Path(ip4.value >> (i+1) << (i+1), 31-i) + yield a + self.enter_step('get_ip4_brws') + next_dic = dic.one if bit else dic.zero if next_dic is None: return dic = next_dic - if dic.match.active(): + if dic.active(): self.enter_step('get_ip4_yield') yield ip4 @@ -374,9 +451,16 @@ class Database(Profiler): if is_first_party: raise NotImplementedError domain = self.pack_domain(domain_str) + self.enter_step('set_domain_src') + if source is None: + level = 0 + source = RulePath() + else: + match = self.get_match(source) + level = match.level + 1 self.enter_step('set_domain_brws') dic = self.domtree - for part in domain.path: + for part in domain.parts: if dic.match_zone.active(): # Refuse to add domain whose zone is already matching return @@ -389,8 +473,8 @@ class Database(Profiler): match = dic.match_zone match.set( updated, - 0, # TODO Level - source or RulePath(), + level, + source, ) def set_hostname(self, @@ -411,14 +495,23 @@ class Database(Profiler): self.enter_step('set_asn') if is_first_party: raise NotImplementedError + if source is None: + level = 0 + source = RulePath() + else: + match = self.get_match(source) + level = match.level + 1 path = self.pack_asn(asn_str) - match = AsnNode() + if path.asn in self.asns: + match = self.asns[path.asn] + else: + match = AsnNode() + self.asns[path.asn] = match match.set( - updated, - 0, - source or RulePath() + updated, + level, + source, ) - self.asns[path.asn] = match def _set_ip4(self, ip4: Ip4Path, @@ -427,24 +520,32 @@ class Database(Profiler): source: Path = None) -> None: if is_first_party: raise NotImplementedError + self.enter_step('set_ip4_src') + if source is None: + level = 0 + source = RulePath() + else: + match = self.get_match(source) + level = match.level + 1 + self.enter_step('set_ip4_brws') dic = self.ip4tree - for i in reversed(range(ip4.prefixlen)): - part = (ip4.value >> i) & 0b1 - if dic.match.active(): + for i in range(31, 31-ip4.prefixlen, -1): + bit = (ip4.value >> i) & 0b1 + if dic.active(): # Refuse to add ip4* whose network is already matching return - next_dic = dic.one if part else dic.zero + next_dic = dic.one if bit else dic.zero if next_dic is None: next_dic = IpTreeNode() - if part: + if bit: dic.one = next_dic else: dic.zero = next_dic dic = next_dic - dic.match.set( + dic.set( updated, - 0, # TODO Level - source or RulePath(), + level, + source, ) def set_ip4address(self, @@ -453,7 +554,6 @@ class Database(Profiler): ) -> None: self.enter_step('set_ip4add_pack') ip4 = self.pack_ip4address(ip4address_str) - self.enter_step('set_ip4add_brws') self._set_ip4(ip4, *args, **kwargs) def set_ip4network(self, @@ -462,5 +562,4 @@ class Database(Profiler): ) -> None: self.enter_step('set_ip4net_pack') ip4 = self.pack_ip4network(ip4network_str) - self.enter_step('set_ip4net_brws') self._set_ip4(ip4, *args, **kwargs) diff --git a/export.py b/export.py index bca3281..0df4229 100755 --- a/export.py +++ b/export.py @@ -33,9 +33,11 @@ if __name__ == '__main__': DB = database.Database() if args.rules: - if not args.count: - raise NotImplementedError - print(DB.count_rules(first_party_only=args.first_party)) + if args.count: + print(DB.count_rules(first_party_only=args.first_party)) + else: + for line in DB.list_rules(): + print(line) else: if args.count: raise NotImplementedError diff --git a/feed_dns.py b/feed_dns.py index be08e98..43df1fd 100755 --- a/feed_dns.py +++ b/feed_dns.py @@ -51,8 +51,7 @@ class Writer(multiprocessing.Process): try: for source in select(self.db, value): - # write(self.db, name, updated, source=source) - write(self.db, name, updated) + write(self.db, name, updated, source=source) except ValueError: self.log.exception("Cannot execute: %s", record)