diff --git a/database.py b/database.py index ee51829..19fbe97 100755 --- a/database.py +++ b/database.py @@ -284,14 +284,18 @@ class Database(): 'UNION ' 'SELECT * FROM (' 'SELECT val, entry FROM zone ' + # 'WHERE val>=:d ' + # 'ORDER BY val ASC LIMIT 1' 'WHERE val<=:d ' - 'ORDER BY val DESC LIMIT 1' + 'AND instr(:d, val) = 1' ')', {'d': domain_prep} ) for val, entry in cursor: + # print(293, val, entry) self.enter_step('get_domain_confirm') if not (val is None or domain_prep.startswith(val)): + # print(297) continue self.enter_step('get_domain_yield') yield entry diff --git a/feed_dns.py b/feed_dns.py index 585a211..4b01814 100755 --- a/feed_dns.py +++ b/feed_dns.py @@ -7,23 +7,24 @@ import logging import sys import typing import multiprocessing +import enum -NUMBER_THREADS = 2 -BLOCK_SIZE = 100 +RecordType = enum.Enum('RecordType', 'A AAAA CNAME PTR') +Record = typing.Tuple[RecordType, int, str, str] # select, confirm, write FUNCTION_MAP: typing.Any = { - 'a': ( + RecordType.A: ( database.Database.get_ip4, database.Database.get_domain_in_zone, database.Database.set_hostname, ), - 'cname': ( + RecordType.CNAME: ( database.Database.get_domain, database.Database.get_domain_in_zone, database.Database.set_hostname, ), - 'ptr': ( + RecordType.PTR: ( database.Database.get_domain, database.Database.get_ip4_in_network, database.Database.set_ip4address, @@ -33,12 +34,12 @@ FUNCTION_MAP: typing.Any = { class Reader(multiprocessing.Process): def __init__(self, - lines_queue: multiprocessing.Queue, + recs_queue: multiprocessing.Queue, write_queue: multiprocessing.Queue, index: int = 0): super(Reader, self).__init__() self.log = logging.getLogger(f'rd{index:03d}') - self.lines_queue = lines_queue + self.recs_queue = recs_queue self.write_queue = write_queue self.index = index @@ -48,15 +49,19 @@ class Reader(multiprocessing.Process): self.db.enter_step('line_wait') block: typing.List[str] try: - for block in iter(self.lines_queue.get, None): - for line in block: - dtype, updated, name, value = line + for block in iter(self.recs_queue.get, None): + record: Record + for record in block: + # print(55, record) + dtype, updated, name, value = record self.db.enter_step('feed_switch') select, confirm, write = FUNCTION_MAP[dtype] for rule in select(self.db, value): + # print(60, rule, list(confirm(self.db, name))) if not any(confirm(self.db, name)): + # print(62, write, name, updated, rule) self.db.enter_step('wait_put') - self.write_queue.put((write, name, updated)) + self.write_queue.put((write, name, updated, rule)) self.db.enter_step('line_wait') except KeyboardInterrupt: self.log.error('Interrupted') @@ -82,9 +87,10 @@ class Writer(multiprocessing.Process): fun: typing.Callable name: str updated: int - for fun, name, updated in iter(self.write_queue.get, None): + source: int + for fun, name, updated, source in iter(self.write_queue.get, None): self.db.enter_step('exec') - fun(self.db, name, updated) + fun(self.db, name, updated, source=source) self.db.enter_step('line_wait') except KeyboardInterrupt: self.log.error('Interrupted') @@ -93,29 +99,142 @@ class Writer(multiprocessing.Process): self.db.close() +class Parser(): + def __init__(self, + buf: typing.Any, + recs_queue: multiprocessing.Queue, + block_size: int, + ): + super(Parser, self).__init__() + self.buf = buf + self.log = logging.getLogger('pr ') + self.recs_queue = recs_queue + self.block: typing.List[Record] = list() + self.block_size = block_size + self.db = database.Database() # Just for timing + self.db.log = logging.getLogger('pr ') + + def register(self, record: Record) -> None: + self.db.enter_step('register') + self.block.append(record) + if len(self.block) >= self.block_size: + self.db.enter_step('put_block') + self.recs_queue.put(self.block) + self.block = list() + + def run(self) -> None: + self.consume() + self.recs_queue.put(self.block) + self.db.close() + + def consume(self) -> None: + raise NotImplementedError + + +class Rapid7Parser(Parser): + TYPES = { + 'a': RecordType.A, + 'aaaa': RecordType.AAAA, + 'cname': RecordType.CNAME, + 'ptr': RecordType.PTR, + } + + def consume(self) -> None: + for line in self.buf: + self.db.enter_step('parse_rapid7') + try: + data = json.loads(line) + except json.decoder.JSONDecodeError: + continue + record = ( + Rapid7Parser.TYPES[data['type']], + int(data['timestamp']), + data['name'], + data['value'] + ) + self.register(record) + + +class DnsMassParser(Parser): + # dnsmass --output Snrql + # --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4 + TYPES = { + 'A': (RecordType.A, -1, None), + 'AAAA': (RecordType.AAAA, -1, None), + 'CNAME': (RecordType.CNAME, -1, -1), + } + + def consume(self) -> None: + self.db.enter_step('parse_dnsmass') + timestamp = 0 + header = True + for line in self.buf: + line = line[:-1] + if not line: + header = True + continue + + split = line.split(' ') + try: + if header: + timestamp = int(split[1]) + header = False + else: + dtype, name_offset, value_offset = \ + DnsMassParser.TYPES[split[1]] + record = ( + dtype, + timestamp, + split[0][:name_offset], + split[2][:value_offset], + ) + self.register(record) + self.db.enter_step('parse_dnsmass') + except KeyError: + continue + + +PARSERS = { + 'rapid7': Rapid7Parser, + 'dnsmass': DnsMassParser, +} + if __name__ == '__main__': # Parsing arguments log = logging.getLogger('feed_dns') - parser = argparse.ArgumentParser( + args_parser = argparse.ArgumentParser( description="TODO") - parser.add_argument( - # '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer, + args_parser.add_argument( + 'parser', + choices=PARSERS.keys(), + help="TODO") + args_parser.add_argument( '-i', '--input', type=argparse.FileType('r'), default=sys.stdin, help="TODO") - args = parser.parse_args() + args_parser.add_argument( + '-j', '--workers', type=int, default=4, + help="TODO") + args_parser.add_argument( + '-b', '--block-size', type=int, default=100, + help="TODO") + args = args_parser.parse_args() DB = database.Database(write=False) # Not needed, just for timing DB.log = logging.getLogger('db ') - lines_queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=100) - write_queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=100) + recs_queue: multiprocessing.Queue = multiprocessing.Queue( + maxsize=10*args.workers) + write_queue: multiprocessing.Queue = multiprocessing.Queue( + maxsize=10*args.workers) DB.enter_step('proc_create') readers: typing.List[Reader] = list() - for w in range(NUMBER_THREADS): - readers.append(Reader(lines_queue, write_queue, w)) + for w in range(args.workers): + readers.append(Reader(recs_queue, write_queue, w)) writer = Writer(write_queue) + parser = PARSERS[args.parser]( + args.input, recs_queue, args.block_size) DB.enter_step('proc_start') for reader in readers: @@ -123,28 +242,12 @@ if __name__ == '__main__': writer.start() try: - block: typing.List[str] = list() - DB.enter_step('iowait') - for line in args.input: - DB.enter_step('block_append') - DB.enter_step('feed_json_parse') - data = json.loads(line) - line = (data['type'], - int(data['timestamp']), - data['name'], - data['value']) - block.append(line) - if len(block) >= BLOCK_SIZE: - DB.enter_step('wait_put') - lines_queue.put(block) - block = list() - DB.enter_step('iowait') - DB.enter_step('wait_put') - lines_queue.put(block) + DB.enter_step('parser_run') + parser.run() DB.enter_step('end_put') - for _ in range(NUMBER_THREADS): - lines_queue.put(None) + for _ in range(args.workers): + recs_queue.put(None) write_queue.put(None) DB.enter_step('proc_join') diff --git a/json_to_csv.py b/json_to_csv.py deleted file mode 100755 index 39ca1b7..0000000 --- a/json_to_csv.py +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import sys -import logging -import json -import csv - -if __name__ == '__main__': - - # Parsing arguments - log = logging.getLogger('json_to_csv') - parser = argparse.ArgumentParser( - description="TODO") - parser.add_argument( - # '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer, - '-i', '--input', type=argparse.FileType('r'), default=sys.stdin, - help="TODO") - parser.add_argument( - # '-i', '--output', type=argparse.FileType('wb'), default=sys.stdout.buffer, - '-o', '--output', type=argparse.FileType('w'), default=sys.stdout, - help="TODO") - args = parser.parse_args() - - writer = csv.writer(args.output) - for line in args.input: - data = json.loads(line) - try: - writer.writerow([ - data['type'][0], # First letter, will need to do something special for AAAA - data['timestamp'], - data['name'], - data['value']]) - except (KeyError, json.decoder.JSONDecodeError): - log.error('Could not parse line: %s', line) - pass diff --git a/new_workflow.sh b/new_workflow.sh index e21b426..c98cd46 100755 --- a/new_workflow.sh +++ b/new_workflow.sh @@ -9,11 +9,11 @@ function log() { # TODO Fetch 'em log "Reading PTR records…" -pv ptr.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py +pv ptr.json.gz | gunzip | ./feed_dns.py log "Reading A records…" -pv a.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py +pv a.json.gz | gunzip | ./feed_dns.py log "Reading CNAME records…" -pv cname.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py +pv cname.json.gz | gunzip | ./feed_dns.py log "Pruning old data…" ./database.py --prune diff --git a/resolve_subdomains.py b/resolve_subdomains.py deleted file mode 100755 index bc26e34..0000000 --- a/resolve_subdomains.py +++ /dev/null @@ -1,264 +0,0 @@ -#!/usr/bin/env python3 - -""" -From a list of subdomains, output only -the ones resolving to a first-party tracker. -""" - -import argparse -import logging -import os -import queue -import sys -import threading -import typing -import time - -import coloredlogs -import dns.exception -import dns.resolver - -DNS_TIMEOUT = 5.0 -NUMBER_TRIES = 5 - - -class Worker(threading.Thread): - """ - Worker process for a DNS resolver. - Will resolve DNS to match first-party subdomains. - """ - - def change_nameserver(self) -> None: - """ - Assign a this worker another nameserver from the queue. - """ - server = None - while server is None: - try: - server = self.orchestrator.nameservers_queue.get(block=False) - except queue.Empty: - self.orchestrator.refill_nameservers_queue() - self.log.info("Using nameserver: %s", server) - self.resolver.nameservers = [server] - - def __init__(self, - orchestrator: 'Orchestrator', - index: int = 0): - super(Worker, self).__init__() - self.log = logging.getLogger(f'worker{index:03d}') - self.orchestrator = orchestrator - - self.resolver = dns.resolver.Resolver() - self.change_nameserver() - - def resolve_subdomain(self, subdomain: str) -> typing.Optional[ - typing.List[ - dns.rrset.RRset - ] - ]: - """ - Returns the resolution chain of the subdomain to an A record, - including any intermediary CNAME. - The last element is an IP address. - Returns None if the nameserver was unable to satisfy the request. - Returns [] if the requests points to nothing. - """ - self.log.debug("Querying %s", subdomain) - try: - query = self.resolver.query(subdomain, 'A', lifetime=DNS_TIMEOUT) - except dns.resolver.NXDOMAIN: - return [] - except dns.resolver.NoAnswer: - return [] - except dns.resolver.YXDOMAIN: - self.log.warning("Query name too long for %s", subdomain) - return None - except dns.resolver.NoNameservers: - # NOTE Most of the time this error message means that the domain - # does not exists, but sometimes it means the that the server - # itself is broken. So we count on the retry logic. - self.log.warning("All nameservers broken for %s", subdomain) - return None - except dns.exception.Timeout: - # NOTE Same as above - self.log.warning("Timeout for %s", subdomain) - return None - except dns.name.EmptyLabel: - self.log.warning("Empty label for %s", subdomain) - return None - return query.response.answer - - def run(self) -> None: - self.log.info("Started") - subdomain: str - for subdomain in iter(self.orchestrator.subdomains_queue.get, None): - - for _ in range(NUMBER_TRIES): - resolved = self.resolve_subdomain(subdomain) - # Retry with another nameserver if error - if resolved is None: - self.change_nameserver() - else: - break - - # If it wasn't found after multiple tries - if resolved is None: - self.log.error("Gave up on %s", subdomain) - resolved = [] - - assert isinstance(resolved, list) - self.orchestrator.results_queue.put(resolved) - - self.orchestrator.results_queue.put(None) - self.log.info("Stopped") - - -class Orchestrator(): - """ - Orchestrator of the different Worker threads. - """ - - def refill_nameservers_queue(self) -> None: - """ - Re-fill the given nameservers into the nameservers queue. - Done every-time the queue is empty, making it - basically looping and infinite. - """ - # Might be in a race condition but that's probably fine - for nameserver in self.nameservers: - self.nameservers_queue.put(nameserver) - self.log.info("Refilled nameserver queue") - - def __init__(self, subdomains: typing.Iterable[str], - nameservers: typing.List[str] = None, - nb_workers: int = 1, - ): - self.log = logging.getLogger('orchestrator') - self.subdomains = subdomains - self.nb_workers = nb_workers - - # Use interal resolver by default - self.nameservers = nameservers or dns.resolver.Resolver().nameservers - - self.subdomains_queue: queue.Queue = queue.Queue( - maxsize=self.nb_workers) - self.results_queue: queue.Queue = queue.Queue() - self.nameservers_queue: queue.Queue = queue.Queue() - - self.refill_nameservers_queue() - - def fill_subdomain_queue(self) -> None: - """ - Read the subdomains in input and put them into the queue. - Done in a thread so we can both: - - yield the results as they come - - not store all the subdomains at once - """ - self.log.info("Started reading subdomains") - # Send data to workers - for subdomain in self.subdomains: - self.subdomains_queue.put(subdomain) - - self.log.info("Finished reading subdomains") - # Send sentinel to each worker - # sentinel = None ~= EOF - for _ in range(self.nb_workers): - self.subdomains_queue.put(None) - - @staticmethod - def format_rrset(rrset: dns.rrset.RRset) -> typing.Iterable[str]: - if rrset.rdtype == dns.rdatatype.CNAME: - dtype = 'c' - elif rrset.rdtype == dns.rdatatype.A: - dtype = 'a' - else: - raise NotImplementedError - name = rrset.name.to_text()[:-1] - for item in rrset.items: - value = item.to_text() - if rrset.rdtype == dns.rdatatype.CNAME: - value = value[:-1] - yield f'{dtype},{int(time.time())},{name},{value}\n' - - def run(self) -> typing.Iterable[str]: - """ - Yield the results. - """ - # Create workers - self.log.info("Creating workers") - for i in range(self.nb_workers): - Worker(self, i).start() - - fill_thread = threading.Thread(target=self.fill_subdomain_queue) - fill_thread.start() - - # Wait for one sentinel per worker - # In the meantime output results - for _ in range(self.nb_workers): - resolved: typing.List[dns.rrset.RRset] - for resolved in iter(self.results_queue.get, None): - for rrset in resolved: - yield from self.format_rrset(rrset) - - self.log.info("Waiting for reader thread") - fill_thread.join() - - self.log.info("Done!") - - -def main() -> None: - """ - Main function when used directly. - Read the subdomains provided and output it, - the last CNAME resolved and the IP adress it resolves to. - Takes as an input a filename (or nothing, for stdin), - and as an output a filename (or nothing, for stdout). - The input must be a subdomain per line, the output is a TODO - Use the file `nameservers` as the list of nameservers - to use, or else it will use the system defaults. - """ - - # Initialization - coloredlogs.install( - level='DEBUG', - fmt='%(asctime)s %(name)s %(levelname)s %(message)s' - ) - - # Parsing arguments - parser = argparse.ArgumentParser( - description="Massively resolves subdomains and store them in a file.") - parser.add_argument( - '-i', '--input', type=argparse.FileType('r'), default=sys.stdin, - help="Input file with one subdomain per line") - parser.add_argument( - '-o', '--output', type=argparse.FileType('w'), default=sys.stdout, - help="Outptut file with DNS chains") - parser.add_argument( - '-n', '--nameservers', default='nameservers', - help="File with one nameserver per line") - parser.add_argument( - '-j', '--workers', type=int, default=512, - help="Number of threads to use") - args = parser.parse_args() - - # Cleaning input - iterator = iter(args.input) - iterator = map(str.strip, iterator) - iterator = filter(None, iterator) - - # Reading nameservers - servers: typing.List[str] = list() - if os.path.isfile(args.nameservers): - servers = open(args.nameservers).readlines() - servers = list(filter(None, map(str.strip, servers))) - - for resolved in Orchestrator( - iterator, - servers, - nb_workers=args.workers - ).run(): - args.output.write(resolved) - - -if __name__ == '__main__': - main()