Geoffrey Frogeye
57416b6e2c
Mostly for performances reasons. First one to implement threading later. Second one to speed up the dichotomy, but it doesn't seem that much better so far.
416 lines
14 KiB
Python
Executable file
416 lines
14 KiB
Python
Executable file
#!/usr/bin/env python3
|
||
|
||
"""
|
||
Utility functions to interact with the database.
|
||
"""
|
||
|
||
import sqlite3
|
||
import typing
|
||
import time
|
||
import os
|
||
import logging
|
||
import argparse
|
||
import coloredlogs
|
||
import ipaddress
|
||
import ctypes
|
||
|
||
coloredlogs.install(
|
||
level='DEBUG',
|
||
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
|
||
)
|
||
|
||
DbValue = typing.Union[None, int, float, str, bytes]
|
||
|
||
|
||
class Database():
|
||
VERSION = 3
|
||
PATH = "blocking.db"
|
||
|
||
def open(self) -> None:
|
||
self.conn = sqlite3.connect(self.PATH)
|
||
self.cursor = self.conn.cursor()
|
||
self.execute("PRAGMA foreign_keys = ON")
|
||
# self.conn.create_function("prepare_ip4address", 1,
|
||
# Database.prepare_ip4address,
|
||
# deterministic=True)
|
||
|
||
def execute(self, cmd: str, args: typing.Union[
|
||
typing.Tuple[DbValue, ...],
|
||
typing.Dict[str, DbValue]] = None) -> None:
|
||
self.cursor.execute(cmd, args or tuple())
|
||
|
||
def get_meta(self, key: str) -> typing.Optional[int]:
|
||
try:
|
||
self.execute("SELECT value FROM meta WHERE key=?", (key,))
|
||
except sqlite3.OperationalError:
|
||
return None
|
||
for ver, in self.cursor:
|
||
return ver
|
||
return None
|
||
|
||
def set_meta(self, key: str, val: int) -> None:
|
||
self.execute("INSERT INTO meta VALUES (?, ?) "
|
||
"ON CONFLICT (key) DO "
|
||
"UPDATE set value=?",
|
||
(key, val, val))
|
||
|
||
def close(self) -> None:
|
||
self.enter_step('close_commit')
|
||
self.conn.commit()
|
||
self.enter_step('close')
|
||
self.conn.close()
|
||
self.profile()
|
||
|
||
def initialize(self) -> None:
|
||
self.enter_step('initialize')
|
||
self.close()
|
||
os.unlink(self.PATH)
|
||
self.open()
|
||
self.log.info("Creating database version %d.", self.VERSION)
|
||
with open("database_schema.sql", 'r') as db_schema:
|
||
self.cursor.executescript(db_schema.read())
|
||
self.set_meta('version', self.VERSION)
|
||
self.conn.commit()
|
||
|
||
def __init__(self) -> None:
|
||
self.log = logging.getLogger('db')
|
||
self.time_last = time.perf_counter()
|
||
self.time_step = 'init'
|
||
self.time_dict: typing.Dict[str, float] = dict()
|
||
self.step_dict: typing.Dict[str, int] = dict()
|
||
self.accel_ip4_buf = ctypes.create_unicode_buffer('Z'*32, 32)
|
||
|
||
self.open()
|
||
version = self.get_meta('version')
|
||
if version != self.VERSION:
|
||
if version is not None:
|
||
self.log.warning(
|
||
"Outdated database version: %d found, will be rebuilt.",
|
||
version)
|
||
self.initialize()
|
||
|
||
updated = self.get_meta('updated')
|
||
if updated is None:
|
||
self.execute('SELECT max(updated) FROM rules')
|
||
data = self.cursor.fetchone()
|
||
updated, = data
|
||
self.updated = updated or 1
|
||
|
||
def enter_step(self, name: str) -> None:
|
||
now = time.perf_counter()
|
||
try:
|
||
self.time_dict[self.time_step] += now - self.time_last
|
||
self.step_dict[self.time_step] += 1
|
||
except KeyError:
|
||
self.time_dict[self.time_step] = now - self.time_last
|
||
self.step_dict[self.time_step] = 1
|
||
self.time_step = name
|
||
self.time_last = time.perf_counter()
|
||
|
||
def profile(self) -> None:
|
||
self.enter_step('profile')
|
||
total = sum(self.time_dict.values())
|
||
for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]):
|
||
times = self.step_dict[key]
|
||
self.log.debug(f"{key:<20}: {times:9d} × {secs/times:5.3e} "
|
||
f"= {secs:9.2f} s ({secs/total:7.2%}) ")
|
||
self.log.debug(f"{'total':<20}: "
|
||
f"{total:9.2f} s ({1:7.2%})")
|
||
|
||
def prepare_hostname(self, hostname: str) -> str:
|
||
return hostname[::-1] + '.'
|
||
|
||
def prepare_zone(self, zone: str) -> str:
|
||
return self.prepare_hostname(zone)
|
||
|
||
@staticmethod
|
||
def prepare_ip4address(address: str) -> int:
|
||
total = 0
|
||
for i, octet in enumerate(address.split('.')):
|
||
total += int(octet) << (3-i)*8
|
||
return total
|
||
# return '{:02x}{:02x}{:02x}{:02x}'.format(
|
||
# *[int(c) for c in address.split('.')])
|
||
# return base64.b16encode(packed).decode()
|
||
# return '{:08b}{:08b}{:08b}{:08b}'.format(
|
||
# *[int(c) for c in address.split('.')])
|
||
# carg = ctypes.c_wchar_p(address)
|
||
# ret = ACCEL.ip4_flat(carg, self.accel_ip4_buf)
|
||
# if ret != 0:
|
||
# raise ValueError
|
||
# return self.accel_ip4_buf.value
|
||
# packed = ipaddress.ip_address(address).packed
|
||
# return packed
|
||
|
||
def prepare_ip4network(self, network: str) -> typing.Tuple[int, int]:
|
||
# def prepare_ip4network(network: str) -> str:
|
||
net = ipaddress.ip_network(network)
|
||
mini = self.prepare_ip4address(net.network_address.exploded)
|
||
maxi = self.prepare_ip4address(net.broadcast_address.exploded)
|
||
# mini = net.network_address.packed
|
||
# maxi = net.broadcast_address.packed
|
||
return mini, maxi
|
||
# return Database.prepare_ip4address(net.network_address.exploded)[:net.prefixlen]
|
||
|
||
def expire(self) -> None:
|
||
self.enter_step('expire')
|
||
self.updated += 1
|
||
self.set_meta('updated', self.updated)
|
||
|
||
def update_references(self) -> None:
|
||
self.enter_step('update_refs')
|
||
self.execute('UPDATE rules AS r SET refs='
|
||
'(SELECT count(*) FROM rules '
|
||
'WHERE source=r.id)')
|
||
|
||
def prune(self) -> None:
|
||
self.enter_step('prune')
|
||
self.execute('DELETE FROM rules WHERE updated<?', (self.updated,))
|
||
|
||
def export(self, first_party_only: bool = False,
|
||
end_chain_only: bool = False) -> typing.Iterable[str]:
|
||
command = 'SELECT val FROM rules ' \
|
||
'INNER JOIN hostname ON rules.id = hostname.entry'
|
||
restrictions: typing.List[str] = list()
|
||
if first_party_only:
|
||
restrictions.append('rules.first_party = 1')
|
||
if end_chain_only:
|
||
restrictions.append('rules.refs = 0')
|
||
if restrictions:
|
||
command += ' WHERE ' + ' AND '.join(restrictions)
|
||
self.execute(command)
|
||
for val, in self.cursor:
|
||
yield val[:-1][::-1]
|
||
|
||
def get_domain(self, domain: str) -> typing.Iterable[int]:
|
||
self.enter_step('get_domain_prepare')
|
||
domain_prep = self.prepare_hostname(domain)
|
||
self.enter_step('get_domain_select')
|
||
self.execute(
|
||
'SELECT null, entry FROM hostname '
|
||
'WHERE val=:d '
|
||
'UNION '
|
||
'SELECT * FROM ('
|
||
'SELECT val, entry FROM zone '
|
||
'WHERE val<=:d '
|
||
'ORDER BY val DESC LIMIT 1'
|
||
')',
|
||
{'d': domain_prep}
|
||
)
|
||
for val, entry in self.cursor:
|
||
self.enter_step('get_domain_confirm')
|
||
if not (val is None or domain_prep.startswith(val)):
|
||
continue
|
||
self.enter_step('get_domain_yield')
|
||
yield entry
|
||
|
||
def get_ip4(self, address: str) -> typing.Iterable[int]:
|
||
self.enter_step('get_ip4_prepare')
|
||
try:
|
||
address_prep = self.prepare_ip4address(address)
|
||
except (ValueError, IndexError):
|
||
self.log.error("Invalid ip4address: %s", address)
|
||
return
|
||
self.enter_step('get_ip4_select')
|
||
self.execute(
|
||
'SELECT entry FROM ip4address '
|
||
# 'SELECT null, entry FROM ip4address '
|
||
'WHERE val=:a '
|
||
'UNION '
|
||
# 'SELECT * FROM ('
|
||
# 'SELECT val, entry FROM ip4network '
|
||
# 'WHERE val<=:a '
|
||
# 'AND instr(:a, val) > 0 '
|
||
# 'ORDER BY val DESC'
|
||
# ')'
|
||
'SELECT entry FROM ip4network '
|
||
'WHERE :a BETWEEN mini AND maxi ',
|
||
{'a': address_prep}
|
||
)
|
||
for val, entry in self.cursor:
|
||
# self.enter_step('get_ip4_confirm')
|
||
# if not (val is None or val.startswith(address_prep)):
|
||
# # PERF startswith but from the end
|
||
# continue
|
||
self.enter_step('get_ip4_yield')
|
||
yield entry
|
||
|
||
def _set_generic(self,
|
||
table: str,
|
||
select_query: str,
|
||
insert_query: str,
|
||
prep: typing.Dict[str, DbValue],
|
||
is_first_party: bool = False,
|
||
source: int = None,
|
||
) -> None:
|
||
# Since this isn't the bulk of the processing,
|
||
# here abstraction > performaces
|
||
|
||
# Fields based on the source
|
||
if source is None:
|
||
first_party = int(is_first_party)
|
||
level = 0
|
||
else:
|
||
self.enter_step(f'set_{table}_source')
|
||
self.execute(
|
||
'SELECT first_party, level FROM rules '
|
||
'WHERE id=?',
|
||
(source,)
|
||
)
|
||
first_party, level = self.cursor.fetchone()
|
||
level += 1
|
||
|
||
self.enter_step(f'set_{table}_select')
|
||
self.execute(select_query, prep)
|
||
|
||
rules_prep = {
|
||
"source": source,
|
||
"updated": self.updated,
|
||
"first_party": first_party,
|
||
"level": level,
|
||
}
|
||
|
||
# If the entry already exists
|
||
for entry, in self.cursor: # only one
|
||
self.enter_step(f'set_{table}_update')
|
||
rules_prep['entry'] = entry
|
||
self.execute(
|
||
'UPDATE rules SET '
|
||
'source=:source, updated=:updated, '
|
||
'first_party=:first_party, level=:level '
|
||
'WHERE id=:entry AND (updated<:updated OR '
|
||
'first_party<:first_party OR level<:level)',
|
||
rules_prep
|
||
)
|
||
# Only update if any of the following:
|
||
# - the entry is outdataed
|
||
# - the entry was not a first_party but this is
|
||
# - this is closer to the original rule
|
||
return
|
||
|
||
# If it does not exist
|
||
|
||
if source is not None:
|
||
self.enter_step(f'set_{table}_incsrc')
|
||
self.execute('UPDATE rules SET refs = refs + 1 WHERE id=?',
|
||
(source,))
|
||
|
||
self.enter_step(f'set_{table}_insert')
|
||
self.execute(
|
||
'INSERT INTO rules '
|
||
'(source, updated, first_party, refs, level) '
|
||
'VALUES (:source, :updated, :first_party, 0, :level) ',
|
||
rules_prep
|
||
)
|
||
self.execute('SELECT id FROM rules WHERE rowid=?',
|
||
(self.cursor.lastrowid,))
|
||
for entry, in self.cursor: # only one
|
||
prep['entry'] = entry
|
||
self.execute(insert_query, prep)
|
||
return
|
||
assert False
|
||
|
||
def set_hostname(self, hostname: str,
|
||
*args: typing.Any, **kwargs: typing.Any) -> None:
|
||
self.enter_step('set_hostname_prepare')
|
||
prep: typing.Dict[str, DbValue] = {
|
||
'val': self.prepare_hostname(hostname),
|
||
}
|
||
self._set_generic(
|
||
'hostname',
|
||
'SELECT entry FROM hostname WHERE val=:val',
|
||
'INSERT INTO hostname (val, entry) '
|
||
'VALUES (:val, :entry)',
|
||
prep,
|
||
*args, **kwargs
|
||
)
|
||
|
||
def set_ip4address(self, ip4address: str,
|
||
*args: typing.Any, **kwargs: typing.Any) -> None:
|
||
self.enter_step('set_ip4add_prepare')
|
||
try:
|
||
ip4address_prep = self.prepare_ip4address(ip4address)
|
||
except (ValueError, IndexError):
|
||
self.log.error("Invalid ip4address: %s", ip4address)
|
||
return
|
||
prep: typing.Dict[str, DbValue] = {
|
||
'val': ip4address_prep,
|
||
}
|
||
self._set_generic(
|
||
'ip4add',
|
||
'SELECT entry FROM ip4address WHERE val=:val',
|
||
'INSERT INTO ip4address (val, entry) '
|
||
'VALUES (:val, :entry)',
|
||
prep,
|
||
*args, **kwargs
|
||
)
|
||
|
||
def set_zone(self, zone: str,
|
||
*args: typing.Any, **kwargs: typing.Any) -> None:
|
||
self.enter_step('set_zone_prepare')
|
||
prep: typing.Dict[str, DbValue] = {
|
||
'val': self.prepare_zone(zone),
|
||
}
|
||
self._set_generic(
|
||
'zone',
|
||
'SELECT entry FROM zone WHERE val=:val',
|
||
'INSERT INTO zone (val, entry) '
|
||
'VALUES (:val, :entry)',
|
||
prep,
|
||
*args, **kwargs
|
||
)
|
||
|
||
def set_ip4network(self, ip4network: str,
|
||
*args: typing.Any, **kwargs: typing.Any) -> None:
|
||
self.enter_step('set_ip4net_prepare')
|
||
try:
|
||
ip4network_prep = self.prepare_ip4network(ip4network)
|
||
except (ValueError, IndexError):
|
||
self.log.error("Invalid ip4network: %s", ip4network)
|
||
return
|
||
prep: typing.Dict[str, DbValue] = {
|
||
'mini': ip4network_prep[0],
|
||
'maxi': ip4network_prep[1],
|
||
}
|
||
self._set_generic(
|
||
'ip4net',
|
||
'SELECT entry FROM ip4network WHERE mini=:mini AND maxi=:maxi',
|
||
'INSERT INTO ip4network (mini, maxi, entry) '
|
||
'VALUES (:mini, :maxi, :entry)',
|
||
prep,
|
||
*args, **kwargs
|
||
)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
|
||
# Parsing arguments
|
||
parser = argparse.ArgumentParser(
|
||
description="Database operations")
|
||
parser.add_argument(
|
||
'-i', '--initialize', action='store_true',
|
||
help="Reconstruct the whole database")
|
||
parser.add_argument(
|
||
'-p', '--prune', action='store_true',
|
||
help="Remove old entries from database")
|
||
parser.add_argument(
|
||
'-e', '--expire', action='store_true',
|
||
help="Set the whole database as an old source")
|
||
parser.add_argument(
|
||
'-r', '--references', action='store_true',
|
||
help="Update the reference count")
|
||
args = parser.parse_args()
|
||
|
||
DB = Database()
|
||
|
||
if args.initialize:
|
||
DB.initialize()
|
||
if args.prune:
|
||
DB.prune()
|
||
if args.expire:
|
||
DB.expire()
|
||
if args.references and not args.prune:
|
||
DB.update_references()
|
||
|
||
DB.close()
|