eulaurarien/feed_dns.py

122 lines
3.8 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
import database
import argparse
import sys
import logging
2019-12-13 12:36:11 +01:00
import threading
import queue
import typing
2019-12-13 12:36:11 +01:00
NUMBER_THREADS = 8
2019-12-13 12:36:11 +01:00
class Worker(threading.Thread):
def __init__(self,
lines_queue: queue.Queue,
write_queue: queue.Queue,
index: int = 0):
super(Worker, self).__init__()
self.log = logging.getLogger(f'worker{index:03d}')
self.lines_queue = lines_queue
self.write_queue = write_queue
self.index = index
2019-12-13 12:36:11 +01:00
def run(self) -> None:
self.db = database.Database(write=False)
self.db.log = logging.getLogger(f'db{self.index:03d}')
self.db.enter_step('wait_line')
line: str
2019-12-13 12:36:11 +01:00
for line in iter(self.lines_queue.get, None):
self.db.enter_step('feed_json_parse')
# split = line.split(b'"')
split = line.split('"')
try:
name = split[7]
dtype = split[11]
value = split[15]
except IndexError:
log.error("Invalid JSON: %s", line)
continue
# DB.enter_step('feed_json_assert')
# data = json.loads(line)
# assert dtype == data['type']
# assert name == data['name']
# assert value == data['value']
2019-12-13 12:36:11 +01:00
self.db.enter_step('feed_switch')
if dtype == 'a':
2019-12-13 12:36:11 +01:00
for rule in self.db.get_ip4(value):
self.db.enter_step('wait_put')
self.write_queue.put(
(database.Database.set_hostname, name, rule))
elif dtype == 'cname':
2019-12-13 12:36:11 +01:00
for rule in self.db.get_domain(value):
self.db.enter_step('wait_put')
self.write_queue.put(
(database.Database.set_hostname, name, rule))
elif dtype == 'ptr':
2019-12-13 12:36:11 +01:00
for rule in self.db.get_domain(value):
self.db.enter_step('wait_put')
self.write_queue.put(
(database.Database.set_ip4address, name, rule))
self.db.enter_step('wait_line')
self.db.enter_step('end')
self.write_queue.put(None)
self.db.close()
if __name__ == '__main__':
# Parsing arguments
log = logging.getLogger('feed_dns')
parser = argparse.ArgumentParser(
description="TODO")
parser.add_argument(
# '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer,
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO")
args = parser.parse_args()
DB = database.Database(write=False) # Not needed, just for timing
DB.log = logging.getLogger('dbf')
DBW = database.Database(write=True)
DBW.log = logging.getLogger('dbw')
lines_queue: queue.Queue = queue.Queue(maxsize=NUMBER_THREADS)
write_queue: queue.Queue = queue.Queue(maxsize=NUMBER_THREADS)
def fill_lines_queue() -> None:
DB.enter_step('iowait')
for line in args.input:
DB.enter_step('wait_put')
lines_queue.put(line)
DB.enter_step('iowait')
2019-12-13 12:36:11 +01:00
DB.enter_step('end_put')
for _ in range(NUMBER_THREADS):
lines_queue.put(None)
for w in range(NUMBER_THREADS):
Worker(lines_queue, write_queue, w).start()
threading.Thread(target=fill_lines_queue).start()
for _ in range(NUMBER_THREADS):
fun: typing.Callable
name: str
source: int
DBW.enter_step('wait_fun')
for fun, name, source in iter(write_queue.get, None):
DBW.enter_step('exec_fun')
fun(DBW, name, source=source)
DBW.enter_step('commit')
DBW.conn.commit()
DBW.enter_step('wait_fun')
DBW.enter_step('end')
DBW.close()
DB.close()