Workflow: Various optimisations and fixes
I forgot to close this one earlier, so: Closes #7
This commit is contained in:
parent
f3eedcba22
commit
ab7ef609dd
8 changed files with 214 additions and 117 deletions
243
database.py
243
database.py
|
@ -12,6 +12,7 @@ import logging
|
|||
import argparse
|
||||
import coloredlogs
|
||||
import ipaddress
|
||||
import math
|
||||
|
||||
coloredlogs.install(
|
||||
level='DEBUG',
|
||||
|
@ -22,43 +23,47 @@ DbValue = typing.Union[None, int, float, str, bytes]
|
|||
|
||||
|
||||
class Database():
|
||||
VERSION = 3
|
||||
VERSION = 4
|
||||
PATH = "blocking.db"
|
||||
|
||||
def open(self) -> None:
|
||||
mode = 'rwc' if self.write else 'ro'
|
||||
uri = f'file:{self.PATH}?mode={mode}'
|
||||
self.conn = sqlite3.connect(uri, uri=True)
|
||||
self.cursor = self.conn.cursor()
|
||||
self.execute("PRAGMA foreign_keys = ON")
|
||||
# self.conn.create_function("prepare_ip4address", 1,
|
||||
# Database.prepare_ip4address,
|
||||
# deterministic=True)
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys = ON")
|
||||
self.conn.create_function("unpack_asn", 1,
|
||||
self.unpack_asn,
|
||||
deterministic=True)
|
||||
self.conn.create_function("unpack_ip4address", 1,
|
||||
self.unpack_ip4address,
|
||||
deterministic=True)
|
||||
self.conn.create_function("unpack_ip4network", 2,
|
||||
self.unpack_ip4network,
|
||||
deterministic=True)
|
||||
self.conn.create_function("unpack_domain", 1,
|
||||
lambda s: s[:-1][::-1],
|
||||
deterministic=True)
|
||||
|
||||
def execute(self, cmd: str, args: typing.Union[
|
||||
typing.Tuple[DbValue, ...],
|
||||
typing.Dict[str, DbValue]] = None) -> None:
|
||||
# self.log.debug(cmd)
|
||||
# self.log.debug(args)
|
||||
self.cursor.execute(cmd, args or tuple())
|
||||
self.conn.create_function("format_zone", 1,
|
||||
lambda s: '*' + s[::-1],
|
||||
deterministic=True)
|
||||
|
||||
def get_meta(self, key: str) -> typing.Optional[int]:
|
||||
cursor = self.conn.cursor()
|
||||
try:
|
||||
self.execute("SELECT value FROM meta WHERE key=?", (key,))
|
||||
cursor.execute("SELECT value FROM meta WHERE key=?", (key,))
|
||||
except sqlite3.OperationalError:
|
||||
return None
|
||||
for ver, in self.cursor:
|
||||
for ver, in cursor:
|
||||
return ver
|
||||
return None
|
||||
|
||||
def set_meta(self, key: str, val: int) -> None:
|
||||
self.execute("INSERT INTO meta VALUES (?, ?) "
|
||||
"ON CONFLICT (key) DO "
|
||||
"UPDATE set value=?",
|
||||
(key, val, val))
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute("INSERT INTO meta VALUES (?, ?) "
|
||||
"ON CONFLICT (key) DO "
|
||||
"UPDATE set value=?",
|
||||
(key, val, val))
|
||||
|
||||
def close(self) -> None:
|
||||
self.enter_step('close_commit')
|
||||
|
@ -76,8 +81,9 @@ class Database():
|
|||
os.unlink(self.PATH)
|
||||
self.open()
|
||||
self.log.info("Creating database version %d.", self.VERSION)
|
||||
cursor = self.conn.cursor()
|
||||
with open("database_schema.sql", 'r') as db_schema:
|
||||
self.cursor.executescript(db_schema.read())
|
||||
cursor.executescript(db_schema.read())
|
||||
self.set_meta('version', self.VERSION)
|
||||
self.conn.commit()
|
||||
|
||||
|
@ -119,21 +125,27 @@ class Database():
|
|||
self.log.debug(f"{'total':<20}: "
|
||||
f"{total:9.2f} s ({1:7.2%})")
|
||||
|
||||
def prepare_hostname(self, hostname: str) -> str:
|
||||
@staticmethod
|
||||
def pack_hostname(hostname: str) -> str:
|
||||
return hostname[::-1] + '.'
|
||||
|
||||
def prepare_zone(self, zone: str) -> str:
|
||||
return self.prepare_hostname(zone)
|
||||
@staticmethod
|
||||
def pack_zone(zone: str) -> str:
|
||||
return Database.pack_hostname(zone)
|
||||
|
||||
@staticmethod
|
||||
def prepare_asn(asn: str) -> int:
|
||||
def pack_asn(asn: str) -> int:
|
||||
asn = asn.upper()
|
||||
if asn.startswith('AS'):
|
||||
asn = asn[2:]
|
||||
return int(asn)
|
||||
|
||||
@staticmethod
|
||||
def prepare_ip4address(address: str) -> int:
|
||||
def unpack_asn(asn: int) -> str:
|
||||
return f'AS{asn}'
|
||||
|
||||
@staticmethod
|
||||
def pack_ip4address(address: str) -> int:
|
||||
total = 0
|
||||
for i, octet in enumerate(address.split('.')):
|
||||
total += int(octet) << (3-i)*8
|
||||
|
@ -151,29 +163,75 @@ class Database():
|
|||
# packed = ipaddress.ip_address(address).packed
|
||||
# return packed
|
||||
|
||||
def prepare_ip4network(self, network: str) -> typing.Tuple[int, int]:
|
||||
# def prepare_ip4network(network: str) -> str:
|
||||
@staticmethod
|
||||
def unpack_ip4address(address: int) -> str:
|
||||
return '.'.join(str((address >> (i * 8)) & 0xFF)
|
||||
for i in reversed(range(4)))
|
||||
|
||||
@staticmethod
|
||||
def pack_ip4network(network: str) -> typing.Tuple[int, int]:
|
||||
# def pack_ip4network(network: str) -> str:
|
||||
net = ipaddress.ip_network(network)
|
||||
mini = self.prepare_ip4address(net.network_address.exploded)
|
||||
maxi = self.prepare_ip4address(net.broadcast_address.exploded)
|
||||
mini = Database.pack_ip4address(net.network_address.exploded)
|
||||
maxi = Database.pack_ip4address(net.broadcast_address.exploded)
|
||||
# mini = net.network_address.packed
|
||||
# maxi = net.broadcast_address.packed
|
||||
return mini, maxi
|
||||
# return Database.prepare_ip4address(net.network_address.exploded)[:net.prefixlen]
|
||||
# return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen]
|
||||
|
||||
@staticmethod
|
||||
def unpack_ip4network(mini: int, maxi: int) -> str:
|
||||
addr = Database.unpack_ip4address(mini)
|
||||
prefixlen = 32-int(math.log2(maxi-mini+1))
|
||||
return f'{addr}/{prefixlen}'
|
||||
|
||||
def update_references(self) -> None:
|
||||
self.enter_step('update_refs')
|
||||
self.execute('UPDATE rules AS r SET refs='
|
||||
'(SELECT count(*) FROM rules '
|
||||
'WHERE source=r.id)')
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute('UPDATE rules AS r SET refs='
|
||||
'(SELECT count(*) FROM rules '
|
||||
'WHERE source=r.id)')
|
||||
|
||||
def prune(self, before: int) -> None:
|
||||
self.enter_step('prune')
|
||||
self.execute('DELETE FROM rules WHERE updated<?', (before,))
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute('DELETE FROM rules WHERE updated<?', (before,))
|
||||
|
||||
def export(self, first_party_only: bool = False,
|
||||
end_chain_only: bool = False) -> typing.Iterable[str]:
|
||||
command = 'SELECT unpack_domain(val) FROM rules ' \
|
||||
def explain(self, entry: int) -> str:
|
||||
# Format current
|
||||
string = '???'
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
'SELECT unpack_asn(val) FROM asn WHERE entry=:entry '
|
||||
'UNION '
|
||||
'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry '
|
||||
'UNION '
|
||||
'SELECT format_zone(val) FROM zone WHERE entry=:entry '
|
||||
'UNION '
|
||||
'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry '
|
||||
'UNION '
|
||||
'SELECT unpack_ip4network(mini, maxi) '
|
||||
'FROM ip4network WHERE entry=:entry ',
|
||||
{"entry": entry}
|
||||
)
|
||||
for val, in cursor: # only one
|
||||
string = str(val)
|
||||
string += f' #{entry}'
|
||||
|
||||
# Add source if any
|
||||
cursor.execute('SELECT source FROM rules WHERE id=?', (entry,))
|
||||
for source, in cursor:
|
||||
if source:
|
||||
string += f' ← {self.explain(source)}'
|
||||
return string
|
||||
|
||||
def export(self,
|
||||
first_party_only: bool = False,
|
||||
end_chain_only: bool = False,
|
||||
explain: bool = False,
|
||||
) -> typing.Iterable[str]:
|
||||
selection = 'entry' if explain else 'unpack_domain(val)'
|
||||
command = f'SELECT {selection} FROM rules ' \
|
||||
'INNER JOIN hostname ON rules.id = hostname.entry'
|
||||
restrictions: typing.List[str] = list()
|
||||
if first_party_only:
|
||||
|
@ -182,16 +240,22 @@ class Database():
|
|||
restrictions.append('rules.refs = 0')
|
||||
if restrictions:
|
||||
command += ' WHERE ' + ' AND '.join(restrictions)
|
||||
command += ' ORDER BY unpack_domain(val) ASC'
|
||||
self.execute(command)
|
||||
for val, in self.cursor:
|
||||
yield val
|
||||
if not explain:
|
||||
command += ' ORDER BY unpack_domain(val) ASC'
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(command)
|
||||
for val, in cursor:
|
||||
if explain:
|
||||
yield self.explain(val)
|
||||
else:
|
||||
yield val
|
||||
|
||||
def get_domain(self, domain: str) -> typing.Iterable[int]:
|
||||
self.enter_step('get_domain_prepare')
|
||||
domain_prep = self.prepare_hostname(domain)
|
||||
domain_prep = self.pack_hostname(domain)
|
||||
cursor = self.conn.cursor()
|
||||
self.enter_step('get_domain_select')
|
||||
self.execute(
|
||||
cursor.execute(
|
||||
'SELECT null, entry FROM hostname '
|
||||
'WHERE val=:d '
|
||||
'UNION '
|
||||
|
@ -202,22 +266,41 @@ class Database():
|
|||
')',
|
||||
{'d': domain_prep}
|
||||
)
|
||||
for val, entry in self.cursor:
|
||||
for val, entry in cursor:
|
||||
self.enter_step('get_domain_confirm')
|
||||
if not (val is None or domain_prep.startswith(val)):
|
||||
continue
|
||||
self.enter_step('get_domain_yield')
|
||||
yield entry
|
||||
|
||||
def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]:
|
||||
self.enter_step('get_domainiz_prepare')
|
||||
domain_prep = self.pack_hostname(domain)
|
||||
cursor = self.conn.cursor()
|
||||
self.enter_step('get_domainiz_select')
|
||||
cursor.execute(
|
||||
'SELECT val, entry FROM zone '
|
||||
'WHERE val<=:d '
|
||||
'ORDER BY val DESC LIMIT 1',
|
||||
{'d': domain_prep}
|
||||
)
|
||||
for val, entry in cursor:
|
||||
self.enter_step('get_domainiz_confirm')
|
||||
if not (val is None or domain_prep.startswith(val)):
|
||||
continue
|
||||
self.enter_step('get_domainiz_yield')
|
||||
yield entry
|
||||
|
||||
def get_ip4(self, address: str) -> typing.Iterable[int]:
|
||||
self.enter_step('get_ip4_prepare')
|
||||
try:
|
||||
address_prep = self.prepare_ip4address(address)
|
||||
address_prep = self.pack_ip4address(address)
|
||||
except (ValueError, IndexError):
|
||||
self.log.error("Invalid ip4address: %s", address)
|
||||
return
|
||||
cursor = self.conn.cursor()
|
||||
self.enter_step('get_ip4_select')
|
||||
self.execute(
|
||||
cursor.execute(
|
||||
'SELECT entry FROM ip4address '
|
||||
# 'SELECT null, entry FROM ip4address '
|
||||
'WHERE val=:a '
|
||||
|
@ -232,7 +315,7 @@ class Database():
|
|||
'WHERE :a BETWEEN mini AND maxi ',
|
||||
{'a': address_prep}
|
||||
)
|
||||
for val, entry in self.cursor:
|
||||
for entry, in cursor:
|
||||
# self.enter_step('get_ip4_confirm')
|
||||
# if not (val is None or val.startswith(address_prep)):
|
||||
# # PERF startswith but from the end
|
||||
|
@ -240,11 +323,29 @@ class Database():
|
|||
self.enter_step('get_ip4_yield')
|
||||
yield entry
|
||||
|
||||
def get_ip4_in_network(self, address: str) -> typing.Iterable[int]:
|
||||
self.enter_step('get_ip4in_prepare')
|
||||
try:
|
||||
address_prep = self.pack_ip4address(address)
|
||||
except (ValueError, IndexError):
|
||||
self.log.error("Invalid ip4address: %s", address)
|
||||
return
|
||||
cursor = self.conn.cursor()
|
||||
self.enter_step('get_ip4in_select')
|
||||
cursor.execute(
|
||||
'SELECT entry FROM ip4network '
|
||||
'WHERE :a BETWEEN mini AND maxi ',
|
||||
{'a': address_prep}
|
||||
)
|
||||
for entry, in cursor:
|
||||
self.enter_step('get_ip4in_yield')
|
||||
yield entry
|
||||
|
||||
def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]:
|
||||
cursor = self.conn.cursor()
|
||||
self.enter_step('list_asn_select')
|
||||
self.enter_step('get_domain_select')
|
||||
self.execute('SELECT val, entry FROM asn')
|
||||
for val, entry in self.cursor:
|
||||
cursor.execute('SELECT val, entry FROM asn')
|
||||
for val, entry in cursor:
|
||||
yield f'AS{val}', entry
|
||||
|
||||
def _set_generic(self,
|
||||
|
@ -260,21 +361,23 @@ class Database():
|
|||
# here abstraction > performaces
|
||||
|
||||
# Fields based on the source
|
||||
self.enter_step(f'set_{table}_prepare')
|
||||
cursor = self.conn.cursor()
|
||||
if source is None:
|
||||
first_party = int(is_first_party)
|
||||
level = 0
|
||||
else:
|
||||
self.enter_step(f'set_{table}_source')
|
||||
self.execute(
|
||||
cursor.execute(
|
||||
'SELECT first_party, level FROM rules '
|
||||
'WHERE id=?',
|
||||
(source,)
|
||||
)
|
||||
first_party, level = self.cursor.fetchone()
|
||||
first_party, level = cursor.fetchone()
|
||||
level += 1
|
||||
|
||||
self.enter_step(f'set_{table}_select')
|
||||
self.execute(select_query, prep)
|
||||
cursor.execute(select_query, prep)
|
||||
|
||||
rules_prep: typing.Dict[str, DbValue] = {
|
||||
"source": source,
|
||||
|
@ -284,10 +387,10 @@ class Database():
|
|||
}
|
||||
|
||||
# If the entry already exists
|
||||
for entry, in self.cursor: # only one
|
||||
for entry, in cursor: # only one
|
||||
self.enter_step(f'set_{table}_update')
|
||||
rules_prep['entry'] = entry
|
||||
self.execute(
|
||||
cursor.execute(
|
||||
'UPDATE rules SET '
|
||||
'source=:source, updated=:updated, '
|
||||
'first_party=:first_party, level=:level '
|
||||
|
@ -303,23 +406,18 @@ class Database():
|
|||
|
||||
# If it does not exist
|
||||
|
||||
if source is not None:
|
||||
self.enter_step(f'set_{table}_incsrc')
|
||||
self.execute('UPDATE rules SET refs = refs + 1 WHERE id=?',
|
||||
(source,))
|
||||
|
||||
self.enter_step(f'set_{table}_insert')
|
||||
self.execute(
|
||||
cursor.execute(
|
||||
'INSERT INTO rules '
|
||||
'(source, updated, first_party, refs, level) '
|
||||
'VALUES (:source, :updated, :first_party, 0, :level) ',
|
||||
'(source, updated, first_party, level) '
|
||||
'VALUES (:source, :updated, :first_party, :level) ',
|
||||
rules_prep
|
||||
)
|
||||
self.execute('SELECT id FROM rules WHERE rowid=?',
|
||||
(self.cursor.lastrowid,))
|
||||
for entry, in self.cursor: # only one
|
||||
cursor.execute('SELECT id FROM rules WHERE rowid=?',
|
||||
(cursor.lastrowid,))
|
||||
for entry, in cursor: # only one
|
||||
prep['entry'] = entry
|
||||
self.execute(insert_query, prep)
|
||||
cursor.execute(insert_query, prep)
|
||||
return
|
||||
assert False
|
||||
|
||||
|
@ -327,7 +425,7 @@ class Database():
|
|||
*args: typing.Any, **kwargs: typing.Any) -> None:
|
||||
self.enter_step('set_hostname_prepare')
|
||||
prep: typing.Dict[str, DbValue] = {
|
||||
'val': self.prepare_hostname(hostname),
|
||||
'val': self.pack_hostname(hostname),
|
||||
}
|
||||
self._set_generic(
|
||||
'hostname',
|
||||
|
@ -342,7 +440,7 @@ class Database():
|
|||
*args: typing.Any, **kwargs: typing.Any) -> None:
|
||||
self.enter_step('set_asn_prepare')
|
||||
try:
|
||||
asn_prep = self.prepare_asn(asn)
|
||||
asn_prep = self.pack_asn(asn)
|
||||
except ValueError:
|
||||
self.log.error("Invalid asn: %s", asn)
|
||||
return
|
||||
|
@ -360,10 +458,9 @@ class Database():
|
|||
|
||||
def set_ip4address(self, ip4address: str,
|
||||
*args: typing.Any, **kwargs: typing.Any) -> None:
|
||||
# TODO Do not add if already in ip4network
|
||||
self.enter_step('set_ip4add_prepare')
|
||||
try:
|
||||
ip4address_prep = self.prepare_ip4address(ip4address)
|
||||
ip4address_prep = self.pack_ip4address(ip4address)
|
||||
except (ValueError, IndexError):
|
||||
self.log.error("Invalid ip4address: %s", ip4address)
|
||||
return
|
||||
|
@ -383,7 +480,7 @@ class Database():
|
|||
*args: typing.Any, **kwargs: typing.Any) -> None:
|
||||
self.enter_step('set_zone_prepare')
|
||||
prep: typing.Dict[str, DbValue] = {
|
||||
'val': self.prepare_zone(zone),
|
||||
'val': self.pack_zone(zone),
|
||||
}
|
||||
self._set_generic(
|
||||
'zone',
|
||||
|
@ -398,7 +495,7 @@ class Database():
|
|||
*args: typing.Any, **kwargs: typing.Any) -> None:
|
||||
self.enter_step('set_ip4net_prepare')
|
||||
try:
|
||||
ip4network_prep = self.prepare_ip4network(ip4network)
|
||||
ip4network_prep = self.pack_ip4network(ip4network)
|
||||
except (ValueError, IndexError):
|
||||
self.log.error("Invalid ip4network: %s", ip4network)
|
||||
return
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue