#!/usr/bin/env python3 import database import argparse import sys import logging import threading import queue import typing NUMBER_THREADS = 8 class Worker(threading.Thread): def __init__(self, lines_queue: queue.Queue, write_queue: queue.Queue, index: int = 0): super(Worker, self).__init__() self.log = logging.getLogger(f'worker{index:03d}') self.lines_queue = lines_queue self.write_queue = write_queue self.index = index def run(self) -> None: self.db = database.Database(write=False) self.db.log = logging.getLogger(f'db{self.index:03d}') self.db.enter_step('wait_line') line: str for line in iter(self.lines_queue.get, None): self.db.enter_step('feed_json_parse') # split = line.split(b'"') split = line.split('"') try: name = split[7] dtype = split[11] value = split[15] except IndexError: log.error("Invalid JSON: %s", line) continue # DB.enter_step('feed_json_assert') # data = json.loads(line) # assert dtype == data['type'] # assert name == data['name'] # assert value == data['value'] self.db.enter_step('feed_switch') if dtype == 'a': for rule in self.db.get_ip4(value): self.db.enter_step('wait_put') self.write_queue.put( (database.Database.set_hostname, name, rule)) elif dtype == 'cname': for rule in self.db.get_domain(value): self.db.enter_step('wait_put') self.write_queue.put( (database.Database.set_hostname, name, rule)) elif dtype == 'ptr': for rule in self.db.get_domain(value): self.db.enter_step('wait_put') self.write_queue.put( (database.Database.set_ip4address, name, rule)) self.db.enter_step('wait_line') self.db.enter_step('end') self.write_queue.put(None) self.db.close() if __name__ == '__main__': # Parsing arguments log = logging.getLogger('feed_dns') parser = argparse.ArgumentParser( description="TODO") parser.add_argument( # '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer, '-i', '--input', type=argparse.FileType('r'), default=sys.stdin, help="TODO") args = parser.parse_args() DB = database.Database(write=False) # Not needed, just for timing DB.log = logging.getLogger('dbf') DBW = database.Database(write=True) DBW.log = logging.getLogger('dbw') lines_queue: queue.Queue = queue.Queue(maxsize=NUMBER_THREADS) write_queue: queue.Queue = queue.Queue(maxsize=NUMBER_THREADS) def fill_lines_queue() -> None: DB.enter_step('iowait') for line in args.input: DB.enter_step('wait_put') lines_queue.put(line) DB.enter_step('iowait') DB.enter_step('end_put') for _ in range(NUMBER_THREADS): lines_queue.put(None) for w in range(NUMBER_THREADS): Worker(lines_queue, write_queue, w).start() threading.Thread(target=fill_lines_queue).start() for _ in range(NUMBER_THREADS): fun: typing.Callable name: str source: int DBW.enter_step('wait_fun') for fun, name, source in iter(write_queue.get, None): DBW.enter_step('exec_fun') fun(DBW, name, source=source) DBW.enter_step('commit') DBW.conn.commit() DBW.enter_step('wait_fun') DBW.enter_step('end') DBW.close() DB.close()