diff --git a/database.py b/database.py index 13f8876..0828691 100644 --- a/database.py +++ b/database.py @@ -70,19 +70,9 @@ class Match(): self.updated: int = 0 self.level: int = 0 self.source: typing.Optional[Path] = None + self.references: int = 0 # FP dupplicate args - def set(self, - updated: int, - level: int, - source: Path, - ) -> None: - if updated > self.updated or level > self.level: - self.updated = updated - self.level = level - self.source = source - # FP dupplicate function - def active(self) -> bool: return self.updated > 0 @@ -143,7 +133,7 @@ class Profiler(): class Database(Profiler): - VERSION = 13 + VERSION = 14 PATH = "blocking.p" def initialize(self) -> None: @@ -268,6 +258,24 @@ class Database(Profiler): else: raise ValueError + def exec_each_asn(self, + callback: MatchCallable, + arg: typing.Any = None, + ) -> typing.Any: + for asn in self.asns: + match = self.asns[asn] + if match.active(): + c = callback( + AsnPath(asn), + match, + arg + ) + try: + yield from c + except TypeError: # not iterable + pass + + def exec_each_domain(self, callback: MatchCallable, arg: typing.Any = None, @@ -277,17 +285,25 @@ class Database(Profiler): _dic = _dic or self.domtree _par = _par or DomainPath([]) if _dic.match_hostname.active(): - yield from callback( + c = callback( HostnamePath(_par.parts), _dic.match_hostname, arg ) + try: + yield from c + except TypeError: # not iterable + pass if _dic.match_zone.active(): - yield from callback( + c = callback( ZonePath(_par.parts), _dic.match_zone, arg ) + try: + yield from c + except TypeError: # not iterable + pass for part in _dic.children: dic = _dic.children[part] yield from self.exec_each_domain( @@ -306,11 +322,15 @@ class Database(Profiler): _dic = _dic or self.ip4tree _par = _par or Ip4Path(0, 0) if _dic.active(): - yield from callback( + c = callback( _par, _dic, arg ) + try: + yield from c + except TypeError: # not iterable + pass # 0 pref = _par.prefixlen + 1 @@ -341,17 +361,35 @@ class Database(Profiler): ) -> typing.Any: yield from self.exec_each_domain(callback) yield from self.exec_each_ip4(callback) - # TODO ASN + yield from self.exec_each_asn(callback) def update_references(self) -> None: - raise NotImplementedError + # Should be correctly calculated normally, + # keeping this just in case + def reset_references_cb(path: Path, + match: Match, _: typing.Any + ) -> None: + match.references = 0 + for _ in self.exec_each(reset_references_cb, None): + pass + + def increment_references_cb(path: Path, + match: Match, _: typing.Any + ) -> None: + if match.source: + source = self.get_match(match.source) + source.references += 1 + for _ in self.exec_each(increment_references_cb, None): + pass def prune(self, before: int, base_only: bool = False) -> None: raise NotImplementedError def explain(self, path: Path) -> str: - string = str(path) match = self.get_match(path) + string = f'{path}' + if not isinstance(path, RulePath): + string += f' #{match.references}' if match.source: string += f' ← {self.explain(match.source)}' return string @@ -361,17 +399,20 @@ class Database(Profiler): end_chain_only: bool = False, explain: bool = False, ) -> typing.Iterable[str]: - if first_party_only or end_chain_only: + if first_party_only: raise NotImplementedError def export_cb(path: Path, match: Match, _: typing.Any ) -> typing.Iterable[str]: assert isinstance(path, DomainPath) - if isinstance(path, HostnamePath): - if explain: - yield self.explain(path) - else: - yield self.unpack_domain(path) + if not isinstance(path, HostnamePath): + return + if end_chain_only and match.references > 0: + return + if explain: + yield self.explain(path) + else: + yield self.unpack_domain(path) yield from self.exec_each_domain(export_cb, None) @@ -437,9 +478,22 @@ class Database(Profiler): self.enter_step('get_ip4_yield') yield ip4 - def list_asn(self) -> typing.Iterable[AsnPath]: - for asn in self.asns: - yield AsnPath(asn) + def set_match(self, + match: Match, + updated: int, + source: Path, + ) -> None: + new_source = self.get_match(source) + new_level = new_source.level + 1 + if updated > match.updated or new_level > match.level: + if match.source: + old_source = self.get_match(match.source) + old_source.references -= 1 + match.updated = updated + match.level = new_level + match.source = source + new_source.references += 1 + # FP dupplicate function def _set_domain(self, hostname: bool, @@ -451,30 +505,23 @@ class Database(Profiler): if is_first_party: raise NotImplementedError domain = self.pack_domain(domain_str) - self.enter_step('set_domain_src') - if source is None: - level = 0 - source = RulePath() - else: - match = self.get_match(source) - level = match.level + 1 self.enter_step('set_domain_brws') dic = self.domtree for part in domain.parts: - if dic.match_zone.active(): - # Refuse to add domain whose zone is already matching - return if part not in dic.children: dic.children[part] = DomainTreeNode() dic = dic.children[part] + if dic.match_zone.active(): + # Refuse to add domain whose zone is already matching + return if hostname: match = dic.match_hostname else: match = dic.match_zone - match.set( + self.set_match( + match, updated, - level, - source, + source or RulePath(), ) def set_hostname(self, @@ -495,22 +542,16 @@ class Database(Profiler): self.enter_step('set_asn') if is_first_party: raise NotImplementedError - if source is None: - level = 0 - source = RulePath() - else: - match = self.get_match(source) - level = match.level + 1 path = self.pack_asn(asn_str) if path.asn in self.asns: match = self.asns[path.asn] else: match = AsnNode() self.asns[path.asn] = match - match.set( + self.set_match( + match, updated, - level, - source, + source or RulePath(), ) def _set_ip4(self, @@ -520,20 +561,10 @@ class Database(Profiler): source: Path = None) -> None: if is_first_party: raise NotImplementedError - self.enter_step('set_ip4_src') - if source is None: - level = 0 - source = RulePath() - else: - match = self.get_match(source) - level = match.level + 1 self.enter_step('set_ip4_brws') dic = self.ip4tree for i in range(31, 31-ip4.prefixlen, -1): bit = (ip4.value >> i) & 0b1 - if dic.active(): - # Refuse to add ip4* whose network is already matching - return next_dic = dic.one if bit else dic.zero if next_dic is None: next_dic = IpTreeNode() @@ -542,10 +573,13 @@ class Database(Profiler): else: dic.zero = next_dic dic = next_dic - dic.set( + if dic.active(): + # Refuse to add ip4* whose network is already matching + return + self.set_match( + dic, updated, - level, - source, + source or RulePath(), ) def set_ip4address(self, diff --git a/feed_asn.py b/feed_asn.py index f34773f..fbdefcd 100755 --- a/feed_asn.py +++ b/feed_asn.py @@ -32,7 +32,10 @@ if __name__ == '__main__': DB = database.Database() - for path in DB.list_asn(): + def add_ranges(path: database.Path, + match: database.Match, + _: typing.Any) -> None: + assert isinstance(path, database.AsnPath) asn_str = database.Database.unpack_asn(path) DB.enter_step('asn_get_ranges') for prefix in get_ranges(asn_str): @@ -49,4 +52,7 @@ if __name__ == '__main__': else: log.error('Unknown prefix version: %s', prefix) + for _ in DB.exec_each_asn(add_ranges, None): + pass + DB.save()