2019-12-09 08:12:48 +01:00
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
"""
|
|
|
|
|
Utility functions to interact with the database.
|
|
|
|
|
"""
|
|
|
|
|
|
2019-12-09 08:12:48 +01:00
|
|
|
|
import sqlite3
|
2019-12-13 00:11:21 +01:00
|
|
|
|
import typing
|
|
|
|
|
import time
|
2019-12-09 08:12:48 +01:00
|
|
|
|
import os
|
2019-12-13 00:11:21 +01:00
|
|
|
|
import logging
|
2019-12-09 08:12:48 +01:00
|
|
|
|
import argparse
|
2019-12-13 00:11:21 +01:00
|
|
|
|
import coloredlogs
|
2019-12-09 08:12:48 +01:00
|
|
|
|
import ipaddress
|
2019-12-13 18:00:00 +01:00
|
|
|
|
import math
|
2019-12-09 08:12:48 +01:00
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
coloredlogs.install(
|
|
|
|
|
level='DEBUG',
|
|
|
|
|
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
DbValue = typing.Union[None, int, float, str, bytes]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Database():
|
2019-12-13 18:36:08 +01:00
|
|
|
|
VERSION = 5
|
2019-12-13 00:11:21 +01:00
|
|
|
|
PATH = "blocking.db"
|
|
|
|
|
|
|
|
|
|
def open(self) -> None:
|
2019-12-13 12:35:05 +01:00
|
|
|
|
mode = 'rwc' if self.write else 'ro'
|
|
|
|
|
uri = f'file:{self.PATH}?mode={mode}'
|
|
|
|
|
self.conn = sqlite3.connect(uri, uri=True)
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
cursor.execute("PRAGMA foreign_keys = ON")
|
|
|
|
|
self.conn.create_function("unpack_asn", 1,
|
|
|
|
|
self.unpack_asn,
|
|
|
|
|
deterministic=True)
|
|
|
|
|
self.conn.create_function("unpack_ip4address", 1,
|
|
|
|
|
self.unpack_ip4address,
|
|
|
|
|
deterministic=True)
|
|
|
|
|
self.conn.create_function("unpack_ip4network", 2,
|
|
|
|
|
self.unpack_ip4network,
|
|
|
|
|
deterministic=True)
|
2019-12-13 08:23:38 +01:00
|
|
|
|
self.conn.create_function("unpack_domain", 1,
|
|
|
|
|
lambda s: s[:-1][::-1],
|
|
|
|
|
deterministic=True)
|
2019-12-13 18:00:00 +01:00
|
|
|
|
self.conn.create_function("format_zone", 1,
|
|
|
|
|
lambda s: '*' + s[::-1],
|
|
|
|
|
deterministic=True)
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
|
|
|
|
def get_meta(self, key: str) -> typing.Optional[int]:
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor = self.conn.cursor()
|
2019-12-13 00:11:21 +01:00
|
|
|
|
try:
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor.execute("SELECT value FROM meta WHERE key=?", (key,))
|
2019-12-13 00:11:21 +01:00
|
|
|
|
except sqlite3.OperationalError:
|
|
|
|
|
return None
|
2019-12-13 18:00:00 +01:00
|
|
|
|
for ver, in cursor:
|
2019-12-13 00:11:21 +01:00
|
|
|
|
return ver
|
|
|
|
|
return None
|
2019-12-09 08:12:48 +01:00
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
def set_meta(self, key: str, val: int) -> None:
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
cursor.execute("INSERT INTO meta VALUES (?, ?) "
|
|
|
|
|
"ON CONFLICT (key) DO "
|
|
|
|
|
"UPDATE set value=?",
|
|
|
|
|
(key, val, val))
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
|
|
|
|
def close(self) -> None:
|
|
|
|
|
self.enter_step('close_commit')
|
|
|
|
|
self.conn.commit()
|
|
|
|
|
self.enter_step('close')
|
|
|
|
|
self.conn.close()
|
|
|
|
|
self.profile()
|
|
|
|
|
|
|
|
|
|
def initialize(self) -> None:
|
|
|
|
|
self.close()
|
2019-12-13 12:35:05 +01:00
|
|
|
|
self.enter_step('initialize')
|
|
|
|
|
if not self.write:
|
|
|
|
|
self.log.error("Cannot initialize in read-only mode.")
|
|
|
|
|
raise
|
2019-12-13 00:11:21 +01:00
|
|
|
|
os.unlink(self.PATH)
|
|
|
|
|
self.open()
|
|
|
|
|
self.log.info("Creating database version %d.", self.VERSION)
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor = self.conn.cursor()
|
2019-12-09 08:12:48 +01:00
|
|
|
|
with open("database_schema.sql", 'r') as db_schema:
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor.executescript(db_schema.read())
|
2019-12-13 00:11:21 +01:00
|
|
|
|
self.set_meta('version', self.VERSION)
|
|
|
|
|
self.conn.commit()
|
|
|
|
|
|
2019-12-13 12:35:05 +01:00
|
|
|
|
def __init__(self, write: bool = False) -> None:
|
2019-12-13 00:11:21 +01:00
|
|
|
|
self.log = logging.getLogger('db')
|
|
|
|
|
self.time_last = time.perf_counter()
|
|
|
|
|
self.time_step = 'init'
|
|
|
|
|
self.time_dict: typing.Dict[str, float] = dict()
|
|
|
|
|
self.step_dict: typing.Dict[str, int] = dict()
|
2019-12-13 12:35:05 +01:00
|
|
|
|
self.write = write
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
|
|
|
|
self.open()
|
|
|
|
|
version = self.get_meta('version')
|
|
|
|
|
if version != self.VERSION:
|
|
|
|
|
if version is not None:
|
|
|
|
|
self.log.warning(
|
|
|
|
|
"Outdated database version: %d found, will be rebuilt.",
|
|
|
|
|
version)
|
|
|
|
|
self.initialize()
|
|
|
|
|
|
|
|
|
|
def enter_step(self, name: str) -> None:
|
|
|
|
|
now = time.perf_counter()
|
|
|
|
|
try:
|
|
|
|
|
self.time_dict[self.time_step] += now - self.time_last
|
|
|
|
|
self.step_dict[self.time_step] += 1
|
|
|
|
|
except KeyError:
|
|
|
|
|
self.time_dict[self.time_step] = now - self.time_last
|
|
|
|
|
self.step_dict[self.time_step] = 1
|
|
|
|
|
self.time_step = name
|
|
|
|
|
self.time_last = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
def profile(self) -> None:
|
|
|
|
|
self.enter_step('profile')
|
|
|
|
|
total = sum(self.time_dict.values())
|
|
|
|
|
for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]):
|
|
|
|
|
times = self.step_dict[key]
|
|
|
|
|
self.log.debug(f"{key:<20}: {times:9d} × {secs/times:5.3e} "
|
|
|
|
|
f"= {secs:9.2f} s ({secs/total:7.2%}) ")
|
|
|
|
|
self.log.debug(f"{'total':<20}: "
|
|
|
|
|
f"{total:9.2f} s ({1:7.2%})")
|
|
|
|
|
|
2019-12-13 18:00:00 +01:00
|
|
|
|
@staticmethod
|
|
|
|
|
def pack_hostname(hostname: str) -> str:
|
2019-12-13 00:11:21 +01:00
|
|
|
|
return hostname[::-1] + '.'
|
|
|
|
|
|
2019-12-13 18:00:00 +01:00
|
|
|
|
@staticmethod
|
|
|
|
|
def pack_zone(zone: str) -> str:
|
|
|
|
|
return Database.pack_hostname(zone)
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-13 08:23:38 +01:00
|
|
|
|
@staticmethod
|
2019-12-13 18:00:00 +01:00
|
|
|
|
def pack_asn(asn: str) -> int:
|
2019-12-13 08:23:38 +01:00
|
|
|
|
asn = asn.upper()
|
|
|
|
|
if asn.startswith('AS'):
|
|
|
|
|
asn = asn[2:]
|
|
|
|
|
return int(asn)
|
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
@staticmethod
|
2019-12-13 18:00:00 +01:00
|
|
|
|
def unpack_asn(asn: int) -> str:
|
|
|
|
|
return f'AS{asn}'
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def pack_ip4address(address: str) -> int:
|
2019-12-13 00:11:21 +01:00
|
|
|
|
total = 0
|
|
|
|
|
for i, octet in enumerate(address.split('.')):
|
|
|
|
|
total += int(octet) << (3-i)*8
|
|
|
|
|
return total
|
|
|
|
|
# return '{:02x}{:02x}{:02x}{:02x}'.format(
|
|
|
|
|
# *[int(c) for c in address.split('.')])
|
|
|
|
|
# return base64.b16encode(packed).decode()
|
|
|
|
|
# return '{:08b}{:08b}{:08b}{:08b}'.format(
|
|
|
|
|
# *[int(c) for c in address.split('.')])
|
|
|
|
|
# carg = ctypes.c_wchar_p(address)
|
|
|
|
|
# ret = ACCEL.ip4_flat(carg, self.accel_ip4_buf)
|
|
|
|
|
# if ret != 0:
|
|
|
|
|
# raise ValueError
|
|
|
|
|
# return self.accel_ip4_buf.value
|
|
|
|
|
# packed = ipaddress.ip_address(address).packed
|
|
|
|
|
# return packed
|
|
|
|
|
|
2019-12-13 18:00:00 +01:00
|
|
|
|
@staticmethod
|
|
|
|
|
def unpack_ip4address(address: int) -> str:
|
|
|
|
|
return '.'.join(str((address >> (i * 8)) & 0xFF)
|
|
|
|
|
for i in reversed(range(4)))
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def pack_ip4network(network: str) -> typing.Tuple[int, int]:
|
|
|
|
|
# def pack_ip4network(network: str) -> str:
|
2019-12-13 00:11:21 +01:00
|
|
|
|
net = ipaddress.ip_network(network)
|
2019-12-13 18:00:00 +01:00
|
|
|
|
mini = Database.pack_ip4address(net.network_address.exploded)
|
|
|
|
|
maxi = Database.pack_ip4address(net.broadcast_address.exploded)
|
2019-12-13 00:11:21 +01:00
|
|
|
|
# mini = net.network_address.packed
|
|
|
|
|
# maxi = net.broadcast_address.packed
|
|
|
|
|
return mini, maxi
|
2019-12-13 18:00:00 +01:00
|
|
|
|
# return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen]
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def unpack_ip4network(mini: int, maxi: int) -> str:
|
|
|
|
|
addr = Database.unpack_ip4address(mini)
|
|
|
|
|
prefixlen = 32-int(math.log2(maxi-mini+1))
|
|
|
|
|
return f'{addr}/{prefixlen}'
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
|
|
|
|
def update_references(self) -> None:
|
|
|
|
|
self.enter_step('update_refs')
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
cursor.execute('UPDATE rules AS r SET refs='
|
|
|
|
|
'(SELECT count(*) FROM rules '
|
|
|
|
|
'WHERE source=r.id)')
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-13 13:54:00 +01:00
|
|
|
|
def prune(self, before: int) -> None:
|
2019-12-13 00:11:21 +01:00
|
|
|
|
self.enter_step('prune')
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
cursor.execute('DELETE FROM rules WHERE updated<?', (before,))
|
|
|
|
|
|
|
|
|
|
def explain(self, entry: int) -> str:
|
|
|
|
|
# Format current
|
|
|
|
|
string = '???'
|
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
cursor.execute(
|
|
|
|
|
'SELECT unpack_asn(val) FROM asn WHERE entry=:entry '
|
|
|
|
|
'UNION '
|
|
|
|
|
'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry '
|
|
|
|
|
'UNION '
|
|
|
|
|
'SELECT format_zone(val) FROM zone WHERE entry=:entry '
|
|
|
|
|
'UNION '
|
|
|
|
|
'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry '
|
|
|
|
|
'UNION '
|
|
|
|
|
'SELECT unpack_ip4network(mini, maxi) '
|
|
|
|
|
'FROM ip4network WHERE entry=:entry ',
|
|
|
|
|
{"entry": entry}
|
|
|
|
|
)
|
|
|
|
|
for val, in cursor: # only one
|
|
|
|
|
string = str(val)
|
|
|
|
|
string += f' #{entry}'
|
|
|
|
|
|
|
|
|
|
# Add source if any
|
|
|
|
|
cursor.execute('SELECT source FROM rules WHERE id=?', (entry,))
|
|
|
|
|
for source, in cursor:
|
|
|
|
|
if source:
|
|
|
|
|
string += f' ← {self.explain(source)}'
|
|
|
|
|
return string
|
|
|
|
|
|
|
|
|
|
def export(self,
|
|
|
|
|
first_party_only: bool = False,
|
|
|
|
|
end_chain_only: bool = False,
|
|
|
|
|
explain: bool = False,
|
|
|
|
|
) -> typing.Iterable[str]:
|
|
|
|
|
selection = 'entry' if explain else 'unpack_domain(val)'
|
|
|
|
|
command = f'SELECT {selection} FROM rules ' \
|
2019-12-13 00:11:21 +01:00
|
|
|
|
'INNER JOIN hostname ON rules.id = hostname.entry'
|
|
|
|
|
restrictions: typing.List[str] = list()
|
|
|
|
|
if first_party_only:
|
|
|
|
|
restrictions.append('rules.first_party = 1')
|
|
|
|
|
if end_chain_only:
|
|
|
|
|
restrictions.append('rules.refs = 0')
|
|
|
|
|
if restrictions:
|
|
|
|
|
command += ' WHERE ' + ' AND '.join(restrictions)
|
2019-12-13 18:00:00 +01:00
|
|
|
|
if not explain:
|
|
|
|
|
command += ' ORDER BY unpack_domain(val) ASC'
|
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
cursor.execute(command)
|
|
|
|
|
for val, in cursor:
|
|
|
|
|
if explain:
|
|
|
|
|
yield self.explain(val)
|
|
|
|
|
else:
|
|
|
|
|
yield val
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-13 18:36:08 +01:00
|
|
|
|
def count_rules(self,
|
|
|
|
|
first_party_only: bool = False,
|
|
|
|
|
) -> str:
|
|
|
|
|
counts: typing.List[str] = list()
|
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
for table in ['asn', 'ip4network', 'ip4address', 'zone', 'hostname']:
|
|
|
|
|
command = f'SELECT count(*) FROM rules ' \
|
|
|
|
|
f'INNER JOIN {table} ON rules.id = {table}.entry ' \
|
|
|
|
|
'WHERE rules.level = 0'
|
|
|
|
|
if first_party_only:
|
|
|
|
|
command += ' AND first_party=1'
|
|
|
|
|
cursor.execute(command)
|
|
|
|
|
count, = cursor.fetchone()
|
|
|
|
|
if count > 0:
|
|
|
|
|
counts.append(f'{table}: {count}')
|
|
|
|
|
|
|
|
|
|
return ', '.join(counts)
|
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
def get_domain(self, domain: str) -> typing.Iterable[int]:
|
|
|
|
|
self.enter_step('get_domain_prepare')
|
2019-12-13 18:00:00 +01:00
|
|
|
|
domain_prep = self.pack_hostname(domain)
|
|
|
|
|
cursor = self.conn.cursor()
|
2019-12-13 00:11:21 +01:00
|
|
|
|
self.enter_step('get_domain_select')
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor.execute(
|
2019-12-13 00:11:21 +01:00
|
|
|
|
'SELECT null, entry FROM hostname '
|
|
|
|
|
'WHERE val=:d '
|
|
|
|
|
'UNION '
|
|
|
|
|
'SELECT * FROM ('
|
|
|
|
|
'SELECT val, entry FROM zone '
|
|
|
|
|
'WHERE val<=:d '
|
|
|
|
|
'ORDER BY val DESC LIMIT 1'
|
|
|
|
|
')',
|
|
|
|
|
{'d': domain_prep}
|
|
|
|
|
)
|
2019-12-13 18:00:00 +01:00
|
|
|
|
for val, entry in cursor:
|
2019-12-13 00:11:21 +01:00
|
|
|
|
self.enter_step('get_domain_confirm')
|
|
|
|
|
if not (val is None or domain_prep.startswith(val)):
|
|
|
|
|
continue
|
|
|
|
|
self.enter_step('get_domain_yield')
|
|
|
|
|
yield entry
|
|
|
|
|
|
2019-12-13 18:00:00 +01:00
|
|
|
|
def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]:
|
|
|
|
|
self.enter_step('get_domainiz_prepare')
|
|
|
|
|
domain_prep = self.pack_hostname(domain)
|
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
self.enter_step('get_domainiz_select')
|
|
|
|
|
cursor.execute(
|
|
|
|
|
'SELECT val, entry FROM zone '
|
|
|
|
|
'WHERE val<=:d '
|
|
|
|
|
'ORDER BY val DESC LIMIT 1',
|
|
|
|
|
{'d': domain_prep}
|
|
|
|
|
)
|
|
|
|
|
for val, entry in cursor:
|
|
|
|
|
self.enter_step('get_domainiz_confirm')
|
|
|
|
|
if not (val is None or domain_prep.startswith(val)):
|
|
|
|
|
continue
|
|
|
|
|
self.enter_step('get_domainiz_yield')
|
|
|
|
|
yield entry
|
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
def get_ip4(self, address: str) -> typing.Iterable[int]:
|
|
|
|
|
self.enter_step('get_ip4_prepare')
|
|
|
|
|
try:
|
2019-12-13 18:00:00 +01:00
|
|
|
|
address_prep = self.pack_ip4address(address)
|
2019-12-13 00:11:21 +01:00
|
|
|
|
except (ValueError, IndexError):
|
|
|
|
|
self.log.error("Invalid ip4address: %s", address)
|
|
|
|
|
return
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor = self.conn.cursor()
|
2019-12-13 00:11:21 +01:00
|
|
|
|
self.enter_step('get_ip4_select')
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor.execute(
|
2019-12-13 00:11:21 +01:00
|
|
|
|
'SELECT entry FROM ip4address '
|
|
|
|
|
# 'SELECT null, entry FROM ip4address '
|
|
|
|
|
'WHERE val=:a '
|
|
|
|
|
'UNION '
|
|
|
|
|
# 'SELECT * FROM ('
|
|
|
|
|
# 'SELECT val, entry FROM ip4network '
|
|
|
|
|
# 'WHERE val<=:a '
|
|
|
|
|
# 'AND instr(:a, val) > 0 '
|
|
|
|
|
# 'ORDER BY val DESC'
|
|
|
|
|
# ')'
|
|
|
|
|
'SELECT entry FROM ip4network '
|
|
|
|
|
'WHERE :a BETWEEN mini AND maxi ',
|
|
|
|
|
{'a': address_prep}
|
|
|
|
|
)
|
2019-12-13 18:00:00 +01:00
|
|
|
|
for entry, in cursor:
|
2019-12-13 00:11:21 +01:00
|
|
|
|
# self.enter_step('get_ip4_confirm')
|
|
|
|
|
# if not (val is None or val.startswith(address_prep)):
|
|
|
|
|
# # PERF startswith but from the end
|
|
|
|
|
# continue
|
|
|
|
|
self.enter_step('get_ip4_yield')
|
|
|
|
|
yield entry
|
|
|
|
|
|
2019-12-13 18:00:00 +01:00
|
|
|
|
def get_ip4_in_network(self, address: str) -> typing.Iterable[int]:
|
|
|
|
|
self.enter_step('get_ip4in_prepare')
|
|
|
|
|
try:
|
|
|
|
|
address_prep = self.pack_ip4address(address)
|
|
|
|
|
except (ValueError, IndexError):
|
|
|
|
|
self.log.error("Invalid ip4address: %s", address)
|
|
|
|
|
return
|
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
self.enter_step('get_ip4in_select')
|
|
|
|
|
cursor.execute(
|
|
|
|
|
'SELECT entry FROM ip4network '
|
|
|
|
|
'WHERE :a BETWEEN mini AND maxi ',
|
|
|
|
|
{'a': address_prep}
|
|
|
|
|
)
|
|
|
|
|
for entry, in cursor:
|
|
|
|
|
self.enter_step('get_ip4in_yield')
|
|
|
|
|
yield entry
|
|
|
|
|
|
2019-12-13 08:23:38 +01:00
|
|
|
|
def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]:
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor = self.conn.cursor()
|
2019-12-13 08:23:38 +01:00
|
|
|
|
self.enter_step('list_asn_select')
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor.execute('SELECT val, entry FROM asn')
|
|
|
|
|
for val, entry in cursor:
|
2019-12-13 08:23:38 +01:00
|
|
|
|
yield f'AS{val}', entry
|
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
def _set_generic(self,
|
|
|
|
|
table: str,
|
|
|
|
|
select_query: str,
|
|
|
|
|
insert_query: str,
|
|
|
|
|
prep: typing.Dict[str, DbValue],
|
2019-12-13 13:54:00 +01:00
|
|
|
|
updated: int,
|
2019-12-13 00:11:21 +01:00
|
|
|
|
is_first_party: bool = False,
|
|
|
|
|
source: int = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
# Since this isn't the bulk of the processing,
|
|
|
|
|
# here abstraction > performaces
|
|
|
|
|
|
|
|
|
|
# Fields based on the source
|
2019-12-13 18:00:00 +01:00
|
|
|
|
self.enter_step(f'set_{table}_prepare')
|
|
|
|
|
cursor = self.conn.cursor()
|
2019-12-13 00:11:21 +01:00
|
|
|
|
if source is None:
|
|
|
|
|
first_party = int(is_first_party)
|
|
|
|
|
level = 0
|
|
|
|
|
else:
|
|
|
|
|
self.enter_step(f'set_{table}_source')
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor.execute(
|
2019-12-13 00:11:21 +01:00
|
|
|
|
'SELECT first_party, level FROM rules '
|
|
|
|
|
'WHERE id=?',
|
|
|
|
|
(source,)
|
|
|
|
|
)
|
2019-12-13 18:00:00 +01:00
|
|
|
|
first_party, level = cursor.fetchone()
|
2019-12-13 00:11:21 +01:00
|
|
|
|
level += 1
|
|
|
|
|
|
|
|
|
|
self.enter_step(f'set_{table}_select')
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor.execute(select_query, prep)
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-13 13:54:00 +01:00
|
|
|
|
rules_prep: typing.Dict[str, DbValue] = {
|
2019-12-13 00:11:21 +01:00
|
|
|
|
"source": source,
|
2019-12-13 13:54:00 +01:00
|
|
|
|
"updated": updated,
|
2019-12-13 00:11:21 +01:00
|
|
|
|
"first_party": first_party,
|
|
|
|
|
"level": level,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# If the entry already exists
|
2019-12-13 18:00:00 +01:00
|
|
|
|
for entry, in cursor: # only one
|
2019-12-13 00:11:21 +01:00
|
|
|
|
self.enter_step(f'set_{table}_update')
|
|
|
|
|
rules_prep['entry'] = entry
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor.execute(
|
2019-12-13 00:11:21 +01:00
|
|
|
|
'UPDATE rules SET '
|
|
|
|
|
'source=:source, updated=:updated, '
|
|
|
|
|
'first_party=:first_party, level=:level '
|
|
|
|
|
'WHERE id=:entry AND (updated<:updated OR '
|
|
|
|
|
'first_party<:first_party OR level<:level)',
|
|
|
|
|
rules_prep
|
|
|
|
|
)
|
|
|
|
|
# Only update if any of the following:
|
|
|
|
|
# - the entry is outdataed
|
|
|
|
|
# - the entry was not a first_party but this is
|
|
|
|
|
# - this is closer to the original rule
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# If it does not exist
|
|
|
|
|
|
|
|
|
|
self.enter_step(f'set_{table}_insert')
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor.execute(
|
2019-12-13 00:11:21 +01:00
|
|
|
|
'INSERT INTO rules '
|
2019-12-13 18:00:00 +01:00
|
|
|
|
'(source, updated, first_party, level) '
|
|
|
|
|
'VALUES (:source, :updated, :first_party, :level) ',
|
2019-12-13 00:11:21 +01:00
|
|
|
|
rules_prep
|
|
|
|
|
)
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor.execute('SELECT id FROM rules WHERE rowid=?',
|
|
|
|
|
(cursor.lastrowid,))
|
|
|
|
|
for entry, in cursor: # only one
|
2019-12-13 00:11:21 +01:00
|
|
|
|
prep['entry'] = entry
|
2019-12-13 18:00:00 +01:00
|
|
|
|
cursor.execute(insert_query, prep)
|
2019-12-13 00:11:21 +01:00
|
|
|
|
return
|
|
|
|
|
assert False
|
|
|
|
|
|
|
|
|
|
def set_hostname(self, hostname: str,
|
|
|
|
|
*args: typing.Any, **kwargs: typing.Any) -> None:
|
|
|
|
|
self.enter_step('set_hostname_prepare')
|
|
|
|
|
prep: typing.Dict[str, DbValue] = {
|
2019-12-13 18:00:00 +01:00
|
|
|
|
'val': self.pack_hostname(hostname),
|
2019-12-13 00:11:21 +01:00
|
|
|
|
}
|
|
|
|
|
self._set_generic(
|
|
|
|
|
'hostname',
|
|
|
|
|
'SELECT entry FROM hostname WHERE val=:val',
|
|
|
|
|
'INSERT INTO hostname (val, entry) '
|
|
|
|
|
'VALUES (:val, :entry)',
|
|
|
|
|
prep,
|
|
|
|
|
*args, **kwargs
|
|
|
|
|
)
|
2019-12-09 08:12:48 +01:00
|
|
|
|
|
2019-12-13 08:23:38 +01:00
|
|
|
|
def set_asn(self, asn: str,
|
|
|
|
|
*args: typing.Any, **kwargs: typing.Any) -> None:
|
|
|
|
|
self.enter_step('set_asn_prepare')
|
|
|
|
|
try:
|
2019-12-13 18:00:00 +01:00
|
|
|
|
asn_prep = self.pack_asn(asn)
|
2019-12-13 08:23:38 +01:00
|
|
|
|
except ValueError:
|
|
|
|
|
self.log.error("Invalid asn: %s", asn)
|
|
|
|
|
return
|
|
|
|
|
prep: typing.Dict[str, DbValue] = {
|
|
|
|
|
'val': asn_prep,
|
|
|
|
|
}
|
|
|
|
|
self._set_generic(
|
|
|
|
|
'asn',
|
|
|
|
|
'SELECT entry FROM asn WHERE val=:val',
|
|
|
|
|
'INSERT INTO asn (val, entry) '
|
|
|
|
|
'VALUES (:val, :entry)',
|
|
|
|
|
prep,
|
|
|
|
|
*args, **kwargs
|
|
|
|
|
)
|
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
def set_ip4address(self, ip4address: str,
|
|
|
|
|
*args: typing.Any, **kwargs: typing.Any) -> None:
|
|
|
|
|
self.enter_step('set_ip4add_prepare')
|
|
|
|
|
try:
|
2019-12-13 18:00:00 +01:00
|
|
|
|
ip4address_prep = self.pack_ip4address(ip4address)
|
2019-12-13 00:11:21 +01:00
|
|
|
|
except (ValueError, IndexError):
|
|
|
|
|
self.log.error("Invalid ip4address: %s", ip4address)
|
|
|
|
|
return
|
|
|
|
|
prep: typing.Dict[str, DbValue] = {
|
|
|
|
|
'val': ip4address_prep,
|
|
|
|
|
}
|
|
|
|
|
self._set_generic(
|
|
|
|
|
'ip4add',
|
|
|
|
|
'SELECT entry FROM ip4address WHERE val=:val',
|
|
|
|
|
'INSERT INTO ip4address (val, entry) '
|
|
|
|
|
'VALUES (:val, :entry)',
|
|
|
|
|
prep,
|
|
|
|
|
*args, **kwargs
|
|
|
|
|
)
|
2019-12-09 08:12:48 +01:00
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
def set_zone(self, zone: str,
|
|
|
|
|
*args: typing.Any, **kwargs: typing.Any) -> None:
|
|
|
|
|
self.enter_step('set_zone_prepare')
|
|
|
|
|
prep: typing.Dict[str, DbValue] = {
|
2019-12-13 18:00:00 +01:00
|
|
|
|
'val': self.pack_zone(zone),
|
2019-12-13 00:11:21 +01:00
|
|
|
|
}
|
|
|
|
|
self._set_generic(
|
|
|
|
|
'zone',
|
|
|
|
|
'SELECT entry FROM zone WHERE val=:val',
|
|
|
|
|
'INSERT INTO zone (val, entry) '
|
|
|
|
|
'VALUES (:val, :entry)',
|
|
|
|
|
prep,
|
|
|
|
|
*args, **kwargs
|
|
|
|
|
)
|
2019-12-09 08:12:48 +01:00
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
def set_ip4network(self, ip4network: str,
|
|
|
|
|
*args: typing.Any, **kwargs: typing.Any) -> None:
|
|
|
|
|
self.enter_step('set_ip4net_prepare')
|
|
|
|
|
try:
|
2019-12-13 18:00:00 +01:00
|
|
|
|
ip4network_prep = self.pack_ip4network(ip4network)
|
2019-12-13 00:11:21 +01:00
|
|
|
|
except (ValueError, IndexError):
|
|
|
|
|
self.log.error("Invalid ip4network: %s", ip4network)
|
|
|
|
|
return
|
|
|
|
|
prep: typing.Dict[str, DbValue] = {
|
|
|
|
|
'mini': ip4network_prep[0],
|
|
|
|
|
'maxi': ip4network_prep[1],
|
|
|
|
|
}
|
|
|
|
|
self._set_generic(
|
|
|
|
|
'ip4net',
|
|
|
|
|
'SELECT entry FROM ip4network WHERE mini=:mini AND maxi=:maxi',
|
|
|
|
|
'INSERT INTO ip4network (mini, maxi, entry) '
|
|
|
|
|
'VALUES (:mini, :maxi, :entry)',
|
|
|
|
|
prep,
|
|
|
|
|
*args, **kwargs
|
2019-12-09 08:12:48 +01:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
# Parsing arguments
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
|
description="Database operations")
|
|
|
|
|
parser.add_argument(
|
2019-12-13 00:11:21 +01:00
|
|
|
|
'-i', '--initialize', action='store_true',
|
|
|
|
|
help="Reconstruct the whole database")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'-p', '--prune', action='store_true',
|
2019-12-13 13:54:00 +01:00
|
|
|
|
help="Remove old (+6 months) entries from database")
|
2019-12-13 00:11:21 +01:00
|
|
|
|
parser.add_argument(
|
|
|
|
|
'-r', '--references', action='store_true',
|
|
|
|
|
help="Update the reference count")
|
2019-12-09 08:12:48 +01:00
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
2019-12-13 12:35:05 +01:00
|
|
|
|
DB = Database(write=True)
|
2019-12-09 08:12:48 +01:00
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
if args.initialize:
|
|
|
|
|
DB.initialize()
|
|
|
|
|
if args.prune:
|
2019-12-13 13:54:00 +01:00
|
|
|
|
DB.prune(before=int(time.time()) - 60*60*24*31*6)
|
2019-12-13 00:11:21 +01:00
|
|
|
|
if args.references and not args.prune:
|
|
|
|
|
DB.update_references()
|
2019-12-09 08:12:48 +01:00
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
DB.close()
|