diff --git a/database.py b/database.py index 13f8876..ea41361 100644 --- a/database.py +++ b/database.py @@ -395,9 +395,7 @@ class Database(Profiler): ) -> str: raise NotImplementedError - def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]: - self.enter_step('get_domain_pack') - domain = self.pack_domain(domain_str) + def get_domain(self, domain: DomainPath) -> typing.Iterable[DomainPath]: self.enter_step('get_domain_brws') dic = self.domtree depth = 0 @@ -417,9 +415,7 @@ class Database(Profiler): self.enter_step('get_domain_yield') yield HostnamePath(domain.parts) - def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]: - self.enter_step('get_ip4_pack') - ip4 = self.pack_ip4address(ip4_str) + def get_ip4(self, ip4: Ip4Path) -> typing.Iterable[Path]: self.enter_step('get_ip4_brws') dic = self.ip4tree for i in range(31, 31-ip4.prefixlen, -1): @@ -443,14 +439,12 @@ class Database(Profiler): def _set_domain(self, hostname: bool, - domain_str: str, + domain: DomainPath, updated: int, is_first_party: bool = None, source: Path = None) -> None: - self.enter_step('set_domain_pack') if is_first_party: raise NotImplementedError - domain = self.pack_domain(domain_str) self.enter_step('set_domain_src') if source is None: level = 0 @@ -488,7 +482,7 @@ class Database(Profiler): self._set_domain(False, *args, **kwargs) def set_asn(self, - asn_str: str, + asn: AsnPath, updated: int, is_first_party: bool = None, source: Path = None) -> None: @@ -501,23 +495,22 @@ class Database(Profiler): else: match = self.get_match(source) level = match.level + 1 - path = self.pack_asn(asn_str) - if path.asn in self.asns: - match = self.asns[path.asn] + if asn.asn in self.asns: + match = self.asns[asn.asn] else: match = AsnNode() - self.asns[path.asn] = match + self.asns[asn.asn] = match match.set( updated, level, source, ) - def _set_ip4(self, - ip4: Ip4Path, - updated: int, - is_first_party: bool = None, - source: Path = None) -> None: + def set_ip4network(self, + ip4: Ip4Path, + updated: int, + is_first_party: bool = None, + source: Path = None) -> None: if is_first_party: raise NotImplementedError self.enter_step('set_ip4_src') @@ -549,17 +542,8 @@ class Database(Profiler): ) def set_ip4address(self, - ip4address_str: str, + ip4: Ip4Path, *args: typing.Any, **kwargs: typing.Any ) -> None: - self.enter_step('set_ip4add_pack') - ip4 = self.pack_ip4address(ip4address_str) - self._set_ip4(ip4, *args, **kwargs) - - def set_ip4network(self, - ip4network_str: str, - *args: typing.Any, **kwargs: typing.Any - ) -> None: - self.enter_step('set_ip4net_pack') - ip4 = self.pack_ip4network(ip4network_str) - self._set_ip4(ip4, *args, **kwargs) + assert ip4.prefixlen == 32 + self.set_ip4network(ip4, *args, **kwargs) diff --git a/feed_dns.py b/feed_dns.py index 43df1fd..39b561f 100755 --- a/feed_dns.py +++ b/feed_dns.py @@ -8,21 +8,28 @@ import typing import multiprocessing import enum -Record = typing.Tuple[typing.Callable, typing.Callable, int, str, str] +Record = typing.Tuple[typing.Callable, + typing.Callable, int, database.Path, database.Path] -# select, write +# select, write, name_packer, value_packer FUNCTION_MAP: typing.Any = { 'a': ( database.Database.get_ip4, database.Database.set_hostname, + database.Database.pack_domain, + database.Database.pack_ip4address, ), 'cname': ( database.Database.get_domain, database.Database.set_hostname, + database.Database.pack_domain, + database.Database.pack_domain, ), 'ptr': ( database.Database.get_domain, database.Database.set_ip4address, + database.Database.pack_ip4address, + database.Database.pack_domain, ), } @@ -49,11 +56,8 @@ class Writer(multiprocessing.Process): select, write, updated, name, value = record self.db.enter_step('feed_switch') - try: - for source in select(self.db, value): - write(self.db, name, updated, source=source) - except ValueError: - self.log.exception("Cannot execute: %s", record) + for source in select(self.db, value): + write(self.db, name, updated, source=source) self.db.enter_step('block_wait') @@ -76,8 +80,33 @@ class Parser(): self.prof = database.Profiler() self.prof.log = logging.getLogger('pr') - def register(self, record: Record) -> None: - self.prof.enter_step('register') + def register(self, + rtype: str, + timestamp: int, + name_str: str, + value_str: str, + ) -> None: + self.prof.enter_step('pack') + try: + select, write, name_packer, value_packer = FUNCTION_MAP[rtype] + except KeyError: + self.log.exception("Unknown record type") + return + try: + name = name_packer(name_str) + except ValueError: + self.log.exception("Cannot parse name ('%s' with %s)", + name_str, name_packer) + return + try: + value = value_packer(value_str) + except ValueError: + self.log.exception("Cannot parse value ('%s' with %s)", + value_str, value_packer) + return + record = (select, write, timestamp, name, value) + + self.prof.enter_step('grow_block') self.block.append(record) if len(self.block) >= self.block_size: self.prof.enter_step('put_block') @@ -96,6 +125,7 @@ class Parser(): class Rapid7Parser(Parser): def consume(self) -> None: data = dict() + self.prof.enter_step('iowait') for line in self.buf: self.prof.enter_step('parse_rapid7') split = line.split('"') @@ -106,26 +136,25 @@ class Rapid7Parser(Parser): val = split[k+2] data[key] = val - select, writer = FUNCTION_MAP[data['type']] - record = ( - select, - writer, + self.register( + data['type'], int(data['timestamp']), data['name'], - data['value'] + data['value'], ) - except IndexError: + self.prof.enter_step('iowait') + except KeyError: + # Sometimes JSON records are off the place self.log.exception("Cannot parse: %s", line) - self.register(record) class DnsMassParser(Parser): # dnsmass --output Snrql # --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4 TYPES = { - 'A': (FUNCTION_MAP['a'][0], FUNCTION_MAP['a'][1], -1, None), - # 'AAAA': (FUNCTION_MAP['aaaa'][0], FUNCTION_MAP['aaaa'][1], -1, None), - 'CNAME': (FUNCTION_MAP['cname'][0], FUNCTION_MAP['cname'][1], -1, -1), + 'A': ('a', -1, None), + # 'AAAA': ('aaaa', -1, None), + 'CNAME': ('cname', -1, -1), } def consume(self) -> None: @@ -144,19 +173,19 @@ class DnsMassParser(Parser): timestamp = int(split[1]) header = False else: - select, write, name_offset, value_offset = \ + rtype, name_offset, value_offset = \ DnsMassParser.TYPES[split[1]] - record = ( - select, - write, + self.register( + rtype, timestamp, split[0][:name_offset], split[2][:value_offset], ) - self.register(record) self.prof.enter_step('parse_dnsmass') except KeyError: - continue + # Malformed records are less likely to happen, + # but we may never be sure + self.log.exception("Cannot parse: %s", line) PARSERS = { @@ -189,7 +218,7 @@ if __name__ == '__main__': args = args_parser.parse_args() recs_queue: multiprocessing.Queue = multiprocessing.Queue( - maxsize=args.queue_size) + maxsize=args.queue_size) writer = Writer(recs_queue) writer.start() diff --git a/feed_rules.py b/feed_rules.py index cca1261..7197498 100755 --- a/feed_rules.py +++ b/feed_rules.py @@ -4,11 +4,22 @@ import database import argparse import sys import time +import typing -FUNCTION_MAP = { - 'zone': database.Database.set_zone, - 'ip4network': database.Database.set_ip4network, - 'asn': database.Database.set_asn, +FUNCTION_MAP: typing.Dict[str, typing.Tuple[ + typing.Callable[[database.Database, database.Path, int], None], + typing.Callable[[str], database.Path], + ]] = { + 'hostname': (database.Database.set_hostname, + database.Database.pack_domain), + 'zone': (database.Database.set_zone, + database.Database.pack_domain), + 'asn': (database.Database.set_asn, + database.Database.pack_asn), + 'ip4address': (database.Database.set_ip4address, + database.Database.pack_ip4address), + 'ip4network': (database.Database.set_ip4network, + database.Database.pack_ip4network), } if __name__ == '__main__': @@ -30,11 +41,12 @@ if __name__ == '__main__': DB = database.Database() - fun = FUNCTION_MAP[args.type] + fun, packer = FUNCTION_MAP[args.type] for rule in args.input: + packed = packer(rule.strip()) fun(DB, - rule.strip(), + packed, # is_first_party=args.first_party, updated=int(time.time()), )