Workflow: SQL -> Tree

Welp. All that for this.
This commit is contained in:
Geoffrey Frogeye 2019-12-15 15:56:26 +01:00
parent 040ce4c14e
commit 4d966371b2
Signed by: geoffrey
GPG key ID: D8A7ECA00A8CD3DD
8 changed files with 296 additions and 699 deletions

3
.gitignore vendored
View file

@ -1,5 +1,4 @@
*.log *.log
*.db *.p
*.db-journal
nameservers nameservers
nameservers.head nameservers.head

718
database.py Executable file → Normal file
View file

@ -4,111 +4,59 @@
Utility functions to interact with the database. Utility functions to interact with the database.
""" """
import sqlite3
import typing import typing
import time import time
import os
import logging import logging
import argparse
import coloredlogs import coloredlogs
import ipaddress import pickle
import math import enum
coloredlogs.install( coloredlogs.install(
level='DEBUG', level='DEBUG',
fmt='%(asctime)s %(name)s %(levelname)s %(message)s' fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
) )
DbValue = typing.Union[None, int, float, str, bytes] PathType = enum.Enum('PathType', 'Rule Hostname Zone Asn Ip4 Ip6')
RulePath = typing.Union[None]
Asn = int
DomainPath = typing.List[str]
Ip4Path = typing.List[int]
Ip6Path = typing.List[int]
Path = typing.Union[RulePath, DomainPath, Asn, Ip4Path, Ip6Path]
TypedPath = typing.Tuple[PathType, Path]
Timestamp = int
Level = int
Match = typing.Tuple[Timestamp, TypedPath, Level]
DebugPath = (PathType.Rule, None)
class Database(): class DomainTreeNode():
VERSION = 5 def __init__(self) -> None:
PATH = "blocking.db" self.children: typing.Dict[str, DomainTreeNode] = dict()
self.match_zone: typing.Optional[Match] = None
self.match_hostname: typing.Optional[Match] = None
def open(self) -> None:
mode = 'rwc' if self.write else 'ro'
uri = f'file:{self.PATH}?mode={mode}'
self.conn = sqlite3.connect(uri, uri=True)
cursor = self.conn.cursor()
cursor.execute("PRAGMA foreign_keys = ON")
self.conn.create_function("unpack_asn", 1,
self.unpack_asn,
deterministic=True)
self.conn.create_function("unpack_ip4address", 1,
self.unpack_ip4address,
deterministic=True)
self.conn.create_function("unpack_ip4network", 2,
self.unpack_ip4network,
deterministic=True)
self.conn.create_function("unpack_domain", 1,
lambda s: s[:-1][::-1],
deterministic=True)
self.conn.create_function("format_zone", 1,
lambda s: '*' + s[::-1],
deterministic=True)
def get_meta(self, key: str) -> typing.Optional[int]: class IpTreeNode():
cursor = self.conn.cursor() def __init__(self) -> None:
try: self.children: typing.List[typing.Optional[IpTreeNode]] = [None, None]
cursor.execute("SELECT value FROM meta WHERE key=?", (key,)) self.match: typing.Optional[Match] = None
except sqlite3.OperationalError:
return None
for ver, in cursor:
return ver
return None
def set_meta(self, key: str, val: int) -> None:
cursor = self.conn.cursor()
cursor.execute("INSERT INTO meta VALUES (?, ?) "
"ON CONFLICT (key) DO "
"UPDATE set value=?",
(key, val, val))
def close(self) -> None: class Profiler():
self.enter_step('close_commit') def __init__(self) -> None:
self.conn.commit() self.log = logging.getLogger('profiler')
self.enter_step('close')
self.conn.close()
self.profile()
def initialize(self) -> None:
self.close()
self.enter_step('initialize')
if not self.write:
self.log.error("Cannot initialize in read-only mode.")
raise
os.unlink(self.PATH)
self.open()
self.log.info("Creating database version %d.", self.VERSION)
cursor = self.conn.cursor()
with open("database_schema.sql", 'r') as db_schema:
cursor.executescript(db_schema.read())
self.set_meta('version', self.VERSION)
self.conn.commit()
def __init__(self, write: bool = False) -> None:
self.log = logging.getLogger('db')
self.time_last = time.perf_counter() self.time_last = time.perf_counter()
self.time_step = 'init' self.time_step = 'init'
self.time_dict: typing.Dict[str, float] = dict() self.time_dict: typing.Dict[str, float] = dict()
self.step_dict: typing.Dict[str, int] = dict() self.step_dict: typing.Dict[str, int] = dict()
self.write = write
self.open()
version = self.get_meta('version')
if version != self.VERSION:
if version is not None:
self.log.warning(
"Outdated database version: %d found, will be rebuilt.",
version)
self.initialize()
def enter_step(self, name: str) -> None: def enter_step(self, name: str) -> None:
now = time.perf_counter() now = time.perf_counter()
try: try:
self.time_dict[self.time_step] += now - self.time_last self.time_dict[self.time_step] += now - self.time_last
self.step_dict[self.time_step] += 1 self.step_dict[self.time_step] += int(name != self.time_step)
except KeyError: except KeyError:
self.time_dict[self.time_step] = now - self.time_last self.time_dict[self.time_step] = now - self.time_last
self.step_dict[self.time_step] = 1 self.step_dict[self.time_step] = 1
@ -125,13 +73,58 @@ class Database():
self.log.debug(f"{'total':<20}: " self.log.debug(f"{'total':<20}: "
f"{total:9.2f} s ({1:7.2%})") f"{total:9.2f} s ({1:7.2%})")
@staticmethod
def pack_hostname(hostname: str) -> str: class Database(Profiler):
return hostname[::-1] + '.' VERSION = 8
PATH = "blocking.p"
def initialize(self) -> None:
self.log.warning(
"Creating database version: %d ",
Database.VERSION)
self.domtree = DomainTreeNode()
self.asns: typing.Set[Asn] = set()
self.ip4tree = IpTreeNode()
def load(self) -> None:
self.enter_step('load')
try:
with open(self.PATH, 'rb') as db_fdsec:
version, data = pickle.load(db_fdsec)
if version == Database.VERSION:
self.domtree, self.asns, self.ip4tree = data
return
self.log.warning(
"Outdated database version found: %d, "
"will be rebuilt.",
version)
except (TypeError, AttributeError, EOFError):
self.log.error(
"Corrupt database found, "
"will be rebuilt.")
except FileNotFoundError:
pass
self.initialize()
def save(self) -> None:
self.enter_step('save')
with open(self.PATH, 'wb') as db_fdsec:
data = self.domtree, self.asns, self.ip4tree
pickle.dump((self.VERSION, data), db_fdsec)
self.profile()
def __init__(self) -> None:
Profiler.__init__(self)
self.log = logging.getLogger('db')
self.load()
@staticmethod @staticmethod
def pack_zone(zone: str) -> str: def pack_domain(domain: str) -> DomainPath:
return Database.pack_hostname(zone) return domain.split('.')[::-1]
@staticmethod
def unpack_domain(domain: DomainPath) -> str:
return '.'.join(domain[::-1])
@staticmethod @staticmethod
def pack_asn(asn: str) -> int: def pack_asn(asn: str) -> int:
@ -145,431 +138,208 @@ class Database():
return f'AS{asn}' return f'AS{asn}'
@staticmethod @staticmethod
def pack_ip4address(address: str) -> int: def pack_ip4address(address: str) -> Ip4Path:
total = 0 addr: Ip4Path = [0] * 32
for i, octet in enumerate(address.split('.')): octets = [int(octet) for octet in address.split('.')]
total += int(octet) << (3-i)*8 for b in range(32):
if total > 0xFFFFFFFF: if (octets[b//8] >> b % 8) & 0b1:
raise ValueError addr[b] = 1
return total return addr
# return '{:02x}{:02x}{:02x}{:02x}'.format(
# *[int(c) for c in address.split('.')])
# return base64.b16encode(packed).decode()
# return '{:08b}{:08b}{:08b}{:08b}'.format(
# *[int(c) for c in address.split('.')])
# carg = ctypes.c_wchar_p(address)
# ret = ACCEL.ip4_flat(carg, self.accel_ip4_buf)
# if ret != 0:
# raise ValueError
# return self.accel_ip4_buf.value
# packed = ipaddress.ip_address(address).packed
# return packed
@staticmethod @staticmethod
def unpack_ip4address(address: int) -> str: def unpack_ip4address(address: Ip4Path) -> str:
return '.'.join(str((address >> (i * 8)) & 0xFF) octets = [0] * 4
for i in reversed(range(4))) for b, bit in enumerate(address):
octets[b//8] = (octets[b//8] << 1) + bit
return '.'.join(map(str, octets))
@staticmethod @staticmethod
def pack_ip4network(network: str) -> typing.Tuple[int, int]: def pack_ip4network(network: str) -> Ip4Path:
# def pack_ip4network(network: str) -> str: address, prefixlen_str = network.split('/')
net = ipaddress.ip_network(network) prefixlen = int(prefixlen_str)
mini = Database.pack_ip4address(net.network_address.exploded) return Database.pack_ip4address(address)[:prefixlen]
maxi = Database.pack_ip4address(net.broadcast_address.exploded)
# mini = net.network_address.packed
# maxi = net.broadcast_address.packed
return mini, maxi
# return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen]
@staticmethod @staticmethod
def unpack_ip4network(mini: int, maxi: int) -> str: def unpack_ip4network(network: Ip4Path) -> str:
addr = Database.unpack_ip4address(mini) address = network.copy()
prefixlen = 32-int(math.log2(maxi-mini+1)) prefixlen = len(network)
for _ in range(32-prefixlen):
address.append(0)
addr = Database.unpack_ip4address(address)
return f'{addr}/{prefixlen}' return f'{addr}/{prefixlen}'
def update_references(self) -> None: def update_references(self) -> None:
self.enter_step('update_refs') raise NotImplementedError
cursor = self.conn.cursor()
cursor.execute('UPDATE rules AS r SET refs='
'(SELECT count(*) FROM rules '
'WHERE source=r.id)')
def prune(self, before: int, base_only: bool = False) -> None: def prune(self, before: int, base_only: bool = False) -> None:
self.enter_step('prune') raise NotImplementedError
cursor = self.conn.cursor()
cmd = 'DELETE FROM rules WHERE updated<?'
if base_only:
cmd += ' AND level=0'
cursor.execute(cmd, (before,))
def explain(self, entry: int) -> str: def explain(self, entry: int) -> str:
# Format current raise NotImplementedError
string = '???'
cursor = self.conn.cursor()
cursor.execute(
'SELECT unpack_asn(val) FROM asn WHERE entry=:entry '
'UNION '
'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry '
'UNION '
'SELECT format_zone(val) FROM zone WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4network(mini, maxi) '
'FROM ip4network WHERE entry=:entry ',
{"entry": entry}
)
for val, in cursor: # only one
string = str(val)
string += f' #{entry}'
# Add source if any
cursor.execute('SELECT source FROM rules WHERE id=?', (entry,))
for source, in cursor:
if source:
string += f'{self.explain(source)}'
return string
def export(self, def export(self,
first_party_only: bool = False, first_party_only: bool = False,
end_chain_only: bool = False, end_chain_only: bool = False,
explain: bool = False, explain: bool = False,
_dic: DomainTreeNode = None,
_par: DomainPath = None,
) -> typing.Iterable[str]: ) -> typing.Iterable[str]:
selection = 'entry' if explain else 'unpack_domain(val)' if first_party_only or end_chain_only or explain:
command = f'SELECT {selection} FROM rules ' \ raise NotImplementedError
'INNER JOIN hostname ON rules.id = hostname.entry' _dic = _dic or self.domtree
restrictions: typing.List[str] = list() _par = _par or list()
if first_party_only: if _dic.match_hostname:
restrictions.append('rules.first_party = 1') yield self.unpack_domain(_par)
if end_chain_only: for part in _dic.children:
restrictions.append('rules.refs = 0') dic = _dic.children[part]
if restrictions: yield from self.export(_dic=dic,
command += ' WHERE ' + ' AND '.join(restrictions) _par=_par + [part])
if not explain:
command += ' ORDER BY unpack_domain(val) ASC'
cursor = self.conn.cursor()
cursor.execute(command)
for val, in cursor:
if explain:
yield self.explain(val)
else:
yield val
def count_rules(self, def count_rules(self,
first_party_only: bool = False, first_party_only: bool = False,
) -> str: ) -> str:
counts: typing.List[str] = list() raise NotImplementedError
cursor = self.conn.cursor()
for table in ['asn', 'ip4network', 'ip4address', 'zone', 'hostname']:
command = f'SELECT count(*) FROM rules ' \
f'INNER JOIN {table} ON rules.id = {table}.entry ' \
'WHERE rules.level = 0'
if first_party_only:
command += ' AND first_party=1'
cursor.execute(command)
count, = cursor.fetchone()
if count > 0:
counts.append(f'{table}: {count}')
return ', '.join(counts) def get_domain(self, domain_str: str) -> typing.Iterable[TypedPath]:
self.enter_step('get_domain_pack')
def get_domain(self, domain: str) -> typing.Iterable[int]: domain = self.pack_domain(domain_str)
self.enter_step('get_domain_prepare') self.enter_step('get_domain_brws')
domain_prep = self.pack_hostname(domain) dic = self.domtree
cursor = self.conn.cursor() depth = 0
self.enter_step('get_domain_select') for part in domain:
cursor.execute( if dic.match_zone:
'SELECT null, entry FROM hostname '
'WHERE val=:d '
'UNION '
'SELECT * FROM ('
'SELECT val, entry FROM zone '
# 'WHERE val>=:d '
# 'ORDER BY val ASC LIMIT 1'
'WHERE val<=:d '
'AND instr(:d, val) = 1'
')',
{'d': domain_prep}
)
for val, entry in cursor:
# print(293, val, entry)
self.enter_step('get_domain_confirm')
if not (val is None or domain_prep.startswith(val)):
# print(297)
continue
self.enter_step('get_domain_yield') self.enter_step('get_domain_yield')
yield entry yield (PathType.Zone, domain[:depth])
self.enter_step('get_domain_brws')
def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]: if part not in dic.children:
self.enter_step('get_domainiz_prepare')
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domainiz_select')
cursor.execute(
'SELECT val, entry FROM zone '
'WHERE val<=:d '
'ORDER BY val DESC LIMIT 1',
{'d': domain_prep}
)
for val, entry in cursor:
self.enter_step('get_domainiz_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
self.enter_step('get_domainiz_yield')
yield entry
def get_ip4(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4_prepare')
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return return
cursor = self.conn.cursor() dic = dic.children[part]
self.enter_step('get_ip4_select') depth += 1
cursor.execute( if dic.match_zone:
'SELECT entry FROM ip4address ' self.enter_step('get_domain_yield')
# 'SELECT null, entry FROM ip4address ' yield (PathType.Zone, domain)
'WHERE val=:a ' if dic.match_hostname:
'UNION ' self.enter_step('get_domain_yield')
# 'SELECT * FROM (' yield (PathType.Hostname, domain)
# 'SELECT val, entry FROM ip4network '
# 'WHERE val<=:a ' def get_ip4(self, ip4_str: str) -> typing.Iterable[TypedPath]:
# 'AND instr(:a, val) > 0 ' self.enter_step('get_ip4_pack')
# 'ORDER BY val DESC' ip4 = self.pack_ip4address(ip4_str)
# ')' self.enter_step('get_ip4_brws')
'SELECT entry FROM ip4network ' dic = self.ip4tree
'WHERE :a BETWEEN mini AND maxi ', depth = 0
{'a': address_prep} for part in ip4:
) if dic.match:
for entry, in cursor:
# self.enter_step('get_ip4_confirm')
# if not (val is None or val.startswith(address_prep)):
# # PERF startswith but from the end
# continue
self.enter_step('get_ip4_yield') self.enter_step('get_ip4_yield')
yield entry yield (PathType.Ip4, ip4[:depth])
self.enter_step('get_ip4_brws')
def get_ip4_in_network(self, address: str) -> typing.Iterable[int]: next_dic = dic.children[part]
self.enter_step('get_ip4in_prepare') if next_dic is None:
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return return
cursor = self.conn.cursor() dic = next_dic
self.enter_step('get_ip4in_select') depth += 1
cursor.execute( if dic.match:
'SELECT entry FROM ip4network ' self.enter_step('get_ip4_yield')
'WHERE :a BETWEEN mini AND maxi ', yield (PathType.Ip4, ip4)
{'a': address_prep}
)
for entry, in cursor:
self.enter_step('get_ip4in_yield')
yield entry
def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]: def list_asn(self) -> typing.Iterable[TypedPath]:
cursor = self.conn.cursor() for asn in self.asns:
self.enter_step('list_asn_select') yield (PathType.Asn, asn)
cursor.execute('SELECT val, entry FROM asn')
for val, entry in cursor:
yield f'AS{val}', entry
def _set_generic(self, def set_hostname(self,
table: str, hostname_str: str,
select_query: str,
insert_query: str,
prep: typing.Dict[str, DbValue],
updated: int, updated: int,
is_first_party: bool = False, is_first_party: bool = None,
source: int = None, source: TypedPath = None) -> None:
) -> None: self.enter_step('set_hostname_pack')
# Since this isn't the bulk of the processing, if is_first_party or source:
# here abstraction > performaces raise NotImplementedError
self.enter_step('set_hostname_brws')
# Fields based on the source hostname = self.pack_domain(hostname_str)
self.enter_step(f'set_{table}_prepare') dic = self.domtree
cursor = self.conn.cursor() for part in hostname:
if source is None: if dic.match_zone:
first_party = int(is_first_party) # Refuse to add hostname whose zone is already matching
level = 0
else:
self.enter_step(f'set_{table}_source')
cursor.execute(
'SELECT first_party, level FROM rules '
'WHERE id=?',
(source,)
)
first_party, level = cursor.fetchone()
level += 1
self.enter_step(f'set_{table}_select')
cursor.execute(select_query, prep)
rules_prep: typing.Dict[str, DbValue] = {
"source": source,
"updated": updated,
"first_party": first_party,
"level": level,
}
# If the entry already exists
for entry, in cursor: # only one
self.enter_step(f'set_{table}_update')
rules_prep['entry'] = entry
cursor.execute(
'UPDATE rules SET '
'source=:source, updated=:updated, '
'first_party=:first_party, level=:level '
'WHERE id=:entry AND (updated<:updated OR '
'first_party<:first_party OR level<:level)',
rules_prep
)
# Only update if any of the following:
# - the entry is outdataed
# - the entry was not a first_party but this is
# - this is closer to the original rule
return return
if part not in dic.children:
dic.children[part] = DomainTreeNode()
dic = dic.children[part]
dic.match_hostname = (updated, DebugPath, 0)
# If it does not exist def set_zone(self,
zone_str: str,
self.enter_step(f'set_{table}_insert') updated: int,
cursor.execute( is_first_party: bool = None,
'INSERT INTO rules ' source: TypedPath = None) -> None:
'(source, updated, first_party, level) ' self.enter_step('set_zone_pack')
'VALUES (:source, :updated, :first_party, :level) ', if is_first_party or source:
rules_prep raise NotImplementedError
) zone = self.pack_domain(zone_str)
cursor.execute('SELECT id FROM rules WHERE rowid=?', self.enter_step('set_zone_brws')
(cursor.lastrowid,)) dic = self.domtree
for entry, in cursor: # only one for part in zone:
prep['entry'] = entry if dic.match_zone:
cursor.execute(insert_query, prep) # Refuse to add zone whose parent zone is already matching
return return
assert False if part not in dic.children:
dic.children[part] = DomainTreeNode()
dic = dic.children[part]
dic.match_zone = (updated, DebugPath, 0)
def set_hostname(self, hostname: str, def set_asn(self,
*args: typing.Any, **kwargs: typing.Any) -> None: asn_str: str,
self.enter_step('set_hostname_prepare') updated: int,
prep: typing.Dict[str, DbValue] = { is_first_party: bool = None,
'val': self.pack_hostname(hostname), source: TypedPath = None) -> None:
} self.enter_step('set_asn_pack')
self._set_generic( if is_first_party or source:
'hostname', # TODO updated
'SELECT entry FROM hostname WHERE val=:val', raise NotImplementedError
'INSERT INTO hostname (val, entry) ' asn = self.pack_asn(asn_str)
'VALUES (:val, :entry)', self.enter_step('set_asn_brws')
prep, self.asns.add(asn)
*args, **kwargs
)
def set_asn(self, asn: str, def set_ip4address(self,
*args: typing.Any, **kwargs: typing.Any) -> None: ip4address_str: str,
self.enter_step('set_asn_prepare') updated: int,
try: is_first_party: bool = None,
asn_prep = self.pack_asn(asn) source: TypedPath = None) -> None:
except ValueError: self.enter_step('set_ip4add_pack')
self.log.error("Invalid asn: %s", asn) if is_first_party or source:
raise NotImplementedError
self.enter_step('set_ip4add_brws')
ip4address = self.pack_ip4address(ip4address_str)
dic = self.ip4tree
for part in ip4address:
if dic.match:
# Refuse to add ip4address whose network is already matching
return return
prep: typing.Dict[str, DbValue] = { next_dic = dic.children[part]
'val': asn_prep, if next_dic is None:
} next_dic = IpTreeNode()
self._set_generic( dic.children[part] = next_dic
'asn', dic = next_dic
'SELECT entry FROM asn WHERE val=:val', dic.match = (updated, DebugPath, 0)
'INSERT INTO asn (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_ip4address(self, ip4address: str, def set_ip4network(self,
*args: typing.Any, **kwargs: typing.Any) -> None: ip4network_str: str,
self.enter_step('set_ip4add_prepare') updated: int,
try: is_first_party: bool = None,
ip4address_prep = self.pack_ip4address(ip4address) source: TypedPath = None) -> None:
except (ValueError, IndexError): self.enter_step('set_ip4net_pack')
self.log.error("Invalid ip4address: %s", ip4address) if is_first_party or source:
raise NotImplementedError
self.enter_step('set_ip4net_brws')
ip4network = self.pack_ip4network(ip4network_str)
dic = self.ip4tree
for part in ip4network:
if dic.match:
# Refuse to add ip4network whose parent network
# is already matching
return return
prep: typing.Dict[str, DbValue] = { next_dic = dic.children[part]
'val': ip4address_prep, if next_dic is None:
} next_dic = IpTreeNode()
self._set_generic( dic.children[part] = next_dic
'ip4add', dic = next_dic
'SELECT entry FROM ip4address WHERE val=:val', dic.match = (updated, DebugPath, 0)
'INSERT INTO ip4address (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_zone(self, zone: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_zone_prepare')
prep: typing.Dict[str, DbValue] = {
'val': self.pack_zone(zone),
}
self._set_generic(
'zone',
'SELECT entry FROM zone WHERE val=:val',
'INSERT INTO zone (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_ip4network(self, ip4network: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_ip4net_prepare')
try:
ip4network_prep = self.pack_ip4network(ip4network)
except (ValueError, IndexError):
self.log.error("Invalid ip4network: %s", ip4network)
return
prep: typing.Dict[str, DbValue] = {
'mini': ip4network_prep[0],
'maxi': ip4network_prep[1],
}
self._set_generic(
'ip4net',
'SELECT entry FROM ip4network WHERE mini=:mini AND maxi=:maxi',
'INSERT INTO ip4network (mini, maxi, entry) '
'VALUES (:mini, :maxi, :entry)',
prep,
*args, **kwargs
)
if __name__ == '__main__':
# Parsing arguments
parser = argparse.ArgumentParser(
description="Database operations")
parser.add_argument(
'-i', '--initialize', action='store_true',
help="Reconstruct the whole database")
parser.add_argument(
'-p', '--prune', action='store_true',
help="Remove old entries from database")
parser.add_argument(
'-b', '--prune-base', action='store_true',
help="TODO")
parser.add_argument(
'-s', '--prune-before', type=int,
default=(int(time.time()) - 60*60*24*31*6),
help="TODO")
parser.add_argument(
'-r', '--references', action='store_true',
help="Update the reference count")
args = parser.parse_args()
DB = Database(write=True)
if args.initialize:
DB.initialize()
if args.prune:
DB.prune(before=args.prune_before, base_only=args.prune_base)
if args.references:
DB.update_references()
DB.close()

View file

@ -1,59 +0,0 @@
-- Remember to increment DB_VERSION
-- in database.py on changes to this file
CREATE TABLE rules (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source INTEGER, -- The rule this one is based on
updated INTEGER, -- If the row was updated during last data import (0: No, 1: Yes)
first_party INTEGER, -- 1: this blocks a first party for sure, 0: maybe
refs INTEGER, -- Number of entries issued from this one
level INTEGER, -- Level of recursion to the root source rule (used for source priority)
FOREIGN KEY (source) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX rules_source ON rules (source); -- for references recounting
CREATE INDEX rules_updated ON rules (updated); -- for pruning
CREATE INDEX rules_level_firstparty ON rules (level, first_party); -- for counting rules
CREATE TABLE asn (
val INTEGER PRIMARY KEY,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX asn_entry ON asn (entry); -- for explainations
CREATE TABLE hostname (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for consistency with zone)
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX hostname_entry ON hostname (entry); -- for explainations
CREATE TABLE zone (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for easier matching)
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX zone_entry ON zone (entry); -- for explainations
CREATE TABLE ip4address (
val INTEGER PRIMARY KEY,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX ip4address_entry ON ip4address (entry); -- for explainations
CREATE TABLE ip4network (
-- val TEXT PRIMARY KEY,
mini INTEGER,
maxi INTEGER,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX ip4network_minmax ON ip4network (mini, maxi);
CREATE INDEX ip4network_entry ON ip4network (entry); -- for explainations
-- Store various things
CREATE TABLE meta (
key TEXT PRIMARY KEY,
value integer
);

View file

@ -45,5 +45,3 @@ if __name__ == '__main__':
explain=args.explain, explain=args.explain,
): ):
print(domain, file=args.output) print(domain, file=args.output)
DB.close()

View file

@ -31,23 +31,25 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
DB = database.Database() DB = database.Database()
DBW = database.Database(write=True)
for asn, entry in DB.list_asn(): for path in DB.list_asn():
ptype, asn = path
assert ptype == database.PathType.Asn
assert isinstance(asn, int)
asn_str = database.Database.unpack_asn(asn)
DB.enter_step('asn_get_ranges') DB.enter_step('asn_get_ranges')
for prefix in get_ranges(asn): for prefix in get_ranges(asn_str):
parsed_prefix: IPNetwork = ipaddress.ip_network(prefix) parsed_prefix: IPNetwork = ipaddress.ip_network(prefix)
if parsed_prefix.version == 4: if parsed_prefix.version == 4:
DBW.set_ip4network( DB.set_ip4network(
prefix, prefix,
source=entry, # source=path,
updated=int(time.time()) updated=int(time.time())
) )
log.info('Added %s from %s (id=%s)', prefix, asn, entry) log.info('Added %s from %s (source=%s)', prefix, asn, path)
elif parsed_prefix.version == 6: elif parsed_prefix.version == 6:
log.warning('Unimplemented prefix version: %s', prefix) log.warning('Unimplemented prefix version: %s', prefix)
else: else:
log.error('Unknown prefix version: %s', prefix) log.error('Unknown prefix version: %s', prefix)
DB.close() DB.save()
DBW.close()

View file

@ -6,126 +6,52 @@ import json
import logging import logging
import sys import sys
import typing import typing
import multiprocessing
import enum import enum
RecordType = enum.Enum('RecordType', 'A AAAA CNAME PTR') RecordType = enum.Enum('RecordType', 'A AAAA CNAME PTR')
Record = typing.Tuple[RecordType, int, str, str] Record = typing.Tuple[RecordType, int, str, str]
# select, confirm, write # select, write
FUNCTION_MAP: typing.Any = { FUNCTION_MAP: typing.Any = {
RecordType.A: ( RecordType.A: (
database.Database.get_ip4, database.Database.get_ip4,
database.Database.get_domain_in_zone,
database.Database.set_hostname, database.Database.set_hostname,
), ),
RecordType.CNAME: ( RecordType.CNAME: (
database.Database.get_domain, database.Database.get_domain,
database.Database.get_domain_in_zone,
database.Database.set_hostname, database.Database.set_hostname,
), ),
RecordType.PTR: ( RecordType.PTR: (
database.Database.get_domain, database.Database.get_domain,
database.Database.get_ip4_in_network,
database.Database.set_ip4address, database.Database.set_ip4address,
), ),
} }
class Reader(multiprocessing.Process):
def __init__(self,
recs_queue: multiprocessing.Queue,
write_queue: multiprocessing.Queue,
index: int = 0):
super(Reader, self).__init__()
self.log = logging.getLogger(f'rd{index:03d}')
self.recs_queue = recs_queue
self.write_queue = write_queue
self.index = index
def run(self) -> None:
self.db = database.Database(write=False)
self.db.log = logging.getLogger(f'db{self.index:03d}')
self.db.enter_step('line_wait')
block: typing.List[str]
try:
for block in iter(self.recs_queue.get, None):
record: Record
for record in block:
# print(55, record)
dtype, updated, name, value = record
self.db.enter_step('feed_switch')
select, confirm, write = FUNCTION_MAP[dtype]
for rule in select(self.db, value):
# print(60, rule, list(confirm(self.db, name)))
if not any(confirm(self.db, name)):
# print(62, write, name, updated, rule)
self.db.enter_step('wait_put')
self.write_queue.put((write, name, updated, rule))
self.db.enter_step('line_wait')
except KeyboardInterrupt:
self.log.error('Interrupted')
self.db.enter_step('end')
self.db.close()
class Writer(multiprocessing.Process):
def __init__(self,
write_queue: multiprocessing.Queue,
):
super(Writer, self).__init__()
self.log = logging.getLogger(f'wr ')
self.write_queue = write_queue
def run(self) -> None:
self.db = database.Database(write=True)
self.db.log = logging.getLogger(f'dbw ')
self.db.enter_step('line_wait')
block: typing.List[str]
try:
fun: typing.Callable
name: str
updated: int
source: int
for fun, name, updated, source in iter(self.write_queue.get, None):
self.db.enter_step('exec')
fun(self.db, name, updated, source=source)
self.db.enter_step('line_wait')
except KeyboardInterrupt:
self.log.error('Interrupted')
self.db.enter_step('end')
self.db.close()
class Parser(): class Parser():
def __init__(self, def __init__(self, buf: typing.Any) -> None:
buf: typing.Any,
recs_queue: multiprocessing.Queue,
block_size: int,
):
super(Parser, self).__init__()
self.buf = buf self.buf = buf
self.log = logging.getLogger('pr ') self.log = logging.getLogger('parser')
self.recs_queue = recs_queue self.db = database.Database()
self.block: typing.List[Record] = list()
self.block_size = block_size def end(self) -> None:
self.db = database.Database() # Just for timing self.db.save()
self.db.log = logging.getLogger('pr ')
def register(self,
rtype: RecordType,
updated: int,
name: str,
value: str
) -> None:
def register(self, record: Record) -> None:
self.db.enter_step('register') self.db.enter_step('register')
self.block.append(record) select, write = FUNCTION_MAP[rtype]
if len(self.block) >= self.block_size: try:
self.db.enter_step('put_block') for source in select(self.db, value):
self.recs_queue.put(self.block) # write(self.db, name, updated, source=source)
self.block = list() write(self.db, name, updated)
except NotImplementedError:
def run(self) -> None: return # DEBUG
self.consume()
self.recs_queue.put(self.block)
self.db.close()
def consume(self) -> None: def consume(self) -> None:
raise NotImplementedError raise NotImplementedError
@ -146,13 +72,12 @@ class Rapid7Parser(Parser):
data = json.loads(line) data = json.loads(line)
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
continue continue
record = ( self.register(
Rapid7Parser.TYPES[data['type']], Rapid7Parser.TYPES[data['type']],
int(data['timestamp']), int(data['timestamp']),
data['name'], data['name'],
data['value'] data['value']
) )
self.register(record)
class DnsMassParser(Parser): class DnsMassParser(Parser):
@ -182,13 +107,12 @@ class DnsMassParser(Parser):
else: else:
dtype, name_offset, value_offset = \ dtype, name_offset, value_offset = \
DnsMassParser.TYPES[split[1]] DnsMassParser.TYPES[split[1]]
record = ( self.register(
dtype, dtype,
timestamp, timestamp,
split[0][:name_offset], split[0][:name_offset],
split[2][:value_offset], split[2][:value_offset],
) )
self.register(record)
self.db.enter_step('parse_dnsmass') self.db.enter_step('parse_dnsmass')
except KeyError: except KeyError:
continue continue
@ -212,49 +136,12 @@ if __name__ == '__main__':
args_parser.add_argument( args_parser.add_argument(
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin, '-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO") help="TODO")
args_parser.add_argument(
'-j', '--workers', type=int, default=4,
help="TODO")
args_parser.add_argument(
'-b', '--block-size', type=int, default=100,
help="TODO")
args = args_parser.parse_args() args = args_parser.parse_args()
DB = database.Database(write=False) # Not needed, just for timing parser = PARSERS[args.parser](args.input)
DB.log = logging.getLogger('db ')
recs_queue: multiprocessing.Queue = multiprocessing.Queue(
maxsize=10*args.workers)
write_queue: multiprocessing.Queue = multiprocessing.Queue(
maxsize=10*args.workers)
DB.enter_step('proc_create')
readers: typing.List[Reader] = list()
for w in range(args.workers):
readers.append(Reader(recs_queue, write_queue, w))
writer = Writer(write_queue)
parser = PARSERS[args.parser](
args.input, recs_queue, args.block_size)
DB.enter_step('proc_start')
for reader in readers:
reader.start()
writer.start()
try: try:
DB.enter_step('parser_run') parser.consume()
parser.run()
DB.enter_step('end_put')
for _ in range(args.workers):
recs_queue.put(None)
write_queue.put(None)
DB.enter_step('proc_join')
for reader in readers:
reader.join()
writer.join()
except KeyboardInterrupt: except KeyboardInterrupt:
log.error('Interrupted') pass
parser.end()
DB.close()

View file

@ -28,15 +28,15 @@ if __name__ == '__main__':
help="The input only comes from verified first-party sources") help="The input only comes from verified first-party sources")
args = parser.parse_args() args = parser.parse_args()
DB = database.Database(write=True) DB = database.Database()
fun = FUNCTION_MAP[args.type] fun = FUNCTION_MAP[args.type]
for rule in args.input: for rule in args.input:
fun(DB, fun(DB,
rule.strip(), rule.strip(),
is_first_party=args.first_party, # is_first_party=args.first_party,
updated=int(time.time()), updated=int(time.time()),
) )
DB.close() DB.save()

View file

@ -6,11 +6,11 @@ function log() {
log "Importing rules…" log "Importing rules…"
BEFORE="$(date +%s)" BEFORE="$(date +%s)"
cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone # cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone
cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone # cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone
cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone # cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone
cat rules_ip/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network # cat rules_ip/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network
cat rules_asn/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py asn # cat rules_asn/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py asn
cat rules/first-party.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone --first-party cat rules/first-party.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone --first-party
cat rules_ip/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network --first-party cat rules_ip/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network --first-party
@ -19,4 +19,4 @@ cat rules_asn/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py as
./feed_asn.py ./feed_asn.py
log "Pruning old rules…" log "Pruning old rules…"
./database.py --prune --prune-before "$BEFORE" --prune-base ./db.py --prune --prune-before "$BEFORE" --prune-base