#!/usr/bin/env python3 """ Utility functions to interact with the database. """ import sqlite3 import typing import time import os import logging import argparse import coloredlogs import ipaddress coloredlogs.install( level='DEBUG', fmt='%(asctime)s %(name)s %(levelname)s %(message)s' ) DbValue = typing.Union[None, int, float, str, bytes] class Database(): VERSION = 3 PATH = "blocking.db" def open(self) -> None: mode = 'rwc' if self.write else 'ro' uri = f'file:{self.PATH}?mode={mode}' self.conn = sqlite3.connect(uri, uri=True) self.cursor = self.conn.cursor() self.execute("PRAGMA foreign_keys = ON") # self.conn.create_function("prepare_ip4address", 1, # Database.prepare_ip4address, # deterministic=True) self.conn.create_function("unpack_domain", 1, lambda s: s[:-1][::-1], deterministic=True) def execute(self, cmd: str, args: typing.Union[ typing.Tuple[DbValue, ...], typing.Dict[str, DbValue]] = None) -> None: # self.log.debug(cmd) # self.log.debug(args) self.cursor.execute(cmd, args or tuple()) def get_meta(self, key: str) -> typing.Optional[int]: try: self.execute("SELECT value FROM meta WHERE key=?", (key,)) except sqlite3.OperationalError: return None for ver, in self.cursor: return ver return None def set_meta(self, key: str, val: int) -> None: self.execute("INSERT INTO meta VALUES (?, ?) " "ON CONFLICT (key) DO " "UPDATE set value=?", (key, val, val)) def close(self) -> None: self.enter_step('close_commit') self.conn.commit() self.enter_step('close') self.conn.close() self.profile() def initialize(self) -> None: self.close() self.enter_step('initialize') if not self.write: self.log.error("Cannot initialize in read-only mode.") raise os.unlink(self.PATH) self.open() self.log.info("Creating database version %d.", self.VERSION) with open("database_schema.sql", 'r') as db_schema: self.cursor.executescript(db_schema.read()) self.set_meta('version', self.VERSION) self.conn.commit() def __init__(self, write: bool = False) -> None: self.log = logging.getLogger('db') self.time_last = time.perf_counter() self.time_step = 'init' self.time_dict: typing.Dict[str, float] = dict() self.step_dict: typing.Dict[str, int] = dict() self.write = write self.open() version = self.get_meta('version') if version != self.VERSION: if version is not None: self.log.warning( "Outdated database version: %d found, will be rebuilt.", version) self.initialize() updated = self.get_meta('updated') if updated is None: self.execute('SELECT max(updated) FROM rules') data = self.cursor.fetchone() updated, = data self.updated = updated or 1 def enter_step(self, name: str) -> None: now = time.perf_counter() try: self.time_dict[self.time_step] += now - self.time_last self.step_dict[self.time_step] += 1 except KeyError: self.time_dict[self.time_step] = now - self.time_last self.step_dict[self.time_step] = 1 self.time_step = name self.time_last = time.perf_counter() def profile(self) -> None: self.enter_step('profile') total = sum(self.time_dict.values()) for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]): times = self.step_dict[key] self.log.debug(f"{key:<20}: {times:9d} × {secs/times:5.3e} " f"= {secs:9.2f} s ({secs/total:7.2%}) ") self.log.debug(f"{'total':<20}: " f"{total:9.2f} s ({1:7.2%})") def prepare_hostname(self, hostname: str) -> str: return hostname[::-1] + '.' def prepare_zone(self, zone: str) -> str: return self.prepare_hostname(zone) @staticmethod def prepare_asn(asn: str) -> int: asn = asn.upper() if asn.startswith('AS'): asn = asn[2:] return int(asn) @staticmethod def prepare_ip4address(address: str) -> int: total = 0 for i, octet in enumerate(address.split('.')): total += int(octet) << (3-i)*8 return total # return '{:02x}{:02x}{:02x}{:02x}'.format( # *[int(c) for c in address.split('.')]) # return base64.b16encode(packed).decode() # return '{:08b}{:08b}{:08b}{:08b}'.format( # *[int(c) for c in address.split('.')]) # carg = ctypes.c_wchar_p(address) # ret = ACCEL.ip4_flat(carg, self.accel_ip4_buf) # if ret != 0: # raise ValueError # return self.accel_ip4_buf.value # packed = ipaddress.ip_address(address).packed # return packed def prepare_ip4network(self, network: str) -> typing.Tuple[int, int]: # def prepare_ip4network(network: str) -> str: net = ipaddress.ip_network(network) mini = self.prepare_ip4address(net.network_address.exploded) maxi = self.prepare_ip4address(net.broadcast_address.exploded) # mini = net.network_address.packed # maxi = net.broadcast_address.packed return mini, maxi # return Database.prepare_ip4address(net.network_address.exploded)[:net.prefixlen] def expire(self) -> None: self.enter_step('expire') self.updated += 1 self.set_meta('updated', self.updated) def update_references(self) -> None: self.enter_step('update_refs') self.execute('UPDATE rules AS r SET refs=' '(SELECT count(*) FROM rules ' 'WHERE source=r.id)') def prune(self) -> None: self.enter_step('prune') self.execute('DELETE FROM rules WHERE updated typing.Iterable[str]: command = 'SELECT unpack_domain(val) FROM rules ' \ 'INNER JOIN hostname ON rules.id = hostname.entry' restrictions: typing.List[str] = list() if first_party_only: restrictions.append('rules.first_party = 1') if end_chain_only: restrictions.append('rules.refs = 0') if restrictions: command += ' WHERE ' + ' AND '.join(restrictions) command += ' ORDER BY unpack_domain(val) ASC' self.execute(command) for val, in self.cursor: yield val def get_domain(self, domain: str) -> typing.Iterable[int]: self.enter_step('get_domain_prepare') domain_prep = self.prepare_hostname(domain) self.enter_step('get_domain_select') self.execute( 'SELECT null, entry FROM hostname ' 'WHERE val=:d ' 'UNION ' 'SELECT * FROM (' 'SELECT val, entry FROM zone ' 'WHERE val<=:d ' 'ORDER BY val DESC LIMIT 1' ')', {'d': domain_prep} ) for val, entry in self.cursor: self.enter_step('get_domain_confirm') if not (val is None or domain_prep.startswith(val)): continue self.enter_step('get_domain_yield') yield entry def get_ip4(self, address: str) -> typing.Iterable[int]: self.enter_step('get_ip4_prepare') try: address_prep = self.prepare_ip4address(address) except (ValueError, IndexError): self.log.error("Invalid ip4address: %s", address) return self.enter_step('get_ip4_select') self.execute( 'SELECT entry FROM ip4address ' # 'SELECT null, entry FROM ip4address ' 'WHERE val=:a ' 'UNION ' # 'SELECT * FROM (' # 'SELECT val, entry FROM ip4network ' # 'WHERE val<=:a ' # 'AND instr(:a, val) > 0 ' # 'ORDER BY val DESC' # ')' 'SELECT entry FROM ip4network ' 'WHERE :a BETWEEN mini AND maxi ', {'a': address_prep} ) for val, entry in self.cursor: # self.enter_step('get_ip4_confirm') # if not (val is None or val.startswith(address_prep)): # # PERF startswith but from the end # continue self.enter_step('get_ip4_yield') yield entry def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]: self.enter_step('list_asn_select') self.enter_step('get_domain_select') self.execute('SELECT val, entry FROM asn') for val, entry in self.cursor: yield f'AS{val}', entry def _set_generic(self, table: str, select_query: str, insert_query: str, prep: typing.Dict[str, DbValue], is_first_party: bool = False, source: int = None, ) -> None: # Since this isn't the bulk of the processing, # here abstraction > performaces # Fields based on the source if source is None: first_party = int(is_first_party) level = 0 else: self.enter_step(f'set_{table}_source') self.execute( 'SELECT first_party, level FROM rules ' 'WHERE id=?', (source,) ) first_party, level = self.cursor.fetchone() level += 1 self.enter_step(f'set_{table}_select') self.execute(select_query, prep) rules_prep = { "source": source, "updated": self.updated, "first_party": first_party, "level": level, } # If the entry already exists for entry, in self.cursor: # only one self.enter_step(f'set_{table}_update') rules_prep['entry'] = entry self.execute( 'UPDATE rules SET ' 'source=:source, updated=:updated, ' 'first_party=:first_party, level=:level ' 'WHERE id=:entry AND (updated<:updated OR ' 'first_party<:first_party OR level<:level)', rules_prep ) # Only update if any of the following: # - the entry is outdataed # - the entry was not a first_party but this is # - this is closer to the original rule return # If it does not exist if source is not None: self.enter_step(f'set_{table}_incsrc') self.execute('UPDATE rules SET refs = refs + 1 WHERE id=?', (source,)) self.enter_step(f'set_{table}_insert') self.execute( 'INSERT INTO rules ' '(source, updated, first_party, refs, level) ' 'VALUES (:source, :updated, :first_party, 0, :level) ', rules_prep ) self.execute('SELECT id FROM rules WHERE rowid=?', (self.cursor.lastrowid,)) for entry, in self.cursor: # only one prep['entry'] = entry self.execute(insert_query, prep) return assert False def set_hostname(self, hostname: str, *args: typing.Any, **kwargs: typing.Any) -> None: self.enter_step('set_hostname_prepare') prep: typing.Dict[str, DbValue] = { 'val': self.prepare_hostname(hostname), } self._set_generic( 'hostname', 'SELECT entry FROM hostname WHERE val=:val', 'INSERT INTO hostname (val, entry) ' 'VALUES (:val, :entry)', prep, *args, **kwargs ) def set_asn(self, asn: str, *args: typing.Any, **kwargs: typing.Any) -> None: self.enter_step('set_asn_prepare') try: asn_prep = self.prepare_asn(asn) except ValueError: self.log.error("Invalid asn: %s", asn) return prep: typing.Dict[str, DbValue] = { 'val': asn_prep, } self._set_generic( 'asn', 'SELECT entry FROM asn WHERE val=:val', 'INSERT INTO asn (val, entry) ' 'VALUES (:val, :entry)', prep, *args, **kwargs ) def set_ip4address(self, ip4address: str, *args: typing.Any, **kwargs: typing.Any) -> None: # TODO Do not add if already in ip4network self.enter_step('set_ip4add_prepare') try: ip4address_prep = self.prepare_ip4address(ip4address) except (ValueError, IndexError): self.log.error("Invalid ip4address: %s", ip4address) return prep: typing.Dict[str, DbValue] = { 'val': ip4address_prep, } self._set_generic( 'ip4add', 'SELECT entry FROM ip4address WHERE val=:val', 'INSERT INTO ip4address (val, entry) ' 'VALUES (:val, :entry)', prep, *args, **kwargs ) def set_zone(self, zone: str, *args: typing.Any, **kwargs: typing.Any) -> None: self.enter_step('set_zone_prepare') prep: typing.Dict[str, DbValue] = { 'val': self.prepare_zone(zone), } self._set_generic( 'zone', 'SELECT entry FROM zone WHERE val=:val', 'INSERT INTO zone (val, entry) ' 'VALUES (:val, :entry)', prep, *args, **kwargs ) def set_ip4network(self, ip4network: str, *args: typing.Any, **kwargs: typing.Any) -> None: self.enter_step('set_ip4net_prepare') try: ip4network_prep = self.prepare_ip4network(ip4network) except (ValueError, IndexError): self.log.error("Invalid ip4network: %s", ip4network) return prep: typing.Dict[str, DbValue] = { 'mini': ip4network_prep[0], 'maxi': ip4network_prep[1], } self._set_generic( 'ip4net', 'SELECT entry FROM ip4network WHERE mini=:mini AND maxi=:maxi', 'INSERT INTO ip4network (mini, maxi, entry) ' 'VALUES (:mini, :maxi, :entry)', prep, *args, **kwargs ) if __name__ == '__main__': # Parsing arguments parser = argparse.ArgumentParser( description="Database operations") parser.add_argument( '-i', '--initialize', action='store_true', help="Reconstruct the whole database") parser.add_argument( '-p', '--prune', action='store_true', help="Remove old entries from database") parser.add_argument( '-e', '--expire', action='store_true', help="Set the whole database as an old source") parser.add_argument( '-r', '--references', action='store_true', help="Update the reference count") args = parser.parse_args() DB = Database(write=True) if args.initialize: DB.initialize() if args.prune: DB.prune() if args.expire: DB.expire() if args.references and not args.prune: DB.update_references() DB.close()