diff --git a/database.py b/database.py index 0828691..3fc93c5 100644 --- a/database.py +++ b/database.py @@ -27,7 +27,17 @@ class Path(): class RulePath(Path): def __str__(self) -> str: - return '(rules)' + return '(rule)' + + +class RuleFirstPath(RulePath): + def __str__(self) -> str: + return '(first-party rule)' + + +class RuleMultiPath(RulePath): + def __str__(self) -> str: + return '(multi-party rule)' class DomainPath(Path): @@ -67,14 +77,18 @@ class Ip4Path(Path): class Match(): def __init__(self) -> None: - self.updated: int = 0 - self.level: int = 0 self.source: typing.Optional[Path] = None - self.references: int = 0 - # FP dupplicate args + self.updated: int = 0 - def active(self) -> bool: - return self.updated > 0 + # Cache + self.level: int = 0 + self.first_party: bool = False + self.references: int = 0 + + def active(self, first_party: bool = None) -> bool: + if self.updated == 0 or (first_party and not self.first_party): + return False + return True class AsnNode(Match): @@ -133,13 +147,21 @@ class Profiler(): class Database(Profiler): - VERSION = 14 + VERSION = 17 PATH = "blocking.p" def initialize(self) -> None: self.log.warning( "Creating database version: %d ", Database.VERSION) + # Dummy match objects that everything refer to + self.rules: typing.List[Match] = list() + for first_party in (False, True): + m = Match() + m.updated = 1 + m.level = 0 + m.first_party = first_party + self.rules.append(m) self.domtree = DomainTreeNode() self.asns: typing.Dict[Asn, AsnNode] = dict() self.ip4tree = IpTreeNode() @@ -150,7 +172,7 @@ class Database(Profiler): with open(self.PATH, 'rb') as db_fdsec: version, data = pickle.load(db_fdsec) if version == Database.VERSION: - self.domtree, self.asns, self.ip4tree = data + self.rules, self.domtree, self.asns, self.ip4tree = data return self.log.warning( "Outdated database version found: %d, " @@ -167,7 +189,7 @@ class Database(Profiler): def save(self) -> None: self.enter_step('save') with open(self.PATH, 'wb') as db_fdsec: - data = self.domtree, self.asns, self.ip4tree + data = self.rules, self.domtree, self.asns, self.ip4tree pickle.dump((self.VERSION, data), db_fdsec) self.profile() @@ -232,8 +254,10 @@ class Database(Profiler): return '.'.join(map(str, octets)) + '/' + str(network.prefixlen) def get_match(self, path: Path) -> Match: - if isinstance(path, RulePath): - return Match() + if isinstance(path, RuleMultiPath): + return self.rules[0] + elif isinstance(path, RuleFirstPath): + return self.rules[1] elif isinstance(path, AsnPath): return self.asns[path.asn] elif isinstance(path, DomainPath): @@ -275,7 +299,6 @@ class Database(Profiler): except TypeError: # not iterable pass - def exec_each_domain(self, callback: MatchCallable, arg: typing.Any = None, @@ -374,8 +397,8 @@ class Database(Profiler): pass def increment_references_cb(path: Path, - match: Match, _: typing.Any - ) -> None: + match: Match, _: typing.Any + ) -> None: if match.source: source = self.get_match(match.source) source.references += 1 @@ -387,9 +410,7 @@ class Database(Profiler): def explain(self, path: Path) -> str: match = self.get_match(path) - string = f'{path}' - if not isinstance(path, RulePath): - string += f' #{match.references}' + string = f'{path} #{match.references}' if match.source: string += f' ← {self.explain(match.source)}' return string @@ -399,14 +420,14 @@ class Database(Profiler): end_chain_only: bool = False, explain: bool = False, ) -> typing.Iterable[str]: - if first_party_only: - raise NotImplementedError def export_cb(path: Path, match: Match, _: typing.Any ) -> typing.Iterable[str]: assert isinstance(path, DomainPath) if not isinstance(path, HostnamePath): return + if first_party_only and not match.first_party: + return if end_chain_only and match.references > 0: return if explain: @@ -419,11 +440,11 @@ class Database(Profiler): def list_rules(self, first_party_only: bool = False, ) -> typing.Iterable[str]: - if first_party_only: - raise NotImplementedError def list_rules_cb(path: Path, match: Match, _: typing.Any ) -> typing.Iterable[str]: + if first_party_only and not match.first_party: + return if isinstance(path, ZonePath) \ or (isinstance(path, Ip4Path) and path.prefixlen < 32): # if match.level == 0: @@ -465,10 +486,10 @@ class Database(Profiler): dic = self.ip4tree for i in range(31, 31-ip4.prefixlen, -1): bit = (ip4.value >> i) & 0b1 + # TODO PERF copy value and slide once every loop if dic.active(): self.enter_step('get_ip4_yield') - a = Ip4Path(ip4.value >> (i+1) << (i+1), 31-i) - yield a + yield Ip4Path(ip4.value >> (i+1) << (i+1), 31-i) self.enter_step('get_ip4_brws') next_dic = dic.one if bit else dic.zero if next_dic is None: @@ -478,50 +499,58 @@ class Database(Profiler): self.enter_step('get_ip4_yield') yield ip4 - def set_match(self, - match: Match, - updated: int, - source: Path, - ) -> None: - new_source = self.get_match(source) - new_level = new_source.level + 1 - if updated > match.updated or new_level > match.level: + def _set_match(self, + match: Match, + updated: int, + source: Path, + source_match: Match = None, + ) -> None: + # source_match is in parameters because most of the time + # its parent function needs it too, + # so it can pass it to save a traversal + source_match = source_match or self.get_match(source) + new_level = source_match.level + 1 + if updated > match.updated or new_level < match.level \ + or source_match.first_party > match.first_party: + # NOTE FP and level of matches referencing this one + # won't be updated until run or prune if match.source: old_source = self.get_match(match.source) old_source.references -= 1 match.updated = updated match.level = new_level + match.first_party = source_match.first_party match.source = source - new_source.references += 1 - # FP dupplicate function + source_match.references += 1 def _set_domain(self, hostname: bool, domain_str: str, updated: int, - is_first_party: bool = None, - source: Path = None) -> None: + source: Path) -> None: self.enter_step('set_domain_pack') - if is_first_party: - raise NotImplementedError domain = self.pack_domain(domain_str) + self.enter_step('set_domain_fp') + source_match = self.get_match(source) + is_first_party = source_match.first_party self.enter_step('set_domain_brws') dic = self.domtree for part in domain.parts: if part not in dic.children: dic.children[part] = DomainTreeNode() dic = dic.children[part] - if dic.match_zone.active(): + if dic.match_zone.active(is_first_party): # Refuse to add domain whose zone is already matching return if hostname: match = dic.match_hostname else: match = dic.match_zone - self.set_match( + self._set_match( match, updated, - source or RulePath(), + source, + source_match=source_match, ) def set_hostname(self, @@ -537,30 +566,27 @@ class Database(Profiler): def set_asn(self, asn_str: str, updated: int, - is_first_party: bool = None, - source: Path = None) -> None: + source: Path) -> None: self.enter_step('set_asn') - if is_first_party: - raise NotImplementedError path = self.pack_asn(asn_str) if path.asn in self.asns: match = self.asns[path.asn] else: match = AsnNode() self.asns[path.asn] = match - self.set_match( + self._set_match( match, updated, - source or RulePath(), + source, ) def _set_ip4(self, ip4: Ip4Path, updated: int, - is_first_party: bool = None, - source: Path = None) -> None: - if is_first_party: - raise NotImplementedError + source: Path) -> None: + self.enter_step('set_ip4_fp') + source_match = self.get_match(source) + is_first_party = source_match.first_party self.enter_step('set_ip4_brws') dic = self.ip4tree for i in range(31, 31-ip4.prefixlen, -1): @@ -573,13 +599,14 @@ class Database(Profiler): else: dic.zero = next_dic dic = next_dic - if dic.active(): + if dic.active(is_first_party): # Refuse to add ip4* whose network is already matching return - self.set_match( + self._set_match( dic, updated, - source or RulePath(), + source, + source_match=source_match, ) def set_ip4address(self, diff --git a/feed_rules.py b/feed_rules.py index cca1261..2b5596e 100755 --- a/feed_rules.py +++ b/feed_rules.py @@ -32,10 +32,16 @@ if __name__ == '__main__': fun = FUNCTION_MAP[args.type] + source: database.RulePath + if args.first_party: + source = database.RuleFirstPath() + else: + source = database.RuleMultiPath() + for rule in args.input: fun(DB, rule.strip(), - # is_first_party=args.first_party, + source=source, updated=int(time.time()), ) diff --git a/import_rules.sh b/import_rules.sh index cdeec93..14c8c78 100755 --- a/import_rules.sh +++ b/import_rules.sh @@ -6,11 +6,11 @@ function log() { log "Importing rules…" BEFORE="$(date +%s)" -# cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone -# cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone -# cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone -# cat rules_ip/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network -# cat rules_asn/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py asn +cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone +cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone +cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone +cat rules_ip/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network +cat rules_asn/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py asn cat rules/first-party.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone --first-party cat rules_ip/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network --first-party