Workflow: SQL -> Tree

Welp. All that for this.
This commit is contained in:
Geoffrey Frogeye 2019-12-15 15:56:26 +01:00
parent 040ce4c14e
commit 4d966371b2
Signed by: geoffrey
GPG key ID: D8A7ECA00A8CD3DD
8 changed files with 296 additions and 699 deletions

728
database.py Executable file → Normal file
View file

@ -4,111 +4,59 @@
Utility functions to interact with the database.
"""
import sqlite3
import typing
import time
import os
import logging
import argparse
import coloredlogs
import ipaddress
import math
import pickle
import enum
coloredlogs.install(
level='DEBUG',
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
)
DbValue = typing.Union[None, int, float, str, bytes]
PathType = enum.Enum('PathType', 'Rule Hostname Zone Asn Ip4 Ip6')
RulePath = typing.Union[None]
Asn = int
DomainPath = typing.List[str]
Ip4Path = typing.List[int]
Ip6Path = typing.List[int]
Path = typing.Union[RulePath, DomainPath, Asn, Ip4Path, Ip6Path]
TypedPath = typing.Tuple[PathType, Path]
Timestamp = int
Level = int
Match = typing.Tuple[Timestamp, TypedPath, Level]
DebugPath = (PathType.Rule, None)
class Database():
VERSION = 5
PATH = "blocking.db"
class DomainTreeNode():
def __init__(self) -> None:
self.children: typing.Dict[str, DomainTreeNode] = dict()
self.match_zone: typing.Optional[Match] = None
self.match_hostname: typing.Optional[Match] = None
def open(self) -> None:
mode = 'rwc' if self.write else 'ro'
uri = f'file:{self.PATH}?mode={mode}'
self.conn = sqlite3.connect(uri, uri=True)
cursor = self.conn.cursor()
cursor.execute("PRAGMA foreign_keys = ON")
self.conn.create_function("unpack_asn", 1,
self.unpack_asn,
deterministic=True)
self.conn.create_function("unpack_ip4address", 1,
self.unpack_ip4address,
deterministic=True)
self.conn.create_function("unpack_ip4network", 2,
self.unpack_ip4network,
deterministic=True)
self.conn.create_function("unpack_domain", 1,
lambda s: s[:-1][::-1],
deterministic=True)
self.conn.create_function("format_zone", 1,
lambda s: '*' + s[::-1],
deterministic=True)
def get_meta(self, key: str) -> typing.Optional[int]:
cursor = self.conn.cursor()
try:
cursor.execute("SELECT value FROM meta WHERE key=?", (key,))
except sqlite3.OperationalError:
return None
for ver, in cursor:
return ver
return None
class IpTreeNode():
def __init__(self) -> None:
self.children: typing.List[typing.Optional[IpTreeNode]] = [None, None]
self.match: typing.Optional[Match] = None
def set_meta(self, key: str, val: int) -> None:
cursor = self.conn.cursor()
cursor.execute("INSERT INTO meta VALUES (?, ?) "
"ON CONFLICT (key) DO "
"UPDATE set value=?",
(key, val, val))
def close(self) -> None:
self.enter_step('close_commit')
self.conn.commit()
self.enter_step('close')
self.conn.close()
self.profile()
def initialize(self) -> None:
self.close()
self.enter_step('initialize')
if not self.write:
self.log.error("Cannot initialize in read-only mode.")
raise
os.unlink(self.PATH)
self.open()
self.log.info("Creating database version %d.", self.VERSION)
cursor = self.conn.cursor()
with open("database_schema.sql", 'r') as db_schema:
cursor.executescript(db_schema.read())
self.set_meta('version', self.VERSION)
self.conn.commit()
def __init__(self, write: bool = False) -> None:
self.log = logging.getLogger('db')
class Profiler():
def __init__(self) -> None:
self.log = logging.getLogger('profiler')
self.time_last = time.perf_counter()
self.time_step = 'init'
self.time_dict: typing.Dict[str, float] = dict()
self.step_dict: typing.Dict[str, int] = dict()
self.write = write
self.open()
version = self.get_meta('version')
if version != self.VERSION:
if version is not None:
self.log.warning(
"Outdated database version: %d found, will be rebuilt.",
version)
self.initialize()
def enter_step(self, name: str) -> None:
now = time.perf_counter()
try:
self.time_dict[self.time_step] += now - self.time_last
self.step_dict[self.time_step] += 1
self.step_dict[self.time_step] += int(name != self.time_step)
except KeyError:
self.time_dict[self.time_step] = now - self.time_last
self.step_dict[self.time_step] = 1
@ -125,13 +73,58 @@ class Database():
self.log.debug(f"{'total':<20}: "
f"{total:9.2f} s ({1:7.2%})")
@staticmethod
def pack_hostname(hostname: str) -> str:
return hostname[::-1] + '.'
class Database(Profiler):
VERSION = 8
PATH = "blocking.p"
def initialize(self) -> None:
self.log.warning(
"Creating database version: %d ",
Database.VERSION)
self.domtree = DomainTreeNode()
self.asns: typing.Set[Asn] = set()
self.ip4tree = IpTreeNode()
def load(self) -> None:
self.enter_step('load')
try:
with open(self.PATH, 'rb') as db_fdsec:
version, data = pickle.load(db_fdsec)
if version == Database.VERSION:
self.domtree, self.asns, self.ip4tree = data
return
self.log.warning(
"Outdated database version found: %d, "
"will be rebuilt.",
version)
except (TypeError, AttributeError, EOFError):
self.log.error(
"Corrupt database found, "
"will be rebuilt.")
except FileNotFoundError:
pass
self.initialize()
def save(self) -> None:
self.enter_step('save')
with open(self.PATH, 'wb') as db_fdsec:
data = self.domtree, self.asns, self.ip4tree
pickle.dump((self.VERSION, data), db_fdsec)
self.profile()
def __init__(self) -> None:
Profiler.__init__(self)
self.log = logging.getLogger('db')
self.load()
@staticmethod
def pack_zone(zone: str) -> str:
return Database.pack_hostname(zone)
def pack_domain(domain: str) -> DomainPath:
return domain.split('.')[::-1]
@staticmethod
def unpack_domain(domain: DomainPath) -> str:
return '.'.join(domain[::-1])
@staticmethod
def pack_asn(asn: str) -> int:
@ -145,431 +138,208 @@ class Database():
return f'AS{asn}'
@staticmethod
def pack_ip4address(address: str) -> int:
total = 0
for i, octet in enumerate(address.split('.')):
total += int(octet) << (3-i)*8
if total > 0xFFFFFFFF:
raise ValueError
return total
# return '{:02x}{:02x}{:02x}{:02x}'.format(
# *[int(c) for c in address.split('.')])
# return base64.b16encode(packed).decode()
# return '{:08b}{:08b}{:08b}{:08b}'.format(
# *[int(c) for c in address.split('.')])
# carg = ctypes.c_wchar_p(address)
# ret = ACCEL.ip4_flat(carg, self.accel_ip4_buf)
# if ret != 0:
# raise ValueError
# return self.accel_ip4_buf.value
# packed = ipaddress.ip_address(address).packed
# return packed
def pack_ip4address(address: str) -> Ip4Path:
addr: Ip4Path = [0] * 32
octets = [int(octet) for octet in address.split('.')]
for b in range(32):
if (octets[b//8] >> b % 8) & 0b1:
addr[b] = 1
return addr
@staticmethod
def unpack_ip4address(address: int) -> str:
return '.'.join(str((address >> (i * 8)) & 0xFF)
for i in reversed(range(4)))
def unpack_ip4address(address: Ip4Path) -> str:
octets = [0] * 4
for b, bit in enumerate(address):
octets[b//8] = (octets[b//8] << 1) + bit
return '.'.join(map(str, octets))
@staticmethod
def pack_ip4network(network: str) -> typing.Tuple[int, int]:
# def pack_ip4network(network: str) -> str:
net = ipaddress.ip_network(network)
mini = Database.pack_ip4address(net.network_address.exploded)
maxi = Database.pack_ip4address(net.broadcast_address.exploded)
# mini = net.network_address.packed
# maxi = net.broadcast_address.packed
return mini, maxi
# return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen]
def pack_ip4network(network: str) -> Ip4Path:
address, prefixlen_str = network.split('/')
prefixlen = int(prefixlen_str)
return Database.pack_ip4address(address)[:prefixlen]
@staticmethod
def unpack_ip4network(mini: int, maxi: int) -> str:
addr = Database.unpack_ip4address(mini)
prefixlen = 32-int(math.log2(maxi-mini+1))
def unpack_ip4network(network: Ip4Path) -> str:
address = network.copy()
prefixlen = len(network)
for _ in range(32-prefixlen):
address.append(0)
addr = Database.unpack_ip4address(address)
return f'{addr}/{prefixlen}'
def update_references(self) -> None:
self.enter_step('update_refs')
cursor = self.conn.cursor()
cursor.execute('UPDATE rules AS r SET refs='
'(SELECT count(*) FROM rules '
'WHERE source=r.id)')
raise NotImplementedError
def prune(self, before: int, base_only: bool = False) -> None:
self.enter_step('prune')
cursor = self.conn.cursor()
cmd = 'DELETE FROM rules WHERE updated<?'
if base_only:
cmd += ' AND level=0'
cursor.execute(cmd, (before,))
raise NotImplementedError
def explain(self, entry: int) -> str:
# Format current
string = '???'
cursor = self.conn.cursor()
cursor.execute(
'SELECT unpack_asn(val) FROM asn WHERE entry=:entry '
'UNION '
'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry '
'UNION '
'SELECT format_zone(val) FROM zone WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4network(mini, maxi) '
'FROM ip4network WHERE entry=:entry ',
{"entry": entry}
)
for val, in cursor: # only one
string = str(val)
string += f' #{entry}'
# Add source if any
cursor.execute('SELECT source FROM rules WHERE id=?', (entry,))
for source, in cursor:
if source:
string += f'{self.explain(source)}'
return string
raise NotImplementedError
def export(self,
first_party_only: bool = False,
end_chain_only: bool = False,
explain: bool = False,
_dic: DomainTreeNode = None,
_par: DomainPath = None,
) -> typing.Iterable[str]:
selection = 'entry' if explain else 'unpack_domain(val)'
command = f'SELECT {selection} FROM rules ' \
'INNER JOIN hostname ON rules.id = hostname.entry'
restrictions: typing.List[str] = list()
if first_party_only:
restrictions.append('rules.first_party = 1')
if end_chain_only:
restrictions.append('rules.refs = 0')
if restrictions:
command += ' WHERE ' + ' AND '.join(restrictions)
if not explain:
command += ' ORDER BY unpack_domain(val) ASC'
cursor = self.conn.cursor()
cursor.execute(command)
for val, in cursor:
if explain:
yield self.explain(val)
else:
yield val
if first_party_only or end_chain_only or explain:
raise NotImplementedError
_dic = _dic or self.domtree
_par = _par or list()
if _dic.match_hostname:
yield self.unpack_domain(_par)
for part in _dic.children:
dic = _dic.children[part]
yield from self.export(_dic=dic,
_par=_par + [part])
def count_rules(self,
first_party_only: bool = False,
) -> str:
counts: typing.List[str] = list()
cursor = self.conn.cursor()
for table in ['asn', 'ip4network', 'ip4address', 'zone', 'hostname']:
command = f'SELECT count(*) FROM rules ' \
f'INNER JOIN {table} ON rules.id = {table}.entry ' \
'WHERE rules.level = 0'
if first_party_only:
command += ' AND first_party=1'
cursor.execute(command)
count, = cursor.fetchone()
if count > 0:
counts.append(f'{table}: {count}')
raise NotImplementedError
return ', '.join(counts)
def get_domain(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domain_prepare')
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domain_select')
cursor.execute(
'SELECT null, entry FROM hostname '
'WHERE val=:d '
'UNION '
'SELECT * FROM ('
'SELECT val, entry FROM zone '
# 'WHERE val>=:d '
# 'ORDER BY val ASC LIMIT 1'
'WHERE val<=:d '
'AND instr(:d, val) = 1'
')',
{'d': domain_prep}
)
for val, entry in cursor:
# print(293, val, entry)
self.enter_step('get_domain_confirm')
if not (val is None or domain_prep.startswith(val)):
# print(297)
continue
def get_domain(self, domain_str: str) -> typing.Iterable[TypedPath]:
self.enter_step('get_domain_pack')
domain = self.pack_domain(domain_str)
self.enter_step('get_domain_brws')
dic = self.domtree
depth = 0
for part in domain:
if dic.match_zone:
self.enter_step('get_domain_yield')
yield (PathType.Zone, domain[:depth])
self.enter_step('get_domain_brws')
if part not in dic.children:
return
dic = dic.children[part]
depth += 1
if dic.match_zone:
self.enter_step('get_domain_yield')
yield entry
yield (PathType.Zone, domain)
if dic.match_hostname:
self.enter_step('get_domain_yield')
yield (PathType.Hostname, domain)
def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domainiz_prepare')
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domainiz_select')
cursor.execute(
'SELECT val, entry FROM zone '
'WHERE val<=:d '
'ORDER BY val DESC LIMIT 1',
{'d': domain_prep}
)
for val, entry in cursor:
self.enter_step('get_domainiz_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
self.enter_step('get_domainiz_yield')
yield entry
def get_ip4(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4_prepare')
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return
cursor = self.conn.cursor()
self.enter_step('get_ip4_select')
cursor.execute(
'SELECT entry FROM ip4address '
# 'SELECT null, entry FROM ip4address '
'WHERE val=:a '
'UNION '
# 'SELECT * FROM ('
# 'SELECT val, entry FROM ip4network '
# 'WHERE val<=:a '
# 'AND instr(:a, val) > 0 '
# 'ORDER BY val DESC'
# ')'
'SELECT entry FROM ip4network '
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for entry, in cursor:
# self.enter_step('get_ip4_confirm')
# if not (val is None or val.startswith(address_prep)):
# # PERF startswith but from the end
# continue
def get_ip4(self, ip4_str: str) -> typing.Iterable[TypedPath]:
self.enter_step('get_ip4_pack')
ip4 = self.pack_ip4address(ip4_str)
self.enter_step('get_ip4_brws')
dic = self.ip4tree
depth = 0
for part in ip4:
if dic.match:
self.enter_step('get_ip4_yield')
yield (PathType.Ip4, ip4[:depth])
self.enter_step('get_ip4_brws')
next_dic = dic.children[part]
if next_dic is None:
return
dic = next_dic
depth += 1
if dic.match:
self.enter_step('get_ip4_yield')
yield entry
yield (PathType.Ip4, ip4)
def get_ip4_in_network(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4in_prepare')
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return
cursor = self.conn.cursor()
self.enter_step('get_ip4in_select')
cursor.execute(
'SELECT entry FROM ip4network '
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for entry, in cursor:
self.enter_step('get_ip4in_yield')
yield entry
def list_asn(self) -> typing.Iterable[TypedPath]:
for asn in self.asns:
yield (PathType.Asn, asn)
def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]:
cursor = self.conn.cursor()
self.enter_step('list_asn_select')
cursor.execute('SELECT val, entry FROM asn')
for val, entry in cursor:
yield f'AS{val}', entry
def _set_generic(self,
table: str,
select_query: str,
insert_query: str,
prep: typing.Dict[str, DbValue],
def set_hostname(self,
hostname_str: str,
updated: int,
is_first_party: bool = False,
source: int = None,
) -> None:
# Since this isn't the bulk of the processing,
# here abstraction > performaces
is_first_party: bool = None,
source: TypedPath = None) -> None:
self.enter_step('set_hostname_pack')
if is_first_party or source:
raise NotImplementedError
self.enter_step('set_hostname_brws')
hostname = self.pack_domain(hostname_str)
dic = self.domtree
for part in hostname:
if dic.match_zone:
# Refuse to add hostname whose zone is already matching
return
if part not in dic.children:
dic.children[part] = DomainTreeNode()
dic = dic.children[part]
dic.match_hostname = (updated, DebugPath, 0)
# Fields based on the source
self.enter_step(f'set_{table}_prepare')
cursor = self.conn.cursor()
if source is None:
first_party = int(is_first_party)
level = 0
else:
self.enter_step(f'set_{table}_source')
cursor.execute(
'SELECT first_party, level FROM rules '
'WHERE id=?',
(source,)
)
first_party, level = cursor.fetchone()
level += 1
def set_zone(self,
zone_str: str,
updated: int,
is_first_party: bool = None,
source: TypedPath = None) -> None:
self.enter_step('set_zone_pack')
if is_first_party or source:
raise NotImplementedError
zone = self.pack_domain(zone_str)
self.enter_step('set_zone_brws')
dic = self.domtree
for part in zone:
if dic.match_zone:
# Refuse to add zone whose parent zone is already matching
return
if part not in dic.children:
dic.children[part] = DomainTreeNode()
dic = dic.children[part]
dic.match_zone = (updated, DebugPath, 0)
self.enter_step(f'set_{table}_select')
cursor.execute(select_query, prep)
def set_asn(self,
asn_str: str,
updated: int,
is_first_party: bool = None,
source: TypedPath = None) -> None:
self.enter_step('set_asn_pack')
if is_first_party or source:
# TODO updated
raise NotImplementedError
asn = self.pack_asn(asn_str)
self.enter_step('set_asn_brws')
self.asns.add(asn)
rules_prep: typing.Dict[str, DbValue] = {
"source": source,
"updated": updated,
"first_party": first_party,
"level": level,
}
def set_ip4address(self,
ip4address_str: str,
updated: int,
is_first_party: bool = None,
source: TypedPath = None) -> None:
self.enter_step('set_ip4add_pack')
if is_first_party or source:
raise NotImplementedError
self.enter_step('set_ip4add_brws')
ip4address = self.pack_ip4address(ip4address_str)
dic = self.ip4tree
for part in ip4address:
if dic.match:
# Refuse to add ip4address whose network is already matching
return
next_dic = dic.children[part]
if next_dic is None:
next_dic = IpTreeNode()
dic.children[part] = next_dic
dic = next_dic
dic.match = (updated, DebugPath, 0)
# If the entry already exists
for entry, in cursor: # only one
self.enter_step(f'set_{table}_update')
rules_prep['entry'] = entry
cursor.execute(
'UPDATE rules SET '
'source=:source, updated=:updated, '
'first_party=:first_party, level=:level '
'WHERE id=:entry AND (updated<:updated OR '
'first_party<:first_party OR level<:level)',
rules_prep
)
# Only update if any of the following:
# - the entry is outdataed
# - the entry was not a first_party but this is
# - this is closer to the original rule
return
# If it does not exist
self.enter_step(f'set_{table}_insert')
cursor.execute(
'INSERT INTO rules '
'(source, updated, first_party, level) '
'VALUES (:source, :updated, :first_party, :level) ',
rules_prep
)
cursor.execute('SELECT id FROM rules WHERE rowid=?',
(cursor.lastrowid,))
for entry, in cursor: # only one
prep['entry'] = entry
cursor.execute(insert_query, prep)
return
assert False
def set_hostname(self, hostname: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_hostname_prepare')
prep: typing.Dict[str, DbValue] = {
'val': self.pack_hostname(hostname),
}
self._set_generic(
'hostname',
'SELECT entry FROM hostname WHERE val=:val',
'INSERT INTO hostname (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_asn(self, asn: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_asn_prepare')
try:
asn_prep = self.pack_asn(asn)
except ValueError:
self.log.error("Invalid asn: %s", asn)
return
prep: typing.Dict[str, DbValue] = {
'val': asn_prep,
}
self._set_generic(
'asn',
'SELECT entry FROM asn WHERE val=:val',
'INSERT INTO asn (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_ip4address(self, ip4address: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_ip4add_prepare')
try:
ip4address_prep = self.pack_ip4address(ip4address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", ip4address)
return
prep: typing.Dict[str, DbValue] = {
'val': ip4address_prep,
}
self._set_generic(
'ip4add',
'SELECT entry FROM ip4address WHERE val=:val',
'INSERT INTO ip4address (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_zone(self, zone: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_zone_prepare')
prep: typing.Dict[str, DbValue] = {
'val': self.pack_zone(zone),
}
self._set_generic(
'zone',
'SELECT entry FROM zone WHERE val=:val',
'INSERT INTO zone (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_ip4network(self, ip4network: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_ip4net_prepare')
try:
ip4network_prep = self.pack_ip4network(ip4network)
except (ValueError, IndexError):
self.log.error("Invalid ip4network: %s", ip4network)
return
prep: typing.Dict[str, DbValue] = {
'mini': ip4network_prep[0],
'maxi': ip4network_prep[1],
}
self._set_generic(
'ip4net',
'SELECT entry FROM ip4network WHERE mini=:mini AND maxi=:maxi',
'INSERT INTO ip4network (mini, maxi, entry) '
'VALUES (:mini, :maxi, :entry)',
prep,
*args, **kwargs
)
if __name__ == '__main__':
# Parsing arguments
parser = argparse.ArgumentParser(
description="Database operations")
parser.add_argument(
'-i', '--initialize', action='store_true',
help="Reconstruct the whole database")
parser.add_argument(
'-p', '--prune', action='store_true',
help="Remove old entries from database")
parser.add_argument(
'-b', '--prune-base', action='store_true',
help="TODO")
parser.add_argument(
'-s', '--prune-before', type=int,
default=(int(time.time()) - 60*60*24*31*6),
help="TODO")
parser.add_argument(
'-r', '--references', action='store_true',
help="Update the reference count")
args = parser.parse_args()
DB = Database(write=True)
if args.initialize:
DB.initialize()
if args.prune:
DB.prune(before=args.prune_before, base_only=args.prune_base)
if args.references:
DB.update_references()
DB.close()
def set_ip4network(self,
ip4network_str: str,
updated: int,
is_first_party: bool = None,
source: TypedPath = None) -> None:
self.enter_step('set_ip4net_pack')
if is_first_party or source:
raise NotImplementedError
self.enter_step('set_ip4net_brws')
ip4network = self.pack_ip4network(ip4network_str)
dic = self.ip4tree
for part in ip4network:
if dic.match:
# Refuse to add ip4network whose parent network
# is already matching
return
next_dic = dic.children[part]
if next_dic is None:
next_dic = IpTreeNode()
dic.children[part] = next_dic
dic = next_dic
dic.match = (updated, DebugPath, 0)