eulaurarien/database.py

576 lines
20 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
Utility functions to interact with the database.
"""
import sqlite3
import typing
import time
import os
import logging
import argparse
import coloredlogs
import ipaddress
import math
coloredlogs.install(
level='DEBUG',
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
)
DbValue = typing.Union[None, int, float, str, bytes]
class Database():
2019-12-13 18:36:08 +01:00
VERSION = 5
PATH = "blocking.db"
def open(self) -> None:
2019-12-13 12:35:05 +01:00
mode = 'rwc' if self.write else 'ro'
uri = f'file:{self.PATH}?mode={mode}'
self.conn = sqlite3.connect(uri, uri=True)
cursor = self.conn.cursor()
cursor.execute("PRAGMA foreign_keys = ON")
self.conn.create_function("unpack_asn", 1,
self.unpack_asn,
deterministic=True)
self.conn.create_function("unpack_ip4address", 1,
self.unpack_ip4address,
deterministic=True)
self.conn.create_function("unpack_ip4network", 2,
self.unpack_ip4network,
deterministic=True)
self.conn.create_function("unpack_domain", 1,
lambda s: s[:-1][::-1],
deterministic=True)
self.conn.create_function("format_zone", 1,
lambda s: '*' + s[::-1],
deterministic=True)
def get_meta(self, key: str) -> typing.Optional[int]:
cursor = self.conn.cursor()
try:
cursor.execute("SELECT value FROM meta WHERE key=?", (key,))
except sqlite3.OperationalError:
return None
for ver, in cursor:
return ver
return None
def set_meta(self, key: str, val: int) -> None:
cursor = self.conn.cursor()
cursor.execute("INSERT INTO meta VALUES (?, ?) "
"ON CONFLICT (key) DO "
"UPDATE set value=?",
(key, val, val))
def close(self) -> None:
self.enter_step('close_commit')
self.conn.commit()
self.enter_step('close')
self.conn.close()
self.profile()
def initialize(self) -> None:
self.close()
2019-12-13 12:35:05 +01:00
self.enter_step('initialize')
if not self.write:
self.log.error("Cannot initialize in read-only mode.")
raise
os.unlink(self.PATH)
self.open()
self.log.info("Creating database version %d.", self.VERSION)
cursor = self.conn.cursor()
with open("database_schema.sql", 'r') as db_schema:
cursor.executescript(db_schema.read())
self.set_meta('version', self.VERSION)
self.conn.commit()
2019-12-13 12:35:05 +01:00
def __init__(self, write: bool = False) -> None:
self.log = logging.getLogger('db')
self.time_last = time.perf_counter()
self.time_step = 'init'
self.time_dict: typing.Dict[str, float] = dict()
self.step_dict: typing.Dict[str, int] = dict()
2019-12-13 12:35:05 +01:00
self.write = write
self.open()
version = self.get_meta('version')
if version != self.VERSION:
if version is not None:
self.log.warning(
"Outdated database version: %d found, will be rebuilt.",
version)
self.initialize()
def enter_step(self, name: str) -> None:
now = time.perf_counter()
try:
self.time_dict[self.time_step] += now - self.time_last
self.step_dict[self.time_step] += 1
except KeyError:
self.time_dict[self.time_step] = now - self.time_last
self.step_dict[self.time_step] = 1
self.time_step = name
self.time_last = time.perf_counter()
def profile(self) -> None:
self.enter_step('profile')
total = sum(self.time_dict.values())
for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]):
times = self.step_dict[key]
self.log.debug(f"{key:<20}: {times:9d} × {secs/times:5.3e} "
f"= {secs:9.2f} s ({secs/total:7.2%}) ")
self.log.debug(f"{'total':<20}: "
f"{total:9.2f} s ({1:7.2%})")
@staticmethod
def pack_hostname(hostname: str) -> str:
return hostname[::-1] + '.'
@staticmethod
def pack_zone(zone: str) -> str:
return Database.pack_hostname(zone)
@staticmethod
def pack_asn(asn: str) -> int:
asn = asn.upper()
if asn.startswith('AS'):
asn = asn[2:]
return int(asn)
@staticmethod
def unpack_asn(asn: int) -> str:
return f'AS{asn}'
@staticmethod
def pack_ip4address(address: str) -> int:
total = 0
for i, octet in enumerate(address.split('.')):
total += int(octet) << (3-i)*8
2019-12-14 16:04:19 +01:00
if total > 0xFFFFFFFF:
raise ValueError
return total
# return '{:02x}{:02x}{:02x}{:02x}'.format(
# *[int(c) for c in address.split('.')])
# return base64.b16encode(packed).decode()
# return '{:08b}{:08b}{:08b}{:08b}'.format(
# *[int(c) for c in address.split('.')])
# carg = ctypes.c_wchar_p(address)
# ret = ACCEL.ip4_flat(carg, self.accel_ip4_buf)
# if ret != 0:
# raise ValueError
# return self.accel_ip4_buf.value
# packed = ipaddress.ip_address(address).packed
# return packed
@staticmethod
def unpack_ip4address(address: int) -> str:
return '.'.join(str((address >> (i * 8)) & 0xFF)
for i in reversed(range(4)))
@staticmethod
def pack_ip4network(network: str) -> typing.Tuple[int, int]:
# def pack_ip4network(network: str) -> str:
net = ipaddress.ip_network(network)
mini = Database.pack_ip4address(net.network_address.exploded)
maxi = Database.pack_ip4address(net.broadcast_address.exploded)
# mini = net.network_address.packed
# maxi = net.broadcast_address.packed
return mini, maxi
# return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen]
@staticmethod
def unpack_ip4network(mini: int, maxi: int) -> str:
addr = Database.unpack_ip4address(mini)
prefixlen = 32-int(math.log2(maxi-mini+1))
return f'{addr}/{prefixlen}'
def update_references(self) -> None:
self.enter_step('update_refs')
cursor = self.conn.cursor()
cursor.execute('UPDATE rules AS r SET refs='
'(SELECT count(*) FROM rules '
'WHERE source=r.id)')
2019-12-14 16:04:19 +01:00
def prune(self, before: int, base_only: bool = False) -> None:
self.enter_step('prune')
cursor = self.conn.cursor()
2019-12-14 16:04:19 +01:00
cmd = 'DELETE FROM rules WHERE updated<?'
if base_only:
cmd += ' AND level=0'
cursor.execute(cmd, (before,))
def explain(self, entry: int) -> str:
# Format current
string = '???'
cursor = self.conn.cursor()
cursor.execute(
'SELECT unpack_asn(val) FROM asn WHERE entry=:entry '
'UNION '
'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry '
'UNION '
'SELECT format_zone(val) FROM zone WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4network(mini, maxi) '
'FROM ip4network WHERE entry=:entry ',
{"entry": entry}
)
for val, in cursor: # only one
string = str(val)
string += f' #{entry}'
# Add source if any
cursor.execute('SELECT source FROM rules WHERE id=?', (entry,))
for source, in cursor:
if source:
string += f'{self.explain(source)}'
return string
def export(self,
first_party_only: bool = False,
end_chain_only: bool = False,
explain: bool = False,
) -> typing.Iterable[str]:
selection = 'entry' if explain else 'unpack_domain(val)'
command = f'SELECT {selection} FROM rules ' \
'INNER JOIN hostname ON rules.id = hostname.entry'
restrictions: typing.List[str] = list()
if first_party_only:
restrictions.append('rules.first_party = 1')
if end_chain_only:
restrictions.append('rules.refs = 0')
if restrictions:
command += ' WHERE ' + ' AND '.join(restrictions)
if not explain:
command += ' ORDER BY unpack_domain(val) ASC'
cursor = self.conn.cursor()
cursor.execute(command)
for val, in cursor:
if explain:
yield self.explain(val)
else:
yield val
2019-12-13 18:36:08 +01:00
def count_rules(self,
first_party_only: bool = False,
) -> str:
counts: typing.List[str] = list()
cursor = self.conn.cursor()
for table in ['asn', 'ip4network', 'ip4address', 'zone', 'hostname']:
command = f'SELECT count(*) FROM rules ' \
f'INNER JOIN {table} ON rules.id = {table}.entry ' \
'WHERE rules.level = 0'
if first_party_only:
command += ' AND first_party=1'
cursor.execute(command)
count, = cursor.fetchone()
if count > 0:
counts.append(f'{table}: {count}')
return ', '.join(counts)
def get_domain(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domain_prepare')
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domain_select')
cursor.execute(
'SELECT null, entry FROM hostname '
'WHERE val=:d '
'UNION '
'SELECT * FROM ('
'SELECT val, entry FROM zone '
# 'WHERE val>=:d '
# 'ORDER BY val ASC LIMIT 1'
'WHERE val<=:d '
'AND instr(:d, val) = 1'
')',
{'d': domain_prep}
)
for val, entry in cursor:
# print(293, val, entry)
self.enter_step('get_domain_confirm')
if not (val is None or domain_prep.startswith(val)):
# print(297)
continue
self.enter_step('get_domain_yield')
yield entry
def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domainiz_prepare')
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domainiz_select')
cursor.execute(
'SELECT val, entry FROM zone '
'WHERE val<=:d '
'ORDER BY val DESC LIMIT 1',
{'d': domain_prep}
)
for val, entry in cursor:
self.enter_step('get_domainiz_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
self.enter_step('get_domainiz_yield')
yield entry
def get_ip4(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4_prepare')
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return
cursor = self.conn.cursor()
self.enter_step('get_ip4_select')
cursor.execute(
'SELECT entry FROM ip4address '
# 'SELECT null, entry FROM ip4address '
'WHERE val=:a '
'UNION '
# 'SELECT * FROM ('
# 'SELECT val, entry FROM ip4network '
# 'WHERE val<=:a '
# 'AND instr(:a, val) > 0 '
# 'ORDER BY val DESC'
# ')'
'SELECT entry FROM ip4network '
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for entry, in cursor:
# self.enter_step('get_ip4_confirm')
# if not (val is None or val.startswith(address_prep)):
# # PERF startswith but from the end
# continue
self.enter_step('get_ip4_yield')
yield entry
def get_ip4_in_network(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4in_prepare')
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return
cursor = self.conn.cursor()
self.enter_step('get_ip4in_select')
cursor.execute(
'SELECT entry FROM ip4network '
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for entry, in cursor:
self.enter_step('get_ip4in_yield')
yield entry
def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]:
cursor = self.conn.cursor()
self.enter_step('list_asn_select')
cursor.execute('SELECT val, entry FROM asn')
for val, entry in cursor:
yield f'AS{val}', entry
def _set_generic(self,
table: str,
select_query: str,
insert_query: str,
prep: typing.Dict[str, DbValue],
updated: int,
is_first_party: bool = False,
source: int = None,
) -> None:
# Since this isn't the bulk of the processing,
# here abstraction > performaces
# Fields based on the source
self.enter_step(f'set_{table}_prepare')
cursor = self.conn.cursor()
if source is None:
first_party = int(is_first_party)
level = 0
else:
self.enter_step(f'set_{table}_source')
cursor.execute(
'SELECT first_party, level FROM rules '
'WHERE id=?',
(source,)
)
first_party, level = cursor.fetchone()
level += 1
self.enter_step(f'set_{table}_select')
cursor.execute(select_query, prep)
rules_prep: typing.Dict[str, DbValue] = {
"source": source,
"updated": updated,
"first_party": first_party,
"level": level,
}
# If the entry already exists
for entry, in cursor: # only one
self.enter_step(f'set_{table}_update')
rules_prep['entry'] = entry
cursor.execute(
'UPDATE rules SET '
'source=:source, updated=:updated, '
'first_party=:first_party, level=:level '
'WHERE id=:entry AND (updated<:updated OR '
'first_party<:first_party OR level<:level)',
rules_prep
)
# Only update if any of the following:
# - the entry is outdataed
# - the entry was not a first_party but this is
# - this is closer to the original rule
return
# If it does not exist
self.enter_step(f'set_{table}_insert')
cursor.execute(
'INSERT INTO rules '
'(source, updated, first_party, level) '
'VALUES (:source, :updated, :first_party, :level) ',
rules_prep
)
cursor.execute('SELECT id FROM rules WHERE rowid=?',
(cursor.lastrowid,))
for entry, in cursor: # only one
prep['entry'] = entry
cursor.execute(insert_query, prep)
return
assert False
def set_hostname(self, hostname: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_hostname_prepare')
prep: typing.Dict[str, DbValue] = {
'val': self.pack_hostname(hostname),
}
self._set_generic(
'hostname',
'SELECT entry FROM hostname WHERE val=:val',
'INSERT INTO hostname (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_asn(self, asn: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_asn_prepare')
try:
asn_prep = self.pack_asn(asn)
except ValueError:
self.log.error("Invalid asn: %s", asn)
return
prep: typing.Dict[str, DbValue] = {
'val': asn_prep,
}
self._set_generic(
'asn',
'SELECT entry FROM asn WHERE val=:val',
'INSERT INTO asn (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_ip4address(self, ip4address: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_ip4add_prepare')
try:
ip4address_prep = self.pack_ip4address(ip4address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", ip4address)
return
prep: typing.Dict[str, DbValue] = {
'val': ip4address_prep,
}
self._set_generic(
'ip4add',
'SELECT entry FROM ip4address WHERE val=:val',
'INSERT INTO ip4address (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_zone(self, zone: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_zone_prepare')
prep: typing.Dict[str, DbValue] = {
'val': self.pack_zone(zone),
}
self._set_generic(
'zone',
'SELECT entry FROM zone WHERE val=:val',
'INSERT INTO zone (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_ip4network(self, ip4network: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_ip4net_prepare')
try:
ip4network_prep = self.pack_ip4network(ip4network)
except (ValueError, IndexError):
self.log.error("Invalid ip4network: %s", ip4network)
return
prep: typing.Dict[str, DbValue] = {
'mini': ip4network_prep[0],
'maxi': ip4network_prep[1],
}
self._set_generic(
'ip4net',
'SELECT entry FROM ip4network WHERE mini=:mini AND maxi=:maxi',
'INSERT INTO ip4network (mini, maxi, entry) '
'VALUES (:mini, :maxi, :entry)',
prep,
*args, **kwargs
)
if __name__ == '__main__':
# Parsing arguments
parser = argparse.ArgumentParser(
description="Database operations")
parser.add_argument(
'-i', '--initialize', action='store_true',
help="Reconstruct the whole database")
parser.add_argument(
'-p', '--prune', action='store_true',
2019-12-14 16:04:19 +01:00
help="Remove old entries from database")
parser.add_argument(
'-b', '--prune-base', action='store_true',
help="TODO")
parser.add_argument(
'-s', '--prune-before', type=int,
default=(int(time.time()) - 60*60*24*31*6),
help="TODO")
parser.add_argument(
'-r', '--references', action='store_true',
help="Update the reference count")
args = parser.parse_args()
2019-12-13 12:35:05 +01:00
DB = Database(write=True)
if args.initialize:
DB.initialize()
if args.prune:
2019-12-14 16:04:19 +01:00
DB.prune(before=args.prune_before, base_only=args.prune_base)
if args.references:
DB.update_references()
DB.close()