|
|
@ -12,6 +12,7 @@ import logging |
|
|
|
import argparse |
|
|
|
import coloredlogs |
|
|
|
import ipaddress |
|
|
|
import math |
|
|
|
|
|
|
|
coloredlogs.install( |
|
|
|
level='DEBUG', |
|
|
@ -22,43 +23,47 @@ DbValue = typing.Union[None, int, float, str, bytes] |
|
|
|
|
|
|
|
|
|
|
|
class Database(): |
|
|
|
VERSION = 3 |
|
|
|
VERSION = 4 |
|
|
|
PATH = "blocking.db" |
|
|
|
|
|
|
|
def open(self) -> None: |
|
|
|
mode = 'rwc' if self.write else 'ro' |
|
|
|
uri = f'file:{self.PATH}?mode={mode}' |
|
|
|
self.conn = sqlite3.connect(uri, uri=True) |
|
|
|
self.cursor = self.conn.cursor() |
|
|
|
self.execute("PRAGMA foreign_keys = ON") |
|
|
|
# self.conn.create_function("prepare_ip4address", 1, |
|
|
|
# Database.prepare_ip4address, |
|
|
|
# deterministic=True) |
|
|
|
cursor = self.conn.cursor() |
|
|
|
cursor.execute("PRAGMA foreign_keys = ON") |
|
|
|
self.conn.create_function("unpack_asn", 1, |
|
|
|
self.unpack_asn, |
|
|
|
deterministic=True) |
|
|
|
self.conn.create_function("unpack_ip4address", 1, |
|
|
|
self.unpack_ip4address, |
|
|
|
deterministic=True) |
|
|
|
self.conn.create_function("unpack_ip4network", 2, |
|
|
|
self.unpack_ip4network, |
|
|
|
deterministic=True) |
|
|
|
self.conn.create_function("unpack_domain", 1, |
|
|
|
lambda s: s[:-1][::-1], |
|
|
|
deterministic=True) |
|
|
|
|
|
|
|
def execute(self, cmd: str, args: typing.Union[ |
|
|
|
typing.Tuple[DbValue, ...], |
|
|
|
typing.Dict[str, DbValue]] = None) -> None: |
|
|
|
# self.log.debug(cmd) |
|
|
|
# self.log.debug(args) |
|
|
|
self.cursor.execute(cmd, args or tuple()) |
|
|
|
self.conn.create_function("format_zone", 1, |
|
|
|
lambda s: '*' + s[::-1], |
|
|
|
deterministic=True) |
|
|
|
|
|
|
|
def get_meta(self, key: str) -> typing.Optional[int]: |
|
|
|
cursor = self.conn.cursor() |
|
|
|
try: |
|
|
|
self.execute("SELECT value FROM meta WHERE key=?", (key,)) |
|
|
|
cursor.execute("SELECT value FROM meta WHERE key=?", (key,)) |
|
|
|
except sqlite3.OperationalError: |
|
|
|
return None |
|
|
|
for ver, in self.cursor: |
|
|
|
for ver, in cursor: |
|
|
|
return ver |
|
|
|
return None |
|
|
|
|
|
|
|
def set_meta(self, key: str, val: int) -> None: |
|
|
|
self.execute("INSERT INTO meta VALUES (?, ?) " |
|
|
|
"ON CONFLICT (key) DO " |
|
|
|
"UPDATE set value=?", |
|
|
|
(key, val, val)) |
|
|
|
cursor = self.conn.cursor() |
|
|
|
cursor.execute("INSERT INTO meta VALUES (?, ?) " |
|
|
|
"ON CONFLICT (key) DO " |
|
|
|
"UPDATE set value=?", |
|
|
|
(key, val, val)) |
|
|
|
|
|
|
|
def close(self) -> None: |
|
|
|
self.enter_step('close_commit') |
|
|
@ -76,8 +81,9 @@ class Database(): |
|
|
|
os.unlink(self.PATH) |
|
|
|
self.open() |
|
|
|
self.log.info("Creating database version %d.", self.VERSION) |
|
|
|
cursor = self.conn.cursor() |
|
|
|
with open("database_schema.sql", 'r') as db_schema: |
|
|
|
self.cursor.executescript(db_schema.read()) |
|
|
|
cursor.executescript(db_schema.read()) |
|
|
|
self.set_meta('version', self.VERSION) |
|
|
|
self.conn.commit() |
|
|
|
|
|
|
@ -119,21 +125,27 @@ class Database(): |
|
|
|
self.log.debug(f"{'total':<20}: " |
|
|
|
f"{total:9.2f} s ({1:7.2%})") |
|
|
|
|
|
|
|
def prepare_hostname(self, hostname: str) -> str: |
|
|
|
@staticmethod |
|
|
|
def pack_hostname(hostname: str) -> str: |
|
|
|
return hostname[::-1] + '.' |
|
|
|
|
|
|
|
def prepare_zone(self, zone: str) -> str: |
|
|
|
return self.prepare_hostname(zone) |
|
|
|
@staticmethod |
|
|
|
def pack_zone(zone: str) -> str: |
|
|
|
return Database.pack_hostname(zone) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def prepare_asn(asn: str) -> int: |
|
|
|
def pack_asn(asn: str) -> int: |
|
|
|
asn = asn.upper() |
|
|
|
if asn.startswith('AS'): |
|
|
|
asn = asn[2:] |
|
|
|
return int(asn) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def prepare_ip4address(address: str) -> int: |
|
|
|
def unpack_asn(asn: int) -> str: |
|
|
|
return f'AS{asn}' |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def pack_ip4address(address: str) -> int: |
|
|
|
total = 0 |
|
|
|
for i, octet in enumerate(address.split('.')): |
|
|
|
total += int(octet) << (3-i)*8 |
|
|
@ -151,29 +163,75 @@ class Database(): |
|
|
|
# packed = ipaddress.ip_address(address).packed |
|
|
|
# return packed |
|
|
|
|
|
|
|
def prepare_ip4network(self, network: str) -> typing.Tuple[int, int]: |
|
|
|
# def prepare_ip4network(network: str) -> str: |
|
|
|
@staticmethod |
|
|
|
def unpack_ip4address(address: int) -> str: |
|
|
|
return '.'.join(str((address >> (i * 8)) & 0xFF) |
|
|
|
for i in reversed(range(4))) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def pack_ip4network(network: str) -> typing.Tuple[int, int]: |
|
|
|
# def pack_ip4network(network: str) -> str: |
|
|
|
net = ipaddress.ip_network(network) |
|
|
|
mini = self.prepare_ip4address(net.network_address.exploded) |
|
|
|
maxi = self.prepare_ip4address(net.broadcast_address.exploded) |
|
|
|
mini = Database.pack_ip4address(net.network_address.exploded) |
|
|
|
maxi = Database.pack_ip4address(net.broadcast_address.exploded) |
|
|
|
# mini = net.network_address.packed |
|
|
|
# maxi = net.broadcast_address.packed |
|
|
|
return mini, maxi |
|
|
|
# return Database.prepare_ip4address(net.network_address.exploded)[:net.prefixlen] |
|
|
|
# return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen] |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def unpack_ip4network(mini: int, maxi: int) -> str: |
|
|
|
addr = Database.unpack_ip4address(mini) |
|
|
|
prefixlen = 32-int(math.log2(maxi-mini+1)) |
|
|
|
return f'{addr}/{prefixlen}' |
|
|
|
|
|
|
|
def update_references(self) -> None: |
|
|
|
self.enter_step('update_refs') |
|
|
|
self.execute('UPDATE rules AS r SET refs=' |
|
|
|
'(SELECT count(*) FROM rules ' |
|
|
|
'WHERE source=r.id)') |
|
|
|
cursor = self.conn.cursor() |
|
|
|
cursor.execute('UPDATE rules AS r SET refs=' |
|
|
|
'(SELECT count(*) FROM rules ' |
|
|
|
'WHERE source=r.id)') |
|
|
|
|
|
|
|
def prune(self, before: int) -> None: |
|
|
|
self.enter_step('prune') |
|
|
|
self.execute('DELETE FROM rules WHERE updated<?', (before,)) |
|
|
|
|
|
|
|
def export(self, first_party_only: bool = False, |
|
|
|
end_chain_only: bool = False) -> typing.Iterable[str]: |
|
|
|
command = 'SELECT unpack_domain(val) FROM rules ' \ |
|
|
|
cursor = self.conn.cursor() |
|
|
|
cursor.execute('DELETE FROM rules WHERE updated<?', (before,)) |
|
|
|
|
|
|
|
def explain(self, entry: int) -> str: |
|
|
|
# Format current |
|
|
|
string = '???' |
|
|
|
cursor = self.conn.cursor() |
|
|
|
cursor.execute( |
|
|
|
'SELECT unpack_asn(val) FROM asn WHERE entry=:entry ' |
|
|
|
'UNION ' |
|
|
|
'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry ' |
|
|
|
'UNION ' |
|
|
|
'SELECT format_zone(val) FROM zone WHERE entry=:entry ' |
|
|
|
'UNION ' |
|
|
|
'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry ' |
|
|
|
'UNION ' |
|
|
|
'SELECT unpack_ip4network(mini, maxi) ' |
|
|
|
'FROM ip4network WHERE entry=:entry ', |
|
|
|
{"entry": entry} |
|
|
|
) |
|
|
|
for val, in cursor: # only one |
|
|
|
string = str(val) |
|
|
|
string += f' #{entry}' |
|
|
|
|
|
|
|
# Add source if any |
|
|
|
cursor.execute('SELECT source FROM rules WHERE id=?', (entry,)) |
|
|
|
for source, in cursor: |
|
|
|
if source: |
|
|
|
string += f' ← {self.explain(source)}' |
|
|
|
return string |
|
|
|
|
|
|
|
def export(self, |
|
|
|
first_party_only: bool = False, |
|
|
|
end_chain_only: bool = False, |
|
|
|
explain: bool = False, |
|
|
|
) -> typing.Iterable[str]: |
|
|
|
selection = 'entry' if explain else 'unpack_domain(val)' |
|
|
|
command = f'SELECT {selection} FROM rules ' \ |
|
|
|
'INNER JOIN hostname ON rules.id = hostname.entry' |
|
|
|
restrictions: typing.List[str] = list() |
|
|
|
if first_party_only: |
|
|
@ -182,16 +240,22 @@ class Database(): |
|
|
|
restrictions.append('rules.refs = 0') |
|
|
|
if restrictions: |
|
|
|
command += ' WHERE ' + ' AND '.join(restrictions) |
|
|
|
command += ' ORDER BY unpack_domain(val) ASC' |
|
|
|
self.execute(command) |
|
|
|
for val, in self.cursor: |
|
|
|
yield val |
|
|
|
if not explain: |
|
|
|
command += ' ORDER BY unpack_domain(val) ASC' |
|
|
|
cursor = self.conn.cursor() |
|
|
|
cursor.execute(command) |
|
|
|
for val, in cursor: |
|
|
|
if explain: |
|
|
|
yield self.explain(val) |
|
|
|
else: |
|
|
|
yield val |
|
|
|
|
|
|
|
def get_domain(self, domain: str) -> typing.Iterable[int]: |
|
|
|
self.enter_step('get_domain_prepare') |
|
|
|
domain_prep = self.prepare_hostname(domain) |
|
|
|
domain_prep = self.pack_hostname(domain) |
|
|
|
cursor = self.conn.cursor() |
|
|
|
self.enter_step('get_domain_select') |
|
|
|
self.execute( |
|
|
|
cursor.execute( |
|
|
|
'SELECT null, entry FROM hostname ' |
|
|
|
'WHERE val=:d ' |
|
|
|
'UNION ' |
|
|
@ -202,22 +266,41 @@ class Database(): |
|
|
|
')', |
|
|
|
{'d': domain_prep} |
|
|
|
) |
|
|
|
for val, entry in self.cursor: |
|
|
|
for val, entry in cursor: |
|
|
|
self.enter_step('get_domain_confirm') |
|
|
|
if not (val is None or domain_prep.startswith(val)): |
|
|
|
continue |
|
|
|
self.enter_step('get_domain_yield') |
|
|
|
yield entry |
|
|
|
|
|
|
|
def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]: |
|
|
|
self.enter_step('get_domainiz_prepare') |
|
|
|
domain_prep = self.pack_hostname(domain) |
|
|
|
cursor = self.conn.cursor() |
|
|
|
self.enter_step('get_domainiz_select') |
|
|
|
cursor.execute( |
|
|
|
'SELECT val, entry FROM zone ' |
|
|
|
'WHERE val<=:d ' |
|
|
|
'ORDER BY val DESC LIMIT 1', |
|
|
|
{'d': domain_prep} |
|
|
|
) |
|
|
|
for val, entry in cursor: |
|
|
|
self.enter_step('get_domainiz_confirm') |
|
|
|
if not (val is None or domain_prep.startswith(val)): |
|
|
|
continue |
|
|
|
self.enter_step('get_domainiz_yield') |
|
|
|
yield entry |
|
|
|
|
|
|
|
def get_ip4(self, address: str) -> typing.Iterable[int]: |
|
|
|
self.enter_step('get_ip4_prepare') |
|
|
|
try: |
|
|
|
address_prep = self.prepare_ip4address(address) |
|
|
|
address_prep = self.pack_ip4address(address) |
|
|
|
except (ValueError, IndexError): |
|
|
|
self.log.error("Invalid ip4address: %s", address) |
|
|
|
return |
|
|
|
cursor = self.conn.cursor() |
|
|
|
self.enter_step('get_ip4_select') |
|
|
|
self.execute( |
|
|
|
cursor.execute( |
|
|
|
'SELECT entry FROM ip4address ' |
|
|
|
# 'SELECT null, entry FROM ip4address ' |
|
|
|
'WHERE val=:a ' |
|
|
@ -232,7 +315,7 @@ class Database(): |
|
|
|
'WHERE :a BETWEEN mini AND maxi ', |
|
|
|
{'a': address_prep} |
|
|
|
) |
|
|
|
for val, entry in self.cursor: |
|
|
|
for entry, in cursor: |
|
|
|
# self.enter_step('get_ip4_confirm') |
|
|
|
# if not (val is None or val.startswith(address_prep)): |
|
|
|
# # PERF startswith but from the end |
|
|
@ -240,11 +323,29 @@ class Database(): |
|
|
|
self.enter_step('get_ip4_yield') |
|
|
|
yield entry |
|
|
|
|
|
|
|
def get_ip4_in_network(self, address: str) -> typing.Iterable[int]: |
|
|
|
self.enter_step('get_ip4in_prepare') |
|
|
|
try: |
|
|
|
address_prep = self.pack_ip4address(address) |
|
|
|
except (ValueError, IndexError): |
|
|
|
self.log.error("Invalid ip4address: %s", address) |
|
|
|
return |
|
|
|
cursor = self.conn.cursor() |
|
|
|
self.enter_step('get_ip4in_select') |
|
|
|
cursor.execute( |
|
|
|
'SELECT entry FROM ip4network ' |
|
|
|
'WHERE :a BETWEEN mini AND maxi ', |
|
|
|
{'a': address_prep} |
|
|
|
) |
|
|
|
for entry, in cursor: |
|
|
|
self.enter_step('get_ip4in_yield') |
|
|
|
yield entry |
|
|
|
|
|
|
|
def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]: |
|
|
|
cursor = self.conn.cursor() |
|
|
|
self.enter_step('list_asn_select') |
|
|
|
self.enter_step('get_domain_select') |
|
|
|
self.execute('SELECT val, entry FROM asn') |
|
|
|
for val, entry in self.cursor: |
|
|
|
cursor.execute('SELECT val, entry FROM asn') |
|
|
|
for val, entry in cursor: |
|
|
|
yield f'AS{val}', entry |
|
|
|
|
|
|
|
def _set_generic(self, |
|
|
@ -260,21 +361,23 @@ class Database(): |
|
|
|
# here abstraction > performaces |
|
|
|
|
|
|
|
# Fields based on the source |
|
|
|
self.enter_step(f'set_{table}_prepare') |
|
|
|
cursor = self.conn.cursor() |
|
|
|
if source is None: |
|
|
|
first_party = int(is_first_party) |
|
|
|
level = 0 |
|
|
|
else: |
|
|
|
self.enter_step(f'set_{table}_source') |
|
|
|
self.execute( |
|
|
|
cursor.execute( |
|
|
|
'SELECT first_party, level FROM rules ' |
|
|
|
'WHERE id=?', |
|
|
|
(source,) |
|
|
|
) |
|
|
|
first_party, level = self.cursor.fetchone() |
|
|
|
first_party, level = cursor.fetchone() |
|
|
|
level += 1 |
|
|
|
|
|
|
|
self.enter_step(f'set_{table}_select') |
|
|
|
self.execute(select_query, prep) |
|
|
|
cursor.execute(select_query, prep) |
|
|
|
|
|
|
|
rules_prep: typing.Dict[str, DbValue] = { |
|
|
|
"source": source, |
|
|
@ -284,10 +387,10 @@ class Database(): |
|
|
|
} |
|
|
|
|
|
|
|
# If the entry already exists |
|
|
|
for entry, in self.cursor: # only one |
|
|
|
for entry, in cursor: # only one |
|
|
|
self.enter_step(f'set_{table}_update') |
|
|
|
rules_prep['entry'] = entry |
|
|
|
self.execute( |
|
|
|
cursor.execute( |
|
|
|
'UPDATE rules SET ' |
|
|
|
'source=:source, updated=:updated, ' |
|
|
|
'first_party=:first_party, level=:level ' |
|
|
@ -303,23 +406,18 @@ class Database(): |
|
|
|
|
|
|
|
# If it does not exist |
|
|
|
|
|
|
|
if source is not None: |
|
|
|
self.enter_step(f'set_{table}_incsrc') |
|
|
|
self.execute('UPDATE rules SET refs = refs + 1 WHERE id=?', |
|
|
|
(source,)) |
|
|
|
|
|
|
|
self.enter_step(f'set_{table}_insert') |
|
|
|
self.execute( |
|
|
|
cursor.execute( |
|
|
|
'INSERT INTO rules ' |
|
|
|
'(source, updated, first_party, refs, level) ' |
|
|
|
'VALUES (:source, :updated, :first_party, 0, :level) ', |
|
|
|
'(source, updated, first_party, level) ' |
|
|
|
'VALUES (:source, :updated, :first_party, :level) ', |
|
|
|
rules_prep |
|
|
|
) |
|
|
|
self.execute('SELECT id FROM rules WHERE rowid=?', |
|
|
|
(self.cursor.lastrowid,)) |
|
|
|
for entry, in self.cursor: # only one |
|
|
|
cursor.execute('SELECT id FROM rules WHERE rowid=?', |
|
|
|
(cursor.lastrowid,)) |
|
|
|
for entry, in cursor: # only one |
|
|
|
prep['entry'] = entry |
|
|
|
self.execute(insert_query, prep) |
|
|
|
cursor.execute(insert_query, prep) |
|
|
|
return |
|
|
|
assert False |
|
|
|
|
|
|
@ -327,7 +425,7 @@ class Database(): |
|
|
|
*args: typing.Any, **kwargs: typing.Any) -> None: |
|
|
|
self.enter_step('set_hostname_prepare') |
|
|
|
prep: typing.Dict[str, DbValue] = { |
|
|
|
'val': self.prepare_hostname(hostname), |
|
|
|
'val': self.pack_hostname(hostname), |
|
|
|
} |
|
|
|
self._set_generic( |
|
|
|
'hostname', |
|
|
@ -342,7 +440,7 @@ class Database(): |
|
|
|
*args: typing.Any, **kwargs: typing.Any) -> None: |
|
|
|
self.enter_step('set_asn_prepare') |
|
|
|
try: |
|
|
|
asn_prep = self.prepare_asn(asn) |
|
|
|
asn_prep = self.pack_asn(asn) |
|
|
|
except ValueError: |
|
|
|
self.log.error("Invalid asn: %s", asn) |
|
|
|
return |
|
|
@ -360,10 +458,9 @@ class Database(): |
|
|
|
|
|
|
|
def set_ip4address(self, ip4address: str, |
|
|
|
*args: typing.Any, **kwargs: typing.Any) -> None: |
|
|
|
# TODO Do not add if already in ip4network |
|
|
|
self.enter_step('set_ip4add_prepare') |
|
|
|
try: |
|
|
|
ip4address_prep = self.prepare_ip4address(ip4address) |
|
|
|
ip4address_prep = self.pack_ip4address(ip4address) |
|
|
|
except (ValueError, IndexError): |
|
|
|
self.log.error("Invalid ip4address: %s", ip4address) |
|
|
|
return |
|
|
@ -383,7 +480,7 @@ class Database(): |
|
|
|
*args: typing.Any, **kwargs: typing.Any) -> None: |
|
|
|
self.enter_step('set_zone_prepare') |
|
|
|
prep: typing.Dict[str, DbValue] = { |
|
|
|
'val': self.prepare_zone(zone), |
|
|
|
'val': self.pack_zone(zone), |
|
|
|
} |
|
|
|
self._set_generic( |
|
|
|
'zone', |
|
|
@ -398,7 +495,7 @@ class Database(): |
|
|
|
*args: typing.Any, **kwargs: typing.Any) -> None: |
|
|
|
self.enter_step('set_ip4net_prepare') |
|
|
|
try: |
|
|
|
ip4network_prep = self.prepare_ip4network(ip4network) |
|
|
|
ip4network_prep = self.pack_ip4network(ip4network) |
|
|
|
except (ValueError, IndexError): |
|
|
|
self.log.error("Invalid ip4network: %s", ip4network) |
|
|
|
return |
|
|
|