eulaurarien/database.py

261 lines
7.1 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
import sqlite3
import os
import argparse
import typing
import ipaddress
import enum
import time
import pprint
"""
Utility functions to interact with the database.
"""
VERSION = 1
PATH = f"blocking.db"
CONN = None
C = None # Cursor
TIME_DICT: typing.Dict[str, float] = dict()
TIME_LAST = time.perf_counter()
TIME_STEP = 'start'
def time_step(step: str) -> None:
global TIME_LAST
global TIME_STEP
now = time.perf_counter()
TIME_DICT.setdefault(TIME_STEP, 0.0)
TIME_DICT[TIME_STEP] += now - TIME_LAST
TIME_STEP = step
TIME_LAST = time.perf_counter()
def time_print() -> None:
time_step('postprint')
total = sum(TIME_DICT.values())
for key, secs in sorted(TIME_DICT.items(), key=lambda t: t[1]):
print(f"{key:<20}: {secs/total:7.2%} = {secs:.6f} s")
print(f"{'total':<20}: {1:7.2%} = {total:.6f} s")
class RowType(enum.Enum):
AS = 1
DomainTree = 2
Domain = 3
IPv4Network = 4
IPv6Network = 6
def open_db() -> None:
time_step('open_db')
global CONN
global C
CONN = sqlite3.connect(PATH)
C = CONN.cursor()
# C.execute("PRAGMA foreign_keys = ON");
initialized = False
try:
C.execute("SELECT value FROM meta WHERE key='version'")
version_ex = C.fetchone()
if version_ex:
if version_ex[0] == VERSION:
initialized = True
else:
print(f"Database version {version_ex[0]} found,"
"it will be deleted.")
except sqlite3.OperationalError:
pass
if not initialized:
time_step('init_db')
print(f"Creating database version {VERSION}.")
CONN.close()
os.unlink(PATH)
CONN = sqlite3.connect(PATH)
C = CONN.cursor()
with open("database_schema.sql", 'r') as db_schema:
C.executescript(db_schema.read())
C.execute("INSERT INTO meta VALUES ('version', ?)", (VERSION,))
CONN.commit()
time_step('other')
def close_db() -> None:
assert CONN
time_step('close_db_commit')
CONN.commit()
time_step('close_db')
CONN.close()
time_step('other')
time_print()
def refresh() -> None:
assert C
C.execute('UPDATE blocking SET updated = 0')
# TODO PERF Use a meta value instead
RULE_SUBDOMAIN_COMMAND = \
'INSERT INTO blocking (key, type, updated, firstparty) ' \
f'VALUES (?, {RowType.DomainTree.value}, 1, ?) ' \
'ON CONFLICT(key)' \
f'DO UPDATE SET source=null, type={RowType.DomainTree.value}, ' \
'updated=1, firstparty=?'
def feed_rule_subdomains(subdomain: str, first_party: bool = False) -> None:
assert C
subdomain = subdomain[::-1]
C.execute(RULE_SUBDOMAIN_COMMAND,
(subdomain, int(first_party), int(first_party)))
# Since regex type takes precedence over domain type,
# and firstparty takes precedence over multiparty,
# we can afford to replace the whole row without checking
# the row without checking previous values and making sure
# firstparty subdomains are updated last
def ip_get_bits(address: ipaddress.IPv4Address) -> typing.Iterator[int]:
for char in address.packed:
for i in range(7, -1, -1):
yield (char >> i) & 0b1
def ip_flat(address: ipaddress.IPv4Address) -> str:
return ''.join(map(str, ip_get_bits(address)))
def ip4_flat(address: str) -> str:
return '{:08b}{:08b}{:08b}{:08b}'.format(
*[int(c) for c in address.split('.')])
RULE_IP4NETWORK_COMMAND = \
'INSERT INTO blocking (key, type, updated, firstparty) ' \
f'VALUES (?, {RowType.IPv4Network.value}, 1, ?) ' \
'ON CONFLICT(key)' \
f'DO UPDATE SET source=null, type={RowType.IPv4Network.value}, ' \
'updated=1, firstparty=?'
def feed_rule_ip4network(network: ipaddress.IPv4Network,
first_party: bool = False) -> None:
assert C
flat = ip_flat(network.network_address)[:network.prefixlen]
C.execute(RULE_IP4NETWORK_COMMAND,
(flat, int(first_party), int(first_party)))
FEED_A_COMMAND_FETCH = \
'SELECT key, firstparty FROM blocking ' \
'WHERE key<=? ' \
'AND updated=1 ' \
f'AND type={RowType.IPv4Network.value} ' \
'ORDER BY key DESC ' \
'LIMIT 1'
FEED_A_COMMAND_UPSERT = \
'INSERT INTO blocking (key, source, type, updated, firstparty) ' \
f'VALUES (?, ?, {RowType.Domain.value}, 1, ?)' \
'ON CONFLICT(key)' \
f'DO UPDATE SET source=?, type={RowType.Domain.value}, ' \
'updated=1, firstparty=? ' \
'WHERE updated=0 OR firstparty<?'
def feed_a(name: str, value_ip: str) -> None:
assert C
assert CONN
time_step('a_flat')
try:
value = ip4_flat(value_ip)
except (ValueError, IndexError):
# Malformed IPs
return
time_step('a_fetch')
C.execute(FEED_A_COMMAND_FETCH, (value,))
base = C.fetchone()
time_step('a_fetch_confirm')
if not base:
return
b_key, b_firstparty = base
if not value.startswith(b_key):
return
name = name[::-1]
time_step('a_upsert')
C.execute(FEED_A_COMMAND_UPSERT,
(name, b_key, b_firstparty, # Insert
b_key, b_firstparty, b_firstparty) # Update
)
time_step('other')
FEED_CNAME_COMMAND_FETCH = \
'SELECT key, type, firstparty FROM blocking ' \
'WHERE key<=? ' \
f'AND (type={RowType.DomainTree.value} OR type={RowType.Domain.value}) ' \
'AND updated=1 ' \
'ORDER BY key DESC ' \
'LIMIT 1'
# f'WHERE ((type={RowType.DomainTree.value} AND key<=?) OR ' \
# f'(type={RowType.Domain.value} AND key=?)) ' \
# This optimisation is counter productive
FEED_CNAME_COMMAND_UPSERT = \
'INSERT INTO blocking (key, source, type, updated, firstparty) ' \
f'VALUES (?, ?, {RowType.Domain.value}, 1, ?)' \
'ON CONFLICT(key)' \
f'DO UPDATE SET source=?, type={RowType.Domain.value}, ' \
'updated=1, firstparty=? ' \
'WHERE updated=0 OR firstparty<?'
def feed_cname(name: str, value: str) -> None:
assert C
assert CONN
value = value[::-1]
time_step('cname_fetch')
C.execute(FEED_CNAME_COMMAND_FETCH, (value,))
base = C.fetchone()
time_step('cname_fetch_confirm')
if not base:
# Should only happen at an extremum of the database
return
b_key, b_type, b_firstparty = base
matching = b_key == value[:len(b_key)] and (
len(value) == len(b_key)
or (
b_type == RowType.DomainTree.value
and value[len(b_key)] == '.'
)
)
if not matching:
return
name = name[::-1]
time_step('cname_upsert')
C.execute(FEED_CNAME_COMMAND_UPSERT,
(name, b_key, b_firstparty, # Insert
b_key, b_firstparty, b_firstparty) # Update
)
time_step('other')
if __name__ == '__main__':
# Parsing arguments
parser = argparse.ArgumentParser(
description="Database operations")
parser.add_argument(
'-r', '--refresh', action='store_true',
help="Set the whole database as an old source")
args = parser.parse_args()
open_db()
if args.refresh:
refresh()
close_db()