diff --git a/.gitignore b/.gitignore index aa3f3eb..188051c 100644 --- a/.gitignore +++ b/.gitignore @@ -3,5 +3,3 @@ *.db-journal nameservers nameservers.head -*.o -*.so diff --git a/database.py b/database.py index d336eba..ee51829 100755 --- a/database.py +++ b/database.py @@ -12,6 +12,7 @@ import logging import argparse import coloredlogs import ipaddress +import math coloredlogs.install( level='DEBUG', @@ -22,43 +23,47 @@ DbValue = typing.Union[None, int, float, str, bytes] class Database(): - VERSION = 3 + VERSION = 5 PATH = "blocking.db" def open(self) -> None: mode = 'rwc' if self.write else 'ro' uri = f'file:{self.PATH}?mode={mode}' self.conn = sqlite3.connect(uri, uri=True) - self.cursor = self.conn.cursor() - self.execute("PRAGMA foreign_keys = ON") - # self.conn.create_function("prepare_ip4address", 1, - # Database.prepare_ip4address, - # deterministic=True) + cursor = self.conn.cursor() + cursor.execute("PRAGMA foreign_keys = ON") + self.conn.create_function("unpack_asn", 1, + self.unpack_asn, + deterministic=True) + self.conn.create_function("unpack_ip4address", 1, + self.unpack_ip4address, + deterministic=True) + self.conn.create_function("unpack_ip4network", 2, + self.unpack_ip4network, + deterministic=True) self.conn.create_function("unpack_domain", 1, lambda s: s[:-1][::-1], deterministic=True) - - def execute(self, cmd: str, args: typing.Union[ - typing.Tuple[DbValue, ...], - typing.Dict[str, DbValue]] = None) -> None: - # self.log.debug(cmd) - # self.log.debug(args) - self.cursor.execute(cmd, args or tuple()) + self.conn.create_function("format_zone", 1, + lambda s: '*' + s[::-1], + deterministic=True) def get_meta(self, key: str) -> typing.Optional[int]: + cursor = self.conn.cursor() try: - self.execute("SELECT value FROM meta WHERE key=?", (key,)) + cursor.execute("SELECT value FROM meta WHERE key=?", (key,)) except sqlite3.OperationalError: return None - for ver, in self.cursor: + for ver, in cursor: return ver return None def set_meta(self, key: str, val: int) -> None: - self.execute("INSERT INTO meta VALUES (?, ?) " - "ON CONFLICT (key) DO " - "UPDATE set value=?", - (key, val, val)) + cursor = self.conn.cursor() + cursor.execute("INSERT INTO meta VALUES (?, ?) " + "ON CONFLICT (key) DO " + "UPDATE set value=?", + (key, val, val)) def close(self) -> None: self.enter_step('close_commit') @@ -76,8 +81,9 @@ class Database(): os.unlink(self.PATH) self.open() self.log.info("Creating database version %d.", self.VERSION) + cursor = self.conn.cursor() with open("database_schema.sql", 'r') as db_schema: - self.cursor.executescript(db_schema.read()) + cursor.executescript(db_schema.read()) self.set_meta('version', self.VERSION) self.conn.commit() @@ -98,13 +104,6 @@ class Database(): version) self.initialize() - updated = self.get_meta('updated') - if updated is None: - self.execute('SELECT max(updated) FROM rules') - data = self.cursor.fetchone() - updated, = data - self.updated = updated or 1 - def enter_step(self, name: str) -> None: now = time.perf_counter() try: @@ -126,24 +125,32 @@ class Database(): self.log.debug(f"{'total':<20}: " f"{total:9.2f} s ({1:7.2%})") - def prepare_hostname(self, hostname: str) -> str: + @staticmethod + def pack_hostname(hostname: str) -> str: return hostname[::-1] + '.' - def prepare_zone(self, zone: str) -> str: - return self.prepare_hostname(zone) + @staticmethod + def pack_zone(zone: str) -> str: + return Database.pack_hostname(zone) @staticmethod - def prepare_asn(asn: str) -> int: + def pack_asn(asn: str) -> int: asn = asn.upper() if asn.startswith('AS'): asn = asn[2:] return int(asn) @staticmethod - def prepare_ip4address(address: str) -> int: + def unpack_asn(asn: int) -> str: + return f'AS{asn}' + + @staticmethod + def pack_ip4address(address: str) -> int: total = 0 for i, octet in enumerate(address.split('.')): total += int(octet) << (3-i)*8 + if total > 0xFFFFFFFF: + raise ValueError return total # return '{:02x}{:02x}{:02x}{:02x}'.format( # *[int(c) for c in address.split('.')]) @@ -158,34 +165,78 @@ class Database(): # packed = ipaddress.ip_address(address).packed # return packed - def prepare_ip4network(self, network: str) -> typing.Tuple[int, int]: - # def prepare_ip4network(network: str) -> str: + @staticmethod + def unpack_ip4address(address: int) -> str: + return '.'.join(str((address >> (i * 8)) & 0xFF) + for i in reversed(range(4))) + + @staticmethod + def pack_ip4network(network: str) -> typing.Tuple[int, int]: + # def pack_ip4network(network: str) -> str: net = ipaddress.ip_network(network) - mini = self.prepare_ip4address(net.network_address.exploded) - maxi = self.prepare_ip4address(net.broadcast_address.exploded) + mini = Database.pack_ip4address(net.network_address.exploded) + maxi = Database.pack_ip4address(net.broadcast_address.exploded) # mini = net.network_address.packed # maxi = net.broadcast_address.packed return mini, maxi - # return Database.prepare_ip4address(net.network_address.exploded)[:net.prefixlen] + # return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen] - def expire(self) -> None: - self.enter_step('expire') - self.updated += 1 - self.set_meta('updated', self.updated) + @staticmethod + def unpack_ip4network(mini: int, maxi: int) -> str: + addr = Database.unpack_ip4address(mini) + prefixlen = 32-int(math.log2(maxi-mini+1)) + return f'{addr}/{prefixlen}' def update_references(self) -> None: self.enter_step('update_refs') - self.execute('UPDATE rules AS r SET refs=' - '(SELECT count(*) FROM rules ' - 'WHERE source=r.id)') + cursor = self.conn.cursor() + cursor.execute('UPDATE rules AS r SET refs=' + '(SELECT count(*) FROM rules ' + 'WHERE source=r.id)') - def prune(self) -> None: + def prune(self, before: int, base_only: bool = False) -> None: self.enter_step('prune') - self.execute('DELETE FROM rules WHERE updated typing.Iterable[str]: - command = 'SELECT unpack_domain(val) FROM rules ' \ + def explain(self, entry: int) -> str: + # Format current + string = '???' + cursor = self.conn.cursor() + cursor.execute( + 'SELECT unpack_asn(val) FROM asn WHERE entry=:entry ' + 'UNION ' + 'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry ' + 'UNION ' + 'SELECT format_zone(val) FROM zone WHERE entry=:entry ' + 'UNION ' + 'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry ' + 'UNION ' + 'SELECT unpack_ip4network(mini, maxi) ' + 'FROM ip4network WHERE entry=:entry ', + {"entry": entry} + ) + for val, in cursor: # only one + string = str(val) + string += f' #{entry}' + + # Add source if any + cursor.execute('SELECT source FROM rules WHERE id=?', (entry,)) + for source, in cursor: + if source: + string += f' ← {self.explain(source)}' + return string + + def export(self, + first_party_only: bool = False, + end_chain_only: bool = False, + explain: bool = False, + ) -> typing.Iterable[str]: + selection = 'entry' if explain else 'unpack_domain(val)' + command = f'SELECT {selection} FROM rules ' \ 'INNER JOIN hostname ON rules.id = hostname.entry' restrictions: typing.List[str] = list() if first_party_only: @@ -194,16 +245,40 @@ class Database(): restrictions.append('rules.refs = 0') if restrictions: command += ' WHERE ' + ' AND '.join(restrictions) - command += ' ORDER BY unpack_domain(val) ASC' - self.execute(command) - for val, in self.cursor: - yield val + if not explain: + command += ' ORDER BY unpack_domain(val) ASC' + cursor = self.conn.cursor() + cursor.execute(command) + for val, in cursor: + if explain: + yield self.explain(val) + else: + yield val + + def count_rules(self, + first_party_only: bool = False, + ) -> str: + counts: typing.List[str] = list() + cursor = self.conn.cursor() + for table in ['asn', 'ip4network', 'ip4address', 'zone', 'hostname']: + command = f'SELECT count(*) FROM rules ' \ + f'INNER JOIN {table} ON rules.id = {table}.entry ' \ + 'WHERE rules.level = 0' + if first_party_only: + command += ' AND first_party=1' + cursor.execute(command) + count, = cursor.fetchone() + if count > 0: + counts.append(f'{table}: {count}') + + return ', '.join(counts) def get_domain(self, domain: str) -> typing.Iterable[int]: self.enter_step('get_domain_prepare') - domain_prep = self.prepare_hostname(domain) + domain_prep = self.pack_hostname(domain) + cursor = self.conn.cursor() self.enter_step('get_domain_select') - self.execute( + cursor.execute( 'SELECT null, entry FROM hostname ' 'WHERE val=:d ' 'UNION ' @@ -214,22 +289,41 @@ class Database(): ')', {'d': domain_prep} ) - for val, entry in self.cursor: + for val, entry in cursor: self.enter_step('get_domain_confirm') if not (val is None or domain_prep.startswith(val)): continue self.enter_step('get_domain_yield') yield entry + def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]: + self.enter_step('get_domainiz_prepare') + domain_prep = self.pack_hostname(domain) + cursor = self.conn.cursor() + self.enter_step('get_domainiz_select') + cursor.execute( + 'SELECT val, entry FROM zone ' + 'WHERE val<=:d ' + 'ORDER BY val DESC LIMIT 1', + {'d': domain_prep} + ) + for val, entry in cursor: + self.enter_step('get_domainiz_confirm') + if not (val is None or domain_prep.startswith(val)): + continue + self.enter_step('get_domainiz_yield') + yield entry + def get_ip4(self, address: str) -> typing.Iterable[int]: self.enter_step('get_ip4_prepare') try: - address_prep = self.prepare_ip4address(address) + address_prep = self.pack_ip4address(address) except (ValueError, IndexError): self.log.error("Invalid ip4address: %s", address) return + cursor = self.conn.cursor() self.enter_step('get_ip4_select') - self.execute( + cursor.execute( 'SELECT entry FROM ip4address ' # 'SELECT null, entry FROM ip4address ' 'WHERE val=:a ' @@ -244,7 +338,7 @@ class Database(): 'WHERE :a BETWEEN mini AND maxi ', {'a': address_prep} ) - for val, entry in self.cursor: + for entry, in cursor: # self.enter_step('get_ip4_confirm') # if not (val is None or val.startswith(address_prep)): # # PERF startswith but from the end @@ -252,11 +346,29 @@ class Database(): self.enter_step('get_ip4_yield') yield entry + def get_ip4_in_network(self, address: str) -> typing.Iterable[int]: + self.enter_step('get_ip4in_prepare') + try: + address_prep = self.pack_ip4address(address) + except (ValueError, IndexError): + self.log.error("Invalid ip4address: %s", address) + return + cursor = self.conn.cursor() + self.enter_step('get_ip4in_select') + cursor.execute( + 'SELECT entry FROM ip4network ' + 'WHERE :a BETWEEN mini AND maxi ', + {'a': address_prep} + ) + for entry, in cursor: + self.enter_step('get_ip4in_yield') + yield entry + def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]: + cursor = self.conn.cursor() self.enter_step('list_asn_select') - self.enter_step('get_domain_select') - self.execute('SELECT val, entry FROM asn') - for val, entry in self.cursor: + cursor.execute('SELECT val, entry FROM asn') + for val, entry in cursor: yield f'AS{val}', entry def _set_generic(self, @@ -264,6 +376,7 @@ class Database(): select_query: str, insert_query: str, prep: typing.Dict[str, DbValue], + updated: int, is_first_party: bool = False, source: int = None, ) -> None: @@ -271,34 +384,36 @@ class Database(): # here abstraction > performaces # Fields based on the source + self.enter_step(f'set_{table}_prepare') + cursor = self.conn.cursor() if source is None: first_party = int(is_first_party) level = 0 else: self.enter_step(f'set_{table}_source') - self.execute( + cursor.execute( 'SELECT first_party, level FROM rules ' 'WHERE id=?', (source,) ) - first_party, level = self.cursor.fetchone() + first_party, level = cursor.fetchone() level += 1 self.enter_step(f'set_{table}_select') - self.execute(select_query, prep) + cursor.execute(select_query, prep) - rules_prep = { + rules_prep: typing.Dict[str, DbValue] = { "source": source, - "updated": self.updated, + "updated": updated, "first_party": first_party, "level": level, } # If the entry already exists - for entry, in self.cursor: # only one + for entry, in cursor: # only one self.enter_step(f'set_{table}_update') rules_prep['entry'] = entry - self.execute( + cursor.execute( 'UPDATE rules SET ' 'source=:source, updated=:updated, ' 'first_party=:first_party, level=:level ' @@ -314,23 +429,18 @@ class Database(): # If it does not exist - if source is not None: - self.enter_step(f'set_{table}_incsrc') - self.execute('UPDATE rules SET refs = refs + 1 WHERE id=?', - (source,)) - self.enter_step(f'set_{table}_insert') - self.execute( + cursor.execute( 'INSERT INTO rules ' - '(source, updated, first_party, refs, level) ' - 'VALUES (:source, :updated, :first_party, 0, :level) ', + '(source, updated, first_party, level) ' + 'VALUES (:source, :updated, :first_party, :level) ', rules_prep ) - self.execute('SELECT id FROM rules WHERE rowid=?', - (self.cursor.lastrowid,)) - for entry, in self.cursor: # only one + cursor.execute('SELECT id FROM rules WHERE rowid=?', + (cursor.lastrowid,)) + for entry, in cursor: # only one prep['entry'] = entry - self.execute(insert_query, prep) + cursor.execute(insert_query, prep) return assert False @@ -338,7 +448,7 @@ class Database(): *args: typing.Any, **kwargs: typing.Any) -> None: self.enter_step('set_hostname_prepare') prep: typing.Dict[str, DbValue] = { - 'val': self.prepare_hostname(hostname), + 'val': self.pack_hostname(hostname), } self._set_generic( 'hostname', @@ -353,7 +463,7 @@ class Database(): *args: typing.Any, **kwargs: typing.Any) -> None: self.enter_step('set_asn_prepare') try: - asn_prep = self.prepare_asn(asn) + asn_prep = self.pack_asn(asn) except ValueError: self.log.error("Invalid asn: %s", asn) return @@ -371,10 +481,9 @@ class Database(): def set_ip4address(self, ip4address: str, *args: typing.Any, **kwargs: typing.Any) -> None: - # TODO Do not add if already in ip4network self.enter_step('set_ip4add_prepare') try: - ip4address_prep = self.prepare_ip4address(ip4address) + ip4address_prep = self.pack_ip4address(ip4address) except (ValueError, IndexError): self.log.error("Invalid ip4address: %s", ip4address) return @@ -394,7 +503,7 @@ class Database(): *args: typing.Any, **kwargs: typing.Any) -> None: self.enter_step('set_zone_prepare') prep: typing.Dict[str, DbValue] = { - 'val': self.prepare_zone(zone), + 'val': self.pack_zone(zone), } self._set_generic( 'zone', @@ -409,7 +518,7 @@ class Database(): *args: typing.Any, **kwargs: typing.Any) -> None: self.enter_step('set_ip4net_prepare') try: - ip4network_prep = self.prepare_ip4network(ip4network) + ip4network_prep = self.pack_ip4network(ip4network) except (ValueError, IndexError): self.log.error("Invalid ip4network: %s", ip4network) return @@ -439,8 +548,12 @@ if __name__ == '__main__': '-p', '--prune', action='store_true', help="Remove old entries from database") parser.add_argument( - '-e', '--expire', action='store_true', - help="Set the whole database as an old source") + '-b', '--prune-base', action='store_true', + help="TODO") + parser.add_argument( + '-s', '--prune-before', type=int, + default=(int(time.time()) - 60*60*24*31*6), + help="TODO") parser.add_argument( '-r', '--references', action='store_true', help="Update the reference count") @@ -451,10 +564,8 @@ if __name__ == '__main__': if args.initialize: DB.initialize() if args.prune: - DB.prune() - if args.expire: - DB.expire() - if args.references and not args.prune: + DB.prune(before=args.prune_before, base_only=args.prune_base) + if args.references: DB.update_references() DB.close() diff --git a/database_schema.sql b/database_schema.sql index 9be81b0..3116a09 100644 --- a/database_schema.sql +++ b/database_schema.sql @@ -10,30 +10,37 @@ CREATE TABLE rules ( level INTEGER, -- Level of recursion to the root source rule (used for source priority) FOREIGN KEY (source) REFERENCES rules(id) ON DELETE CASCADE ); +CREATE INDEX rules_source ON rules (source); -- for references recounting +CREATE INDEX rules_updated ON rules (updated); -- for pruning +CREATE INDEX rules_level_firstparty ON rules (level, first_party); -- for counting rules CREATE TABLE asn ( val INTEGER PRIMARY KEY, entry INTEGER, FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE ); +CREATE INDEX asn_entry ON asn (entry); -- for explainations CREATE TABLE hostname ( val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for consistency with zone) entry INTEGER, FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE ); +CREATE INDEX hostname_entry ON hostname (entry); -- for explainations CREATE TABLE zone ( val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for easier matching) entry INTEGER, FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE ); +CREATE INDEX zone_entry ON zone (entry); -- for explainations CREATE TABLE ip4address ( val INTEGER PRIMARY KEY, entry INTEGER, FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE ); +CREATE INDEX ip4address_entry ON ip4address (entry); -- for explainations CREATE TABLE ip4network ( -- val TEXT PRIMARY KEY, @@ -43,6 +50,7 @@ CREATE TABLE ip4network ( FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE ); CREATE INDEX ip4network_minmax ON ip4network (mini, maxi); +CREATE INDEX ip4network_entry ON ip4network (entry); -- for explainations -- Store various things CREATE TABLE meta ( diff --git a/export.py b/export.py index 58b276b..886582c 100755 --- a/export.py +++ b/export.py @@ -19,12 +19,31 @@ if __name__ == '__main__': parser.add_argument( '-e', '--end-chain', action='store_true', help="TODO") + parser.add_argument( + '-x', '--explain', action='store_true', + help="TODO") + parser.add_argument( + '-r', '--rules', action='store_true', + help="TODO") + parser.add_argument( + '-c', '--count', action='store_true', + help="TODO") args = parser.parse_args() DB = database.Database() - for domain in DB.export(first_party_only=args.first_party, - end_chain_only=args.end_chain): - print(domain, file=args.output) + if args.rules: + if not args.count: + raise NotImplementedError + print(DB.count_rules(first_party_only=args.first_party)) + else: + if args.count: + raise NotImplementedError + for domain in DB.export( + first_party_only=args.first_party, + end_chain_only=args.end_chain, + explain=args.explain, + ): + print(domain, file=args.output) DB.close() diff --git a/feed_asn.py b/feed_asn.py new file mode 100755 index 0000000..098f931 --- /dev/null +++ b/feed_asn.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 + +import database +import argparse +import requests +import typing +import ipaddress +import logging +import time + +IPNetwork = typing.Union[ipaddress.IPv4Network, ipaddress.IPv6Network] + + +def get_ranges(asn: str) -> typing.Iterable[str]: + req = requests.get( + 'https://stat.ripe.net/data/as-routing-consistency/data.json', + params={'resource': asn} + ) + data = req.json() + for pref in data['data']['prefixes']: + yield pref['prefix'] + + +if __name__ == '__main__': + + log = logging.getLogger('feed_asn') + + # Parsing arguments + parser = argparse.ArgumentParser( + description="TODO") + args = parser.parse_args() + + DB = database.Database() + DBW = database.Database(write=True) + + for asn, entry in DB.list_asn(): + DB.enter_step('asn_get_ranges') + for prefix in get_ranges(asn): + parsed_prefix: IPNetwork = ipaddress.ip_network(prefix) + if parsed_prefix.version == 4: + DBW.set_ip4network( + prefix, + source=entry, + updated=int(time.time()) + ) + log.info('Added %s from %s (id=%s)', prefix, asn, entry) + elif parsed_prefix.version == 6: + log.warning('Unimplemented prefix version: %s', prefix) + else: + log.error('Unknown prefix version: %s', prefix) + + DB.close() + DBW.close() diff --git a/feed_dns.py b/feed_dns.py old mode 100644 new mode 100755 index fed322d..585a211 --- a/feed_dns.py +++ b/feed_dns.py @@ -1,23 +1,43 @@ #!/usr/bin/env python3 -import database import argparse -import sys +import database +import json import logging -import threading -import queue +import sys import typing +import multiprocessing -NUMBER_THREADS = 8 +NUMBER_THREADS = 2 +BLOCK_SIZE = 100 + +# select, confirm, write +FUNCTION_MAP: typing.Any = { + 'a': ( + database.Database.get_ip4, + database.Database.get_domain_in_zone, + database.Database.set_hostname, + ), + 'cname': ( + database.Database.get_domain, + database.Database.get_domain_in_zone, + database.Database.set_hostname, + ), + 'ptr': ( + database.Database.get_domain, + database.Database.get_ip4_in_network, + database.Database.set_ip4address, + ), +} -class Worker(threading.Thread): +class Reader(multiprocessing.Process): def __init__(self, - lines_queue: queue.Queue, - write_queue: queue.Queue, + lines_queue: multiprocessing.Queue, + write_queue: multiprocessing.Queue, index: int = 0): - super(Worker, self).__init__() - self.log = logging.getLogger(f'worker{index:03d}') + super(Reader, self).__init__() + self.log = logging.getLogger(f'rd{index:03d}') self.lines_queue = lines_queue self.write_queue = write_queue self.index = index @@ -25,45 +45,51 @@ class Worker(threading.Thread): def run(self) -> None: self.db = database.Database(write=False) self.db.log = logging.getLogger(f'db{self.index:03d}') - self.db.enter_step('wait_line') - line: str - for line in iter(self.lines_queue.get, None): - self.db.enter_step('feed_json_parse') - # split = line.split(b'"') - split = line.split('"') - try: - name = split[7] - dtype = split[11] - value = split[15] - except IndexError: - log.error("Invalid JSON: %s", line) - continue - # DB.enter_step('feed_json_assert') - # data = json.loads(line) - # assert dtype == data['type'] - # assert name == data['name'] - # assert value == data['value'] - - self.db.enter_step('feed_switch') - if dtype == 'a': - for rule in self.db.get_ip4(value): - self.db.enter_step('wait_put') - self.write_queue.put( - (database.Database.set_hostname, name, rule)) - elif dtype == 'cname': - for rule in self.db.get_domain(value): - self.db.enter_step('wait_put') - self.write_queue.put( - (database.Database.set_hostname, name, rule)) - elif dtype == 'ptr': - for rule in self.db.get_domain(value): - self.db.enter_step('wait_put') - self.write_queue.put( - (database.Database.set_ip4address, name, rule)) - self.db.enter_step('wait_line') + self.db.enter_step('line_wait') + block: typing.List[str] + try: + for block in iter(self.lines_queue.get, None): + for line in block: + dtype, updated, name, value = line + self.db.enter_step('feed_switch') + select, confirm, write = FUNCTION_MAP[dtype] + for rule in select(self.db, value): + if not any(confirm(self.db, name)): + self.db.enter_step('wait_put') + self.write_queue.put((write, name, updated)) + self.db.enter_step('line_wait') + except KeyboardInterrupt: + self.log.error('Interrupted') + + self.db.enter_step('end') + self.db.close() + + +class Writer(multiprocessing.Process): + def __init__(self, + write_queue: multiprocessing.Queue, + ): + super(Writer, self).__init__() + self.log = logging.getLogger(f'wr ') + self.write_queue = write_queue + + def run(self) -> None: + self.db = database.Database(write=True) + self.db.log = logging.getLogger(f'dbw ') + self.db.enter_step('line_wait') + block: typing.List[str] + try: + fun: typing.Callable + name: str + updated: int + for fun, name, updated in iter(self.write_queue.get, None): + self.db.enter_step('exec') + fun(self.db, name, updated) + self.db.enter_step('line_wait') + except KeyboardInterrupt: + self.log.error('Interrupted') self.db.enter_step('end') - self.write_queue.put(None) self.db.close() @@ -80,42 +106,52 @@ if __name__ == '__main__': args = parser.parse_args() DB = database.Database(write=False) # Not needed, just for timing - DB.log = logging.getLogger('dbf') - DBW = database.Database(write=True) - DBW.log = logging.getLogger('dbw') + DB.log = logging.getLogger('db ') - lines_queue: queue.Queue = queue.Queue(maxsize=NUMBER_THREADS) - write_queue: queue.Queue = queue.Queue(maxsize=NUMBER_THREADS) + lines_queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=100) + write_queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=100) - def fill_lines_queue() -> None: + DB.enter_step('proc_create') + readers: typing.List[Reader] = list() + for w in range(NUMBER_THREADS): + readers.append(Reader(lines_queue, write_queue, w)) + writer = Writer(write_queue) + + DB.enter_step('proc_start') + for reader in readers: + reader.start() + writer.start() + + try: + block: typing.List[str] = list() DB.enter_step('iowait') for line in args.input: - DB.enter_step('wait_put') - lines_queue.put(line) + DB.enter_step('block_append') + DB.enter_step('feed_json_parse') + data = json.loads(line) + line = (data['type'], + int(data['timestamp']), + data['name'], + data['value']) + block.append(line) + if len(block) >= BLOCK_SIZE: + DB.enter_step('wait_put') + lines_queue.put(block) + block = list() DB.enter_step('iowait') + DB.enter_step('wait_put') + lines_queue.put(block) DB.enter_step('end_put') for _ in range(NUMBER_THREADS): lines_queue.put(None) + write_queue.put(None) - for w in range(NUMBER_THREADS): - Worker(lines_queue, write_queue, w).start() + DB.enter_step('proc_join') + for reader in readers: + reader.join() + writer.join() + except KeyboardInterrupt: + log.error('Interrupted') - threading.Thread(target=fill_lines_queue).start() - - for _ in range(NUMBER_THREADS): - fun: typing.Callable - name: str - source: int - DBW.enter_step('wait_fun') - for fun, name, source in iter(write_queue.get, None): - DBW.enter_step('exec_fun') - fun(DBW, name, source=source) - DBW.enter_step('commit') - DBW.conn.commit() - DBW.enter_step('wait_fun') - - DBW.enter_step('end') - - DBW.close() DB.close() diff --git a/feed_rules.py b/feed_rules.py index 72888f5..715126e 100755 --- a/feed_rules.py +++ b/feed_rules.py @@ -3,6 +3,7 @@ import database import argparse import sys +import time FUNCTION_MAP = { 'zone': database.Database.set_zone, @@ -32,6 +33,10 @@ if __name__ == '__main__': fun = FUNCTION_MAP[args.type] for rule in args.input: - fun(DB, rule.strip(), is_first_party=args.first_party) + fun(DB, + rule.strip(), + is_first_party=args.first_party, + updated=int(time.time()), + ) DB.close() diff --git a/fetch_resources.sh b/fetch_resources.sh index 01121d8..e799729 100755 --- a/fetch_resources.sh +++ b/fetch_resources.sh @@ -18,7 +18,7 @@ log "Retrieving rules…" rm -f rules*/*.cache.* dl https://easylist.to/easylist/easyprivacy.txt rules_adblock/easyprivacy.cache.txt # From firebog.net Tracking & Telemetry Lists -dl https://v.firebog.net/hosts/Prigent-Ads.txt rules/prigent-ads.cache.list +# dl https://v.firebog.net/hosts/Prigent-Ads.txt rules/prigent-ads.cache.list # dl https://gitlab.com/quidsup/notrack-blocklists/raw/master/notrack-blocklist.txt rules/notrack-blocklist.cache.list # False positives: https://github.com/WaLLy3K/wally3k.github.io/issues/73 -> 69.media.tumblr.com chicdn.net dl https://raw.githubusercontent.com/StevenBlack/hosts/master/data/add.2o7Net/hosts rules_hosts/add2o7.cache.txt diff --git a/filter_subdomains.sh b/filter_subdomains.sh index 67783e8..d4b90ae 100755 --- a/filter_subdomains.sh +++ b/filter_subdomains.sh @@ -4,6 +4,12 @@ function log() { echo -e "\033[33m$@\033[0m" } +log "Pruning old data…" +./database.py --prune + +log "Recounting references…" +./database.py --references + log "Exporting lists…" ./export.py --first-party --output dist/firstparty-trackers.txt ./export.py --first-party --end-chain --output dist/firstparty-only-trackers.txt @@ -11,6 +17,8 @@ log "Exporting lists…" ./export.py --end-chain --output dist/multiparty-only-trackers.txt log "Generating hosts lists…" +./export.py --rules --count --first-party > temp/count_rules_firstparty.txt +./export.py --rules --count > temp/count_rules_multiparty.txt function generate_hosts { basename="$1" description="$2" @@ -36,15 +44,16 @@ function generate_hosts { echo "#" echo "# Generation date: $(date -Isec)" echo "# Generation software: eulaurarien $(git describe --tags)" - echo "# Number of source websites: TODO" - echo "# Number of source subdomains: TODO" + echo "# Number of source websites: $(wc -l temp/all_websites.list | cut -d' ' -f1)" + echo "# Number of source subdomains: $(wc -l temp/all_subdomains.list | cut -d' ' -f1)" + echo "# Number of source DNS records: ~2M + $(wc -l temp/all_resolved.json | cut -d' ' -f1)" echo "#" - echo "# Number of known first-party trackers: TODO" - echo "# Number of first-party subdomains: $(wc -l dist/firstparty-trackers.txt | cut -d' ' -f1)" + echo "# Known first-party trackers: $(cat temp/count_rules_firstparty.txt)" + echo "# Number of first-party hostnames: $(wc -l dist/firstparty-trackers.txt | cut -d' ' -f1)" echo "# … excluding redirected: $(wc -l dist/firstparty-only-trackers.txt | cut -d' ' -f1)" echo "#" - echo "# Number of known multi-party trackers: TODO" - echo "# Number of multi-party subdomains: $(wc -l dist/multiparty-trackers.txt | cut -d' ' -f1)" + echo "# Known multi-party trackers: $(cat temp/count_rules_multiparty.txt)" + echo "# Number of multi-party hostnames: $(wc -l dist/multiparty-trackers.txt | cut -d' ' -f1)" echo "# … excluding redirected: $(wc -l dist/multiparty-only-trackers.txt | cut -d' ' -f1)" echo sed 's|^|0.0.0.0 |' "dist/$basename.txt" diff --git a/import_rules.sh b/import_rules.sh index 358155c..33c4fbd 100755 --- a/import_rules.sh +++ b/import_rules.sh @@ -5,6 +5,7 @@ function log() { } log "Importing rules…" +BEFORE="$(date +%s)" cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone @@ -17,3 +18,5 @@ cat rules_asn/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py as ./feed_asn.py +log "Pruning old rules…" +./database.py --prune --prune-before "$BEFORE" --prune-base diff --git a/json_to_csv.py b/json_to_csv.py new file mode 100755 index 0000000..39ca1b7 --- /dev/null +++ b/json_to_csv.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 + +import argparse +import sys +import logging +import json +import csv + +if __name__ == '__main__': + + # Parsing arguments + log = logging.getLogger('json_to_csv') + parser = argparse.ArgumentParser( + description="TODO") + parser.add_argument( + # '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer, + '-i', '--input', type=argparse.FileType('r'), default=sys.stdin, + help="TODO") + parser.add_argument( + # '-i', '--output', type=argparse.FileType('wb'), default=sys.stdout.buffer, + '-o', '--output', type=argparse.FileType('w'), default=sys.stdout, + help="TODO") + args = parser.parse_args() + + writer = csv.writer(args.output) + for line in args.input: + data = json.loads(line) + try: + writer.writerow([ + data['type'][0], # First letter, will need to do something special for AAAA + data['timestamp'], + data['name'], + data['value']]) + except (KeyError, json.decoder.JSONDecodeError): + log.error('Could not parse line: %s', line) + pass diff --git a/new_workflow.sh b/new_workflow.sh index bc2a78b..e21b426 100755 --- a/new_workflow.sh +++ b/new_workflow.sh @@ -4,18 +4,16 @@ function log() { echo -e "\033[33m$@\033[0m" } -log "Preparing database…" -./database.py --expire - +./fetch_resources.sh ./import_rules.sh # TODO Fetch 'em log "Reading PTR records…" -pv ptr.json.gz | gunzip | ./feed_dns.py +pv ptr.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py log "Reading A records…" -pv a.json.gz | gunzip | ./feed_dns.py +pv a.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py log "Reading CNAME records…" -pv cname.json.gz | gunzip | ./feed_dns.py +pv cname.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py log "Pruning old data…" ./database.py --prune diff --git a/regexes.py b/regexes.py deleted file mode 100644 index 0e48441..0000000 --- a/regexes.py +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env python3 - -""" -List of regex matching first-party trackers. -""" - -# Syntax: https://docs.python.org/3/library/re.html#regular-expression-syntax - -REGEXES = [ - r'^.+\.eulerian\.net\.$', # Eulerian - r'^.+\.criteo\.com\.$', # Criteo - r'^.+\.dnsdelegation\.io\.$', # Criteo - r'^.+\.keyade\.com\.$', # Keyade - r'^.+\.omtrdc\.net\.$', # Adobe Experience Cloud - r'^.+\.bp01\.net\.$', # NP6 - r'^.+\.ati-host\.net\.$', # Xiti (AT Internet) - r'^.+\.at-o\.net\.$', # Xiti (AT Internet) - r'^.+\.edgkey\.net\.$', # Edgekey (Akamai) - r'^.+\.akaimaiedge\.net\.$', # Edgekey (Akamai) - r'^.+\.storetail\.io\.$', # Storetail (Criteo) -] diff --git a/resolve_subdomains.py b/resolve_subdomains.py index ec10c47..bc26e34 100755 --- a/resolve_subdomains.py +++ b/resolve_subdomains.py @@ -12,22 +12,15 @@ import queue import sys import threading import typing -import csv +import time import coloredlogs import dns.exception import dns.resolver -import progressbar DNS_TIMEOUT = 5.0 -NUMBER_THREADS = 512 NUMBER_TRIES = 5 -# TODO All the domains don't get treated, -# so it leaves with 4-5 subdomains not resolved - -glob = None - class Worker(threading.Thread): """ @@ -59,9 +52,9 @@ class Worker(threading.Thread): self.change_nameserver() def resolve_subdomain(self, subdomain: str) -> typing.Optional[ - typing.List[ - str - ] + typing.List[ + dns.rrset.RRset + ] ]: """ Returns the resolution chain of the subdomain to an A record, @@ -93,18 +86,7 @@ class Worker(threading.Thread): except dns.name.EmptyLabel: self.log.warning("Empty label for %s", subdomain) return None - resolved = list() - last = len(query.response.answer) - 1 - for a, answer in enumerate(query.response.answer): - if answer.rdtype == dns.rdatatype.CNAME: - assert a < last - resolved.append(answer.items[0].to_text()[:-1]) - elif answer.rdtype == dns.rdatatype.A: - assert a == last - resolved.append(answer.items[0].address) - else: - assert False - return resolved + return query.response.answer def run(self) -> None: self.log.info("Started") @@ -124,7 +106,6 @@ class Worker(threading.Thread): self.log.error("Gave up on %s", subdomain) resolved = [] - resolved.insert(0, subdomain) assert isinstance(resolved, list) self.orchestrator.results_queue.put(resolved) @@ -150,15 +131,17 @@ class Orchestrator(): def __init__(self, subdomains: typing.Iterable[str], nameservers: typing.List[str] = None, + nb_workers: int = 1, ): self.log = logging.getLogger('orchestrator') self.subdomains = subdomains + self.nb_workers = nb_workers # Use interal resolver by default self.nameservers = nameservers or dns.resolver.Resolver().nameservers self.subdomains_queue: queue.Queue = queue.Queue( - maxsize=NUMBER_THREADS) + maxsize=self.nb_workers) self.results_queue: queue.Queue = queue.Queue() self.nameservers_queue: queue.Queue = queue.Queue() @@ -179,16 +162,31 @@ class Orchestrator(): self.log.info("Finished reading subdomains") # Send sentinel to each worker # sentinel = None ~= EOF - for _ in range(NUMBER_THREADS): + for _ in range(self.nb_workers): self.subdomains_queue.put(None) - def run(self) -> typing.Iterable[typing.List[str]]: + @staticmethod + def format_rrset(rrset: dns.rrset.RRset) -> typing.Iterable[str]: + if rrset.rdtype == dns.rdatatype.CNAME: + dtype = 'c' + elif rrset.rdtype == dns.rdatatype.A: + dtype = 'a' + else: + raise NotImplementedError + name = rrset.name.to_text()[:-1] + for item in rrset.items: + value = item.to_text() + if rrset.rdtype == dns.rdatatype.CNAME: + value = value[:-1] + yield f'{dtype},{int(time.time())},{name},{value}\n' + + def run(self) -> typing.Iterable[str]: """ Yield the results. """ # Create workers self.log.info("Creating workers") - for i in range(NUMBER_THREADS): + for i in range(self.nb_workers): Worker(self, i).start() fill_thread = threading.Thread(target=self.fill_subdomain_queue) @@ -196,10 +194,11 @@ class Orchestrator(): # Wait for one sentinel per worker # In the meantime output results - for _ in range(NUMBER_THREADS): - result: typing.List[str] - for result in iter(self.results_queue.get, None): - yield result + for _ in range(self.nb_workers): + resolved: typing.List[dns.rrset.RRset] + for resolved in iter(self.results_queue.get, None): + for rrset in resolved: + yield from self.format_rrset(rrset) self.log.info("Waiting for reader thread") fill_thread.join() @@ -214,11 +213,9 @@ def main() -> None: the last CNAME resolved and the IP adress it resolves to. Takes as an input a filename (or nothing, for stdin), and as an output a filename (or nothing, for stdout). - The input must be a subdomain per line, the output is a comma-sep - file with the columns source CNAME and A. + The input must be a subdomain per line, the output is a TODO Use the file `nameservers` as the list of nameservers to use, or else it will use the system defaults. - Also shows a nice progressbar. """ # Initialization @@ -236,28 +233,14 @@ def main() -> None: parser.add_argument( '-o', '--output', type=argparse.FileType('w'), default=sys.stdout, help="Outptut file with DNS chains") - # parser.add_argument( - # '-n', '--nameserver', type=argparse.FileType('r'), - # default='nameservers', help="File with one nameserver per line") - # parser.add_argument( - # '-j', '--workers', type=int, default=512, - # help="Number of threads to use") + parser.add_argument( + '-n', '--nameservers', default='nameservers', + help="File with one nameserver per line") + parser.add_argument( + '-j', '--workers', type=int, default=512, + help="Number of threads to use") args = parser.parse_args() - # Progress bar - widgets = [ - progressbar.Percentage(), - ' ', progressbar.SimpleProgress(), - ' ', progressbar.Bar(), - ' ', progressbar.Timer(), - ' ', progressbar.AdaptiveTransferSpeed(unit='req'), - ' ', progressbar.AdaptiveETA(), - ] - progress = progressbar.ProgressBar(widgets=widgets) - if args.input.seekable(): - progress.max_value = len(args.input.readlines()) - args.input.seek(0) - # Cleaning input iterator = iter(args.input) iterator = map(str.strip, iterator) @@ -265,19 +248,16 @@ def main() -> None: # Reading nameservers servers: typing.List[str] = list() - if os.path.isfile('nameservers'): - servers = open('nameservers').readlines() + if os.path.isfile(args.nameservers): + servers = open(args.nameservers).readlines() servers = list(filter(None, map(str.strip, servers))) - writer = csv.writer(args.output) - - progress.start() - global glob - glob = Orchestrator(iterator, servers) - for resolved in glob.run(): - progress.update(progress.value + 1) - writer.writerow(resolved) - progress.finish() + for resolved in Orchestrator( + iterator, + servers, + nb_workers=args.workers + ).run(): + args.output.write(resolved) if __name__ == '__main__': diff --git a/resolve_subdomains.sh b/resolve_subdomains.sh index ed7af79..e37ddeb 100755 --- a/resolve_subdomains.sh +++ b/resolve_subdomains.sh @@ -4,11 +4,9 @@ function log() { echo -e "\033[33m$@\033[0m" } -# Resolve the CNAME chain of all the known subdomains for later analysis -log "Compiling subdomain lists..." -pv subdomains/*.list | sort -u > temp/all_subdomains.list +log "Compiling locally known subdomain…" # Sort by last character to utilize the DNS server caching mechanism -pv temp/all_subdomains.list | rev | sort | rev > temp/all_subdomains_reversort.list -./resolve_subdomains.py --input temp/all_subdomains_reversort.list --output temp/all_resolved.csv -sort -u temp/all_resolved.csv > temp/all_resolved_sorted.csv +pv subdomains/*.list | sed 's/\r$//' | rev | sort -u | rev > temp/all_subdomains.list +log "Resolving locally known subdomain…" +pv temp/all_subdomains.list | ./resolve_subdomains.py --output temp/all_resolved.csv