#!/usr/bin/env python3 """ Utility functions to interact with the database. """ import sqlite3 import typing import time import os import logging import argparse import coloredlogs import ipaddress import math coloredlogs.install( level='DEBUG', fmt='%(asctime)s %(name)s %(levelname)s %(message)s' ) DbValue = typing.Union[None, int, float, str, bytes] class Database(): VERSION = 5 PATH = "blocking.db" def open(self) -> None: mode = 'rwc' if self.write else 'ro' uri = f'file:{self.PATH}?mode={mode}' self.conn = sqlite3.connect(uri, uri=True) cursor = self.conn.cursor() cursor.execute("PRAGMA foreign_keys = ON") self.conn.create_function("unpack_asn", 1, self.unpack_asn, deterministic=True) self.conn.create_function("unpack_ip4address", 1, self.unpack_ip4address, deterministic=True) self.conn.create_function("unpack_ip4network", 2, self.unpack_ip4network, deterministic=True) self.conn.create_function("unpack_domain", 1, lambda s: s[:-1][::-1], deterministic=True) self.conn.create_function("format_zone", 1, lambda s: '*' + s[::-1], deterministic=True) def get_meta(self, key: str) -> typing.Optional[int]: cursor = self.conn.cursor() try: cursor.execute("SELECT value FROM meta WHERE key=?", (key,)) except sqlite3.OperationalError: return None for ver, in cursor: return ver return None def set_meta(self, key: str, val: int) -> None: cursor = self.conn.cursor() cursor.execute("INSERT INTO meta VALUES (?, ?) " "ON CONFLICT (key) DO " "UPDATE set value=?", (key, val, val)) def close(self) -> None: self.enter_step('close_commit') self.conn.commit() self.enter_step('close') self.conn.close() self.profile() def initialize(self) -> None: self.close() self.enter_step('initialize') if not self.write: self.log.error("Cannot initialize in read-only mode.") raise os.unlink(self.PATH) self.open() self.log.info("Creating database version %d.", self.VERSION) cursor = self.conn.cursor() with open("database_schema.sql", 'r') as db_schema: cursor.executescript(db_schema.read()) self.set_meta('version', self.VERSION) self.conn.commit() def __init__(self, write: bool = False) -> None: self.log = logging.getLogger('db') self.time_last = time.perf_counter() self.time_step = 'init' self.time_dict: typing.Dict[str, float] = dict() self.step_dict: typing.Dict[str, int] = dict() self.write = write self.open() version = self.get_meta('version') if version != self.VERSION: if version is not None: self.log.warning( "Outdated database version: %d found, will be rebuilt.", version) self.initialize() def enter_step(self, name: str) -> None: now = time.perf_counter() try: self.time_dict[self.time_step] += now - self.time_last self.step_dict[self.time_step] += 1 except KeyError: self.time_dict[self.time_step] = now - self.time_last self.step_dict[self.time_step] = 1 self.time_step = name self.time_last = time.perf_counter() def profile(self) -> None: self.enter_step('profile') total = sum(self.time_dict.values()) for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]): times = self.step_dict[key] self.log.debug(f"{key:<20}: {times:9d} × {secs/times:5.3e} " f"= {secs:9.2f} s ({secs/total:7.2%}) ") self.log.debug(f"{'total':<20}: " f"{total:9.2f} s ({1:7.2%})") @staticmethod def pack_hostname(hostname: str) -> str: return hostname[::-1] + '.' @staticmethod def pack_zone(zone: str) -> str: return Database.pack_hostname(zone) @staticmethod def pack_asn(asn: str) -> int: asn = asn.upper() if asn.startswith('AS'): asn = asn[2:] return int(asn) @staticmethod def unpack_asn(asn: int) -> str: return f'AS{asn}' @staticmethod def pack_ip4address(address: str) -> int: total = 0 for i, octet in enumerate(address.split('.')): total += int(octet) << (3-i)*8 return total # return '{:02x}{:02x}{:02x}{:02x}'.format( # *[int(c) for c in address.split('.')]) # return base64.b16encode(packed).decode() # return '{:08b}{:08b}{:08b}{:08b}'.format( # *[int(c) for c in address.split('.')]) # carg = ctypes.c_wchar_p(address) # ret = ACCEL.ip4_flat(carg, self.accel_ip4_buf) # if ret != 0: # raise ValueError # return self.accel_ip4_buf.value # packed = ipaddress.ip_address(address).packed # return packed @staticmethod def unpack_ip4address(address: int) -> str: return '.'.join(str((address >> (i * 8)) & 0xFF) for i in reversed(range(4))) @staticmethod def pack_ip4network(network: str) -> typing.Tuple[int, int]: # def pack_ip4network(network: str) -> str: net = ipaddress.ip_network(network) mini = Database.pack_ip4address(net.network_address.exploded) maxi = Database.pack_ip4address(net.broadcast_address.exploded) # mini = net.network_address.packed # maxi = net.broadcast_address.packed return mini, maxi # return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen] @staticmethod def unpack_ip4network(mini: int, maxi: int) -> str: addr = Database.unpack_ip4address(mini) prefixlen = 32-int(math.log2(maxi-mini+1)) return f'{addr}/{prefixlen}' def update_references(self) -> None: self.enter_step('update_refs') cursor = self.conn.cursor() cursor.execute('UPDATE rules AS r SET refs=' '(SELECT count(*) FROM rules ' 'WHERE source=r.id)') def prune(self, before: int) -> None: self.enter_step('prune') cursor = self.conn.cursor() cursor.execute('DELETE FROM rules WHERE updated str: # Format current string = '???' cursor = self.conn.cursor() cursor.execute( 'SELECT unpack_asn(val) FROM asn WHERE entry=:entry ' 'UNION ' 'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry ' 'UNION ' 'SELECT format_zone(val) FROM zone WHERE entry=:entry ' 'UNION ' 'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry ' 'UNION ' 'SELECT unpack_ip4network(mini, maxi) ' 'FROM ip4network WHERE entry=:entry ', {"entry": entry} ) for val, in cursor: # only one string = str(val) string += f' #{entry}' # Add source if any cursor.execute('SELECT source FROM rules WHERE id=?', (entry,)) for source, in cursor: if source: string += f' ← {self.explain(source)}' return string def export(self, first_party_only: bool = False, end_chain_only: bool = False, explain: bool = False, ) -> typing.Iterable[str]: selection = 'entry' if explain else 'unpack_domain(val)' command = f'SELECT {selection} FROM rules ' \ 'INNER JOIN hostname ON rules.id = hostname.entry' restrictions: typing.List[str] = list() if first_party_only: restrictions.append('rules.first_party = 1') if end_chain_only: restrictions.append('rules.refs = 0') if restrictions: command += ' WHERE ' + ' AND '.join(restrictions) if not explain: command += ' ORDER BY unpack_domain(val) ASC' cursor = self.conn.cursor() cursor.execute(command) for val, in cursor: if explain: yield self.explain(val) else: yield val def count_rules(self, first_party_only: bool = False, ) -> str: counts: typing.List[str] = list() cursor = self.conn.cursor() for table in ['asn', 'ip4network', 'ip4address', 'zone', 'hostname']: command = f'SELECT count(*) FROM rules ' \ f'INNER JOIN {table} ON rules.id = {table}.entry ' \ 'WHERE rules.level = 0' if first_party_only: command += ' AND first_party=1' cursor.execute(command) count, = cursor.fetchone() if count > 0: counts.append(f'{table}: {count}') return ', '.join(counts) def get_domain(self, domain: str) -> typing.Iterable[int]: self.enter_step('get_domain_prepare') domain_prep = self.pack_hostname(domain) cursor = self.conn.cursor() self.enter_step('get_domain_select') cursor.execute( 'SELECT null, entry FROM hostname ' 'WHERE val=:d ' 'UNION ' 'SELECT * FROM (' 'SELECT val, entry FROM zone ' 'WHERE val<=:d ' 'ORDER BY val DESC LIMIT 1' ')', {'d': domain_prep} ) for val, entry in cursor: self.enter_step('get_domain_confirm') if not (val is None or domain_prep.startswith(val)): continue self.enter_step('get_domain_yield') yield entry def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]: self.enter_step('get_domainiz_prepare') domain_prep = self.pack_hostname(domain) cursor = self.conn.cursor() self.enter_step('get_domainiz_select') cursor.execute( 'SELECT val, entry FROM zone ' 'WHERE val<=:d ' 'ORDER BY val DESC LIMIT 1', {'d': domain_prep} ) for val, entry in cursor: self.enter_step('get_domainiz_confirm') if not (val is None or domain_prep.startswith(val)): continue self.enter_step('get_domainiz_yield') yield entry def get_ip4(self, address: str) -> typing.Iterable[int]: self.enter_step('get_ip4_prepare') try: address_prep = self.pack_ip4address(address) except (ValueError, IndexError): self.log.error("Invalid ip4address: %s", address) return cursor = self.conn.cursor() self.enter_step('get_ip4_select') cursor.execute( 'SELECT entry FROM ip4address ' # 'SELECT null, entry FROM ip4address ' 'WHERE val=:a ' 'UNION ' # 'SELECT * FROM (' # 'SELECT val, entry FROM ip4network ' # 'WHERE val<=:a ' # 'AND instr(:a, val) > 0 ' # 'ORDER BY val DESC' # ')' 'SELECT entry FROM ip4network ' 'WHERE :a BETWEEN mini AND maxi ', {'a': address_prep} ) for entry, in cursor: # self.enter_step('get_ip4_confirm') # if not (val is None or val.startswith(address_prep)): # # PERF startswith but from the end # continue self.enter_step('get_ip4_yield') yield entry def get_ip4_in_network(self, address: str) -> typing.Iterable[int]: self.enter_step('get_ip4in_prepare') try: address_prep = self.pack_ip4address(address) except (ValueError, IndexError): self.log.error("Invalid ip4address: %s", address) return cursor = self.conn.cursor() self.enter_step('get_ip4in_select') cursor.execute( 'SELECT entry FROM ip4network ' 'WHERE :a BETWEEN mini AND maxi ', {'a': address_prep} ) for entry, in cursor: self.enter_step('get_ip4in_yield') yield entry def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]: cursor = self.conn.cursor() self.enter_step('list_asn_select') cursor.execute('SELECT val, entry FROM asn') for val, entry in cursor: yield f'AS{val}', entry def _set_generic(self, table: str, select_query: str, insert_query: str, prep: typing.Dict[str, DbValue], updated: int, is_first_party: bool = False, source: int = None, ) -> None: # Since this isn't the bulk of the processing, # here abstraction > performaces # Fields based on the source self.enter_step(f'set_{table}_prepare') cursor = self.conn.cursor() if source is None: first_party = int(is_first_party) level = 0 else: self.enter_step(f'set_{table}_source') cursor.execute( 'SELECT first_party, level FROM rules ' 'WHERE id=?', (source,) ) first_party, level = cursor.fetchone() level += 1 self.enter_step(f'set_{table}_select') cursor.execute(select_query, prep) rules_prep: typing.Dict[str, DbValue] = { "source": source, "updated": updated, "first_party": first_party, "level": level, } # If the entry already exists for entry, in cursor: # only one self.enter_step(f'set_{table}_update') rules_prep['entry'] = entry cursor.execute( 'UPDATE rules SET ' 'source=:source, updated=:updated, ' 'first_party=:first_party, level=:level ' 'WHERE id=:entry AND (updated<:updated OR ' 'first_party<:first_party OR level<:level)', rules_prep ) # Only update if any of the following: # - the entry is outdataed # - the entry was not a first_party but this is # - this is closer to the original rule return # If it does not exist self.enter_step(f'set_{table}_insert') cursor.execute( 'INSERT INTO rules ' '(source, updated, first_party, level) ' 'VALUES (:source, :updated, :first_party, :level) ', rules_prep ) cursor.execute('SELECT id FROM rules WHERE rowid=?', (cursor.lastrowid,)) for entry, in cursor: # only one prep['entry'] = entry cursor.execute(insert_query, prep) return assert False def set_hostname(self, hostname: str, *args: typing.Any, **kwargs: typing.Any) -> None: self.enter_step('set_hostname_prepare') prep: typing.Dict[str, DbValue] = { 'val': self.pack_hostname(hostname), } self._set_generic( 'hostname', 'SELECT entry FROM hostname WHERE val=:val', 'INSERT INTO hostname (val, entry) ' 'VALUES (:val, :entry)', prep, *args, **kwargs ) def set_asn(self, asn: str, *args: typing.Any, **kwargs: typing.Any) -> None: self.enter_step('set_asn_prepare') try: asn_prep = self.pack_asn(asn) except ValueError: self.log.error("Invalid asn: %s", asn) return prep: typing.Dict[str, DbValue] = { 'val': asn_prep, } self._set_generic( 'asn', 'SELECT entry FROM asn WHERE val=:val', 'INSERT INTO asn (val, entry) ' 'VALUES (:val, :entry)', prep, *args, **kwargs ) def set_ip4address(self, ip4address: str, *args: typing.Any, **kwargs: typing.Any) -> None: self.enter_step('set_ip4add_prepare') try: ip4address_prep = self.pack_ip4address(ip4address) except (ValueError, IndexError): self.log.error("Invalid ip4address: %s", ip4address) return prep: typing.Dict[str, DbValue] = { 'val': ip4address_prep, } self._set_generic( 'ip4add', 'SELECT entry FROM ip4address WHERE val=:val', 'INSERT INTO ip4address (val, entry) ' 'VALUES (:val, :entry)', prep, *args, **kwargs ) def set_zone(self, zone: str, *args: typing.Any, **kwargs: typing.Any) -> None: self.enter_step('set_zone_prepare') prep: typing.Dict[str, DbValue] = { 'val': self.pack_zone(zone), } self._set_generic( 'zone', 'SELECT entry FROM zone WHERE val=:val', 'INSERT INTO zone (val, entry) ' 'VALUES (:val, :entry)', prep, *args, **kwargs ) def set_ip4network(self, ip4network: str, *args: typing.Any, **kwargs: typing.Any) -> None: self.enter_step('set_ip4net_prepare') try: ip4network_prep = self.pack_ip4network(ip4network) except (ValueError, IndexError): self.log.error("Invalid ip4network: %s", ip4network) return prep: typing.Dict[str, DbValue] = { 'mini': ip4network_prep[0], 'maxi': ip4network_prep[1], } self._set_generic( 'ip4net', 'SELECT entry FROM ip4network WHERE mini=:mini AND maxi=:maxi', 'INSERT INTO ip4network (mini, maxi, entry) ' 'VALUES (:mini, :maxi, :entry)', prep, *args, **kwargs ) if __name__ == '__main__': # Parsing arguments parser = argparse.ArgumentParser( description="Database operations") parser.add_argument( '-i', '--initialize', action='store_true', help="Reconstruct the whole database") parser.add_argument( '-p', '--prune', action='store_true', help="Remove old (+6 months) entries from database") parser.add_argument( '-r', '--references', action='store_true', help="Update the reference count") args = parser.parse_args() DB = Database(write=True) if args.initialize: DB.initialize() if args.prune: DB.prune(before=int(time.time()) - 60*60*24*31*6) if args.references and not args.prune: DB.update_references() DB.close()