From 231bb83667710a2c110f60e9dd48028d4743de3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20=E2=80=9CFrogeye=E2=80=9D=20Preud=27homme?= Date: Fri, 13 Dec 2019 12:36:11 +0100 Subject: [PATCH] Threaded feed_dns Largely disapointing --- feed_dns.py | 117 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 90 insertions(+), 27 deletions(-) mode change 100755 => 100644 feed_dns.py diff --git a/feed_dns.py b/feed_dns.py old mode 100755 new mode 100644 index 2993d6d..fed322d --- a/feed_dns.py +++ b/feed_dns.py @@ -4,27 +4,31 @@ import database import argparse import sys import logging +import threading +import queue +import typing -if __name__ == '__main__': +NUMBER_THREADS = 8 - # Parsing arguments - log = logging.getLogger('feed_dns') - parser = argparse.ArgumentParser( - description="TODO") - parser.add_argument( - # '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer, - '-i', '--input', type=argparse.FileType('r'), default=sys.stdin, - help="TODO") - args = parser.parse_args() - DB = database.Database(write=True) +class Worker(threading.Thread): + def __init__(self, + lines_queue: queue.Queue, + write_queue: queue.Queue, + index: int = 0): + super(Worker, self).__init__() + self.log = logging.getLogger(f'worker{index:03d}') + self.lines_queue = lines_queue + self.write_queue = write_queue + self.index = index - try: - DB.enter_step('iowait') - # line: bytes + def run(self) -> None: + self.db = database.Database(write=False) + self.db.log = logging.getLogger(f'db{self.index:03d}') + self.db.enter_step('wait_line') line: str - for line in args.input: - DB.enter_step('feed_json_parse') + for line in iter(self.lines_queue.get, None): + self.db.enter_step('feed_json_parse') # split = line.split(b'"') split = line.split('"') try: @@ -40,19 +44,78 @@ if __name__ == '__main__': # assert name == data['name'] # assert value == data['value'] - DB.enter_step('feed_switch') + self.db.enter_step('feed_switch') if dtype == 'a': - for rule in DB.get_ip4(value): - DB.set_hostname(name, source=rule) + for rule in self.db.get_ip4(value): + self.db.enter_step('wait_put') + self.write_queue.put( + (database.Database.set_hostname, name, rule)) elif dtype == 'cname': - for rule in DB.get_domain(value): - DB.set_hostname(name, source=rule) + for rule in self.db.get_domain(value): + self.db.enter_step('wait_put') + self.write_queue.put( + (database.Database.set_hostname, name, rule)) elif dtype == 'ptr': - for rule in DB.get_domain(value): - DB.set_ip4address(name, source=rule) - DB.enter_step('iowait') - except KeyboardInterrupt: - log.warning("Interupted.") - pass + for rule in self.db.get_domain(value): + self.db.enter_step('wait_put') + self.write_queue.put( + (database.Database.set_ip4address, name, rule)) + self.db.enter_step('wait_line') + self.db.enter_step('end') + self.write_queue.put(None) + self.db.close() + + +if __name__ == '__main__': + + # Parsing arguments + log = logging.getLogger('feed_dns') + parser = argparse.ArgumentParser( + description="TODO") + parser.add_argument( + # '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer, + '-i', '--input', type=argparse.FileType('r'), default=sys.stdin, + help="TODO") + args = parser.parse_args() + + DB = database.Database(write=False) # Not needed, just for timing + DB.log = logging.getLogger('dbf') + DBW = database.Database(write=True) + DBW.log = logging.getLogger('dbw') + + lines_queue: queue.Queue = queue.Queue(maxsize=NUMBER_THREADS) + write_queue: queue.Queue = queue.Queue(maxsize=NUMBER_THREADS) + + def fill_lines_queue() -> None: + DB.enter_step('iowait') + for line in args.input: + DB.enter_step('wait_put') + lines_queue.put(line) + DB.enter_step('iowait') + + DB.enter_step('end_put') + for _ in range(NUMBER_THREADS): + lines_queue.put(None) + + for w in range(NUMBER_THREADS): + Worker(lines_queue, write_queue, w).start() + + threading.Thread(target=fill_lines_queue).start() + + for _ in range(NUMBER_THREADS): + fun: typing.Callable + name: str + source: int + DBW.enter_step('wait_fun') + for fun, name, source in iter(write_queue.get, None): + DBW.enter_step('exec_fun') + fun(DBW, name, source=source) + DBW.enter_step('commit') + DBW.conn.commit() + DBW.enter_step('wait_fun') + + DBW.enter_step('end') + + DBW.close() DB.close()