From 4d966371b2d696222c9bc625c31984baad3eefd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20=E2=80=9CFrogeye=E2=80=9D=20Preud=27homme?= Date: Sun, 15 Dec 2019 15:56:26 +0100 Subject: [PATCH] Workflow: SQL -> Tree Welp. All that for this. --- .gitignore | 3 +- database.py | 728 +++++++++++++++----------------------------- database_schema.sql | 59 ---- export.py | 2 - feed_asn.py | 18 +- feed_dns.py | 167 ++-------- feed_rules.py | 6 +- import_rules.sh | 12 +- 8 files changed, 296 insertions(+), 699 deletions(-) mode change 100755 => 100644 database.py delete mode 100644 database_schema.sql diff --git a/.gitignore b/.gitignore index 188051c..c72635d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ *.log -*.db -*.db-journal +*.p nameservers nameservers.head diff --git a/database.py b/database.py old mode 100755 new mode 100644 index 19fbe97..2d970e3 --- a/database.py +++ b/database.py @@ -4,111 +4,59 @@ Utility functions to interact with the database. """ -import sqlite3 import typing import time -import os import logging -import argparse import coloredlogs -import ipaddress -import math +import pickle +import enum coloredlogs.install( level='DEBUG', fmt='%(asctime)s %(name)s %(levelname)s %(message)s' ) -DbValue = typing.Union[None, int, float, str, bytes] +PathType = enum.Enum('PathType', 'Rule Hostname Zone Asn Ip4 Ip6') +RulePath = typing.Union[None] +Asn = int +DomainPath = typing.List[str] +Ip4Path = typing.List[int] +Ip6Path = typing.List[int] +Path = typing.Union[RulePath, DomainPath, Asn, Ip4Path, Ip6Path] +TypedPath = typing.Tuple[PathType, Path] +Timestamp = int +Level = int +Match = typing.Tuple[Timestamp, TypedPath, Level] + +DebugPath = (PathType.Rule, None) -class Database(): - VERSION = 5 - PATH = "blocking.db" +class DomainTreeNode(): + def __init__(self) -> None: + self.children: typing.Dict[str, DomainTreeNode] = dict() + self.match_zone: typing.Optional[Match] = None + self.match_hostname: typing.Optional[Match] = None - def open(self) -> None: - mode = 'rwc' if self.write else 'ro' - uri = f'file:{self.PATH}?mode={mode}' - self.conn = sqlite3.connect(uri, uri=True) - cursor = self.conn.cursor() - cursor.execute("PRAGMA foreign_keys = ON") - self.conn.create_function("unpack_asn", 1, - self.unpack_asn, - deterministic=True) - self.conn.create_function("unpack_ip4address", 1, - self.unpack_ip4address, - deterministic=True) - self.conn.create_function("unpack_ip4network", 2, - self.unpack_ip4network, - deterministic=True) - self.conn.create_function("unpack_domain", 1, - lambda s: s[:-1][::-1], - deterministic=True) - self.conn.create_function("format_zone", 1, - lambda s: '*' + s[::-1], - deterministic=True) - def get_meta(self, key: str) -> typing.Optional[int]: - cursor = self.conn.cursor() - try: - cursor.execute("SELECT value FROM meta WHERE key=?", (key,)) - except sqlite3.OperationalError: - return None - for ver, in cursor: - return ver - return None +class IpTreeNode(): + def __init__(self) -> None: + self.children: typing.List[typing.Optional[IpTreeNode]] = [None, None] + self.match: typing.Optional[Match] = None - def set_meta(self, key: str, val: int) -> None: - cursor = self.conn.cursor() - cursor.execute("INSERT INTO meta VALUES (?, ?) " - "ON CONFLICT (key) DO " - "UPDATE set value=?", - (key, val, val)) - def close(self) -> None: - self.enter_step('close_commit') - self.conn.commit() - self.enter_step('close') - self.conn.close() - self.profile() - - def initialize(self) -> None: - self.close() - self.enter_step('initialize') - if not self.write: - self.log.error("Cannot initialize in read-only mode.") - raise - os.unlink(self.PATH) - self.open() - self.log.info("Creating database version %d.", self.VERSION) - cursor = self.conn.cursor() - with open("database_schema.sql", 'r') as db_schema: - cursor.executescript(db_schema.read()) - self.set_meta('version', self.VERSION) - self.conn.commit() - - def __init__(self, write: bool = False) -> None: - self.log = logging.getLogger('db') +class Profiler(): + def __init__(self) -> None: + self.log = logging.getLogger('profiler') self.time_last = time.perf_counter() self.time_step = 'init' self.time_dict: typing.Dict[str, float] = dict() self.step_dict: typing.Dict[str, int] = dict() - self.write = write - - self.open() - version = self.get_meta('version') - if version != self.VERSION: - if version is not None: - self.log.warning( - "Outdated database version: %d found, will be rebuilt.", - version) - self.initialize() def enter_step(self, name: str) -> None: now = time.perf_counter() try: self.time_dict[self.time_step] += now - self.time_last - self.step_dict[self.time_step] += 1 + self.step_dict[self.time_step] += int(name != self.time_step) except KeyError: self.time_dict[self.time_step] = now - self.time_last self.step_dict[self.time_step] = 1 @@ -125,13 +73,58 @@ class Database(): self.log.debug(f"{'total':<20}: " f"{total:9.2f} s ({1:7.2%})") - @staticmethod - def pack_hostname(hostname: str) -> str: - return hostname[::-1] + '.' + +class Database(Profiler): + VERSION = 8 + PATH = "blocking.p" + + def initialize(self) -> None: + self.log.warning( + "Creating database version: %d ", + Database.VERSION) + self.domtree = DomainTreeNode() + self.asns: typing.Set[Asn] = set() + self.ip4tree = IpTreeNode() + + def load(self) -> None: + self.enter_step('load') + try: + with open(self.PATH, 'rb') as db_fdsec: + version, data = pickle.load(db_fdsec) + if version == Database.VERSION: + self.domtree, self.asns, self.ip4tree = data + return + self.log.warning( + "Outdated database version found: %d, " + "will be rebuilt.", + version) + except (TypeError, AttributeError, EOFError): + self.log.error( + "Corrupt database found, " + "will be rebuilt.") + except FileNotFoundError: + pass + self.initialize() + + def save(self) -> None: + self.enter_step('save') + with open(self.PATH, 'wb') as db_fdsec: + data = self.domtree, self.asns, self.ip4tree + pickle.dump((self.VERSION, data), db_fdsec) + self.profile() + + def __init__(self) -> None: + Profiler.__init__(self) + self.log = logging.getLogger('db') + self.load() @staticmethod - def pack_zone(zone: str) -> str: - return Database.pack_hostname(zone) + def pack_domain(domain: str) -> DomainPath: + return domain.split('.')[::-1] + + @staticmethod + def unpack_domain(domain: DomainPath) -> str: + return '.'.join(domain[::-1]) @staticmethod def pack_asn(asn: str) -> int: @@ -145,431 +138,208 @@ class Database(): return f'AS{asn}' @staticmethod - def pack_ip4address(address: str) -> int: - total = 0 - for i, octet in enumerate(address.split('.')): - total += int(octet) << (3-i)*8 - if total > 0xFFFFFFFF: - raise ValueError - return total - # return '{:02x}{:02x}{:02x}{:02x}'.format( - # *[int(c) for c in address.split('.')]) - # return base64.b16encode(packed).decode() - # return '{:08b}{:08b}{:08b}{:08b}'.format( - # *[int(c) for c in address.split('.')]) - # carg = ctypes.c_wchar_p(address) - # ret = ACCEL.ip4_flat(carg, self.accel_ip4_buf) - # if ret != 0: - # raise ValueError - # return self.accel_ip4_buf.value - # packed = ipaddress.ip_address(address).packed - # return packed + def pack_ip4address(address: str) -> Ip4Path: + addr: Ip4Path = [0] * 32 + octets = [int(octet) for octet in address.split('.')] + for b in range(32): + if (octets[b//8] >> b % 8) & 0b1: + addr[b] = 1 + return addr @staticmethod - def unpack_ip4address(address: int) -> str: - return '.'.join(str((address >> (i * 8)) & 0xFF) - for i in reversed(range(4))) + def unpack_ip4address(address: Ip4Path) -> str: + octets = [0] * 4 + for b, bit in enumerate(address): + octets[b//8] = (octets[b//8] << 1) + bit + return '.'.join(map(str, octets)) @staticmethod - def pack_ip4network(network: str) -> typing.Tuple[int, int]: - # def pack_ip4network(network: str) -> str: - net = ipaddress.ip_network(network) - mini = Database.pack_ip4address(net.network_address.exploded) - maxi = Database.pack_ip4address(net.broadcast_address.exploded) - # mini = net.network_address.packed - # maxi = net.broadcast_address.packed - return mini, maxi - # return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen] + def pack_ip4network(network: str) -> Ip4Path: + address, prefixlen_str = network.split('/') + prefixlen = int(prefixlen_str) + return Database.pack_ip4address(address)[:prefixlen] @staticmethod - def unpack_ip4network(mini: int, maxi: int) -> str: - addr = Database.unpack_ip4address(mini) - prefixlen = 32-int(math.log2(maxi-mini+1)) + def unpack_ip4network(network: Ip4Path) -> str: + address = network.copy() + prefixlen = len(network) + for _ in range(32-prefixlen): + address.append(0) + addr = Database.unpack_ip4address(address) return f'{addr}/{prefixlen}' def update_references(self) -> None: - self.enter_step('update_refs') - cursor = self.conn.cursor() - cursor.execute('UPDATE rules AS r SET refs=' - '(SELECT count(*) FROM rules ' - 'WHERE source=r.id)') + raise NotImplementedError def prune(self, before: int, base_only: bool = False) -> None: - self.enter_step('prune') - cursor = self.conn.cursor() - cmd = 'DELETE FROM rules WHERE updated str: - # Format current - string = '???' - cursor = self.conn.cursor() - cursor.execute( - 'SELECT unpack_asn(val) FROM asn WHERE entry=:entry ' - 'UNION ' - 'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry ' - 'UNION ' - 'SELECT format_zone(val) FROM zone WHERE entry=:entry ' - 'UNION ' - 'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry ' - 'UNION ' - 'SELECT unpack_ip4network(mini, maxi) ' - 'FROM ip4network WHERE entry=:entry ', - {"entry": entry} - ) - for val, in cursor: # only one - string = str(val) - string += f' #{entry}' - - # Add source if any - cursor.execute('SELECT source FROM rules WHERE id=?', (entry,)) - for source, in cursor: - if source: - string += f' ← {self.explain(source)}' - return string + raise NotImplementedError def export(self, first_party_only: bool = False, end_chain_only: bool = False, explain: bool = False, + _dic: DomainTreeNode = None, + _par: DomainPath = None, ) -> typing.Iterable[str]: - selection = 'entry' if explain else 'unpack_domain(val)' - command = f'SELECT {selection} FROM rules ' \ - 'INNER JOIN hostname ON rules.id = hostname.entry' - restrictions: typing.List[str] = list() - if first_party_only: - restrictions.append('rules.first_party = 1') - if end_chain_only: - restrictions.append('rules.refs = 0') - if restrictions: - command += ' WHERE ' + ' AND '.join(restrictions) - if not explain: - command += ' ORDER BY unpack_domain(val) ASC' - cursor = self.conn.cursor() - cursor.execute(command) - for val, in cursor: - if explain: - yield self.explain(val) - else: - yield val + if first_party_only or end_chain_only or explain: + raise NotImplementedError + _dic = _dic or self.domtree + _par = _par or list() + if _dic.match_hostname: + yield self.unpack_domain(_par) + for part in _dic.children: + dic = _dic.children[part] + yield from self.export(_dic=dic, + _par=_par + [part]) def count_rules(self, first_party_only: bool = False, ) -> str: - counts: typing.List[str] = list() - cursor = self.conn.cursor() - for table in ['asn', 'ip4network', 'ip4address', 'zone', 'hostname']: - command = f'SELECT count(*) FROM rules ' \ - f'INNER JOIN {table} ON rules.id = {table}.entry ' \ - 'WHERE rules.level = 0' - if first_party_only: - command += ' AND first_party=1' - cursor.execute(command) - count, = cursor.fetchone() - if count > 0: - counts.append(f'{table}: {count}') + raise NotImplementedError - return ', '.join(counts) - - def get_domain(self, domain: str) -> typing.Iterable[int]: - self.enter_step('get_domain_prepare') - domain_prep = self.pack_hostname(domain) - cursor = self.conn.cursor() - self.enter_step('get_domain_select') - cursor.execute( - 'SELECT null, entry FROM hostname ' - 'WHERE val=:d ' - 'UNION ' - 'SELECT * FROM (' - 'SELECT val, entry FROM zone ' - # 'WHERE val>=:d ' - # 'ORDER BY val ASC LIMIT 1' - 'WHERE val<=:d ' - 'AND instr(:d, val) = 1' - ')', - {'d': domain_prep} - ) - for val, entry in cursor: - # print(293, val, entry) - self.enter_step('get_domain_confirm') - if not (val is None or domain_prep.startswith(val)): - # print(297) - continue + def get_domain(self, domain_str: str) -> typing.Iterable[TypedPath]: + self.enter_step('get_domain_pack') + domain = self.pack_domain(domain_str) + self.enter_step('get_domain_brws') + dic = self.domtree + depth = 0 + for part in domain: + if dic.match_zone: + self.enter_step('get_domain_yield') + yield (PathType.Zone, domain[:depth]) + self.enter_step('get_domain_brws') + if part not in dic.children: + return + dic = dic.children[part] + depth += 1 + if dic.match_zone: self.enter_step('get_domain_yield') - yield entry + yield (PathType.Zone, domain) + if dic.match_hostname: + self.enter_step('get_domain_yield') + yield (PathType.Hostname, domain) - def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]: - self.enter_step('get_domainiz_prepare') - domain_prep = self.pack_hostname(domain) - cursor = self.conn.cursor() - self.enter_step('get_domainiz_select') - cursor.execute( - 'SELECT val, entry FROM zone ' - 'WHERE val<=:d ' - 'ORDER BY val DESC LIMIT 1', - {'d': domain_prep} - ) - for val, entry in cursor: - self.enter_step('get_domainiz_confirm') - if not (val is None or domain_prep.startswith(val)): - continue - self.enter_step('get_domainiz_yield') - yield entry - - def get_ip4(self, address: str) -> typing.Iterable[int]: - self.enter_step('get_ip4_prepare') - try: - address_prep = self.pack_ip4address(address) - except (ValueError, IndexError): - self.log.error("Invalid ip4address: %s", address) - return - cursor = self.conn.cursor() - self.enter_step('get_ip4_select') - cursor.execute( - 'SELECT entry FROM ip4address ' - # 'SELECT null, entry FROM ip4address ' - 'WHERE val=:a ' - 'UNION ' - # 'SELECT * FROM (' - # 'SELECT val, entry FROM ip4network ' - # 'WHERE val<=:a ' - # 'AND instr(:a, val) > 0 ' - # 'ORDER BY val DESC' - # ')' - 'SELECT entry FROM ip4network ' - 'WHERE :a BETWEEN mini AND maxi ', - {'a': address_prep} - ) - for entry, in cursor: - # self.enter_step('get_ip4_confirm') - # if not (val is None or val.startswith(address_prep)): - # # PERF startswith but from the end - # continue + def get_ip4(self, ip4_str: str) -> typing.Iterable[TypedPath]: + self.enter_step('get_ip4_pack') + ip4 = self.pack_ip4address(ip4_str) + self.enter_step('get_ip4_brws') + dic = self.ip4tree + depth = 0 + for part in ip4: + if dic.match: + self.enter_step('get_ip4_yield') + yield (PathType.Ip4, ip4[:depth]) + self.enter_step('get_ip4_brws') + next_dic = dic.children[part] + if next_dic is None: + return + dic = next_dic + depth += 1 + if dic.match: self.enter_step('get_ip4_yield') - yield entry + yield (PathType.Ip4, ip4) - def get_ip4_in_network(self, address: str) -> typing.Iterable[int]: - self.enter_step('get_ip4in_prepare') - try: - address_prep = self.pack_ip4address(address) - except (ValueError, IndexError): - self.log.error("Invalid ip4address: %s", address) - return - cursor = self.conn.cursor() - self.enter_step('get_ip4in_select') - cursor.execute( - 'SELECT entry FROM ip4network ' - 'WHERE :a BETWEEN mini AND maxi ', - {'a': address_prep} - ) - for entry, in cursor: - self.enter_step('get_ip4in_yield') - yield entry + def list_asn(self) -> typing.Iterable[TypedPath]: + for asn in self.asns: + yield (PathType.Asn, asn) - def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]: - cursor = self.conn.cursor() - self.enter_step('list_asn_select') - cursor.execute('SELECT val, entry FROM asn') - for val, entry in cursor: - yield f'AS{val}', entry - - def _set_generic(self, - table: str, - select_query: str, - insert_query: str, - prep: typing.Dict[str, DbValue], + def set_hostname(self, + hostname_str: str, updated: int, - is_first_party: bool = False, - source: int = None, - ) -> None: - # Since this isn't the bulk of the processing, - # here abstraction > performaces + is_first_party: bool = None, + source: TypedPath = None) -> None: + self.enter_step('set_hostname_pack') + if is_first_party or source: + raise NotImplementedError + self.enter_step('set_hostname_brws') + hostname = self.pack_domain(hostname_str) + dic = self.domtree + for part in hostname: + if dic.match_zone: + # Refuse to add hostname whose zone is already matching + return + if part not in dic.children: + dic.children[part] = DomainTreeNode() + dic = dic.children[part] + dic.match_hostname = (updated, DebugPath, 0) - # Fields based on the source - self.enter_step(f'set_{table}_prepare') - cursor = self.conn.cursor() - if source is None: - first_party = int(is_first_party) - level = 0 - else: - self.enter_step(f'set_{table}_source') - cursor.execute( - 'SELECT first_party, level FROM rules ' - 'WHERE id=?', - (source,) - ) - first_party, level = cursor.fetchone() - level += 1 + def set_zone(self, + zone_str: str, + updated: int, + is_first_party: bool = None, + source: TypedPath = None) -> None: + self.enter_step('set_zone_pack') + if is_first_party or source: + raise NotImplementedError + zone = self.pack_domain(zone_str) + self.enter_step('set_zone_brws') + dic = self.domtree + for part in zone: + if dic.match_zone: + # Refuse to add zone whose parent zone is already matching + return + if part not in dic.children: + dic.children[part] = DomainTreeNode() + dic = dic.children[part] + dic.match_zone = (updated, DebugPath, 0) - self.enter_step(f'set_{table}_select') - cursor.execute(select_query, prep) + def set_asn(self, + asn_str: str, + updated: int, + is_first_party: bool = None, + source: TypedPath = None) -> None: + self.enter_step('set_asn_pack') + if is_first_party or source: + # TODO updated + raise NotImplementedError + asn = self.pack_asn(asn_str) + self.enter_step('set_asn_brws') + self.asns.add(asn) - rules_prep: typing.Dict[str, DbValue] = { - "source": source, - "updated": updated, - "first_party": first_party, - "level": level, - } + def set_ip4address(self, + ip4address_str: str, + updated: int, + is_first_party: bool = None, + source: TypedPath = None) -> None: + self.enter_step('set_ip4add_pack') + if is_first_party or source: + raise NotImplementedError + self.enter_step('set_ip4add_brws') + ip4address = self.pack_ip4address(ip4address_str) + dic = self.ip4tree + for part in ip4address: + if dic.match: + # Refuse to add ip4address whose network is already matching + return + next_dic = dic.children[part] + if next_dic is None: + next_dic = IpTreeNode() + dic.children[part] = next_dic + dic = next_dic + dic.match = (updated, DebugPath, 0) - # If the entry already exists - for entry, in cursor: # only one - self.enter_step(f'set_{table}_update') - rules_prep['entry'] = entry - cursor.execute( - 'UPDATE rules SET ' - 'source=:source, updated=:updated, ' - 'first_party=:first_party, level=:level ' - 'WHERE id=:entry AND (updated<:updated OR ' - 'first_party<:first_party OR level<:level)', - rules_prep - ) - # Only update if any of the following: - # - the entry is outdataed - # - the entry was not a first_party but this is - # - this is closer to the original rule - return - - # If it does not exist - - self.enter_step(f'set_{table}_insert') - cursor.execute( - 'INSERT INTO rules ' - '(source, updated, first_party, level) ' - 'VALUES (:source, :updated, :first_party, :level) ', - rules_prep - ) - cursor.execute('SELECT id FROM rules WHERE rowid=?', - (cursor.lastrowid,)) - for entry, in cursor: # only one - prep['entry'] = entry - cursor.execute(insert_query, prep) - return - assert False - - def set_hostname(self, hostname: str, - *args: typing.Any, **kwargs: typing.Any) -> None: - self.enter_step('set_hostname_prepare') - prep: typing.Dict[str, DbValue] = { - 'val': self.pack_hostname(hostname), - } - self._set_generic( - 'hostname', - 'SELECT entry FROM hostname WHERE val=:val', - 'INSERT INTO hostname (val, entry) ' - 'VALUES (:val, :entry)', - prep, - *args, **kwargs - ) - - def set_asn(self, asn: str, - *args: typing.Any, **kwargs: typing.Any) -> None: - self.enter_step('set_asn_prepare') - try: - asn_prep = self.pack_asn(asn) - except ValueError: - self.log.error("Invalid asn: %s", asn) - return - prep: typing.Dict[str, DbValue] = { - 'val': asn_prep, - } - self._set_generic( - 'asn', - 'SELECT entry FROM asn WHERE val=:val', - 'INSERT INTO asn (val, entry) ' - 'VALUES (:val, :entry)', - prep, - *args, **kwargs - ) - - def set_ip4address(self, ip4address: str, - *args: typing.Any, **kwargs: typing.Any) -> None: - self.enter_step('set_ip4add_prepare') - try: - ip4address_prep = self.pack_ip4address(ip4address) - except (ValueError, IndexError): - self.log.error("Invalid ip4address: %s", ip4address) - return - prep: typing.Dict[str, DbValue] = { - 'val': ip4address_prep, - } - self._set_generic( - 'ip4add', - 'SELECT entry FROM ip4address WHERE val=:val', - 'INSERT INTO ip4address (val, entry) ' - 'VALUES (:val, :entry)', - prep, - *args, **kwargs - ) - - def set_zone(self, zone: str, - *args: typing.Any, **kwargs: typing.Any) -> None: - self.enter_step('set_zone_prepare') - prep: typing.Dict[str, DbValue] = { - 'val': self.pack_zone(zone), - } - self._set_generic( - 'zone', - 'SELECT entry FROM zone WHERE val=:val', - 'INSERT INTO zone (val, entry) ' - 'VALUES (:val, :entry)', - prep, - *args, **kwargs - ) - - def set_ip4network(self, ip4network: str, - *args: typing.Any, **kwargs: typing.Any) -> None: - self.enter_step('set_ip4net_prepare') - try: - ip4network_prep = self.pack_ip4network(ip4network) - except (ValueError, IndexError): - self.log.error("Invalid ip4network: %s", ip4network) - return - prep: typing.Dict[str, DbValue] = { - 'mini': ip4network_prep[0], - 'maxi': ip4network_prep[1], - } - self._set_generic( - 'ip4net', - 'SELECT entry FROM ip4network WHERE mini=:mini AND maxi=:maxi', - 'INSERT INTO ip4network (mini, maxi, entry) ' - 'VALUES (:mini, :maxi, :entry)', - prep, - *args, **kwargs - ) - - -if __name__ == '__main__': - - # Parsing arguments - parser = argparse.ArgumentParser( - description="Database operations") - parser.add_argument( - '-i', '--initialize', action='store_true', - help="Reconstruct the whole database") - parser.add_argument( - '-p', '--prune', action='store_true', - help="Remove old entries from database") - parser.add_argument( - '-b', '--prune-base', action='store_true', - help="TODO") - parser.add_argument( - '-s', '--prune-before', type=int, - default=(int(time.time()) - 60*60*24*31*6), - help="TODO") - parser.add_argument( - '-r', '--references', action='store_true', - help="Update the reference count") - args = parser.parse_args() - - DB = Database(write=True) - - if args.initialize: - DB.initialize() - if args.prune: - DB.prune(before=args.prune_before, base_only=args.prune_base) - if args.references: - DB.update_references() - - DB.close() + def set_ip4network(self, + ip4network_str: str, + updated: int, + is_first_party: bool = None, + source: TypedPath = None) -> None: + self.enter_step('set_ip4net_pack') + if is_first_party or source: + raise NotImplementedError + self.enter_step('set_ip4net_brws') + ip4network = self.pack_ip4network(ip4network_str) + dic = self.ip4tree + for part in ip4network: + if dic.match: + # Refuse to add ip4network whose parent network + # is already matching + return + next_dic = dic.children[part] + if next_dic is None: + next_dic = IpTreeNode() + dic.children[part] = next_dic + dic = next_dic + dic.match = (updated, DebugPath, 0) diff --git a/database_schema.sql b/database_schema.sql deleted file mode 100644 index 3116a09..0000000 --- a/database_schema.sql +++ /dev/null @@ -1,59 +0,0 @@ --- Remember to increment DB_VERSION --- in database.py on changes to this file - -CREATE TABLE rules ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - source INTEGER, -- The rule this one is based on - updated INTEGER, -- If the row was updated during last data import (0: No, 1: Yes) - first_party INTEGER, -- 1: this blocks a first party for sure, 0: maybe - refs INTEGER, -- Number of entries issued from this one - level INTEGER, -- Level of recursion to the root source rule (used for source priority) - FOREIGN KEY (source) REFERENCES rules(id) ON DELETE CASCADE -); -CREATE INDEX rules_source ON rules (source); -- for references recounting -CREATE INDEX rules_updated ON rules (updated); -- for pruning -CREATE INDEX rules_level_firstparty ON rules (level, first_party); -- for counting rules - -CREATE TABLE asn ( - val INTEGER PRIMARY KEY, - entry INTEGER, - FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE -); -CREATE INDEX asn_entry ON asn (entry); -- for explainations - -CREATE TABLE hostname ( - val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for consistency with zone) - entry INTEGER, - FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE -); -CREATE INDEX hostname_entry ON hostname (entry); -- for explainations - -CREATE TABLE zone ( - val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for easier matching) - entry INTEGER, - FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE -); -CREATE INDEX zone_entry ON zone (entry); -- for explainations - -CREATE TABLE ip4address ( - val INTEGER PRIMARY KEY, - entry INTEGER, - FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE -); -CREATE INDEX ip4address_entry ON ip4address (entry); -- for explainations - -CREATE TABLE ip4network ( - -- val TEXT PRIMARY KEY, - mini INTEGER, - maxi INTEGER, - entry INTEGER, - FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE -); -CREATE INDEX ip4network_minmax ON ip4network (mini, maxi); -CREATE INDEX ip4network_entry ON ip4network (entry); -- for explainations - --- Store various things -CREATE TABLE meta ( - key TEXT PRIMARY KEY, - value integer -); diff --git a/export.py b/export.py index 886582c..bca3281 100755 --- a/export.py +++ b/export.py @@ -45,5 +45,3 @@ if __name__ == '__main__': explain=args.explain, ): print(domain, file=args.output) - - DB.close() diff --git a/feed_asn.py b/feed_asn.py index 098f931..ead63fe 100755 --- a/feed_asn.py +++ b/feed_asn.py @@ -31,23 +31,25 @@ if __name__ == '__main__': args = parser.parse_args() DB = database.Database() - DBW = database.Database(write=True) - for asn, entry in DB.list_asn(): + for path in DB.list_asn(): + ptype, asn = path + assert ptype == database.PathType.Asn + assert isinstance(asn, int) + asn_str = database.Database.unpack_asn(asn) DB.enter_step('asn_get_ranges') - for prefix in get_ranges(asn): + for prefix in get_ranges(asn_str): parsed_prefix: IPNetwork = ipaddress.ip_network(prefix) if parsed_prefix.version == 4: - DBW.set_ip4network( + DB.set_ip4network( prefix, - source=entry, + # source=path, updated=int(time.time()) ) - log.info('Added %s from %s (id=%s)', prefix, asn, entry) + log.info('Added %s from %s (source=%s)', prefix, asn, path) elif parsed_prefix.version == 6: log.warning('Unimplemented prefix version: %s', prefix) else: log.error('Unknown prefix version: %s', prefix) - DB.close() - DBW.close() + DB.save() diff --git a/feed_dns.py b/feed_dns.py index 4b01814..3acad9a 100755 --- a/feed_dns.py +++ b/feed_dns.py @@ -6,126 +6,52 @@ import json import logging import sys import typing -import multiprocessing import enum RecordType = enum.Enum('RecordType', 'A AAAA CNAME PTR') Record = typing.Tuple[RecordType, int, str, str] -# select, confirm, write +# select, write FUNCTION_MAP: typing.Any = { RecordType.A: ( database.Database.get_ip4, - database.Database.get_domain_in_zone, database.Database.set_hostname, ), RecordType.CNAME: ( database.Database.get_domain, - database.Database.get_domain_in_zone, database.Database.set_hostname, ), RecordType.PTR: ( database.Database.get_domain, - database.Database.get_ip4_in_network, database.Database.set_ip4address, ), } -class Reader(multiprocessing.Process): - def __init__(self, - recs_queue: multiprocessing.Queue, - write_queue: multiprocessing.Queue, - index: int = 0): - super(Reader, self).__init__() - self.log = logging.getLogger(f'rd{index:03d}') - self.recs_queue = recs_queue - self.write_queue = write_queue - self.index = index - - def run(self) -> None: - self.db = database.Database(write=False) - self.db.log = logging.getLogger(f'db{self.index:03d}') - self.db.enter_step('line_wait') - block: typing.List[str] - try: - for block in iter(self.recs_queue.get, None): - record: Record - for record in block: - # print(55, record) - dtype, updated, name, value = record - self.db.enter_step('feed_switch') - select, confirm, write = FUNCTION_MAP[dtype] - for rule in select(self.db, value): - # print(60, rule, list(confirm(self.db, name))) - if not any(confirm(self.db, name)): - # print(62, write, name, updated, rule) - self.db.enter_step('wait_put') - self.write_queue.put((write, name, updated, rule)) - self.db.enter_step('line_wait') - except KeyboardInterrupt: - self.log.error('Interrupted') - - self.db.enter_step('end') - self.db.close() - - -class Writer(multiprocessing.Process): - def __init__(self, - write_queue: multiprocessing.Queue, - ): - super(Writer, self).__init__() - self.log = logging.getLogger(f'wr ') - self.write_queue = write_queue - - def run(self) -> None: - self.db = database.Database(write=True) - self.db.log = logging.getLogger(f'dbw ') - self.db.enter_step('line_wait') - block: typing.List[str] - try: - fun: typing.Callable - name: str - updated: int - source: int - for fun, name, updated, source in iter(self.write_queue.get, None): - self.db.enter_step('exec') - fun(self.db, name, updated, source=source) - self.db.enter_step('line_wait') - except KeyboardInterrupt: - self.log.error('Interrupted') - - self.db.enter_step('end') - self.db.close() - - class Parser(): - def __init__(self, - buf: typing.Any, - recs_queue: multiprocessing.Queue, - block_size: int, - ): - super(Parser, self).__init__() + def __init__(self, buf: typing.Any) -> None: self.buf = buf - self.log = logging.getLogger('pr ') - self.recs_queue = recs_queue - self.block: typing.List[Record] = list() - self.block_size = block_size - self.db = database.Database() # Just for timing - self.db.log = logging.getLogger('pr ') + self.log = logging.getLogger('parser') + self.db = database.Database() + + def end(self) -> None: + self.db.save() + + def register(self, + rtype: RecordType, + updated: int, + name: str, + value: str + ) -> None: - def register(self, record: Record) -> None: self.db.enter_step('register') - self.block.append(record) - if len(self.block) >= self.block_size: - self.db.enter_step('put_block') - self.recs_queue.put(self.block) - self.block = list() - - def run(self) -> None: - self.consume() - self.recs_queue.put(self.block) - self.db.close() + select, write = FUNCTION_MAP[rtype] + try: + for source in select(self.db, value): + # write(self.db, name, updated, source=source) + write(self.db, name, updated) + except NotImplementedError: + return # DEBUG def consume(self) -> None: raise NotImplementedError @@ -146,13 +72,12 @@ class Rapid7Parser(Parser): data = json.loads(line) except json.decoder.JSONDecodeError: continue - record = ( + self.register( Rapid7Parser.TYPES[data['type']], int(data['timestamp']), data['name'], data['value'] ) - self.register(record) class DnsMassParser(Parser): @@ -182,13 +107,12 @@ class DnsMassParser(Parser): else: dtype, name_offset, value_offset = \ DnsMassParser.TYPES[split[1]] - record = ( + self.register( dtype, timestamp, split[0][:name_offset], split[2][:value_offset], ) - self.register(record) self.db.enter_step('parse_dnsmass') except KeyError: continue @@ -212,49 +136,12 @@ if __name__ == '__main__': args_parser.add_argument( '-i', '--input', type=argparse.FileType('r'), default=sys.stdin, help="TODO") - args_parser.add_argument( - '-j', '--workers', type=int, default=4, - help="TODO") - args_parser.add_argument( - '-b', '--block-size', type=int, default=100, - help="TODO") args = args_parser.parse_args() - DB = database.Database(write=False) # Not needed, just for timing - DB.log = logging.getLogger('db ') - - recs_queue: multiprocessing.Queue = multiprocessing.Queue( - maxsize=10*args.workers) - write_queue: multiprocessing.Queue = multiprocessing.Queue( - maxsize=10*args.workers) - - DB.enter_step('proc_create') - readers: typing.List[Reader] = list() - for w in range(args.workers): - readers.append(Reader(recs_queue, write_queue, w)) - writer = Writer(write_queue) - parser = PARSERS[args.parser]( - args.input, recs_queue, args.block_size) - - DB.enter_step('proc_start') - for reader in readers: - reader.start() - writer.start() - + parser = PARSERS[args.parser](args.input) try: - DB.enter_step('parser_run') - parser.run() - - DB.enter_step('end_put') - for _ in range(args.workers): - recs_queue.put(None) - write_queue.put(None) - - DB.enter_step('proc_join') - for reader in readers: - reader.join() - writer.join() + parser.consume() except KeyboardInterrupt: - log.error('Interrupted') + pass + parser.end() - DB.close() diff --git a/feed_rules.py b/feed_rules.py index 715126e..cca1261 100755 --- a/feed_rules.py +++ b/feed_rules.py @@ -28,15 +28,15 @@ if __name__ == '__main__': help="The input only comes from verified first-party sources") args = parser.parse_args() - DB = database.Database(write=True) + DB = database.Database() fun = FUNCTION_MAP[args.type] for rule in args.input: fun(DB, rule.strip(), - is_first_party=args.first_party, + # is_first_party=args.first_party, updated=int(time.time()), ) - DB.close() + DB.save() diff --git a/import_rules.sh b/import_rules.sh index 33c4fbd..cdeec93 100755 --- a/import_rules.sh +++ b/import_rules.sh @@ -6,11 +6,11 @@ function log() { log "Importing rules…" BEFORE="$(date +%s)" -cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone -cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone -cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone -cat rules_ip/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network -cat rules_asn/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py asn +# cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone +# cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone +# cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone +# cat rules_ip/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network +# cat rules_asn/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py asn cat rules/first-party.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone --first-party cat rules_ip/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network --first-party @@ -19,4 +19,4 @@ cat rules_asn/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py as ./feed_asn.py log "Pruning old rules…" -./database.py --prune --prune-before "$BEFORE" --prune-base +./db.py --prune --prune-before "$BEFORE" --prune-base