Workflow: Can now import DnsMass output

Well, in a specific format but DnsMass nonetheless
This commit is contained in:
Geoffrey Frogeye 2019-12-14 23:59:50 +01:00
parent 189deeb559
commit ddceed3d25
Signed by: geoffrey
GPG key ID: D8A7ECA00A8CD3DD
5 changed files with 152 additions and 345 deletions

View file

@ -284,14 +284,18 @@ class Database():
'UNION ' 'UNION '
'SELECT * FROM (' 'SELECT * FROM ('
'SELECT val, entry FROM zone ' 'SELECT val, entry FROM zone '
# 'WHERE val>=:d '
# 'ORDER BY val ASC LIMIT 1'
'WHERE val<=:d ' 'WHERE val<=:d '
'ORDER BY val DESC LIMIT 1' 'AND instr(:d, val) = 1'
')', ')',
{'d': domain_prep} {'d': domain_prep}
) )
for val, entry in cursor: for val, entry in cursor:
# print(293, val, entry)
self.enter_step('get_domain_confirm') self.enter_step('get_domain_confirm')
if not (val is None or domain_prep.startswith(val)): if not (val is None or domain_prep.startswith(val)):
# print(297)
continue continue
self.enter_step('get_domain_yield') self.enter_step('get_domain_yield')
yield entry yield entry

View file

@ -7,23 +7,24 @@ import logging
import sys import sys
import typing import typing
import multiprocessing import multiprocessing
import enum
NUMBER_THREADS = 2 RecordType = enum.Enum('RecordType', 'A AAAA CNAME PTR')
BLOCK_SIZE = 100 Record = typing.Tuple[RecordType, int, str, str]
# select, confirm, write # select, confirm, write
FUNCTION_MAP: typing.Any = { FUNCTION_MAP: typing.Any = {
'a': ( RecordType.A: (
database.Database.get_ip4, database.Database.get_ip4,
database.Database.get_domain_in_zone, database.Database.get_domain_in_zone,
database.Database.set_hostname, database.Database.set_hostname,
), ),
'cname': ( RecordType.CNAME: (
database.Database.get_domain, database.Database.get_domain,
database.Database.get_domain_in_zone, database.Database.get_domain_in_zone,
database.Database.set_hostname, database.Database.set_hostname,
), ),
'ptr': ( RecordType.PTR: (
database.Database.get_domain, database.Database.get_domain,
database.Database.get_ip4_in_network, database.Database.get_ip4_in_network,
database.Database.set_ip4address, database.Database.set_ip4address,
@ -33,12 +34,12 @@ FUNCTION_MAP: typing.Any = {
class Reader(multiprocessing.Process): class Reader(multiprocessing.Process):
def __init__(self, def __init__(self,
lines_queue: multiprocessing.Queue, recs_queue: multiprocessing.Queue,
write_queue: multiprocessing.Queue, write_queue: multiprocessing.Queue,
index: int = 0): index: int = 0):
super(Reader, self).__init__() super(Reader, self).__init__()
self.log = logging.getLogger(f'rd{index:03d}') self.log = logging.getLogger(f'rd{index:03d}')
self.lines_queue = lines_queue self.recs_queue = recs_queue
self.write_queue = write_queue self.write_queue = write_queue
self.index = index self.index = index
@ -48,15 +49,19 @@ class Reader(multiprocessing.Process):
self.db.enter_step('line_wait') self.db.enter_step('line_wait')
block: typing.List[str] block: typing.List[str]
try: try:
for block in iter(self.lines_queue.get, None): for block in iter(self.recs_queue.get, None):
for line in block: record: Record
dtype, updated, name, value = line for record in block:
# print(55, record)
dtype, updated, name, value = record
self.db.enter_step('feed_switch') self.db.enter_step('feed_switch')
select, confirm, write = FUNCTION_MAP[dtype] select, confirm, write = FUNCTION_MAP[dtype]
for rule in select(self.db, value): for rule in select(self.db, value):
# print(60, rule, list(confirm(self.db, name)))
if not any(confirm(self.db, name)): if not any(confirm(self.db, name)):
# print(62, write, name, updated, rule)
self.db.enter_step('wait_put') self.db.enter_step('wait_put')
self.write_queue.put((write, name, updated)) self.write_queue.put((write, name, updated, rule))
self.db.enter_step('line_wait') self.db.enter_step('line_wait')
except KeyboardInterrupt: except KeyboardInterrupt:
self.log.error('Interrupted') self.log.error('Interrupted')
@ -82,9 +87,10 @@ class Writer(multiprocessing.Process):
fun: typing.Callable fun: typing.Callable
name: str name: str
updated: int updated: int
for fun, name, updated in iter(self.write_queue.get, None): source: int
for fun, name, updated, source in iter(self.write_queue.get, None):
self.db.enter_step('exec') self.db.enter_step('exec')
fun(self.db, name, updated) fun(self.db, name, updated, source=source)
self.db.enter_step('line_wait') self.db.enter_step('line_wait')
except KeyboardInterrupt: except KeyboardInterrupt:
self.log.error('Interrupted') self.log.error('Interrupted')
@ -93,29 +99,142 @@ class Writer(multiprocessing.Process):
self.db.close() self.db.close()
class Parser():
def __init__(self,
buf: typing.Any,
recs_queue: multiprocessing.Queue,
block_size: int,
):
super(Parser, self).__init__()
self.buf = buf
self.log = logging.getLogger('pr ')
self.recs_queue = recs_queue
self.block: typing.List[Record] = list()
self.block_size = block_size
self.db = database.Database() # Just for timing
self.db.log = logging.getLogger('pr ')
def register(self, record: Record) -> None:
self.db.enter_step('register')
self.block.append(record)
if len(self.block) >= self.block_size:
self.db.enter_step('put_block')
self.recs_queue.put(self.block)
self.block = list()
def run(self) -> None:
self.consume()
self.recs_queue.put(self.block)
self.db.close()
def consume(self) -> None:
raise NotImplementedError
class Rapid7Parser(Parser):
TYPES = {
'a': RecordType.A,
'aaaa': RecordType.AAAA,
'cname': RecordType.CNAME,
'ptr': RecordType.PTR,
}
def consume(self) -> None:
for line in self.buf:
self.db.enter_step('parse_rapid7')
try:
data = json.loads(line)
except json.decoder.JSONDecodeError:
continue
record = (
Rapid7Parser.TYPES[data['type']],
int(data['timestamp']),
data['name'],
data['value']
)
self.register(record)
class DnsMassParser(Parser):
# dnsmass --output Snrql
# --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4
TYPES = {
'A': (RecordType.A, -1, None),
'AAAA': (RecordType.AAAA, -1, None),
'CNAME': (RecordType.CNAME, -1, -1),
}
def consume(self) -> None:
self.db.enter_step('parse_dnsmass')
timestamp = 0
header = True
for line in self.buf:
line = line[:-1]
if not line:
header = True
continue
split = line.split(' ')
try:
if header:
timestamp = int(split[1])
header = False
else:
dtype, name_offset, value_offset = \
DnsMassParser.TYPES[split[1]]
record = (
dtype,
timestamp,
split[0][:name_offset],
split[2][:value_offset],
)
self.register(record)
self.db.enter_step('parse_dnsmass')
except KeyError:
continue
PARSERS = {
'rapid7': Rapid7Parser,
'dnsmass': DnsMassParser,
}
if __name__ == '__main__': if __name__ == '__main__':
# Parsing arguments # Parsing arguments
log = logging.getLogger('feed_dns') log = logging.getLogger('feed_dns')
parser = argparse.ArgumentParser( args_parser = argparse.ArgumentParser(
description="TODO") description="TODO")
parser.add_argument( args_parser.add_argument(
# '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer, 'parser',
choices=PARSERS.keys(),
help="TODO")
args_parser.add_argument(
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin, '-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO") help="TODO")
args = parser.parse_args() args_parser.add_argument(
'-j', '--workers', type=int, default=4,
help="TODO")
args_parser.add_argument(
'-b', '--block-size', type=int, default=100,
help="TODO")
args = args_parser.parse_args()
DB = database.Database(write=False) # Not needed, just for timing DB = database.Database(write=False) # Not needed, just for timing
DB.log = logging.getLogger('db ') DB.log = logging.getLogger('db ')
lines_queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=100) recs_queue: multiprocessing.Queue = multiprocessing.Queue(
write_queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=100) maxsize=10*args.workers)
write_queue: multiprocessing.Queue = multiprocessing.Queue(
maxsize=10*args.workers)
DB.enter_step('proc_create') DB.enter_step('proc_create')
readers: typing.List[Reader] = list() readers: typing.List[Reader] = list()
for w in range(NUMBER_THREADS): for w in range(args.workers):
readers.append(Reader(lines_queue, write_queue, w)) readers.append(Reader(recs_queue, write_queue, w))
writer = Writer(write_queue) writer = Writer(write_queue)
parser = PARSERS[args.parser](
args.input, recs_queue, args.block_size)
DB.enter_step('proc_start') DB.enter_step('proc_start')
for reader in readers: for reader in readers:
@ -123,28 +242,12 @@ if __name__ == '__main__':
writer.start() writer.start()
try: try:
block: typing.List[str] = list() DB.enter_step('parser_run')
DB.enter_step('iowait') parser.run()
for line in args.input:
DB.enter_step('block_append')
DB.enter_step('feed_json_parse')
data = json.loads(line)
line = (data['type'],
int(data['timestamp']),
data['name'],
data['value'])
block.append(line)
if len(block) >= BLOCK_SIZE:
DB.enter_step('wait_put')
lines_queue.put(block)
block = list()
DB.enter_step('iowait')
DB.enter_step('wait_put')
lines_queue.put(block)
DB.enter_step('end_put') DB.enter_step('end_put')
for _ in range(NUMBER_THREADS): for _ in range(args.workers):
lines_queue.put(None) recs_queue.put(None)
write_queue.put(None) write_queue.put(None)
DB.enter_step('proc_join') DB.enter_step('proc_join')

View file

@ -1,36 +0,0 @@
#!/usr/bin/env python3
import argparse
import sys
import logging
import json
import csv
if __name__ == '__main__':
# Parsing arguments
log = logging.getLogger('json_to_csv')
parser = argparse.ArgumentParser(
description="TODO")
parser.add_argument(
# '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer,
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO")
parser.add_argument(
# '-i', '--output', type=argparse.FileType('wb'), default=sys.stdout.buffer,
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
help="TODO")
args = parser.parse_args()
writer = csv.writer(args.output)
for line in args.input:
data = json.loads(line)
try:
writer.writerow([
data['type'][0], # First letter, will need to do something special for AAAA
data['timestamp'],
data['name'],
data['value']])
except (KeyError, json.decoder.JSONDecodeError):
log.error('Could not parse line: %s', line)
pass

View file

@ -9,11 +9,11 @@ function log() {
# TODO Fetch 'em # TODO Fetch 'em
log "Reading PTR records…" log "Reading PTR records…"
pv ptr.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py pv ptr.json.gz | gunzip | ./feed_dns.py
log "Reading A records…" log "Reading A records…"
pv a.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py pv a.json.gz | gunzip | ./feed_dns.py
log "Reading CNAME records…" log "Reading CNAME records…"
pv cname.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py pv cname.json.gz | gunzip | ./feed_dns.py
log "Pruning old data…" log "Pruning old data…"
./database.py --prune ./database.py --prune

View file

@ -1,264 +0,0 @@
#!/usr/bin/env python3
"""
From a list of subdomains, output only
the ones resolving to a first-party tracker.
"""
import argparse
import logging
import os
import queue
import sys
import threading
import typing
import time
import coloredlogs
import dns.exception
import dns.resolver
DNS_TIMEOUT = 5.0
NUMBER_TRIES = 5
class Worker(threading.Thread):
"""
Worker process for a DNS resolver.
Will resolve DNS to match first-party subdomains.
"""
def change_nameserver(self) -> None:
"""
Assign a this worker another nameserver from the queue.
"""
server = None
while server is None:
try:
server = self.orchestrator.nameservers_queue.get(block=False)
except queue.Empty:
self.orchestrator.refill_nameservers_queue()
self.log.info("Using nameserver: %s", server)
self.resolver.nameservers = [server]
def __init__(self,
orchestrator: 'Orchestrator',
index: int = 0):
super(Worker, self).__init__()
self.log = logging.getLogger(f'worker{index:03d}')
self.orchestrator = orchestrator
self.resolver = dns.resolver.Resolver()
self.change_nameserver()
def resolve_subdomain(self, subdomain: str) -> typing.Optional[
typing.List[
dns.rrset.RRset
]
]:
"""
Returns the resolution chain of the subdomain to an A record,
including any intermediary CNAME.
The last element is an IP address.
Returns None if the nameserver was unable to satisfy the request.
Returns [] if the requests points to nothing.
"""
self.log.debug("Querying %s", subdomain)
try:
query = self.resolver.query(subdomain, 'A', lifetime=DNS_TIMEOUT)
except dns.resolver.NXDOMAIN:
return []
except dns.resolver.NoAnswer:
return []
except dns.resolver.YXDOMAIN:
self.log.warning("Query name too long for %s", subdomain)
return None
except dns.resolver.NoNameservers:
# NOTE Most of the time this error message means that the domain
# does not exists, but sometimes it means the that the server
# itself is broken. So we count on the retry logic.
self.log.warning("All nameservers broken for %s", subdomain)
return None
except dns.exception.Timeout:
# NOTE Same as above
self.log.warning("Timeout for %s", subdomain)
return None
except dns.name.EmptyLabel:
self.log.warning("Empty label for %s", subdomain)
return None
return query.response.answer
def run(self) -> None:
self.log.info("Started")
subdomain: str
for subdomain in iter(self.orchestrator.subdomains_queue.get, None):
for _ in range(NUMBER_TRIES):
resolved = self.resolve_subdomain(subdomain)
# Retry with another nameserver if error
if resolved is None:
self.change_nameserver()
else:
break
# If it wasn't found after multiple tries
if resolved is None:
self.log.error("Gave up on %s", subdomain)
resolved = []
assert isinstance(resolved, list)
self.orchestrator.results_queue.put(resolved)
self.orchestrator.results_queue.put(None)
self.log.info("Stopped")
class Orchestrator():
"""
Orchestrator of the different Worker threads.
"""
def refill_nameservers_queue(self) -> None:
"""
Re-fill the given nameservers into the nameservers queue.
Done every-time the queue is empty, making it
basically looping and infinite.
"""
# Might be in a race condition but that's probably fine
for nameserver in self.nameservers:
self.nameservers_queue.put(nameserver)
self.log.info("Refilled nameserver queue")
def __init__(self, subdomains: typing.Iterable[str],
nameservers: typing.List[str] = None,
nb_workers: int = 1,
):
self.log = logging.getLogger('orchestrator')
self.subdomains = subdomains
self.nb_workers = nb_workers
# Use interal resolver by default
self.nameservers = nameservers or dns.resolver.Resolver().nameservers
self.subdomains_queue: queue.Queue = queue.Queue(
maxsize=self.nb_workers)
self.results_queue: queue.Queue = queue.Queue()
self.nameservers_queue: queue.Queue = queue.Queue()
self.refill_nameservers_queue()
def fill_subdomain_queue(self) -> None:
"""
Read the subdomains in input and put them into the queue.
Done in a thread so we can both:
- yield the results as they come
- not store all the subdomains at once
"""
self.log.info("Started reading subdomains")
# Send data to workers
for subdomain in self.subdomains:
self.subdomains_queue.put(subdomain)
self.log.info("Finished reading subdomains")
# Send sentinel to each worker
# sentinel = None ~= EOF
for _ in range(self.nb_workers):
self.subdomains_queue.put(None)
@staticmethod
def format_rrset(rrset: dns.rrset.RRset) -> typing.Iterable[str]:
if rrset.rdtype == dns.rdatatype.CNAME:
dtype = 'c'
elif rrset.rdtype == dns.rdatatype.A:
dtype = 'a'
else:
raise NotImplementedError
name = rrset.name.to_text()[:-1]
for item in rrset.items:
value = item.to_text()
if rrset.rdtype == dns.rdatatype.CNAME:
value = value[:-1]
yield f'{dtype},{int(time.time())},{name},{value}\n'
def run(self) -> typing.Iterable[str]:
"""
Yield the results.
"""
# Create workers
self.log.info("Creating workers")
for i in range(self.nb_workers):
Worker(self, i).start()
fill_thread = threading.Thread(target=self.fill_subdomain_queue)
fill_thread.start()
# Wait for one sentinel per worker
# In the meantime output results
for _ in range(self.nb_workers):
resolved: typing.List[dns.rrset.RRset]
for resolved in iter(self.results_queue.get, None):
for rrset in resolved:
yield from self.format_rrset(rrset)
self.log.info("Waiting for reader thread")
fill_thread.join()
self.log.info("Done!")
def main() -> None:
"""
Main function when used directly.
Read the subdomains provided and output it,
the last CNAME resolved and the IP adress it resolves to.
Takes as an input a filename (or nothing, for stdin),
and as an output a filename (or nothing, for stdout).
The input must be a subdomain per line, the output is a TODO
Use the file `nameservers` as the list of nameservers
to use, or else it will use the system defaults.
"""
# Initialization
coloredlogs.install(
level='DEBUG',
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
)
# Parsing arguments
parser = argparse.ArgumentParser(
description="Massively resolves subdomains and store them in a file.")
parser.add_argument(
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="Input file with one subdomain per line")
parser.add_argument(
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
help="Outptut file with DNS chains")
parser.add_argument(
'-n', '--nameservers', default='nameservers',
help="File with one nameserver per line")
parser.add_argument(
'-j', '--workers', type=int, default=512,
help="Number of threads to use")
args = parser.parse_args()
# Cleaning input
iterator = iter(args.input)
iterator = map(str.strip, iterator)
iterator = filter(None, iterator)
# Reading nameservers
servers: typing.List[str] = list()
if os.path.isfile(args.nameservers):
servers = open(args.nameservers).readlines()
servers = list(filter(None, map(str.strip, servers)))
for resolved in Orchestrator(
iterator,
servers,
nb_workers=args.workers
).run():
args.output.write(resolved)
if __name__ == '__main__':
main()