2019-12-09 08:12:48 +01:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
|
|
import database
|
|
|
|
import argparse
|
|
|
|
import sys
|
2019-12-13 00:11:21 +01:00
|
|
|
import logging
|
2019-12-13 12:36:11 +01:00
|
|
|
import threading
|
|
|
|
import queue
|
|
|
|
import typing
|
2019-12-09 08:12:48 +01:00
|
|
|
|
2019-12-13 12:36:11 +01:00
|
|
|
NUMBER_THREADS = 8
|
2019-12-09 08:12:48 +01:00
|
|
|
|
|
|
|
|
2019-12-13 12:36:11 +01:00
|
|
|
class Worker(threading.Thread):
|
|
|
|
def __init__(self,
|
|
|
|
lines_queue: queue.Queue,
|
|
|
|
write_queue: queue.Queue,
|
|
|
|
index: int = 0):
|
|
|
|
super(Worker, self).__init__()
|
|
|
|
self.log = logging.getLogger(f'worker{index:03d}')
|
|
|
|
self.lines_queue = lines_queue
|
|
|
|
self.write_queue = write_queue
|
|
|
|
self.index = index
|
2019-12-09 08:12:48 +01:00
|
|
|
|
2019-12-13 12:36:11 +01:00
|
|
|
def run(self) -> None:
|
|
|
|
self.db = database.Database(write=False)
|
|
|
|
self.db.log = logging.getLogger(f'db{self.index:03d}')
|
|
|
|
self.db.enter_step('wait_line')
|
2019-12-13 00:11:21 +01:00
|
|
|
line: str
|
2019-12-13 12:36:11 +01:00
|
|
|
for line in iter(self.lines_queue.get, None):
|
|
|
|
self.db.enter_step('feed_json_parse')
|
2019-12-13 00:11:21 +01:00
|
|
|
# split = line.split(b'"')
|
|
|
|
split = line.split('"')
|
|
|
|
try:
|
|
|
|
name = split[7]
|
|
|
|
dtype = split[11]
|
|
|
|
value = split[15]
|
|
|
|
except IndexError:
|
|
|
|
log.error("Invalid JSON: %s", line)
|
|
|
|
continue
|
|
|
|
# DB.enter_step('feed_json_assert')
|
2019-12-09 08:12:48 +01:00
|
|
|
# data = json.loads(line)
|
|
|
|
# assert dtype == data['type']
|
|
|
|
# assert name == data['name']
|
|
|
|
# assert value == data['value']
|
2019-12-13 00:11:21 +01:00
|
|
|
|
2019-12-13 12:36:11 +01:00
|
|
|
self.db.enter_step('feed_switch')
|
2019-12-13 00:11:21 +01:00
|
|
|
if dtype == 'a':
|
2019-12-13 12:36:11 +01:00
|
|
|
for rule in self.db.get_ip4(value):
|
|
|
|
self.db.enter_step('wait_put')
|
|
|
|
self.write_queue.put(
|
|
|
|
(database.Database.set_hostname, name, rule))
|
2019-12-13 00:11:21 +01:00
|
|
|
elif dtype == 'cname':
|
2019-12-13 12:36:11 +01:00
|
|
|
for rule in self.db.get_domain(value):
|
|
|
|
self.db.enter_step('wait_put')
|
|
|
|
self.write_queue.put(
|
|
|
|
(database.Database.set_hostname, name, rule))
|
2019-12-13 00:11:21 +01:00
|
|
|
elif dtype == 'ptr':
|
2019-12-13 12:36:11 +01:00
|
|
|
for rule in self.db.get_domain(value):
|
|
|
|
self.db.enter_step('wait_put')
|
|
|
|
self.write_queue.put(
|
|
|
|
(database.Database.set_ip4address, name, rule))
|
|
|
|
self.db.enter_step('wait_line')
|
|
|
|
|
|
|
|
self.db.enter_step('end')
|
|
|
|
self.write_queue.put(None)
|
|
|
|
self.db.close()
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
|
|
# Parsing arguments
|
|
|
|
log = logging.getLogger('feed_dns')
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="TODO")
|
|
|
|
parser.add_argument(
|
|
|
|
# '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer,
|
|
|
|
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
|
|
|
|
help="TODO")
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
DB = database.Database(write=False) # Not needed, just for timing
|
|
|
|
DB.log = logging.getLogger('dbf')
|
|
|
|
DBW = database.Database(write=True)
|
|
|
|
DBW.log = logging.getLogger('dbw')
|
|
|
|
|
|
|
|
lines_queue: queue.Queue = queue.Queue(maxsize=NUMBER_THREADS)
|
|
|
|
write_queue: queue.Queue = queue.Queue(maxsize=NUMBER_THREADS)
|
|
|
|
|
|
|
|
def fill_lines_queue() -> None:
|
|
|
|
DB.enter_step('iowait')
|
|
|
|
for line in args.input:
|
|
|
|
DB.enter_step('wait_put')
|
|
|
|
lines_queue.put(line)
|
2019-12-13 00:11:21 +01:00
|
|
|
DB.enter_step('iowait')
|
2019-12-09 08:12:48 +01:00
|
|
|
|
2019-12-13 12:36:11 +01:00
|
|
|
DB.enter_step('end_put')
|
|
|
|
for _ in range(NUMBER_THREADS):
|
|
|
|
lines_queue.put(None)
|
|
|
|
|
|
|
|
for w in range(NUMBER_THREADS):
|
|
|
|
Worker(lines_queue, write_queue, w).start()
|
|
|
|
|
|
|
|
threading.Thread(target=fill_lines_queue).start()
|
|
|
|
|
|
|
|
for _ in range(NUMBER_THREADS):
|
|
|
|
fun: typing.Callable
|
|
|
|
name: str
|
|
|
|
source: int
|
|
|
|
DBW.enter_step('wait_fun')
|
|
|
|
for fun, name, source in iter(write_queue.get, None):
|
|
|
|
DBW.enter_step('exec_fun')
|
|
|
|
fun(DBW, name, source=source)
|
|
|
|
DBW.enter_step('commit')
|
|
|
|
DBW.conn.commit()
|
|
|
|
DBW.enter_step('wait_fun')
|
|
|
|
|
|
|
|
DBW.enter_step('end')
|
|
|
|
|
|
|
|
DBW.close()
|
2019-12-13 00:11:21 +01:00
|
|
|
DB.close()
|