From 45325782d2c5c6dfda93d12f4468588871c5a8ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20=E2=80=9CFrogeye=E2=80=9D=20Preud=27homme?= Date: Sun, 15 Dec 2019 17:05:41 +0100 Subject: [PATCH] Multi-processed parser --- feed_dns.py | 111 +++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 83 insertions(+), 28 deletions(-) diff --git a/feed_dns.py b/feed_dns.py index b106968..c2438d8 100755 --- a/feed_dns.py +++ b/feed_dns.py @@ -5,6 +5,7 @@ import database import logging import sys import typing +import multiprocessing import enum RecordType = enum.Enum('RecordType', 'A AAAA CNAME PTR') @@ -27,27 +28,66 @@ FUNCTION_MAP: typing.Any = { } -class Parser(): - def __init__(self, buf: typing.Any) -> None: - self.buf = buf - self.log = logging.getLogger('parser') - self.db = database.Database() +class Writer(multiprocessing.Process): + def __init__(self, + recs_queue: multiprocessing.Queue, + index: int = 0): + super(Writer, self).__init__() + self.log = logging.getLogger(f'wr') + self.recs_queue = recs_queue - def end(self) -> None: + def run(self) -> None: + self.db = database.Database() + self.db.log = logging.getLogger(f'wr') + + self.db.enter_step('block_wait') + block: typing.List[Record] + for block in iter(self.recs_queue.get, None): + + record: Record + for record in block: + + rtype, updated, name, value = record + self.db.enter_step('feed_switch') + + select, write = FUNCTION_MAP[rtype] + for source in select(self.db, value): + # write(self.db, name, updated, source=source) + write(self.db, name, updated) + + self.db.enter_step('block_wait') + + self.db.enter_step('end') self.db.save() - def register(self, - rtype: RecordType, - updated: int, - name: str, - value: str - ) -> None: - self.db.enter_step('register') - select, write = FUNCTION_MAP[rtype] - for source in select(self.db, value): - # write(self.db, name, updated, source=source) - write(self.db, name, updated) +class Parser(): + def __init__(self, + buf: typing.Any, + recs_queue: multiprocessing.Queue, + block_size: int, + ): + super(Parser, self).__init__() + self.buf = buf + self.log = logging.getLogger('pr') + self.recs_queue = recs_queue + self.block: typing.List[Record] = list() + self.block_size = block_size + self.prof = database.Profiler() + self.prof.log = logging.getLogger('pr') + + def register(self, record: Record) -> None: + self.prof.enter_step('register') + self.block.append(record) + if len(self.block) >= self.block_size: + self.prof.enter_step('put_block') + self.recs_queue.put(self.block) + self.block = list() + + def run(self) -> None: + self.consume() + self.recs_queue.put(self.block) + self.prof.profile() def consume(self) -> None: raise NotImplementedError @@ -64,7 +104,7 @@ class Rapid7Parser(Parser): def consume(self) -> None: data = dict() for line in self.buf: - self.db.enter_step('parse_rapid7') + self.prof.enter_step('parse_rapid7') split = line.split('"') for k in range(1, 14, 4): @@ -72,12 +112,13 @@ class Rapid7Parser(Parser): val = split[k+2] data[key] = val - self.register( + record = ( Rapid7Parser.TYPES[data['type']], int(data['timestamp']), data['name'], data['value'] ) + self.register(record) class DnsMassParser(Parser): @@ -90,7 +131,7 @@ class DnsMassParser(Parser): } def consume(self) -> None: - self.db.enter_step('parse_dnsmass') + self.prof.enter_step('parse_dnsmass') timestamp = 0 header = True for line in self.buf: @@ -107,13 +148,14 @@ class DnsMassParser(Parser): else: dtype, name_offset, value_offset = \ DnsMassParser.TYPES[split[1]] - self.register( + record = ( dtype, timestamp, split[0][:name_offset], split[2][:value_offset], ) - self.db.enter_step('parse_dnsmass') + self.register(record) + self.prof.enter_step('parse_dnsmass') except KeyError: continue @@ -136,12 +178,25 @@ if __name__ == '__main__': args_parser.add_argument( '-i', '--input', type=argparse.FileType('r'), default=sys.stdin, help="TODO") + args_parser.add_argument( + '-j', '--workers', type=int, default=4, + help="TODO") + args_parser.add_argument( + '-b', '--block-size', type=int, default=100, + help="TODO") + args_parser.add_argument( + '-q', '--queue-size', type=int, default=10, + help="TODO") args = args_parser.parse_args() - parser = PARSERS[args.parser](args.input) - try: - parser.consume() - except KeyboardInterrupt: - pass - parser.end() + recs_queue: multiprocessing.Queue = multiprocessing.Queue( + maxsize=args.queue_size) + writer = Writer(recs_queue) + writer.start() + + parser = PARSERS[args.parser](args.input, recs_queue, args.block_size) + parser.run() + + recs_queue.put(None) + writer.join()