From aec8d3f8de50f4cef07f22992d1530880674faad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20=E2=80=9CFrogeye=E2=80=9D=20Preud=27homme?= Date: Sun, 15 Dec 2019 22:21:05 +0100 Subject: [PATCH] Reworked how paths work Get those tuples out of my eyes --- database.py | 278 ++++++++++++++++++++++++++++++++---------------- feed_asn.py | 7 +- feed_dns.old.py | 147 +++++++++++++++++++++++++ feed_dns.py | 36 ++++--- 4 files changed, 354 insertions(+), 114 deletions(-) create mode 100755 feed_dns.old.py diff --git a/database.py b/database.py index 7e78a9d..fc2855e 100644 --- a/database.py +++ b/database.py @@ -16,19 +16,48 @@ coloredlogs.install( fmt='%(asctime)s %(name)s %(levelname)s %(message)s' ) -PathType = enum.Enum('PathType', 'Rule Hostname Zone Asn Ip4 Ip6') -RulePath = typing.Union[None] Asn = int -DomainPath = typing.List[str] -Ip4Path = typing.Tuple[int, int] # value, prefixlen -Ip6Path = typing.List[int] -Path = typing.Union[RulePath, DomainPath, Asn, Ip4Path, Ip6Path] -TypedPath = typing.Tuple[PathType, Path] Timestamp = int Level = int -Match = typing.Tuple[Timestamp, TypedPath, Level] -DebugPath = (PathType.Rule, None) + +class Path(): + pass + + +class RulePath(Path): + pass + + +class DomainPath(Path): + def __init__(self, path: typing.List[str]): + self.path = path + + +class HostnamePath(DomainPath): + pass + + +class ZonePath(DomainPath): + pass + + +class AsnPath(Path): + def __init__(self, asn: Asn): + self.asn = asn + + +class Ip4Path(Path): + def __init__(self, value: int, prefixlen: int): + self.value = value + self.prefixlen = prefixlen + + +Match = typing.Tuple[Timestamp, Path, Level] + +# class AsnNode(): +# def __init__(self, asn: int) -> None: +# self.asn = asn class DomainTreeNode(): @@ -44,6 +73,13 @@ class IpTreeNode(): self.match: typing.Optional[Match] = None +Node = typing.Union[DomainTreeNode, IpTreeNode, Asn] +NodeCallable = typing.Callable[[Path, + Node, + typing.Optional[typing.Any]], + typing.Any] + + class Profiler(): def __init__(self) -> None: self.log = logging.getLogger('profiler') @@ -53,6 +89,7 @@ class Profiler(): self.step_dict: typing.Dict[str, int] = dict() def enter_step(self, name: str) -> None: + return now = time.perf_counter() try: self.time_dict[self.time_step] += now - self.time_last @@ -75,7 +112,7 @@ class Profiler(): class Database(Profiler): - VERSION = 8 + VERSION = 9 PATH = "blocking.p" def initialize(self) -> None: @@ -120,34 +157,34 @@ class Database(Profiler): @staticmethod def pack_domain(domain: str) -> DomainPath: - return domain.split('.')[::-1] + return DomainPath(domain.split('.')[::-1]) @staticmethod def unpack_domain(domain: DomainPath) -> str: - return '.'.join(domain[::-1]) + return '.'.join(domain.path[::-1]) @staticmethod - def pack_asn(asn: str) -> int: + def pack_asn(asn: str) -> AsnPath: asn = asn.upper() if asn.startswith('AS'): asn = asn[2:] - return int(asn) + return AsnPath(int(asn)) @staticmethod - def unpack_asn(asn: int) -> str: - return f'AS{asn}' + def unpack_asn(asn: AsnPath) -> str: + return f'AS{asn.asn}' @staticmethod def pack_ip4address(address: str) -> Ip4Path: addr = 0 for split in address.split('.'): addr = (addr << 8) + int(split) - return (addr, 32) + return Ip4Path(addr, 32) @staticmethod def unpack_ip4address(address: Ip4Path) -> str: - addr, prefixlen = address - assert prefixlen == 32 + addr = address.value + assert address.prefixlen == 32 octets: typing.List[int] = list() octets = [0] * 4 for o in reversed(range(4)): @@ -159,14 +196,76 @@ class Database(Profiler): def pack_ip4network(network: str) -> Ip4Path: address, prefixlen_str = network.split('/') prefixlen = int(prefixlen_str) - addr, _ = Database.pack_ip4address(address) - return (addr, prefixlen) + addr = Database.pack_ip4address(address) + addr.prefixlen = prefixlen + return addr @staticmethod def unpack_ip4network(network: Ip4Path) -> str: - address, prefixlen = network - addr = Database.unpack_ip4address((address, 32)) - return f'{addr}/{prefixlen}' + addr = network.value + octets: typing.List[int] = list() + octets = [0] * 4 + for o in reversed(range(4)): + octets[o] = addr & 0xFF + addr >>= 8 + return '.'.join(map(str, octets)) + '/' + str(network.prefixlen) + + def exec_each_domain(self, + callback: NodeCallable, + arg: typing.Any = None, + _dic: DomainTreeNode = None, + _par: DomainPath = None, + ) -> typing.Any: + _dic = _dic or self.domtree + _par = _par or DomainPath([]) + yield from callback(_par, _dic, arg) + for part in _dic.children: + dic = _dic.children[part] + yield from self.exec_each_domain( + callback, + arg, + _dic=dic, + _par=DomainPath(_par.path + [part]) + ) + + def exec_each_ip4(self, + callback: NodeCallable, + arg: typing.Any = None, + _dic: IpTreeNode = None, + _par: Ip4Path = None, + ) -> typing.Any: + _dic = _dic or self.ip4tree + _par = _par or Ip4Path(0, 0) + callback(_par, _dic, arg) + + # 0 + dic = _dic.children[0] + if dic: + addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-_par.prefixlen))) + assert addr0 == _par.value + yield from self.exec_each_ip4( + callback, + arg, + _dic=dic, + _par=Ip4Path(addr0, _par.prefixlen+1) + ) + # 1 + dic = _dic.children[1] + if dic: + addr1 = _par.value | (1 << (32-_par.prefixlen)) + yield from self.exec_each_ip4( + callback, + arg, + _dic=dic, + _par=Ip4Path(addr1, _par.prefixlen+1) + ) + + def exec_each(self, + callback: NodeCallable, + arg: typing.Any = None, + ) -> typing.Any: + yield from self.exec_each_domain(callback) + yield from self.exec_each_ip4(callback) def update_references(self) -> None: raise NotImplementedError @@ -181,35 +280,35 @@ class Database(Profiler): first_party_only: bool = False, end_chain_only: bool = False, explain: bool = False, - _dic: DomainTreeNode = None, - _par: DomainPath = None, ) -> typing.Iterable[str]: if first_party_only or end_chain_only or explain: raise NotImplementedError - _dic = _dic or self.domtree - _par = _par or list() - if _dic.match_hostname: - yield self.unpack_domain(_par) - for part in _dic.children: - dic = _dic.children[part] - yield from self.export(_dic=dic, - _par=_par + [part]) + + def export_cb(path: Path, node: Node, _: typing.Any + ) -> typing.Iterable[str]: + assert isinstance(path, DomainPath) + assert isinstance(node, DomainTreeNode) + if node.match_hostname: + a = self.unpack_domain(path) + yield a + + yield from self.exec_each_domain(export_cb, None) def count_rules(self, first_party_only: bool = False, ) -> str: raise NotImplementedError - def get_domain(self, domain_str: str) -> typing.Iterable[TypedPath]: + def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]: self.enter_step('get_domain_pack') domain = self.pack_domain(domain_str) self.enter_step('get_domain_brws') dic = self.domtree depth = 0 - for part in domain: + for part in domain.path: if dic.match_zone: self.enter_step('get_domain_yield') - yield (PathType.Zone, domain[:depth]) + yield ZonePath(domain.path[:depth]) self.enter_step('get_domain_brws') if part not in dic.children: return @@ -217,21 +316,21 @@ class Database(Profiler): depth += 1 if dic.match_zone: self.enter_step('get_domain_yield') - yield (PathType.Zone, domain) + yield ZonePath(domain.path) if dic.match_hostname: self.enter_step('get_domain_yield') - yield (PathType.Hostname, domain) + yield HostnamePath(domain.path) - def get_ip4(self, ip4_str: str) -> typing.Iterable[TypedPath]: + def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]: self.enter_step('get_ip4_pack') - ip4, prefixlen = self.pack_ip4address(ip4_str) + ip4 = self.pack_ip4address(ip4_str) self.enter_step('get_ip4_brws') dic = self.ip4tree - for i in reversed(range(prefixlen)): - part = (ip4 >> i) & 0b1 + for i in reversed(range(ip4.prefixlen)): + part = (ip4.value >> i) & 0b1 if dic.match: self.enter_step('get_ip4_yield') - yield (PathType.Ip4, (ip4, 32-i)) + yield Ip4Path(ip4.value, 32-i) self.enter_step('get_ip4_brws') next_dic = dic.children[part] if next_dic is None: @@ -239,108 +338,99 @@ class Database(Profiler): dic = next_dic if dic.match: self.enter_step('get_ip4_yield') - yield (PathType.Ip4, ip4) + yield ip4 - def list_asn(self) -> typing.Iterable[TypedPath]: + def list_asn(self) -> typing.Iterable[AsnPath]: for asn in self.asns: - yield (PathType.Asn, asn) + yield AsnPath(asn) def set_hostname(self, hostname_str: str, updated: int, is_first_party: bool = None, - source: TypedPath = None) -> None: + source: Path = None) -> None: self.enter_step('set_hostname_pack') - if is_first_party or source: + if is_first_party: raise NotImplementedError self.enter_step('set_hostname_brws') hostname = self.pack_domain(hostname_str) dic = self.domtree - for part in hostname: + for part in hostname.path: if dic.match_zone: # Refuse to add hostname whose zone is already matching return if part not in dic.children: dic.children[part] = DomainTreeNode() dic = dic.children[part] - dic.match_hostname = (updated, DebugPath, 0) + dic.match_hostname = (updated, source or RulePath(), 0) def set_zone(self, zone_str: str, updated: int, is_first_party: bool = None, - source: TypedPath = None) -> None: + source: Path = None) -> None: self.enter_step('set_zone_pack') - if is_first_party or source: + if is_first_party: raise NotImplementedError zone = self.pack_domain(zone_str) self.enter_step('set_zone_brws') dic = self.domtree - for part in zone: + for part in zone.path: if dic.match_zone: # Refuse to add zone whose parent zone is already matching return if part not in dic.children: dic.children[part] = DomainTreeNode() dic = dic.children[part] - dic.match_zone = (updated, DebugPath, 0) + dic.match_zone = (updated, source or RulePath(), 0) def set_asn(self, asn_str: str, updated: int, is_first_party: bool = None, - source: TypedPath = None) -> None: + source: Path = None) -> None: self.enter_step('set_asn_pack') if is_first_party or source: # TODO updated raise NotImplementedError asn = self.pack_asn(asn_str) self.enter_step('set_asn_brws') - self.asns.add(asn) + self.asns.add(asn.asn) + + def _set_ip4(self, + ip4: Ip4Path, + updated: int, + is_first_party: bool = None, + source: Path = None) -> None: + if is_first_party: + raise NotImplementedError + dic = self.ip4tree + for i in reversed(range(ip4.prefixlen)): + part = (ip4.value >> i) & 0b1 + if dic.match: + # Refuse to add ip4* whose network is already matching + return + next_dic = dic.children[part] + if next_dic is None: + next_dic = IpTreeNode() + dic.children[part] = next_dic + dic = next_dic + dic.match = (updated, source or RulePath(), 0) def set_ip4address(self, ip4address_str: str, - updated: int, - is_first_party: bool = None, - source: TypedPath = None) -> None: + *args: typing.Any, **kwargs: typing.Any + ) -> None: self.enter_step('set_ip4add_pack') - if is_first_party or source: - raise NotImplementedError - ip4, prefixlen = self.pack_ip4address(ip4address_str) + ip4 = self.pack_ip4address(ip4address_str) self.enter_step('set_ip4add_brws') - dic = self.ip4tree - for i in reversed(range(prefixlen)): - part = (ip4 >> i) & 0b1 - if dic.match: - # Refuse to add ip4address whose network is already matching - return - next_dic = dic.children[part] - if next_dic is None: - next_dic = IpTreeNode() - dic.children[part] = next_dic - dic = next_dic - dic.match = (updated, DebugPath, 0) + self._set_ip4(ip4, *args, **kwargs) def set_ip4network(self, ip4network_str: str, - updated: int, - is_first_party: bool = None, - source: TypedPath = None) -> None: + *args: typing.Any, **kwargs: typing.Any + ) -> None: self.enter_step('set_ip4net_pack') - if is_first_party or source: - raise NotImplementedError + ip4 = self.pack_ip4network(ip4network_str) self.enter_step('set_ip4net_brws') - ip4, prefixlen = self.pack_ip4network(ip4network_str) - dic = self.ip4tree - for i in reversed(range(prefixlen)): - part = (ip4 >> i) & 0b1 - if dic.match: - # Refuse to add ip4network whose parent network - # is already matching - return - next_dic = dic.children[part] - if next_dic is None: - next_dic = IpTreeNode() - dic.children[part] = next_dic - dic = next_dic - dic.match = (updated, DebugPath, 0) + self._set_ip4(ip4, *args, **kwargs) diff --git a/feed_asn.py b/feed_asn.py index ead63fe..aa311f8 100755 --- a/feed_asn.py +++ b/feed_asn.py @@ -33,10 +33,7 @@ if __name__ == '__main__': DB = database.Database() for path in DB.list_asn(): - ptype, asn = path - assert ptype == database.PathType.Asn - assert isinstance(asn, int) - asn_str = database.Database.unpack_asn(asn) + asn_str = database.Database.unpack_asn(path) DB.enter_step('asn_get_ranges') for prefix in get_ranges(asn_str): parsed_prefix: IPNetwork = ipaddress.ip_network(prefix) @@ -46,7 +43,7 @@ if __name__ == '__main__': # source=path, updated=int(time.time()) ) - log.info('Added %s from %s (source=%s)', prefix, asn, path) + log.info('Added %s from %s (%s)', prefix, asn_str, path) elif parsed_prefix.version == 6: log.warning('Unimplemented prefix version: %s', prefix) else: diff --git a/feed_dns.old.py b/feed_dns.old.py new file mode 100755 index 0000000..b106968 --- /dev/null +++ b/feed_dns.old.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 + +import argparse +import database +import logging +import sys +import typing +import enum + +RecordType = enum.Enum('RecordType', 'A AAAA CNAME PTR') +Record = typing.Tuple[RecordType, int, str, str] + +# select, write +FUNCTION_MAP: typing.Any = { + RecordType.A: ( + database.Database.get_ip4, + database.Database.set_hostname, + ), + RecordType.CNAME: ( + database.Database.get_domain, + database.Database.set_hostname, + ), + RecordType.PTR: ( + database.Database.get_domain, + database.Database.set_ip4address, + ), +} + + +class Parser(): + def __init__(self, buf: typing.Any) -> None: + self.buf = buf + self.log = logging.getLogger('parser') + self.db = database.Database() + + def end(self) -> None: + self.db.save() + + def register(self, + rtype: RecordType, + updated: int, + name: str, + value: str + ) -> None: + + self.db.enter_step('register') + select, write = FUNCTION_MAP[rtype] + for source in select(self.db, value): + # write(self.db, name, updated, source=source) + write(self.db, name, updated) + + def consume(self) -> None: + raise NotImplementedError + + +class Rapid7Parser(Parser): + TYPES = { + 'a': RecordType.A, + 'aaaa': RecordType.AAAA, + 'cname': RecordType.CNAME, + 'ptr': RecordType.PTR, + } + + def consume(self) -> None: + data = dict() + for line in self.buf: + self.db.enter_step('parse_rapid7') + split = line.split('"') + + for k in range(1, 14, 4): + key = split[k] + val = split[k+2] + data[key] = val + + self.register( + Rapid7Parser.TYPES[data['type']], + int(data['timestamp']), + data['name'], + data['value'] + ) + + +class DnsMassParser(Parser): + # dnsmass --output Snrql + # --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4 + TYPES = { + 'A': (RecordType.A, -1, None), + 'AAAA': (RecordType.AAAA, -1, None), + 'CNAME': (RecordType.CNAME, -1, -1), + } + + def consume(self) -> None: + self.db.enter_step('parse_dnsmass') + timestamp = 0 + header = True + for line in self.buf: + line = line[:-1] + if not line: + header = True + continue + + split = line.split(' ') + try: + if header: + timestamp = int(split[1]) + header = False + else: + dtype, name_offset, value_offset = \ + DnsMassParser.TYPES[split[1]] + self.register( + dtype, + timestamp, + split[0][:name_offset], + split[2][:value_offset], + ) + self.db.enter_step('parse_dnsmass') + except KeyError: + continue + + +PARSERS = { + 'rapid7': Rapid7Parser, + 'dnsmass': DnsMassParser, +} + +if __name__ == '__main__': + + # Parsing arguments + log = logging.getLogger('feed_dns') + args_parser = argparse.ArgumentParser( + description="TODO") + args_parser.add_argument( + 'parser', + choices=PARSERS.keys(), + help="TODO") + args_parser.add_argument( + '-i', '--input', type=argparse.FileType('r'), default=sys.stdin, + help="TODO") + args = args_parser.parse_args() + + parser = PARSERS[args.parser](args.input) + try: + parser.consume() + except KeyboardInterrupt: + pass + parser.end() + diff --git a/feed_dns.py b/feed_dns.py index d72dc49..be08e98 100755 --- a/feed_dns.py +++ b/feed_dns.py @@ -49,9 +49,12 @@ class Writer(multiprocessing.Process): select, write, updated, name, value = record self.db.enter_step('feed_switch') - for source in select(self.db, value): - # write(self.db, name, updated, source=source) - write(self.db, name, updated) + try: + for source in select(self.db, value): + # write(self.db, name, updated, source=source) + write(self.db, name, updated) + except ValueError: + self.log.exception("Cannot execute: %s", record) self.db.enter_step('block_wait') @@ -98,19 +101,22 @@ class Rapid7Parser(Parser): self.prof.enter_step('parse_rapid7') split = line.split('"') - for k in range(1, 14, 4): - key = split[k] - val = split[k+2] - data[key] = val + try: + for k in range(1, 14, 4): + key = split[k] + val = split[k+2] + data[key] = val - select, writer = FUNCTION_MAP[data['type']] - record = ( - select, - writer, - int(data['timestamp']), - data['name'], - data['value'] - ) + select, writer = FUNCTION_MAP[data['type']] + record = ( + select, + writer, + int(data['timestamp']), + data['name'], + data['value'] + ) + except IndexError: + self.log.exception("Cannot parse: %s", line) self.register(record)