|
|
@ -4,111 +4,59 @@ |
|
|
|
Utility functions to interact with the database. |
|
|
|
""" |
|
|
|
|
|
|
|
import sqlite3 |
|
|
|
import typing |
|
|
|
import time |
|
|
|
import os |
|
|
|
import logging |
|
|
|
import argparse |
|
|
|
import coloredlogs |
|
|
|
import ipaddress |
|
|
|
import math |
|
|
|
import pickle |
|
|
|
import enum |
|
|
|
|
|
|
|
coloredlogs.install( |
|
|
|
level='DEBUG', |
|
|
|
fmt='%(asctime)s %(name)s %(levelname)s %(message)s' |
|
|
|
) |
|
|
|
|
|
|
|
DbValue = typing.Union[None, int, float, str, bytes] |
|
|
|
|
|
|
|
|
|
|
|
class Database(): |
|
|
|
VERSION = 5 |
|
|
|
PATH = "blocking.db" |
|
|
|
|
|
|
|
def open(self) -> None: |
|
|
|
mode = 'rwc' if self.write else 'ro' |
|
|
|
uri = f'file:{self.PATH}?mode={mode}' |
|
|
|
self.conn = sqlite3.connect(uri, uri=True) |
|
|
|
cursor = self.conn.cursor() |
|
|
|
cursor.execute("PRAGMA foreign_keys = ON") |
|
|
|
self.conn.create_function("unpack_asn", 1, |
|
|
|
self.unpack_asn, |
|
|
|
deterministic=True) |
|
|
|
self.conn.create_function("unpack_ip4address", 1, |
|
|
|
self.unpack_ip4address, |
|
|
|
deterministic=True) |
|
|
|
self.conn.create_function("unpack_ip4network", 2, |
|
|
|
self.unpack_ip4network, |
|
|
|
deterministic=True) |
|
|
|
self.conn.create_function("unpack_domain", 1, |
|
|
|
lambda s: s[:-1][::-1], |
|
|
|
deterministic=True) |
|
|
|
self.conn.create_function("format_zone", 1, |
|
|
|
lambda s: '*' + s[::-1], |
|
|
|
deterministic=True) |
|
|
|
|
|
|
|
def get_meta(self, key: str) -> typing.Optional[int]: |
|
|
|
cursor = self.conn.cursor() |
|
|
|
try: |
|
|
|
cursor.execute("SELECT value FROM meta WHERE key=?", (key,)) |
|
|
|
except sqlite3.OperationalError: |
|
|
|
return None |
|
|
|
for ver, in cursor: |
|
|
|
return ver |
|
|
|
return None |
|
|
|
|
|
|
|
def set_meta(self, key: str, val: int) -> None: |
|
|
|
cursor = self.conn.cursor() |
|
|
|
cursor.execute("INSERT INTO meta VALUES (?, ?) " |
|
|
|
"ON CONFLICT (key) DO " |
|
|
|
"UPDATE set value=?", |
|
|
|
(key, val, val)) |
|
|
|
|
|
|
|
def close(self) -> None: |
|
|
|
self.enter_step('close_commit') |
|
|
|
self.conn.commit() |
|
|
|
self.enter_step('close') |
|
|
|
self.conn.close() |
|
|
|
self.profile() |
|
|
|
PathType = enum.Enum('PathType', 'Rule Hostname Zone Asn Ip4 Ip6') |
|
|
|
RulePath = typing.Union[None] |
|
|
|
Asn = int |
|
|
|
DomainPath = typing.List[str] |
|
|
|
Ip4Path = typing.List[int] |
|
|
|
Ip6Path = typing.List[int] |
|
|
|
Path = typing.Union[RulePath, DomainPath, Asn, Ip4Path, Ip6Path] |
|
|
|
TypedPath = typing.Tuple[PathType, Path] |
|
|
|
Timestamp = int |
|
|
|
Level = int |
|
|
|
Match = typing.Tuple[Timestamp, TypedPath, Level] |
|
|
|
|
|
|
|
def initialize(self) -> None: |
|
|
|
self.close() |
|
|
|
self.enter_step('initialize') |
|
|
|
if not self.write: |
|
|
|
self.log.error("Cannot initialize in read-only mode.") |
|
|
|
raise |
|
|
|
os.unlink(self.PATH) |
|
|
|
self.open() |
|
|
|
self.log.info("Creating database version %d.", self.VERSION) |
|
|
|
cursor = self.conn.cursor() |
|
|
|
with open("database_schema.sql", 'r') as db_schema: |
|
|
|
cursor.executescript(db_schema.read()) |
|
|
|
self.set_meta('version', self.VERSION) |
|
|
|
self.conn.commit() |
|
|
|
|
|
|
|
def __init__(self, write: bool = False) -> None: |
|
|
|
self.log = logging.getLogger('db') |
|
|
|
DebugPath = (PathType.Rule, None) |
|
|
|
|
|
|
|
|
|
|
|
class DomainTreeNode(): |
|
|
|
def __init__(self) -> None: |
|
|
|
self.children: typing.Dict[str, DomainTreeNode] = dict() |
|
|
|
self.match_zone: typing.Optional[Match] = None |
|
|
|
self.match_hostname: typing.Optional[Match] = None |
|
|
|
|
|
|
|
|
|
|
|
class IpTreeNode(): |
|
|
|
def __init__(self) -> None: |
|
|
|
self.children: typing.List[typing.Optional[IpTreeNode]] = [None, None] |
|
|
|
self.match: typing.Optional[Match] = None |
|
|
|
|
|
|
|
|
|
|
|
class Profiler(): |
|
|
|
def __init__(self) -> None: |
|
|
|
self.log = logging.getLogger('profiler') |
|
|
|
self.time_last = time.perf_counter() |
|
|
|
self.time_step = 'init' |
|
|
|
self.time_dict: typing.Dict[str, float] = dict() |
|
|
|
self.step_dict: typing.Dict[str, int] = dict() |
|
|
|
self.write = write |
|
|
|
|
|
|
|
self.open() |
|
|
|
version = self.get_meta('version') |
|
|
|
if version != self.VERSION: |
|
|
|
if version is not None: |
|
|
|
self.log.warning( |
|
|
|
"Outdated database version: %d found, will be rebuilt.", |
|
|
|
version) |
|
|
|
self.initialize() |
|
|
|
|
|
|
|
def enter_step(self, name: str) -> None: |
|
|
|
now = time.perf_counter() |
|
|
|
try: |
|
|
|
self.time_dict[self.time_step] += now - self.time_last |
|
|
|
self.step_dict[self.time_step] += 1 |
|
|
|
self.step_dict[self.time_step] += int(name != self.time_step) |
|
|
|
except KeyError: |
|
|
|
self.time_dict[self.time_step] = now - self.time_last |
|
|
|
self.step_dict[self.time_step] = 1 |
|
|
@ -125,13 +73,58 @@ class Database(): |
|
|
|
self.log.debug(f"{'total':<20}: " |
|
|
|
f"{total:9.2f} s ({1:7.2%})") |
|
|
|
|
|
|
|
|
|
|
|
class Database(Profiler): |
|
|
|
VERSION = 8 |
|
|
|
PATH = "blocking.p" |
|
|
|
|
|
|
|
def initialize(self) -> None: |
|
|
|
self.log.warning( |
|
|
|
"Creating database version: %d ", |
|
|
|
Database.VERSION) |
|
|
|
self.domtree = DomainTreeNode() |
|
|
|
self.asns: typing.Set[Asn] = set() |
|
|
|
self.ip4tree = IpTreeNode() |
|
|
|
|
|
|
|
def load(self) -> None: |
|
|
|
self.enter_step('load') |
|
|
|
try: |
|
|
|
with open(self.PATH, 'rb') as db_fdsec: |
|
|
|
version, data = pickle.load(db_fdsec) |
|
|
|
if version == Database.VERSION: |
|
|
|
self.domtree, self.asns, self.ip4tree = data |
|
|
|
return |
|
|
|
self.log.warning( |
|
|
|
"Outdated database version found: %d, " |
|
|
|
"will be rebuilt.", |
|
|
|
version) |
|
|
|
except (TypeError, AttributeError, EOFError): |
|
|
|
self.log.error( |
|
|
|
"Corrupt database found, " |
|
|
|
"will be rebuilt.") |
|
|
|
except FileNotFoundError: |
|
|
|
pass |
|
|
|
self.initialize() |
|
|
|
|
|
|
|
def save(self) -> None: |
|
|
|
self.enter_step('save') |
|
|
|
with open(self.PATH, 'wb') as db_fdsec: |
|
|
|
data = self.domtree, self.asns, self.ip4tree |
|
|
|
pickle.dump((self.VERSION, data), db_fdsec) |
|
|
|
self.profile() |
|
|
|
|
|
|
|
def __init__(self) -> None: |
|
|
|
Profiler.__init__(self) |
|
|
|
self.log = logging.getLogger('db') |
|
|
|
self.load() |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def pack_hostname(hostname: str) -> str: |
|
|
|
return hostname[::-1] + '.' |
|
|
|
def pack_domain(domain: str) -> DomainPath: |
|
|
|
return domain.split('.')[::-1] |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def pack_zone(zone: str) -> str: |
|
|
|
return Database.pack_hostname(zone) |
|
|
|
def unpack_domain(domain: DomainPath) -> str: |
|
|
|
return '.'.join(domain[::-1]) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def pack_asn(asn: str) -> int: |
|
|
@ -145,431 +138,208 @@ class Database(): |
|
|
|
return f'AS{asn}' |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def pack_ip4address(address: str) -> int: |
|
|
|
total = 0 |
|
|
|
for i, octet in enumerate(address.split('.')): |
|
|
|
total += int(octet) << (3-i)*8 |
|
|
|
if total > 0xFFFFFFFF: |
|
|
|
raise ValueError |
|
|
|
return total |
|
|
|
# return '{:02x}{:02x}{:02x}{:02x}'.format( |
|
|
|
# *[int(c) for c in address.split('.')]) |
|
|
|
# return base64.b16encode(packed).decode() |
|
|
|
# return '{:08b}{:08b}{:08b}{:08b}'.format( |
|
|
|
# *[int(c) for c in address.split('.')]) |
|
|
|
# carg = ctypes.c_wchar_p(address) |
|
|
|
# ret = ACCEL.ip4_flat(carg, self.accel_ip4_buf) |
|
|
|
# if ret != 0: |
|
|
|
# raise ValueError |
|
|
|
# return self.accel_ip4_buf.value |
|
|
|
# packed = ipaddress.ip_address(address).packed |
|
|
|
# return packed |
|
|
|
def pack_ip4address(address: str) -> Ip4Path: |
|
|
|
addr: Ip4Path = [0] * 32 |
|
|
|
octets = [int(octet) for octet in address.split('.')] |
|
|
|
for b in range(32): |
|
|
|
if (octets[b//8] >> b % 8) & 0b1: |
|
|
|
addr[b] = 1 |
|
|
|
return addr |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def unpack_ip4address(address: int) -> str: |
|
|
|
return '.'.join(str((address >> (i * 8)) & 0xFF) |
|
|
|
for i in reversed(range(4))) |
|
|
|
def unpack_ip4address(address: Ip4Path) -> str: |
|
|
|
octets = [0] * 4 |
|
|
|
for b, bit in enumerate(address): |
|
|
|
octets[b//8] = (octets[b//8] << 1) + bit |
|
|
|
return '.'.join(map(str, octets)) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def pack_ip4network(network: str) -> typing.Tuple[int, int]: |
|
|
|
# def pack_ip4network(network: str) -> str: |
|
|
|
net = ipaddress.ip_network(network) |
|
|
|
mini = Database.pack_ip4address(net.network_address.exploded) |
|
|
|
maxi = Database.pack_ip4address(net.broadcast_address.exploded) |
|
|
|
# mini = net.network_address.packed |
|
|
|
# maxi = net.broadcast_address.packed |
|
|
|
return mini, maxi |
|
|
|
# return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen] |
|
|
|
def pack_ip4network(network: str) -> Ip4Path: |
|
|
|
address, prefixlen_str = network.split('/') |
|
|
|
prefixlen = int(prefixlen_str) |
|
|
|
return Database.pack_ip4address(address)[:prefixlen] |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def unpack_ip4network(mini: int, maxi: int) -> str: |
|
|
|
addr = Database.unpack_ip4address(mini) |
|
|
|
prefixlen = 32-int(math.log2(maxi-mini+1)) |
|
|
|
def unpack_ip4network(network: Ip4Path) -> str: |
|
|
|
address = network.copy() |
|
|
|
prefixlen = len(network) |
|
|
|
for _ in range(32-prefixlen): |
|
|
|
address.append(0) |
|
|
|
addr = Database.unpack_ip4address(address) |
|
|
|
return f'{addr}/{prefixlen}' |
|
|
|
|
|
|
|
def update_references(self) -> None: |
|
|
|
self.enter_step('update_refs') |
|
|
|
cursor = self.conn.cursor() |
|
|
|
cursor.execute('UPDATE rules AS r SET refs=' |
|
|
|
'(SELECT count(*) FROM rules ' |
|
|
|
'WHERE source=r.id)') |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def prune(self, before: int, base_only: bool = False) -> None: |
|
|
|
self.enter_step('prune') |
|
|
|
cursor = self.conn.cursor() |
|
|
|
cmd = 'DELETE FROM rules WHERE updated<?' |
|
|
|
if base_only: |
|
|
|
cmd += ' AND level=0' |
|
|
|
cursor.execute(cmd, (before,)) |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def explain(self, entry: int) -> str: |
|
|
|
# Format current |
|
|
|
string = '???' |
|
|
|
cursor = self.conn.cursor() |
|
|
|
cursor.execute( |
|
|
|
'SELECT unpack_asn(val) FROM asn WHERE entry=:entry ' |
|
|
|
'UNION ' |
|
|
|
'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry ' |
|
|
|
'UNION ' |
|
|
|
'SELECT format_zone(val) FROM zone WHERE entry=:entry ' |
|
|
|
'UNION ' |
|
|
|
'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry ' |
|
|
|
'UNION ' |
|
|
|
'SELECT unpack_ip4network(mini, maxi) ' |
|
|
|
'FROM ip4network WHERE entry=:entry ', |
|
|
|
{"entry": entry} |
|
|
|
) |
|
|
|
for val, in cursor: # only one |
|
|
|
string = str(val) |
|
|
|
string += f' #{entry}' |
|
|
|
|
|
|
|
# Add source if any |
|
|
|
cursor.execute('SELECT source FROM rules WHERE id=?', (entry,)) |
|
|
|
for source, in cursor: |
|
|
|
if source: |
|
|
|
string += f' ← {self.explain(source)}' |
|
|
|
return string |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def export(self, |
|
|
|
first_party_only: bool = False, |
|
|
|
end_chain_only: bool = False, |
|
|
|
explain: bool = False, |
|
|
|
_dic: DomainTreeNode = None, |
|
|
|
_par: DomainPath = None, |
|
|
|
) -> typing.Iterable[str]: |
|
|
|
selection = 'entry' if explain else 'unpack_domain(val)' |
|
|
|
command = f'SELECT {selection} FROM rules ' \ |
|
|
|
'INNER JOIN hostname ON rules.id = hostname.entry' |
|
|
|
restrictions: typing.List[str] = list() |
|
|
|
if first_party_only: |
|
|
|
restrictions.append('rules.first_party = 1') |
|
|
|
if end_chain_only: |
|
|
|
restrictions.append('rules.refs = 0') |
|
|
|
if restrictions: |
|
|
|
command += ' WHERE ' + ' AND '.join(restrictions) |
|
|
|
if not explain: |
|
|
|
command += ' ORDER BY unpack_domain(val) ASC' |
|
|
|
cursor = self.conn.cursor() |
|
|
|
cursor.execute(command) |
|
|
|
for val, in cursor: |
|
|
|
if explain: |
|
|
|
yield self.explain(val) |
|
|
|
else: |
|
|
|
yield val |
|
|
|
if first_party_only or end_chain_only or explain: |
|
|
|
raise NotImplementedError |
|
|
|
_dic = _dic or self.domtree |
|
|
|
_par = _par or list() |
|
|
|
if _dic.match_hostname: |
|
|
|
yield self.unpack_domain(_par) |
|
|
|
for part in _dic.children: |
|
|
|
dic = _dic.children[part] |
|
|
|
yield from self.export(_dic=dic, |
|
|
|
_par=_par + [part]) |
|
|
|
|
|
|
|
def count_rules(self, |
|
|
|
first_party_only: bool = False, |
|
|
|
) -> str: |
|
|
|
counts: typing.List[str] = list() |
|
|
|
cursor = self.conn.cursor() |
|
|
|
for table in ['asn', 'ip4network', 'ip4address', 'zone', 'hostname']: |
|
|
|
command = f'SELECT count(*) FROM rules ' \ |
|
|
|
f'INNER JOIN {table} ON rules.id = {table}.entry ' \ |
|
|
|
'WHERE rules.level = 0' |
|
|
|
if first_party_only: |
|
|
|
command += ' AND first_party=1' |
|
|
|
cursor.execute(command) |
|
|
|
count, = cursor.fetchone() |
|
|
|
if count > 0: |
|
|
|
counts.append(f'{table}: {count}') |
|
|
|
|
|
|
|
return ', '.join(counts) |
|
|
|
|
|
|
|
def get_domain(self, domain: str) -> typing.Iterable[int]: |
|
|
|
self.enter_step('get_domain_prepare') |
|
|
|
domain_prep = self.pack_hostname(domain) |
|
|
|
cursor = self.conn.cursor() |
|
|
|
self.enter_step('get_domain_select') |
|
|
|
cursor.execute( |
|
|
|
'SELECT null, entry FROM hostname ' |
|
|
|
'WHERE val=:d ' |
|
|
|
'UNION ' |
|
|
|
'SELECT * FROM (' |
|
|
|
'SELECT val, entry FROM zone ' |
|
|
|
# 'WHERE val>=:d ' |
|
|
|
# 'ORDER BY val ASC LIMIT 1' |
|
|
|
'WHERE val<=:d ' |
|
|
|
'AND instr(:d, val) = 1' |
|
|
|
')', |
|
|
|
{'d': domain_prep} |
|
|
|
) |
|
|
|
for val, entry in cursor: |
|
|
|
# print(293, val, entry) |
|
|
|
self.enter_step('get_domain_confirm') |
|
|
|
if not (val is None or domain_prep.startswith(val)): |
|
|
|
# print(297) |
|
|
|
continue |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def get_domain(self, domain_str: str) -> typing.Iterable[TypedPath]: |
|
|
|
self.enter_step('get_domain_pack') |
|
|
|
domain = self.pack_domain(domain_str) |
|
|
|
self.enter_step('get_domain_brws') |
|
|
|
dic = self.domtree |
|
|
|
depth = 0 |
|
|
|
for part in domain: |
|
|
|
if dic.match_zone: |
|
|
|
self.enter_step('get_domain_yield') |
|
|
|
yield (PathType.Zone, domain[:depth]) |
|
|
|
self.enter_step('get_domain_brws') |
|
|
|
if part not in dic.children: |
|
|
|
return |
|
|
|
dic = dic.children[part] |
|
|
|
depth += 1 |
|
|
|
if dic.match_zone: |
|
|
|
self.enter_step('get_domain_yield') |
|
|
|
yield entry |
|
|
|
|
|
|
|
def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]: |
|
|
|
self.enter_step('get_domainiz_prepare') |
|
|
|
domain_prep = self.pack_hostname(domain) |
|
|
|
cursor = self.conn.cursor() |
|
|
|
self.enter_step('get_domainiz_select') |
|
|
|
cursor.execute( |
|
|
|
'SELECT val, entry FROM zone ' |
|
|
|
'WHERE val<=:d ' |
|
|
|
'ORDER BY val DESC LIMIT 1', |
|
|
|
{'d': domain_prep} |
|
|
|
) |
|
|
|
for val, entry in cursor: |
|
|
|
self.enter_step('get_domainiz_confirm') |
|
|
|
if not (val is None or domain_prep.startswith(val)): |
|
|
|
continue |
|
|
|
self.enter_step('get_domainiz_yield') |
|
|
|
yield entry |
|
|
|
|
|
|
|
def get_ip4(self, address: str) -> typing.Iterable[int]: |
|
|
|
self.enter_step('get_ip4_prepare') |
|
|
|
try: |
|
|
|
address_prep = self.pack_ip4address(address) |
|
|
|
except (ValueError, IndexError): |
|
|
|
self.log.error("Invalid ip4address: %s", address) |
|
|
|
return |
|
|
|
cursor = self.conn.cursor() |
|
|
|
self.enter_step('get_ip4_select') |
|
|
|
cursor.execute( |
|
|
|
'SELECT entry FROM ip4address ' |
|
|
|
# 'SELECT null, entry FROM ip4address ' |
|
|
|
'WHERE val=:a ' |
|
|
|
'UNION ' |
|
|
|
# 'SELECT * FROM (' |
|
|
|
# 'SELECT val, entry FROM ip4network ' |
|
|
|
# 'WHERE val<=:a ' |
|
|
|
# 'AND instr(:a, val) > 0 ' |
|
|
|
# 'ORDER BY val DESC' |
|
|
|
# ')' |
|
|
|
'SELECT entry FROM ip4network ' |
|
|
|
'WHERE :a BETWEEN mini AND maxi ', |
|
|
|
{'a': address_prep} |
|
|
|
) |
|
|
|
for entry, in cursor: |
|
|
|
# self.enter_step('get_ip4_confirm') |
|
|
|
# if not (val is None or val.startswith(address_prep)): |
|
|
|
# # PERF startswith but from the end |
|
|
|
# continue |
|
|
|
yield (PathType.Zone, domain) |
|
|
|
if dic.match_hostname: |
|
|
|
self.enter_step('get_domain_yield') |
|
|
|
yield (PathType.Hostname, domain) |
|
|
|
|
|
|
|
def get_ip4(self, ip4_str: str) -> typing.Iterable[TypedPath]: |
|
|
|
self.enter_step('get_ip4_pack') |
|
|
|
ip4 = self.pack_ip4address(ip4_str) |
|
|
|
self.enter_step('get_ip4_brws') |
|
|
|
dic = self.ip4tree |
|
|
|
depth = 0 |
|
|
|
for part in ip4: |
|
|
|
if dic.match: |
|
|
|
self.enter_step('get_ip4_yield') |
|
|
|
yield (PathType.Ip4, ip4[:depth]) |
|
|
|
self.enter_step('get_ip4_brws') |
|
|
|
next_dic = dic.children[part] |
|
|
|
if next_dic is None: |
|
|
|
return |
|
|
|
dic = next_dic |
|
|
|
depth += 1 |
|
|
|
if dic.match: |
|
|
|
self.enter_step('get_ip4_yield') |
|
|
|
yield entry |
|
|
|
yield (PathType.Ip4, ip4) |
|
|
|
|
|
|
|
def get_ip4_in_network(self, address: str) -> typing.Iterable[int]: |
|
|
|
self.enter_step('get_ip4in_prepare') |
|
|
|
try: |
|
|
|
address_prep = self.pack_ip4address(address) |
|
|
|
except (ValueError, IndexError): |
|
|
|
self.log.error("Invalid ip4address: %s", address) |
|
|
|
return |
|
|
|
cursor = self.conn.cursor() |
|
|
|
self.enter_step('get_ip4in_select') |
|
|
|
cursor.execute( |
|
|
|
'SELECT entry FROM ip4network ' |
|
|
|
'WHERE :a BETWEEN mini AND maxi ', |
|
|
|
{'a': address_prep} |
|
|
|
) |
|
|
|
for entry, in cursor: |
|
|
|
self.enter_step('get_ip4in_yield') |
|
|
|
yield entry |
|
|
|
|
|
|
|
def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]: |
|
|
|
cursor = self.conn.cursor() |
|
|
|
self.enter_step('list_asn_select') |
|
|
|
cursor.execute('SELECT val, entry FROM asn') |
|
|
|
for val, entry in cursor: |
|
|
|
yield f'AS{val}', entry |
|
|
|
|
|
|
|
def _set_generic(self, |
|
|
|
table: str, |
|
|
|
select_query: str, |
|
|
|
insert_query: str, |
|
|
|
prep: typing.Dict[str, DbValue], |
|
|
|
def list_asn(self) -> typing.Iterable[TypedPath]: |
|
|
|
for asn in self.asns: |
|
|
|
yield (PathType.Asn, asn) |
|
|
|
|
|
|
|
def set_hostname(self, |
|
|
|
hostname_str: str, |
|
|
|
updated: int, |
|
|
|
is_first_party: bool = False, |
|
|
|
source: int = None, |
|
|
|
) -> None: |
|
|
|
# Since this isn't the bulk of the processing, |
|
|
|
# here abstraction > performaces |
|
|
|
|
|
|
|
# Fields based on the source |
|
|
|
self.enter_step(f'set_{table}_prepare') |
|
|
|
cursor = self.conn.cursor() |
|
|
|
if source is None: |
|
|
|
first_party = int(is_first_party) |
|
|
|
level = 0 |
|
|
|
else: |
|
|
|
self.enter_step(f'set_{table}_source') |
|
|
|
cursor.execute( |
|
|
|
'SELECT first_party, level FROM rules ' |
|
|
|
'WHERE id=?', |
|
|
|
(source,) |
|
|
|
) |
|
|
|
first_party, level = cursor.fetchone() |
|
|
|
level += 1 |
|
|
|
|
|
|
|
self.enter_step(f'set_{table}_select') |
|
|
|
cursor.execute(select_query, prep) |
|
|
|
|
|
|
|
rules_prep: typing.Dict[str, DbValue] = { |
|
|
|
"source": source, |
|
|
|
"updated": updated, |
|
|
|
"first_party": first_party, |
|
|
|
"level": level, |
|
|
|
} |
|
|
|
|
|
|
|
# If the entry already exists |
|
|
|
for entry, in cursor: # only one |
|
|
|
self.enter_step(f'set_{table}_update') |
|
|
|
rules_prep['entry'] = entry |
|
|
|
cursor.execute( |
|
|
|
'UPDATE rules SET ' |
|
|
|
'source=:source, updated=:updated, ' |
|
|
|
'first_party=:first_party, level=:level ' |
|
|
|
'WHERE id=:entry AND (updated<:updated OR ' |
|
|
|
'first_party<:first_party OR level<:level)', |
|
|
|
rules_prep |
|
|
|
) |
|
|
|
# Only update if any of the following: |
|
|
|
# - the entry is outdataed |
|
|
|
# - the entry was not a first_party but this is |
|
|
|
# - this is closer to the original rule |
|
|
|
return |
|
|
|
|
|
|
|
# If it does not exist |
|
|
|
|
|
|
|
self.enter_step(f'set_{table}_insert') |
|
|
|
cursor.execute( |
|
|
|
'INSERT INTO rules ' |
|
|
|
'(source, updated, first_party, level) ' |
|
|
|
'VALUES (:source, :updated, :first_party, :level) ', |
|
|
|
rules_prep |
|
|
|
) |
|
|
|
cursor.execute('SELECT id FROM rules WHERE rowid=?', |
|
|
|
(cursor.lastrowid,)) |
|
|
|
for entry, in cursor: # only one |
|
|
|
prep['entry'] = entry |
|
|
|
cursor.execute(insert_query, prep) |
|
|
|
return |
|
|
|
assert False |
|
|
|
|
|
|
|
def set_hostname(self, hostname: str, |
|
|
|
*args: typing.Any, **kwargs: typing.Any) -> None: |
|
|
|
self.enter_step('set_hostname_prepare') |
|
|
|
prep: typing.Dict[str, DbValue] = { |
|
|
|
'val': self.pack_hostname(hostname), |
|
|
|
} |
|
|
|
self._set_generic( |
|
|
|
'hostname', |
|
|
|
'SELECT entry FROM hostname WHERE val=:val', |
|
|
|
'INSERT INTO hostname (val, entry) ' |
|
|
|
'VALUES (:val, :entry)', |
|
|
|
prep, |
|
|
|
*args, **kwargs |
|
|
|
) |
|
|
|
|
|
|
|
def set_asn(self, asn: str, |
|
|
|
*args: typing.Any, **kwargs: typing.Any) -> None: |
|
|
|
self.enter_step('set_asn_prepare') |
|
|
|
try: |
|
|
|
asn_prep = self.pack_asn(asn) |
|
|
|
except ValueError: |
|
|
|
self.log.error("Invalid asn: %s", asn) |
|
|
|
return |
|
|
|
prep: typing.Dict[str, DbValue] = { |
|
|
|
'val': asn_prep, |
|
|
|
} |
|
|
|
self._set_generic( |
|
|
|
'asn', |
|
|
|
'SELECT entry FROM asn WHERE val=:val', |
|
|
|
'INSERT INTO asn (val, entry) ' |
|
|
|
'VALUES (:val, :entry)', |
|
|
|
prep, |
|
|
|
*args, **kwargs |
|
|
|
) |
|
|
|
|
|
|
|
def set_ip4address(self, ip4address: str, |
|
|
|
*args: typing.Any, **kwargs: typing.Any) -> None: |
|
|
|
self.enter_step('set_ip4add_prepare') |
|
|
|
try: |
|
|
|
ip4address_prep = self.pack_ip4address(ip4address) |
|
|
|
except (ValueError, IndexError): |
|
|
|
self.log.error("Invalid ip4address: %s", ip4address) |
|
|
|
return |
|
|
|
prep: typing.Dict[str, DbValue] = { |
|
|
|
'val': ip4address_prep, |
|
|
|
} |
|
|
|
self._set_generic( |
|
|
|
'ip4add', |
|
|
|
'SELECT entry FROM ip4address WHERE val=:val', |
|
|
|
'INSERT INTO ip4address (val, entry) ' |
|
|
|
'VALUES (:val, :entry)', |
|
|
|
prep, |
|
|
|
*args, **kwargs |
|
|
|
) |
|
|
|
|
|
|
|
def set_zone(self, zone: str, |
|
|
|
*args: typing.Any, **kwargs: typing.Any) -> None: |
|
|
|
self.enter_step('set_zone_prepare') |
|
|
|
prep: typing.Dict[str, DbValue] = { |
|
|
|
'val': self.pack_zone(zone), |
|
|
|
} |
|
|
|
self._set_generic( |
|
|
|
'zone', |
|
|
|
'SELECT entry FROM zone WHERE val=:val', |
|
|
|
'INSERT INTO zone (val, entry) ' |
|
|
|
'VALUES (:val, :entry)', |
|
|
|
prep, |
|
|
|
*args, **kwargs |
|
|
|
) |
|
|
|
|
|
|
|
def set_ip4network(self, ip4network: str, |
|
|
|
*args: typing.Any, **kwargs: typing.Any) -> None: |
|
|
|
self.enter_step('set_ip4net_prepare') |
|
|
|
try: |
|
|
|
ip4network_prep = self.pack_ip4network(ip4network) |
|
|
|
except (ValueError, IndexError): |
|
|
|
self.log.error("Invalid ip4network: %s", ip4network) |
|
|
|
return |
|
|
|
prep: typing.Dict[str, DbValue] = { |
|
|
|
'mini': ip4network_prep[0], |
|
|
|
'maxi': ip4network_prep[1], |
|
|
|
} |
|
|
|
self._set_generic( |
|
|
|
'ip4net', |
|
|
|
'SELECT entry FROM ip4network WHERE mini=:mini AND maxi=:maxi', |
|
|
|
'INSERT INTO ip4network (mini, maxi, entry) ' |
|
|
|
'VALUES (:mini, :maxi, :entry)', |
|
|
|
prep, |
|
|
|
*args, **kwargs |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
# Parsing arguments |
|
|
|
parser = argparse.ArgumentParser( |
|
|
|
description="Database operations") |
|
|
|
parser.add_argument( |
|
|
|
'-i', '--initialize', action='store_true', |
|
|
|
help="Reconstruct the whole database") |
|
|
|
parser.add_argument( |
|
|
|
'-p', '--prune', action='store_true', |
|
|
|
help="Remove old entries from database") |
|
|
|
parser.add_argument( |
|
|
|
'-b', '--prune-base', action='store_true', |
|
|
|
help="TODO") |
|
|
|
parser.add_argument( |
|
|
|
'-s', '--prune-before', type=int, |
|
|
|
default=(int(time.time()) - 60*60*24*31*6), |
|
|
|
help="TODO") |
|
|
|
parser.add_argument( |
|
|
|
'-r', '--references', action='store_true', |
|
|
|
help="Update the reference count") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
DB = Database(write=True) |
|
|
|
|
|
|
|
if args.initialize: |
|
|
|
DB.initialize() |
|
|
|
if args.prune: |
|
|
|
DB.prune(before=args.prune_before, base_only=args.prune_base) |
|
|
|
if args.references: |
|
|
|
DB.update_references() |
|
|
|
|
|
|
|
DB.close() |
|
|
|
is_first_party: bool = None, |
|
|
|
source: TypedPath = None) -> None: |
|
|
|
self.enter_step('set_hostname_pack') |
|
|
|
if is_first_party or source: |
|
|
|
raise NotImplementedError |
|
|
|
self.enter_step('set_hostname_brws') |
|
|
|
hostname = self.pack_domain(hostname_str) |
|
|
|
dic = self.domtree |
|
|
|
for part in hostname: |
|
|
|
if dic.match_zone: |
|
|
|
# Refuse to add hostname whose zone is already matching |
|
|
|
return |
|
|
|
if part not in dic.children: |
|
|
|
dic.children[part] = DomainTreeNode() |
|
|
|
dic = dic.children[part] |
|
|
|
dic.match_hostname = (updated, DebugPath, 0) |
|
|
|
|
|
|
|
def set_zone(self, |
|
|
|
zone_str: str, |
|
|
|
updated: int, |
|
|
|
is_first_party: bool = None, |
|
|
|
source: TypedPath = None) -> None: |
|
|
|
self.enter_step('set_zone_pack') |
|
|
|
if is_first_party or source: |
|
|
|
raise NotImplementedError |
|
|
|
zone = self.pack_domain(zone_str) |
|
|
|
self.enter_step('set_zone_brws') |
|
|
|
dic = self.domtree |
|
|
|
for part in zone: |
|
|
|
if dic.match_zone: |
|
|
|
# Refuse to add zone whose parent zone is already matching |
|
|
|
return |
|
|
|
if part not in dic.children: |
|
|
|
dic.children[part] = DomainTreeNode() |
|
|
|
dic = dic.children[part] |
|
|
|
dic.match_zone = (updated, DebugPath, 0) |
|
|
|
|
|
|
|
def set_asn(self, |
|
|
|
asn_str: str, |
|
|
|
updated: int, |
|
|
|
is_first_party: bool = None, |
|
|
|
source: TypedPath = None) -> None: |
|
|
|
self.enter_step('set_asn_pack') |
|
|
|
if is_first_party or source: |
|
|
|
# TODO updated |
|
|
|
raise NotImplementedError |
|
|
|
asn = self.pack_asn(asn_str) |
|
|
|
self.enter_step('set_asn_brws') |
|
|
|
self.asns.add(asn) |
|
|
|
|
|
|
|
def set_ip4address(self, |
|
|
|
ip4address_str: str, |
|
|
|
updated: int, |
|
|
|
is_first_party: bool = None, |
|
|
|
source: TypedPath = None) -> None: |
|
|
|
self.enter_step('set_ip4add_pack') |
|
|
|
if is_first_party or source: |
|
|
|
raise NotImplementedError |
|
|
|
self.enter_step('set_ip4add_brws') |
|
|
|
ip4address = self.pack_ip4address(ip4address_str) |
|
|
|
dic = self.ip4tree |
|
|
|
for part in ip4address: |
|
|
|
if dic.match: |
|
|
|
# Refuse to add ip4address whose network is already matching |
|
|
|
return |
|
|
|
next_dic = dic.children[part] |
|
|
|
if next_dic is None: |
|
|
|
next_dic = IpTreeNode() |
|
|
|
dic.children[part] = next_dic |
|
|
|
dic = next_dic |
|
|
|
dic.match = (updated, DebugPath, 0) |
|
|
|
|
|
|
|
def set_ip4network(self, |
|
|
|
ip4network_str: str, |
|
|
|
updated: int, |
|
|
|
is_first_party: bool = None, |
|
|
|
source: TypedPath = None) -> None: |
|
|
|
self.enter_step('set_ip4net_pack') |
|
|
|
if is_first_party or source: |
|
|
|
raise NotImplementedError |
|
|
|
self.enter_step('set_ip4net_brws') |
|
|
|
ip4network = self.pack_ip4network(ip4network_str) |
|
|
|
dic = self.ip4tree |
|
|
|
for part in ip4network: |
|
|
|
if dic.match: |
|
|
|
# Refuse to add ip4network whose parent network |
|
|
|
# is already matching |
|
|
|
return |
|
|
|
next_dic = dic.children[part] |
|
|
|
if next_dic is None: |
|
|
|
next_dic = IpTreeNode() |
|
|
|
dic.children[part] = next_dic |
|
|
|
dic = next_dic |
|
|
|
dic.match = (updated, DebugPath, 0) |