#!/usr/bin/env python3 import argparse import database import json import logging import sys import typing import multiprocessing import enum RecordType = enum.Enum('RecordType', 'A AAAA CNAME PTR') Record = typing.Tuple[RecordType, int, str, str] # select, confirm, write FUNCTION_MAP: typing.Any = { RecordType.A: ( database.Database.get_ip4, database.Database.get_domain_in_zone, database.Database.set_hostname, ), RecordType.CNAME: ( database.Database.get_domain, database.Database.get_domain_in_zone, database.Database.set_hostname, ), RecordType.PTR: ( database.Database.get_domain, database.Database.get_ip4_in_network, database.Database.set_ip4address, ), } class Reader(multiprocessing.Process): def __init__(self, recs_queue: multiprocessing.Queue, write_queue: multiprocessing.Queue, index: int = 0): super(Reader, self).__init__() self.log = logging.getLogger(f'rd{index:03d}') self.recs_queue = recs_queue self.write_queue = write_queue self.index = index def run(self) -> None: self.db = database.Database(write=False) self.db.log = logging.getLogger(f'db{self.index:03d}') self.db.enter_step('line_wait') block: typing.List[str] try: for block in iter(self.recs_queue.get, None): record: Record for record in block: # print(55, record) dtype, updated, name, value = record self.db.enter_step('feed_switch') select, confirm, write = FUNCTION_MAP[dtype] for rule in select(self.db, value): # print(60, rule, list(confirm(self.db, name))) if not any(confirm(self.db, name)): # print(62, write, name, updated, rule) self.db.enter_step('wait_put') self.write_queue.put((write, name, updated, rule)) self.db.enter_step('line_wait') except KeyboardInterrupt: self.log.error('Interrupted') self.db.enter_step('end') self.db.close() class Writer(multiprocessing.Process): def __init__(self, write_queue: multiprocessing.Queue, ): super(Writer, self).__init__() self.log = logging.getLogger(f'wr ') self.write_queue = write_queue def run(self) -> None: self.db = database.Database(write=True) self.db.log = logging.getLogger(f'dbw ') self.db.enter_step('line_wait') block: typing.List[str] try: fun: typing.Callable name: str updated: int source: int for fun, name, updated, source in iter(self.write_queue.get, None): self.db.enter_step('exec') fun(self.db, name, updated, source=source) self.db.enter_step('line_wait') except KeyboardInterrupt: self.log.error('Interrupted') self.db.enter_step('end') self.db.close() class Parser(): def __init__(self, buf: typing.Any, recs_queue: multiprocessing.Queue, block_size: int, ): super(Parser, self).__init__() self.buf = buf self.log = logging.getLogger('pr ') self.recs_queue = recs_queue self.block: typing.List[Record] = list() self.block_size = block_size self.db = database.Database() # Just for timing self.db.log = logging.getLogger('pr ') def register(self, record: Record) -> None: self.db.enter_step('register') self.block.append(record) if len(self.block) >= self.block_size: self.db.enter_step('put_block') self.recs_queue.put(self.block) self.block = list() def run(self) -> None: self.consume() self.recs_queue.put(self.block) self.db.close() def consume(self) -> None: raise NotImplementedError class Rapid7Parser(Parser): TYPES = { 'a': RecordType.A, 'aaaa': RecordType.AAAA, 'cname': RecordType.CNAME, 'ptr': RecordType.PTR, } def consume(self) -> None: for line in self.buf: self.db.enter_step('parse_rapid7') try: data = json.loads(line) except json.decoder.JSONDecodeError: continue record = ( Rapid7Parser.TYPES[data['type']], int(data['timestamp']), data['name'], data['value'] ) self.register(record) class DnsMassParser(Parser): # dnsmass --output Snrql # --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4 TYPES = { 'A': (RecordType.A, -1, None), 'AAAA': (RecordType.AAAA, -1, None), 'CNAME': (RecordType.CNAME, -1, -1), } def consume(self) -> None: self.db.enter_step('parse_dnsmass') timestamp = 0 header = True for line in self.buf: line = line[:-1] if not line: header = True continue split = line.split(' ') try: if header: timestamp = int(split[1]) header = False else: dtype, name_offset, value_offset = \ DnsMassParser.TYPES[split[1]] record = ( dtype, timestamp, split[0][:name_offset], split[2][:value_offset], ) self.register(record) self.db.enter_step('parse_dnsmass') except KeyError: continue PARSERS = { 'rapid7': Rapid7Parser, 'dnsmass': DnsMassParser, } if __name__ == '__main__': # Parsing arguments log = logging.getLogger('feed_dns') args_parser = argparse.ArgumentParser( description="TODO") args_parser.add_argument( 'parser', choices=PARSERS.keys(), help="TODO") args_parser.add_argument( '-i', '--input', type=argparse.FileType('r'), default=sys.stdin, help="TODO") args_parser.add_argument( '-j', '--workers', type=int, default=4, help="TODO") args_parser.add_argument( '-b', '--block-size', type=int, default=100, help="TODO") args = args_parser.parse_args() DB = database.Database(write=False) # Not needed, just for timing DB.log = logging.getLogger('db ') recs_queue: multiprocessing.Queue = multiprocessing.Queue( maxsize=10*args.workers) write_queue: multiprocessing.Queue = multiprocessing.Queue( maxsize=10*args.workers) DB.enter_step('proc_create') readers: typing.List[Reader] = list() for w in range(args.workers): readers.append(Reader(recs_queue, write_queue, w)) writer = Writer(write_queue) parser = PARSERS[args.parser]( args.input, recs_queue, args.block_size) DB.enter_step('proc_start') for reader in readers: reader.start() writer.start() try: DB.enter_step('parser_run') parser.run() DB.enter_step('end_put') for _ in range(args.workers): recs_queue.put(None) write_queue.put(None) DB.enter_step('proc_join') for reader in readers: reader.join() writer.join() except KeyboardInterrupt: log.error('Interrupted') DB.close()