From 3dcccad39a489e76c2c66d2eebe15c0d42af21a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20=E2=80=9CFrogeye=E2=80=9D=20Preud=27homme?= Date: Sat, 14 Aug 2021 23:27:28 +0200 Subject: [PATCH] Black pass --- adblock_to_domain_list.py | 33 ++-- collect_subdomains.py | 27 ++- database.py | 366 ++++++++++++++++++-------------------- db.py | 38 ++-- export.py | 79 +++++--- feed_asn.py | 43 +++-- feed_dns.py | 162 +++++++++-------- feed_rules.py | 40 +++-- generate_index.py | 8 +- 9 files changed, 416 insertions(+), 380 deletions(-) diff --git a/adblock_to_domain_list.py b/adblock_to_domain_list.py index 168be16..fef1bce 100755 --- a/adblock_to_domain_list.py +++ b/adblock_to_domain_list.py @@ -16,25 +16,36 @@ import abp.filters def get_domains(rule: abp.filters.parser.Filter) -> typing.Iterable[str]: if rule.options: return - selector_type = rule.selector['type'] - selector_value = rule.selector['value'] - if selector_type == 'url-pattern' \ - and selector_value.startswith('||') \ - and selector_value.endswith('^'): + selector_type = rule.selector["type"] + selector_value = rule.selector["value"] + if ( + selector_type == "url-pattern" + and selector_value.startswith("||") + and selector_value.endswith("^") + ): yield selector_value[2:-1] -if __name__ == '__main__': +if __name__ == "__main__": # Parsing arguments parser = argparse.ArgumentParser( - description="Extract whole domains from an AdBlock blocking list") + description="Extract whole domains from an AdBlock blocking list" + ) parser.add_argument( - '-i', '--input', type=argparse.FileType('r'), default=sys.stdin, - help="Input file with AdBlock rules") + "-i", + "--input", + type=argparse.FileType("r"), + default=sys.stdin, + help="Input file with AdBlock rules", + ) parser.add_argument( - '-o', '--output', type=argparse.FileType('w'), default=sys.stdout, - help="Outptut file with one rule tracking subdomain per line") + "-o", + "--output", + type=argparse.FileType("w"), + default=sys.stdout, + help="Outptut file with one rule tracking subdomain per line", + ) args = parser.parse_args() # Reading rules diff --git a/collect_subdomains.py b/collect_subdomains.py index 33879e2..b0ecb5f 100755 --- a/collect_subdomains.py +++ b/collect_subdomains.py @@ -16,26 +16,25 @@ import selenium.webdriver.firefox.options import seleniumwire.webdriver import logging -log = logging.getLogger('cs') +log = logging.getLogger("cs") DRIVER = None SCROLL_TIME = 10.0 SCROLL_STEPS = 100 -SCROLL_CMD = f'window.scrollBy(0,document.body.scrollHeight/{SCROLL_STEPS})' +SCROLL_CMD = f"window.scrollBy(0,document.body.scrollHeight/{SCROLL_STEPS})" def new_driver() -> seleniumwire.webdriver.browser.Firefox: profile = selenium.webdriver.FirefoxProfile() - profile.set_preference('privacy.trackingprotection.enabled', False) - profile.set_preference('network.cookie.cookieBehavior', 0) - profile.set_preference('privacy.trackingprotection.pbmode.enabled', False) - profile.set_preference( - 'privacy.trackingprotection.cryptomining.enabled', False) - profile.set_preference( - 'privacy.trackingprotection.fingerprinting.enabled', False) + profile.set_preference("privacy.trackingprotection.enabled", False) + profile.set_preference("network.cookie.cookieBehavior", 0) + profile.set_preference("privacy.trackingprotection.pbmode.enabled", False) + profile.set_preference("privacy.trackingprotection.cryptomining.enabled", False) + profile.set_preference("privacy.trackingprotection.fingerprinting.enabled", False) options = selenium.webdriver.firefox.options.Options() # options.add_argument('-headless') - driver = seleniumwire.webdriver.Firefox(profile, - executable_path='geckodriver', options=options) + driver = seleniumwire.webdriver.Firefox( + profile, executable_path="geckodriver", options=options + ) return driver @@ -60,7 +59,7 @@ def collect_subdomains(url: str) -> typing.Iterable[str]: DRIVER.get(url) for s in range(SCROLL_STEPS): DRIVER.execute_script(SCROLL_CMD) - time.sleep(SCROLL_TIME/SCROLL_STEPS) + time.sleep(SCROLL_TIME / SCROLL_STEPS) for request in DRIVER.requests: if request.response: yield subdomain_from_url(request.path) @@ -78,10 +77,10 @@ def collect_subdomains_standalone(url: str) -> None: print(subdomain) -if __name__ == '__main__': +if __name__ == "__main__": assert len(sys.argv) <= 2 filename = None - if len(sys.argv) == 2 and sys.argv[1] != '-': + if len(sys.argv) == 2 and sys.argv[1] != "-": filename = sys.argv[1] num_lines = sum(1 for line in open(filename)) iterator = progressbar.progressbar(open(filename), max_value=num_lines) diff --git a/database.py b/database.py index 8532bc0..e742b18 100644 --- a/database.py +++ b/database.py @@ -15,33 +15,30 @@ import os TLD_LIST: typing.Set[str] = set() -coloredlogs.install( - level='DEBUG', - fmt='%(asctime)s %(name)s %(levelname)s %(message)s' -) +coloredlogs.install(level="DEBUG", fmt="%(asctime)s %(name)s %(levelname)s %(message)s") Asn = int Timestamp = int Level = int -class Path(): +class Path: pass class RulePath(Path): def __str__(self) -> str: - return '(rule)' + return "(rule)" class RuleFirstPath(RulePath): def __str__(self) -> str: - return '(first-party rule)' + return "(first-party rule)" class RuleMultiPath(RulePath): def __str__(self) -> str: - return '(multi-party rule)' + return "(multi-party rule)" class DomainPath(Path): @@ -49,7 +46,7 @@ class DomainPath(Path): self.parts = parts def __str__(self) -> str: - return '?.' + Database.unpack_domain(self) + return "?." + Database.unpack_domain(self) class HostnamePath(DomainPath): @@ -59,7 +56,7 @@ class HostnamePath(DomainPath): class ZonePath(DomainPath): def __str__(self) -> str: - return '*.' + Database.unpack_domain(self) + return "*." + Database.unpack_domain(self) class AsnPath(Path): @@ -79,7 +76,7 @@ class Ip4Path(Path): return Database.unpack_ip4network(self) -class Match(): +class Match: def __init__(self) -> None: self.source: typing.Optional[Path] = None self.updated: int = 0 @@ -102,10 +99,10 @@ class Match(): class AsnNode(Match): def __init__(self) -> None: Match.__init__(self) - self.name = '' + self.name = "" -class DomainTreeNode(): +class DomainTreeNode: def __init__(self) -> None: self.children: typing.Dict[str, DomainTreeNode] = dict() self.match_zone = Match() @@ -120,18 +117,16 @@ class IpTreeNode(Match): Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode] -MatchCallable = typing.Callable[[Path, - Match], - typing.Any] +MatchCallable = typing.Callable[[Path, Match], typing.Any] -class Profiler(): +class Profiler: def __init__(self) -> None: - do_profile = int(os.environ.get('PROFILE', '0')) + do_profile = int(os.environ.get("PROFILE", "0")) if do_profile: - self.log = logging.getLogger('profiler') + self.log = logging.getLogger("profiler") self.time_last = time.perf_counter() - self.time_step = 'init' + self.time_step = "init" self.time_dict: typing.Dict[str, float] = dict() self.step_dict: typing.Dict[str, int] = dict() self.enter_step = self.enter_step_real @@ -158,14 +153,17 @@ class Profiler(): return def profile_real(self) -> None: - self.enter_step('profile') + self.enter_step("profile") total = sum(self.time_dict.values()) for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]): times = self.step_dict[key] - self.log.debug(f"{key:<20}: {times:9d} × {secs/times:5.3e} " - f"= {secs:9.2f} s ({secs/total:7.2%}) ") - self.log.debug(f"{'total':<20}: " - f"{total:9.2f} s ({1:7.2%})") + self.log.debug( + f"{key:<20}: {times:9d} × {secs/times:5.3e} " + f"= {secs:9.2f} s ({secs/total:7.2%}) " + ) + self.log.debug( + f"{'total':<20}: " f"{total:9.2f} s ({1:7.2%})" + ) class Database(Profiler): @@ -173,9 +171,7 @@ class Database(Profiler): PATH = "blocking.p" def initialize(self) -> None: - self.log.warning( - "Creating database version: %d ", - Database.VERSION) + self.log.warning("Creating database version: %d ", Database.VERSION) # Dummy match objects that everything refer to self.rules: typing.List[Match] = list() for first_party in (False, True): @@ -189,76 +185,77 @@ class Database(Profiler): self.ip4tree = IpTreeNode() def load(self) -> None: - self.enter_step('load') + self.enter_step("load") try: - with open(self.PATH, 'rb') as db_fdsec: + with open(self.PATH, "rb") as db_fdsec: version, data = pickle.load(db_fdsec) if version == Database.VERSION: self.rules, self.domtree, self.asns, self.ip4tree = data return self.log.warning( - "Outdated database version found: %d, " - "it will be rebuilt.", - version) + "Outdated database version found: %d, " "it will be rebuilt.", + version, + ) except (TypeError, AttributeError, EOFError): self.log.error( - "Corrupt (or heavily outdated) database found, " - "it will be rebuilt.") + "Corrupt (or heavily outdated) database found, " "it will be rebuilt." + ) except FileNotFoundError: pass self.initialize() def save(self) -> None: - self.enter_step('save') - with open(self.PATH, 'wb') as db_fdsec: + self.enter_step("save") + with open(self.PATH, "wb") as db_fdsec: data = self.rules, self.domtree, self.asns, self.ip4tree pickle.dump((self.VERSION, data), db_fdsec) self.profile() def __init__(self) -> None: Profiler.__init__(self) - self.log = logging.getLogger('db') + self.log = logging.getLogger("db") self.load() self.ip4cache_shift: int = 32 self.ip4cache = numpy.ones(1) def _set_ip4cache(self, path: Path, _: Match) -> None: assert isinstance(path, Ip4Path) - self.enter_step('set_ip4cache') + self.enter_step("set_ip4cache") mini = path.value >> self.ip4cache_shift - maxi = (path.value + 2**(32-path.prefixlen)) >> self.ip4cache_shift + maxi = (path.value + 2 ** (32 - path.prefixlen)) >> self.ip4cache_shift if mini == maxi: self.ip4cache[mini] = True else: self.ip4cache[mini:maxi] = True - def fill_ip4cache(self, max_size: int = 512*1024**2) -> None: + def fill_ip4cache(self, max_size: int = 512 * 1024 ** 2) -> None: """ Size in bytes """ - if max_size > 2**32/8: - self.log.warning("Allocating more than 512 MiB of RAM for " - "the Ip4 cache is not necessary.") - max_cache_width = int(math.log2(max(1, max_size*8))) + if max_size > 2 ** 32 / 8: + self.log.warning( + "Allocating more than 512 MiB of RAM for " + "the Ip4 cache is not necessary." + ) + max_cache_width = int(math.log2(max(1, max_size * 8))) allocated = False cache_width = min(32, max_cache_width) while not allocated: - cache_size = 2**cache_width + cache_size = 2 ** cache_width try: self.ip4cache = numpy.zeros(cache_size, dtype=bool) except MemoryError: - self.log.exception( - "Could not allocate cache. Retrying a smaller one.") + self.log.exception("Could not allocate cache. Retrying a smaller one.") cache_width -= 1 continue allocated = True - self.ip4cache_shift = 32-cache_width + self.ip4cache_shift = 32 - cache_width for _ in self.exec_each_ip4(self._set_ip4cache): pass @staticmethod def populate_tld_list() -> None: - with open('temp/all_tld.list', 'r') as tld_fdesc: + with open("temp/all_tld.list", "r") as tld_fdesc: for tld in tld_fdesc: tld = tld.strip() TLD_LIST.add(tld) @@ -267,7 +264,7 @@ class Database(Profiler): def validate_domain(path: str) -> bool: if len(path) > 255: return False - splits = path.split('.') + splits = path.split(".") if not TLD_LIST: Database.populate_tld_list() if splits[-1] not in TLD_LIST: @@ -279,26 +276,26 @@ class Database(Profiler): @staticmethod def pack_domain(domain: str) -> DomainPath: - return DomainPath(domain.split('.')[::-1]) + return DomainPath(domain.split(".")[::-1]) @staticmethod def unpack_domain(domain: DomainPath) -> str: - return '.'.join(domain.parts[::-1]) + return ".".join(domain.parts[::-1]) @staticmethod def pack_asn(asn: str) -> AsnPath: asn = asn.upper() - if asn.startswith('AS'): + if asn.startswith("AS"): asn = asn[2:] return AsnPath(int(asn)) @staticmethod def unpack_asn(asn: AsnPath) -> str: - return f'AS{asn.asn}' + return f"AS{asn.asn}" @staticmethod def validate_ip4address(path: str) -> bool: - splits = path.split('.') + splits = path.split(".") if len(splits) != 4: return False for split in splits: @@ -312,7 +309,7 @@ class Database(Profiler): @staticmethod def pack_ip4address_low(address: str) -> int: addr = 0 - for split in address.split('.'): + for split in address.split("."): octet = int(split) addr = (addr << 8) + octet return addr @@ -330,12 +327,12 @@ class Database(Profiler): for o in reversed(range(4)): octets[o] = addr & 0xFF addr >>= 8 - return '.'.join(map(str, octets)) + return ".".join(map(str, octets)) @staticmethod def validate_ip4network(path: str) -> bool: # A bit generous but ok for our usage - splits = path.split('/') + splits = path.split("/") if len(splits) != 2: return False if not Database.validate_ip4address(splits[0]): @@ -349,7 +346,7 @@ class Database(Profiler): @staticmethod def pack_ip4network(network: str) -> Ip4Path: - address, prefixlen_str = network.split('/') + address, prefixlen_str = network.split("/") prefixlen = int(prefixlen_str) addr = Database.pack_ip4address(address) addr.prefixlen = prefixlen @@ -363,7 +360,7 @@ class Database(Profiler): for o in reversed(range(4)): octets[o] = addr & 0xFF addr >>= 8 - return '.'.join(map(str, octets)) + '/' + str(network.prefixlen) + return ".".join(map(str, octets)) + "/" + str(network.prefixlen) def get_match(self, path: Path) -> Match: if isinstance(path, RuleMultiPath): @@ -384,7 +381,7 @@ class Database(Profiler): raise ValueError elif isinstance(path, Ip4Path): dici = self.ip4tree - for i in range(31, 31-path.prefixlen, -1): + for i in range(31, 31 - path.prefixlen, -1): bit = (path.value >> i) & 0b1 dici_next = dici.one if bit else dici.zero if not dici_next: @@ -394,9 +391,10 @@ class Database(Profiler): else: raise ValueError - def exec_each_asn(self, - callback: MatchCallable, - ) -> typing.Any: + def exec_each_asn( + self, + callback: MatchCallable, + ) -> typing.Any: for asn in self.asns: match = self.asns[asn] if match.active(): @@ -409,11 +407,12 @@ class Database(Profiler): except TypeError: # not iterable pass - def exec_each_domain(self, - callback: MatchCallable, - _dic: DomainTreeNode = None, - _par: DomainPath = None, - ) -> typing.Any: + def exec_each_domain( + self, + callback: MatchCallable, + _dic: DomainTreeNode = None, + _par: DomainPath = None, + ) -> typing.Any: _dic = _dic or self.domtree _par = _par or DomainPath([]) if _dic.match_hostname.active(): @@ -437,16 +436,15 @@ class Database(Profiler): for part in _dic.children: dic = _dic.children[part] yield from self.exec_each_domain( - callback, - _dic=dic, - _par=DomainPath(_par.parts + [part]) + callback, _dic=dic, _par=DomainPath(_par.parts + [part]) ) - def exec_each_ip4(self, - callback: MatchCallable, - _dic: IpTreeNode = None, - _par: Ip4Path = None, - ) -> typing.Any: + def exec_each_ip4( + self, + callback: MatchCallable, + _dic: IpTreeNode = None, + _par: Ip4Path = None, + ) -> typing.Any: _dic = _dic or self.ip4tree _par = _par or Ip4Path(0, 0) if _dic.active(): @@ -466,25 +464,18 @@ class Database(Profiler): # addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-pref))) # assert addr0 == _par.value addr0 = _par.value - yield from self.exec_each_ip4( - callback, - _dic=dic, - _par=Ip4Path(addr0, pref) - ) + yield from self.exec_each_ip4(callback, _dic=dic, _par=Ip4Path(addr0, pref)) # 1 dic = _dic.one if dic: - addr1 = _par.value | (1 << (32-pref)) + addr1 = _par.value | (1 << (32 - pref)) # assert addr1 != _par.value - yield from self.exec_each_ip4( - callback, - _dic=dic, - _par=Ip4Path(addr1, pref) - ) + yield from self.exec_each_ip4(callback, _dic=dic, _par=Ip4Path(addr1, pref)) - def exec_each(self, - callback: MatchCallable, - ) -> typing.Any: + def exec_each( + self, + callback: MatchCallable, + ) -> typing.Any: yield from self.exec_each_domain(callback) yield from self.exec_each_ip4(callback) yield from self.exec_each_asn(callback) @@ -492,19 +483,17 @@ class Database(Profiler): def update_references(self) -> None: # Should be correctly calculated normally, # keeping this just in case - def reset_references_cb(path: Path, - match: Match - ) -> None: + def reset_references_cb(path: Path, match: Match) -> None: match.references = 0 + for _ in self.exec_each(reset_references_cb): pass - def increment_references_cb(path: Path, - match: Match - ) -> None: + def increment_references_cb(path: Path, match: Match) -> None: if match.source: source = self.get_match(match.source) source.references += 1 + for _ in self.exec_each(increment_references_cb): pass @@ -513,9 +502,7 @@ class Database(Profiler): # matches until all disabled matches reference count = 0 did_something = True - def clean_deps_cb(path: Path, - match: Match - ) -> None: + def clean_deps_cb(path: Path, match: Match) -> None: nonlocal did_something if not match.source: return @@ -530,15 +517,13 @@ class Database(Profiler): while did_something: did_something = False - self.enter_step('pass_clean_deps') + self.enter_step("pass_clean_deps") for _ in self.exec_each(clean_deps_cb): pass def prune(self, before: int, base_only: bool = False) -> None: # Disable the matches targeted - def prune_cb(path: Path, - match: Match - ) -> None: + def prune_cb(path: Path, match: Match) -> None: if base_only and match.level > 1: return if match.updated > before: @@ -546,7 +531,7 @@ class Database(Profiler): self._unset_match(match) self.log.debug("Print: disabled %s", path) - self.enter_step('pass_prune') + self.enter_step("pass_prune") for _ in self.exec_each(prune_cb): pass @@ -559,25 +544,24 @@ class Database(Profiler): match = self.get_match(path) string = str(path) if isinstance(match, AsnNode): - string += f' ({match.name})' - party_char = 'F' if match.first_party else 'M' - dup_char = 'D' if match.dupplicate else '_' - string += f' {match.level}{party_char}{dup_char}{match.references}' + string += f" ({match.name})" + party_char = "F" if match.first_party else "M" + dup_char = "D" if match.dupplicate else "_" + string += f" {match.level}{party_char}{dup_char}{match.references}" if match.source: - string += f' ← {self.explain(match.source)}' + string += f" ← {self.explain(match.source)}" return string - def list_records(self, - first_party_only: bool = False, - end_chain_only: bool = False, - no_dupplicates: bool = False, - rules_only: bool = False, - hostnames_only: bool = False, - explain: bool = False, - ) -> typing.Iterable[str]: - - def export_cb(path: Path, match: Match - ) -> typing.Iterable[str]: + def list_records( + self, + first_party_only: bool = False, + end_chain_only: bool = False, + no_dupplicates: bool = False, + rules_only: bool = False, + hostnames_only: bool = False, + explain: bool = False, + ) -> typing.Iterable[str]: + def export_cb(path: Path, match: Match) -> typing.Iterable[str]: if first_party_only and not match.first_party: return if end_chain_only and match.references > 0: @@ -596,13 +580,14 @@ class Database(Profiler): yield from self.exec_each(export_cb) - def count_records(self, - first_party_only: bool = False, - end_chain_only: bool = False, - no_dupplicates: bool = False, - rules_only: bool = False, - hostnames_only: bool = False, - ) -> str: + def count_records( + self, + first_party_only: bool = False, + end_chain_only: bool = False, + no_dupplicates: bool = False, + rules_only: bool = False, + hostnames_only: bool = False, + ) -> str: memo: typing.Dict[str, int] = dict() def count_records_cb(path: Path, match: Match) -> None: @@ -627,75 +612,80 @@ class Database(Profiler): split: typing.List[str] = list() for key, value in sorted(memo.items(), key=lambda s: s[0]): - split.append(f'{key[:-4].lower()}s: {value}') - return ', '.join(split) + split.append(f"{key[:-4].lower()}s: {value}") + return ", ".join(split) def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]: - self.enter_step('get_domain_pack') + self.enter_step("get_domain_pack") domain = self.pack_domain(domain_str) - self.enter_step('get_domain_brws') + self.enter_step("get_domain_brws") dic = self.domtree depth = 0 for part in domain.parts: if dic.match_zone.active(): - self.enter_step('get_domain_yield') + self.enter_step("get_domain_yield") yield ZonePath(domain.parts[:depth]) - self.enter_step('get_domain_brws') + self.enter_step("get_domain_brws") if part not in dic.children: return dic = dic.children[part] depth += 1 if dic.match_zone.active(): - self.enter_step('get_domain_yield') + self.enter_step("get_domain_yield") yield ZonePath(domain.parts) if dic.match_hostname.active(): - self.enter_step('get_domain_yield') + self.enter_step("get_domain_yield") yield HostnamePath(domain.parts) def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]: - self.enter_step('get_ip4_pack') + self.enter_step("get_ip4_pack") ip4val = self.pack_ip4address_low(ip4_str) - self.enter_step('get_ip4_cache') + self.enter_step("get_ip4_cache") if not self.ip4cache[ip4val >> self.ip4cache_shift]: return - self.enter_step('get_ip4_brws') + self.enter_step("get_ip4_brws") dic = self.ip4tree for i in range(31, -1, -1): bit = (ip4val >> i) & 0b1 if dic.active(): - self.enter_step('get_ip4_yield') - yield Ip4Path(ip4val >> (i+1) << (i+1), 31-i) - self.enter_step('get_ip4_brws') + self.enter_step("get_ip4_yield") + yield Ip4Path(ip4val >> (i + 1) << (i + 1), 31 - i) + self.enter_step("get_ip4_brws") next_dic = dic.one if bit else dic.zero if next_dic is None: return dic = next_dic if dic.active(): - self.enter_step('get_ip4_yield') + self.enter_step("get_ip4_yield") yield Ip4Path(ip4val, 32) - def _unset_match(self, - match: Match, - ) -> None: + def _unset_match( + self, + match: Match, + ) -> None: match.disable() if match.source: source_match = self.get_match(match.source) source_match.references -= 1 - def _set_match(self, - match: Match, - updated: int, - source: Path, - source_match: Match = None, - dupplicate: bool = False, - ) -> None: + def _set_match( + self, + match: Match, + updated: int, + source: Path, + source_match: Match = None, + dupplicate: bool = False, + ) -> None: # source_match is in parameters because most of the time # its parent function needs it too, # so it can pass it to save a traversal source_match = source_match or self.get_match(source) new_level = source_match.level + 1 - if updated > match.updated or new_level < match.level \ - or source_match.first_party > match.first_party: + if ( + updated > match.updated + or new_level < match.level + or source_match.first_party > match.first_party + ): # NOTE FP and level of matches referencing this one # won't be updated until run or prune if match.source: @@ -708,20 +698,18 @@ class Database(Profiler): source_match.references += 1 match.dupplicate = dupplicate - def _set_domain(self, - hostname: bool, - domain_str: str, - updated: int, - source: Path) -> None: - self.enter_step('set_domain_val') + def _set_domain( + self, hostname: bool, domain_str: str, updated: int, source: Path + ) -> None: + self.enter_step("set_domain_val") if not Database.validate_domain(domain_str): raise ValueError(f"Invalid domain: {domain_str}") - self.enter_step('set_domain_pack') + self.enter_step("set_domain_pack") domain = self.pack_domain(domain_str) - self.enter_step('set_domain_fp') + self.enter_step("set_domain_fp") source_match = self.get_match(source) is_first_party = source_match.first_party - self.enter_step('set_domain_brws') + self.enter_step("set_domain_brws") dic = self.domtree dupplicate = False for part in domain.parts: @@ -742,21 +730,14 @@ class Database(Profiler): dupplicate=dupplicate, ) - def set_hostname(self, - *args: typing.Any, **kwargs: typing.Any - ) -> None: + def set_hostname(self, *args: typing.Any, **kwargs: typing.Any) -> None: self._set_domain(True, *args, **kwargs) - def set_zone(self, - *args: typing.Any, **kwargs: typing.Any - ) -> None: + def set_zone(self, *args: typing.Any, **kwargs: typing.Any) -> None: self._set_domain(False, *args, **kwargs) - def set_asn(self, - asn_str: str, - updated: int, - source: Path) -> None: - self.enter_step('set_asn') + def set_asn(self, asn_str: str, updated: int, source: Path) -> None: + self.enter_step("set_asn") path = self.pack_asn(asn_str) if path.asn in self.asns: match = self.asns[path.asn] @@ -769,17 +750,14 @@ class Database(Profiler): source, ) - def _set_ip4(self, - ip4: Ip4Path, - updated: int, - source: Path) -> None: - self.enter_step('set_ip4_fp') + def _set_ip4(self, ip4: Ip4Path, updated: int, source: Path) -> None: + self.enter_step("set_ip4_fp") source_match = self.get_match(source) is_first_party = source_match.first_party - self.enter_step('set_ip4_brws') + self.enter_step("set_ip4_brws") dic = self.ip4tree dupplicate = False - for i in range(31, 31-ip4.prefixlen, -1): + for i in range(31, 31 - ip4.prefixlen, -1): bit = (ip4.value >> i) & 0b1 next_dic = dic.one if bit else dic.zero if next_dic is None: @@ -800,24 +778,22 @@ class Database(Profiler): ) self._set_ip4cache(ip4, dic) - def set_ip4address(self, - ip4address_str: str, - *args: typing.Any, **kwargs: typing.Any - ) -> None: - self.enter_step('set_ip4add_val') + def set_ip4address( + self, ip4address_str: str, *args: typing.Any, **kwargs: typing.Any + ) -> None: + self.enter_step("set_ip4add_val") if not Database.validate_ip4address(ip4address_str): raise ValueError(f"Invalid ip4address: {ip4address_str}") - self.enter_step('set_ip4add_pack') + self.enter_step("set_ip4add_pack") ip4 = self.pack_ip4address(ip4address_str) self._set_ip4(ip4, *args, **kwargs) - def set_ip4network(self, - ip4network_str: str, - *args: typing.Any, **kwargs: typing.Any - ) -> None: - self.enter_step('set_ip4net_val') + def set_ip4network( + self, ip4network_str: str, *args: typing.Any, **kwargs: typing.Any + ) -> None: + self.enter_step("set_ip4net_val") if not Database.validate_ip4network(ip4network_str): raise ValueError(f"Invalid ip4network: {ip4network_str}") - self.enter_step('set_ip4net_pack') + self.enter_step("set_ip4net_pack") ip4 = self.pack_ip4network(ip4network_str) self._set_ip4(ip4, *args, **kwargs) diff --git a/db.py b/db.py index 91d00c5..2420c20 100755 --- a/db.py +++ b/db.py @@ -5,29 +5,37 @@ import database import time import os -if __name__ == '__main__': +if __name__ == "__main__": # Parsing arguments - parser = argparse.ArgumentParser( - description="Database operations") + parser = argparse.ArgumentParser(description="Database operations") parser.add_argument( - '-i', '--initialize', action='store_true', - help="Reconstruct the whole database") + "-i", "--initialize", action="store_true", help="Reconstruct the whole database" + ) parser.add_argument( - '-p', '--prune', action='store_true', - help="Remove old entries from database") + "-p", "--prune", action="store_true", help="Remove old entries from database" + ) parser.add_argument( - '-b', '--prune-base', action='store_true', + "-b", + "--prune-base", + action="store_true", help="With --prune, only prune base rules " - "(the ones added by ./feed_rules.py)") + "(the ones added by ./feed_rules.py)", + ) parser.add_argument( - '-s', '--prune-before', type=int, - default=(int(time.time()) - 60*60*24*31*6), + "-s", + "--prune-before", + type=int, + default=(int(time.time()) - 60 * 60 * 24 * 31 * 6), help="With --prune, only rules updated before " - "this UNIX timestamp will be deleted") + "this UNIX timestamp will be deleted", + ) parser.add_argument( - '-r', '--references', action='store_true', - help="DEBUG: Update the reference count") + "-r", + "--references", + action="store_true", + help="DEBUG: Update the reference count", + ) args = parser.parse_args() if not args.initialize: @@ -37,7 +45,7 @@ if __name__ == '__main__': os.unlink(database.Database.PATH) DB = database.Database() - DB.enter_step('main') + DB.enter_step("main") if args.prune: DB.prune(before=args.prune_before, base_only=args.prune_base) if args.references: diff --git a/export.py b/export.py index c5eefb2..bb172bd 100755 --- a/export.py +++ b/export.py @@ -5,53 +5,80 @@ import argparse import sys -if __name__ == '__main__': +if __name__ == "__main__": # Parsing arguments parser = argparse.ArgumentParser( - description="Export the hostnames rules stored " - "in the Database as plain text") + description="Export the hostnames rules stored " "in the Database as plain text" + ) parser.add_argument( - '-o', '--output', type=argparse.FileType('w'), default=sys.stdout, - help="Output file, one rule per line") + "-o", + "--output", + type=argparse.FileType("w"), + default=sys.stdout, + help="Output file, one rule per line", + ) parser.add_argument( - '-f', '--first-party', action='store_true', - help="Only output rules issued from first-party sources") + "-f", + "--first-party", + action="store_true", + help="Only output rules issued from first-party sources", + ) parser.add_argument( - '-e', '--end-chain', action='store_true', - help="Only output rules that are not referenced by any other") + "-e", + "--end-chain", + action="store_true", + help="Only output rules that are not referenced by any other", + ) parser.add_argument( - '-r', '--rules', action='store_true', - help="Output all kinds of rules, not just hostnames") + "-r", + "--rules", + action="store_true", + help="Output all kinds of rules, not just hostnames", + ) parser.add_argument( - '-b', '--base-rules', action='store_true', + "-b", + "--base-rules", + action="store_true", help="Output base rules " "(the ones added by ./feed_rules.py) " - "(implies --rules)") + "(implies --rules)", + ) parser.add_argument( - '-d', '--no-dupplicates', action='store_true', + "-d", + "--no-dupplicates", + action="store_true", help="Do not output rules that already match a zone/network rule " - "(e.g. dummy.example.com when there's a zone example.com rule)") + "(e.g. dummy.example.com when there's a zone example.com rule)", + ) parser.add_argument( - '-x', '--explain', action='store_true', + "-x", + "--explain", + action="store_true", help="Show the chain of rules leading to one " - "(and the number of references they have)") + "(and the number of references they have)", + ) parser.add_argument( - '-c', '--count', action='store_true', - help="Show the number of rules per type instead of listing them") + "-c", + "--count", + action="store_true", + help="Show the number of rules per type instead of listing them", + ) args = parser.parse_args() DB = database.Database() if args.count: assert not args.explain - print(DB.count_records( - first_party_only=args.first_party, - end_chain_only=args.end_chain, - no_dupplicates=args.no_dupplicates, - rules_only=args.base_rules, - hostnames_only=not (args.rules or args.base_rules), - )) + print( + DB.count_records( + first_party_only=args.first_party, + end_chain_only=args.end_chain, + no_dupplicates=args.no_dupplicates, + rules_only=args.base_rules, + hostnames_only=not (args.rules or args.base_rules), + ) + ) else: for domain in DB.list_records( first_party_only=args.first_party, diff --git a/feed_asn.py b/feed_asn.py index 25a35e2..e601d4a 100755 --- a/feed_asn.py +++ b/feed_asn.py @@ -13,57 +13,54 @@ IPNetwork = typing.Union[ipaddress.IPv4Network, ipaddress.IPv6Network] def get_ranges(asn: str) -> typing.Iterable[str]: req = requests.get( - 'https://stat.ripe.net/data/as-routing-consistency/data.json', - params={'resource': asn} + "https://stat.ripe.net/data/as-routing-consistency/data.json", + params={"resource": asn}, ) data = req.json() - for pref in data['data']['prefixes']: - yield pref['prefix'] + for pref in data["data"]["prefixes"]: + yield pref["prefix"] def get_name(asn: str) -> str: req = requests.get( - 'https://stat.ripe.net/data/as-overview/data.json', - params={'resource': asn} + "https://stat.ripe.net/data/as-overview/data.json", params={"resource": asn} ) data = req.json() - return data['data']['holder'] + return data["data"]["holder"] -if __name__ == '__main__': +if __name__ == "__main__": - log = logging.getLogger('feed_asn') + log = logging.getLogger("feed_asn") # Parsing arguments parser = argparse.ArgumentParser( - description="Add the IP ranges associated to the AS in the database") + description="Add the IP ranges associated to the AS in the database" + ) args = parser.parse_args() DB = database.Database() - def add_ranges(path: database.Path, - match: database.Match, - ) -> None: + def add_ranges( + path: database.Path, + match: database.Match, + ) -> None: assert isinstance(path, database.AsnPath) assert isinstance(match, database.AsnNode) asn_str = database.Database.unpack_asn(path) - DB.enter_step('asn_get_name') + DB.enter_step("asn_get_name") name = get_name(asn_str) match.name = name - DB.enter_step('asn_get_ranges') + DB.enter_step("asn_get_ranges") for prefix in get_ranges(asn_str): parsed_prefix: IPNetwork = ipaddress.ip_network(prefix) if parsed_prefix.version == 4: - DB.set_ip4network( - prefix, - source=path, - updated=int(time.time()) - ) - log.info('Added %s from %s (%s)', prefix, path, name) + DB.set_ip4network(prefix, source=path, updated=int(time.time())) + log.info("Added %s from %s (%s)", prefix, path, name) elif parsed_prefix.version == 6: - log.warning('Unimplemented prefix version: %s', prefix) + log.warning("Unimplemented prefix version: %s", prefix) else: - log.error('Unknown prefix version: %s', prefix) + log.error("Unknown prefix version: %s", prefix) for _ in DB.exec_each_asn(add_ranges): pass diff --git a/feed_dns.py b/feed_dns.py index bf5f296..0e19b88 100755 --- a/feed_dns.py +++ b/feed_dns.py @@ -12,15 +12,15 @@ Record = typing.Tuple[typing.Callable, typing.Callable, int, str, str] # select, write FUNCTION_MAP: typing.Any = { - 'a': ( + "a": ( database.Database.get_ip4, database.Database.set_hostname, ), - 'cname': ( + "cname": ( database.Database.get_domain, database.Database.set_hostname, ), - 'ptr': ( + "ptr": ( database.Database.get_domain, database.Database.set_ip4address, ), @@ -28,15 +28,16 @@ FUNCTION_MAP: typing.Any = { class Writer(multiprocessing.Process): - def __init__(self, - recs_queue: multiprocessing.Queue = None, - autosave_interval: int = 0, - ip4_cache: int = 0, - ): + def __init__( + self, + recs_queue: multiprocessing.Queue = None, + autosave_interval: int = 0, + ip4_cache: int = 0, + ): if recs_queue: # MP super(Writer, self).__init__() self.recs_queue = recs_queue - self.log = logging.getLogger(f'wr') + self.log = logging.getLogger(f"wr") self.autosave_interval = autosave_interval self.ip4_cache = ip4_cache if not recs_queue: # No MP @@ -44,11 +45,11 @@ class Writer(multiprocessing.Process): def open_db(self) -> None: self.db = database.Database() - self.db.log = logging.getLogger(f'wr') + self.db.log = logging.getLogger(f"wr") self.db.fill_ip4cache(max_size=self.ip4_cache) def exec_record(self, record: Record) -> None: - self.db.enter_step('exec_record') + self.db.enter_step("exec_record") select, write, updated, name, value = record try: for source in select(self.db, value): @@ -59,7 +60,7 @@ class Writer(multiprocessing.Process): self.log.exception("Cannot execute: %s", record) def end(self) -> None: - self.db.enter_step('end') + self.db.enter_step("end") self.db.save() def run(self) -> None: @@ -69,7 +70,7 @@ class Writer(multiprocessing.Process): else: next_save = 0 - self.db.enter_step('block_wait') + self.db.enter_step("block_wait") block: typing.List[Record] for block in iter(self.recs_queue.get, None): @@ -83,20 +84,21 @@ class Writer(multiprocessing.Process): self.log.info("Done!") next_save = time.time() + self.autosave_interval - self.db.enter_step('block_wait') + self.db.enter_step("block_wait") self.end() -class Parser(): - def __init__(self, - buf: typing.Any, - recs_queue: multiprocessing.Queue = None, - block_size: int = 0, - writer: Writer = None, - ): +class Parser: + def __init__( + self, + buf: typing.Any, + recs_queue: multiprocessing.Queue = None, + block_size: int = 0, + writer: Writer = None, + ): assert bool(writer) ^ bool(block_size and recs_queue) self.buf = buf - self.log = logging.getLogger('pr') + self.log = logging.getLogger("pr") self.recs_queue = recs_queue if writer: # No MP self.prof: database.Profiler = writer.db @@ -105,14 +107,14 @@ class Parser(): self.block: typing.List[Record] = list() self.block_size = block_size self.prof = database.Profiler() - self.prof.log = logging.getLogger('pr') + self.prof.log = logging.getLogger("pr") self.register = self.add_to_queue def add_to_queue(self, record: Record) -> None: - self.prof.enter_step('register') + self.prof.enter_step("register") self.block.append(record) if len(self.block) >= self.block_size: - self.prof.enter_step('put_block') + self.prof.enter_step("put_block") assert self.recs_queue self.recs_queue.put(self.block) self.block = list() @@ -131,26 +133,26 @@ class Rapid7Parser(Parser): def consume(self) -> None: data = dict() for line in self.buf: - self.prof.enter_step('parse_rapid7') + self.prof.enter_step("parse_rapid7") split = line.split('"') try: for k in range(1, 14, 4): key = split[k] - val = split[k+2] + val = split[k + 2] data[key] = val - select, writer = FUNCTION_MAP[data['type']] + select, writer = FUNCTION_MAP[data["type"]] record = ( select, writer, - int(data['timestamp']), - data['name'], - data['value'] + int(data["timestamp"]), + data["name"], + data["value"], ) except (IndexError, KeyError): - # IndexError: missing field - # KeyError: Unknown type field + # IndexError: missing field + # KeyError: Unknown type field self.log.exception("Cannot parse: %s", line) self.register(record) @@ -159,13 +161,13 @@ class MassDnsParser(Parser): # massdns --output Snrql # --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4 TYPES = { - 'A': (FUNCTION_MAP['a'][0], FUNCTION_MAP['a'][1], -1, None), + "A": (FUNCTION_MAP["a"][0], FUNCTION_MAP["a"][1], -1, None), # 'AAAA': (FUNCTION_MAP['aaaa'][0], FUNCTION_MAP['aaaa'][1], -1, None), - 'CNAME': (FUNCTION_MAP['cname'][0], FUNCTION_MAP['cname'][1], -1, -1), + "CNAME": (FUNCTION_MAP["cname"][0], FUNCTION_MAP["cname"][1], -1, -1), } def consume(self) -> None: - self.prof.enter_step('parse_massdns') + self.prof.enter_step("parse_massdns") timestamp = 0 header = True for line in self.buf: @@ -174,14 +176,15 @@ class MassDnsParser(Parser): header = True continue - split = line.split(' ') + split = line.split(" ") try: if header: timestamp = int(split[1]) header = False else: - select, write, name_offset, value_offset = \ - MassDnsParser.TYPES[split[1]] + select, write, name_offset, value_offset = MassDnsParser.TYPES[ + split[1] + ] record = ( select, write, @@ -190,75 +193,86 @@ class MassDnsParser(Parser): split[2][:value_offset].lower(), ) self.register(record) - self.prof.enter_step('parse_massdns') + self.prof.enter_step("parse_massdns") except KeyError: continue PARSERS = { - 'rapid7': Rapid7Parser, - 'massdns': MassDnsParser, + "rapid7": Rapid7Parser, + "massdns": MassDnsParser, } -if __name__ == '__main__': +if __name__ == "__main__": # Parsing arguments - log = logging.getLogger('feed_dns') + log = logging.getLogger("feed_dns") args_parser = argparse.ArgumentParser( description="Read DNS records and import " - "tracking-relevant data into the database") + "tracking-relevant data into the database" + ) + args_parser.add_argument("parser", choices=PARSERS.keys(), help="Input format") args_parser.add_argument( - 'parser', - choices=PARSERS.keys(), - help="Input format") + "-i", + "--input", + type=argparse.FileType("r"), + default=sys.stdin, + help="Input file", + ) args_parser.add_argument( - '-i', '--input', type=argparse.FileType('r'), default=sys.stdin, - help="Input file") + "-b", "--block-size", type=int, default=1024, help="Performance tuning value" + ) args_parser.add_argument( - '-b', '--block-size', type=int, default=1024, - help="Performance tuning value") + "-q", "--queue-size", type=int, default=128, help="Performance tuning value" + ) args_parser.add_argument( - '-q', '--queue-size', type=int, default=128, - help="Performance tuning value") + "-a", + "--autosave-interval", + type=int, + default=900, + help="Interval to which the database will save in seconds. " "0 to disable.", + ) args_parser.add_argument( - '-a', '--autosave-interval', type=int, default=900, - help="Interval to which the database will save in seconds. " - "0 to disable.") + "-s", + "--single-process", + action="store_true", + help="Only use one process. " "Might be useful for single core computers.", + ) args_parser.add_argument( - '-s', '--single-process', action='store_true', - help="Only use one process. " - "Might be useful for single core computers.") - args_parser.add_argument( - '-4', '--ip4-cache', type=int, default=0, + "-4", + "--ip4-cache", + type=int, + default=0, help="RAM cache for faster IPv4 lookup. " "Maximum useful value: 512 MiB (536870912). " "Warning: Depending on the rules, this might already " - "be a memory-heavy process, even without the cache.") + "be a memory-heavy process, even without the cache.", + ) args = args_parser.parse_args() parser_cls = PARSERS[args.parser] if args.single_process: writer = Writer( - autosave_interval=args.autosave_interval, - ip4_cache=args.ip4_cache + autosave_interval=args.autosave_interval, ip4_cache=args.ip4_cache ) parser = parser_cls(args.input, writer=writer) parser.run() writer.end() else: recs_queue: multiprocessing.Queue = multiprocessing.Queue( - maxsize=args.queue_size) + maxsize=args.queue_size + ) - writer = Writer(recs_queue, - autosave_interval=args.autosave_interval, - ip4_cache=args.ip4_cache - ) + writer = Writer( + recs_queue, + autosave_interval=args.autosave_interval, + ip4_cache=args.ip4_cache, + ) writer.start() - parser = parser_cls(args.input, - recs_queue=recs_queue, - block_size=args.block_size - ) + parser = parser_cls( + args.input, recs_queue=recs_queue, block_size=args.block_size + ) parser.run() recs_queue.put(None) diff --git a/feed_rules.py b/feed_rules.py index 9d0365f..1b8f215 100755 --- a/feed_rules.py +++ b/feed_rules.py @@ -6,28 +6,33 @@ import sys import time FUNCTION_MAP = { - 'zone': database.Database.set_zone, - 'hostname': database.Database.set_hostname, - 'asn': database.Database.set_asn, - 'ip4network': database.Database.set_ip4network, - 'ip4address': database.Database.set_ip4address, + "zone": database.Database.set_zone, + "hostname": database.Database.set_hostname, + "asn": database.Database.set_asn, + "ip4network": database.Database.set_ip4network, + "ip4address": database.Database.set_ip4address, } -if __name__ == '__main__': +if __name__ == "__main__": # Parsing arguments - parser = argparse.ArgumentParser( - description="Import base rules to the database") + parser = argparse.ArgumentParser(description="Import base rules to the database") parser.add_argument( - 'type', - choices=FUNCTION_MAP.keys(), - help="Type of rule inputed") + "type", choices=FUNCTION_MAP.keys(), help="Type of rule inputed" + ) parser.add_argument( - '-i', '--input', type=argparse.FileType('r'), default=sys.stdin, - help="File with one rule per line") + "-i", + "--input", + type=argparse.FileType("r"), + default=sys.stdin, + help="File with one rule per line", + ) parser.add_argument( - '-f', '--first-party', action='store_true', - help="The input only comes from verified first-party sources") + "-f", + "--first-party", + action="store_true", + help="The input only comes from verified first-party sources", + ) args = parser.parse_args() DB = database.Database() @@ -43,11 +48,12 @@ if __name__ == '__main__': for rule in args.input: rule = rule.strip() try: - fun(DB, + fun( + DB, rule, source=source, updated=int(time.time()), - ) + ) except ValueError: DB.log.error(f"Could not add rule: {rule}") diff --git a/generate_index.py b/generate_index.py index 9a5a03e..2f9415b 100755 --- a/generate_index.py +++ b/generate_index.py @@ -2,11 +2,9 @@ import markdown2 -extras = [ - "header-ids" -] +extras = ["header-ids"] -with open('dist/README.md', 'r') as fdesc: +with open("dist/README.md", "r") as fdesc: body = markdown2.markdown(fdesc.read(), extras=extras) output = f""" @@ -23,5 +21,5 @@ output = f""" """ -with open('dist/index.html', 'w') as fdesc: +with open("dist/index.html", "w") as fdesc: fdesc.write(output)