eulaurarien/feed_dns.py

261 lines
7.7 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import database
import json
import logging
import sys
import typing
import multiprocessing
import enum
RecordType = enum.Enum('RecordType', 'A AAAA CNAME PTR')
Record = typing.Tuple[RecordType, int, str, str]
# select, confirm, write
FUNCTION_MAP: typing.Any = {
RecordType.A: (
database.Database.get_ip4,
database.Database.get_domain_in_zone,
database.Database.set_hostname,
),
RecordType.CNAME: (
database.Database.get_domain,
database.Database.get_domain_in_zone,
database.Database.set_hostname,
),
RecordType.PTR: (
database.Database.get_domain,
database.Database.get_ip4_in_network,
database.Database.set_ip4address,
),
}
class Reader(multiprocessing.Process):
def __init__(self,
recs_queue: multiprocessing.Queue,
write_queue: multiprocessing.Queue,
index: int = 0):
super(Reader, self).__init__()
self.log = logging.getLogger(f'rd{index:03d}')
self.recs_queue = recs_queue
self.write_queue = write_queue
self.index = index
def run(self) -> None:
self.db = database.Database(write=False)
self.db.log = logging.getLogger(f'db{self.index:03d}')
self.db.enter_step('line_wait')
block: typing.List[str]
try:
for block in iter(self.recs_queue.get, None):
record: Record
for record in block:
# print(55, record)
dtype, updated, name, value = record
self.db.enter_step('feed_switch')
select, confirm, write = FUNCTION_MAP[dtype]
for rule in select(self.db, value):
# print(60, rule, list(confirm(self.db, name)))
if not any(confirm(self.db, name)):
# print(62, write, name, updated, rule)
self.db.enter_step('wait_put')
self.write_queue.put((write, name, updated, rule))
self.db.enter_step('line_wait')
except KeyboardInterrupt:
self.log.error('Interrupted')
self.db.enter_step('end')
self.db.close()
class Writer(multiprocessing.Process):
def __init__(self,
write_queue: multiprocessing.Queue,
):
super(Writer, self).__init__()
self.log = logging.getLogger(f'wr ')
self.write_queue = write_queue
def run(self) -> None:
self.db = database.Database(write=True)
self.db.log = logging.getLogger(f'dbw ')
self.db.enter_step('line_wait')
block: typing.List[str]
try:
fun: typing.Callable
name: str
updated: int
source: int
for fun, name, updated, source in iter(self.write_queue.get, None):
self.db.enter_step('exec')
fun(self.db, name, updated, source=source)
self.db.enter_step('line_wait')
except KeyboardInterrupt:
self.log.error('Interrupted')
self.db.enter_step('end')
self.db.close()
class Parser():
def __init__(self,
buf: typing.Any,
recs_queue: multiprocessing.Queue,
block_size: int,
):
super(Parser, self).__init__()
self.buf = buf
self.log = logging.getLogger('pr ')
self.recs_queue = recs_queue
self.block: typing.List[Record] = list()
self.block_size = block_size
self.db = database.Database() # Just for timing
self.db.log = logging.getLogger('pr ')
def register(self, record: Record) -> None:
self.db.enter_step('register')
self.block.append(record)
if len(self.block) >= self.block_size:
self.db.enter_step('put_block')
self.recs_queue.put(self.block)
self.block = list()
def run(self) -> None:
self.consume()
self.recs_queue.put(self.block)
self.db.close()
def consume(self) -> None:
raise NotImplementedError
class Rapid7Parser(Parser):
TYPES = {
'a': RecordType.A,
'aaaa': RecordType.AAAA,
'cname': RecordType.CNAME,
'ptr': RecordType.PTR,
}
def consume(self) -> None:
for line in self.buf:
self.db.enter_step('parse_rapid7')
try:
data = json.loads(line)
except json.decoder.JSONDecodeError:
continue
record = (
Rapid7Parser.TYPES[data['type']],
int(data['timestamp']),
data['name'],
data['value']
)
self.register(record)
class DnsMassParser(Parser):
# dnsmass --output Snrql
# --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4
TYPES = {
'A': (RecordType.A, -1, None),
'AAAA': (RecordType.AAAA, -1, None),
'CNAME': (RecordType.CNAME, -1, -1),
}
def consume(self) -> None:
self.db.enter_step('parse_dnsmass')
timestamp = 0
header = True
for line in self.buf:
line = line[:-1]
if not line:
header = True
continue
split = line.split(' ')
try:
if header:
timestamp = int(split[1])
header = False
else:
dtype, name_offset, value_offset = \
DnsMassParser.TYPES[split[1]]
record = (
dtype,
timestamp,
split[0][:name_offset],
split[2][:value_offset],
)
self.register(record)
self.db.enter_step('parse_dnsmass')
except KeyError:
continue
PARSERS = {
'rapid7': Rapid7Parser,
'dnsmass': DnsMassParser,
}
if __name__ == '__main__':
# Parsing arguments
log = logging.getLogger('feed_dns')
args_parser = argparse.ArgumentParser(
description="TODO")
args_parser.add_argument(
'parser',
choices=PARSERS.keys(),
help="TODO")
args_parser.add_argument(
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO")
args_parser.add_argument(
'-j', '--workers', type=int, default=4,
help="TODO")
args_parser.add_argument(
'-b', '--block-size', type=int, default=100,
help="TODO")
args = args_parser.parse_args()
DB = database.Database(write=False) # Not needed, just for timing
DB.log = logging.getLogger('db ')
recs_queue: multiprocessing.Queue = multiprocessing.Queue(
maxsize=10*args.workers)
write_queue: multiprocessing.Queue = multiprocessing.Queue(
maxsize=10*args.workers)
DB.enter_step('proc_create')
readers: typing.List[Reader] = list()
for w in range(args.workers):
readers.append(Reader(recs_queue, write_queue, w))
writer = Writer(write_queue)
parser = PARSERS[args.parser](
args.input, recs_queue, args.block_size)
DB.enter_step('proc_start')
for reader in readers:
reader.start()
writer.start()
try:
DB.enter_step('parser_run')
parser.run()
DB.enter_step('end_put')
for _ in range(args.workers):
recs_queue.put(None)
write_queue.put(None)
DB.enter_step('proc_join')
for reader in readers:
reader.join()
writer.join()
except KeyboardInterrupt:
log.error('Interrupted')
DB.close()