Workflow: SQL -> Tree

Welp. All that for this.
This commit is contained in:
Geoffrey Frogeye 2019-12-15 15:56:26 +01:00
parent 040ce4c14e
commit 4d966371b2
Signed by: geoffrey
GPG Key ID: D8A7ECA00A8CD3DD
8 changed files with 296 additions and 699 deletions

3
.gitignore vendored
View File

@ -1,5 +1,4 @@
*.log
*.db
*.db-journal
*.p
nameservers
nameservers.head

718
database.py Executable file → Normal file
View File

@ -4,111 +4,59 @@
Utility functions to interact with the database.
"""
import sqlite3
import typing
import time
import os
import logging
import argparse
import coloredlogs
import ipaddress
import math
import pickle
import enum
coloredlogs.install(
level='DEBUG',
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
)
DbValue = typing.Union[None, int, float, str, bytes]
PathType = enum.Enum('PathType', 'Rule Hostname Zone Asn Ip4 Ip6')
RulePath = typing.Union[None]
Asn = int
DomainPath = typing.List[str]
Ip4Path = typing.List[int]
Ip6Path = typing.List[int]
Path = typing.Union[RulePath, DomainPath, Asn, Ip4Path, Ip6Path]
TypedPath = typing.Tuple[PathType, Path]
Timestamp = int
Level = int
Match = typing.Tuple[Timestamp, TypedPath, Level]
DebugPath = (PathType.Rule, None)
class Database():
VERSION = 5
PATH = "blocking.db"
class DomainTreeNode():
def __init__(self) -> None:
self.children: typing.Dict[str, DomainTreeNode] = dict()
self.match_zone: typing.Optional[Match] = None
self.match_hostname: typing.Optional[Match] = None
def open(self) -> None:
mode = 'rwc' if self.write else 'ro'
uri = f'file:{self.PATH}?mode={mode}'
self.conn = sqlite3.connect(uri, uri=True)
cursor = self.conn.cursor()
cursor.execute("PRAGMA foreign_keys = ON")
self.conn.create_function("unpack_asn", 1,
self.unpack_asn,
deterministic=True)
self.conn.create_function("unpack_ip4address", 1,
self.unpack_ip4address,
deterministic=True)
self.conn.create_function("unpack_ip4network", 2,
self.unpack_ip4network,
deterministic=True)
self.conn.create_function("unpack_domain", 1,
lambda s: s[:-1][::-1],
deterministic=True)
self.conn.create_function("format_zone", 1,
lambda s: '*' + s[::-1],
deterministic=True)
def get_meta(self, key: str) -> typing.Optional[int]:
cursor = self.conn.cursor()
try:
cursor.execute("SELECT value FROM meta WHERE key=?", (key,))
except sqlite3.OperationalError:
return None
for ver, in cursor:
return ver
return None
class IpTreeNode():
def __init__(self) -> None:
self.children: typing.List[typing.Optional[IpTreeNode]] = [None, None]
self.match: typing.Optional[Match] = None
def set_meta(self, key: str, val: int) -> None:
cursor = self.conn.cursor()
cursor.execute("INSERT INTO meta VALUES (?, ?) "
"ON CONFLICT (key) DO "
"UPDATE set value=?",
(key, val, val))
def close(self) -> None:
self.enter_step('close_commit')
self.conn.commit()
self.enter_step('close')
self.conn.close()
self.profile()
def initialize(self) -> None:
self.close()
self.enter_step('initialize')
if not self.write:
self.log.error("Cannot initialize in read-only mode.")
raise
os.unlink(self.PATH)
self.open()
self.log.info("Creating database version %d.", self.VERSION)
cursor = self.conn.cursor()
with open("database_schema.sql", 'r') as db_schema:
cursor.executescript(db_schema.read())
self.set_meta('version', self.VERSION)
self.conn.commit()
def __init__(self, write: bool = False) -> None:
self.log = logging.getLogger('db')
class Profiler():
def __init__(self) -> None:
self.log = logging.getLogger('profiler')
self.time_last = time.perf_counter()
self.time_step = 'init'
self.time_dict: typing.Dict[str, float] = dict()
self.step_dict: typing.Dict[str, int] = dict()
self.write = write
self.open()
version = self.get_meta('version')
if version != self.VERSION:
if version is not None:
self.log.warning(
"Outdated database version: %d found, will be rebuilt.",
version)
self.initialize()
def enter_step(self, name: str) -> None:
now = time.perf_counter()
try:
self.time_dict[self.time_step] += now - self.time_last
self.step_dict[self.time_step] += 1
self.step_dict[self.time_step] += int(name != self.time_step)
except KeyError:
self.time_dict[self.time_step] = now - self.time_last
self.step_dict[self.time_step] = 1
@ -125,13 +73,58 @@ class Database():
self.log.debug(f"{'total':<20}: "
f"{total:9.2f} s ({1:7.2%})")
@staticmethod
def pack_hostname(hostname: str) -> str:
return hostname[::-1] + '.'
class Database(Profiler):
VERSION = 8
PATH = "blocking.p"
def initialize(self) -> None:
self.log.warning(
"Creating database version: %d ",
Database.VERSION)
self.domtree = DomainTreeNode()
self.asns: typing.Set[Asn] = set()
self.ip4tree = IpTreeNode()
def load(self) -> None:
self.enter_step('load')
try:
with open(self.PATH, 'rb') as db_fdsec:
version, data = pickle.load(db_fdsec)
if version == Database.VERSION:
self.domtree, self.asns, self.ip4tree = data
return
self.log.warning(
"Outdated database version found: %d, "
"will be rebuilt.",
version)
except (TypeError, AttributeError, EOFError):
self.log.error(
"Corrupt database found, "
"will be rebuilt.")
except FileNotFoundError:
pass
self.initialize()
def save(self) -> None:
self.enter_step('save')
with open(self.PATH, 'wb') as db_fdsec:
data = self.domtree, self.asns, self.ip4tree
pickle.dump((self.VERSION, data), db_fdsec)
self.profile()
def __init__(self) -> None:
Profiler.__init__(self)
self.log = logging.getLogger('db')
self.load()
@staticmethod
def pack_zone(zone: str) -> str:
return Database.pack_hostname(zone)
def pack_domain(domain: str) -> DomainPath:
return domain.split('.')[::-1]
@staticmethod
def unpack_domain(domain: DomainPath) -> str:
return '.'.join(domain[::-1])
@staticmethod
def pack_asn(asn: str) -> int:
@ -145,431 +138,208 @@ class Database():
return f'AS{asn}'
@staticmethod
def pack_ip4address(address: str) -> int:
total = 0
for i, octet in enumerate(address.split('.')):
total += int(octet) << (3-i)*8
if total > 0xFFFFFFFF:
raise ValueError
return total
# return '{:02x}{:02x}{:02x}{:02x}'.format(
# *[int(c) for c in address.split('.')])
# return base64.b16encode(packed).decode()
# return '{:08b}{:08b}{:08b}{:08b}'.format(
# *[int(c) for c in address.split('.')])
# carg = ctypes.c_wchar_p(address)
# ret = ACCEL.ip4_flat(carg, self.accel_ip4_buf)
# if ret != 0:
# raise ValueError
# return self.accel_ip4_buf.value
# packed = ipaddress.ip_address(address).packed
# return packed
def pack_ip4address(address: str) -> Ip4Path:
addr: Ip4Path = [0] * 32
octets = [int(octet) for octet in address.split('.')]
for b in range(32):
if (octets[b//8] >> b % 8) & 0b1:
addr[b] = 1
return addr
@staticmethod
def unpack_ip4address(address: int) -> str:
return '.'.join(str((address >> (i * 8)) & 0xFF)
for i in reversed(range(4)))
def unpack_ip4address(address: Ip4Path) -> str:
octets = [0] * 4
for b, bit in enumerate(address):
octets[b//8] = (octets[b//8] << 1) + bit
return '.'.join(map(str, octets))
@staticmethod
def pack_ip4network(network: str) -> typing.Tuple[int, int]:
# def pack_ip4network(network: str) -> str:
net = ipaddress.ip_network(network)
mini = Database.pack_ip4address(net.network_address.exploded)
maxi = Database.pack_ip4address(net.broadcast_address.exploded)
# mini = net.network_address.packed
# maxi = net.broadcast_address.packed
return mini, maxi
# return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen]
def pack_ip4network(network: str) -> Ip4Path:
address, prefixlen_str = network.split('/')
prefixlen = int(prefixlen_str)
return Database.pack_ip4address(address)[:prefixlen]
@staticmethod
def unpack_ip4network(mini: int, maxi: int) -> str:
addr = Database.unpack_ip4address(mini)
prefixlen = 32-int(math.log2(maxi-mini+1))
def unpack_ip4network(network: Ip4Path) -> str:
address = network.copy()
prefixlen = len(network)
for _ in range(32-prefixlen):
address.append(0)
addr = Database.unpack_ip4address(address)
return f'{addr}/{prefixlen}'
def update_references(self) -> None:
self.enter_step('update_refs')
cursor = self.conn.cursor()
cursor.execute('UPDATE rules AS r SET refs='
'(SELECT count(*) FROM rules '
'WHERE source=r.id)')
raise NotImplementedError
def prune(self, before: int, base_only: bool = False) -> None:
self.enter_step('prune')
cursor = self.conn.cursor()
cmd = 'DELETE FROM rules WHERE updated<?'
if base_only:
cmd += ' AND level=0'
cursor.execute(cmd, (before,))
raise NotImplementedError
def explain(self, entry: int) -> str:
# Format current
string = '???'
cursor = self.conn.cursor()
cursor.execute(
'SELECT unpack_asn(val) FROM asn WHERE entry=:entry '
'UNION '
'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry '
'UNION '
'SELECT format_zone(val) FROM zone WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4network(mini, maxi) '
'FROM ip4network WHERE entry=:entry ',
{"entry": entry}
)
for val, in cursor: # only one
string = str(val)
string += f' #{entry}'
# Add source if any
cursor.execute('SELECT source FROM rules WHERE id=?', (entry,))
for source, in cursor:
if source:
string += f'{self.explain(source)}'
return string
raise NotImplementedError
def export(self,
first_party_only: bool = False,
end_chain_only: bool = False,
explain: bool = False,
_dic: DomainTreeNode = None,
_par: DomainPath = None,
) -> typing.Iterable[str]:
selection = 'entry' if explain else 'unpack_domain(val)'
command = f'SELECT {selection} FROM rules ' \
'INNER JOIN hostname ON rules.id = hostname.entry'
restrictions: typing.List[str] = list()
if first_party_only:
restrictions.append('rules.first_party = 1')
if end_chain_only:
restrictions.append('rules.refs = 0')
if restrictions:
command += ' WHERE ' + ' AND '.join(restrictions)
if not explain:
command += ' ORDER BY unpack_domain(val) ASC'
cursor = self.conn.cursor()
cursor.execute(command)
for val, in cursor:
if explain:
yield self.explain(val)
else:
yield val
if first_party_only or end_chain_only or explain:
raise NotImplementedError
_dic = _dic or self.domtree
_par = _par or list()
if _dic.match_hostname:
yield self.unpack_domain(_par)
for part in _dic.children:
dic = _dic.children[part]
yield from self.export(_dic=dic,
_par=_par + [part])
def count_rules(self,
first_party_only: bool = False,
) -> str:
counts: typing.List[str] = list()
cursor = self.conn.cursor()
for table in ['asn', 'ip4network', 'ip4address', 'zone', 'hostname']:
command = f'SELECT count(*) FROM rules ' \
f'INNER JOIN {table} ON rules.id = {table}.entry ' \
'WHERE rules.level = 0'
if first_party_only:
command += ' AND first_party=1'
cursor.execute(command)
count, = cursor.fetchone()
if count > 0:
counts.append(f'{table}: {count}')
raise NotImplementedError
return ', '.join(counts)
def get_domain(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domain_prepare')
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domain_select')
cursor.execute(
'SELECT null, entry FROM hostname '
'WHERE val=:d '
'UNION '
'SELECT * FROM ('
'SELECT val, entry FROM zone '
# 'WHERE val>=:d '
# 'ORDER BY val ASC LIMIT 1'
'WHERE val<=:d '
'AND instr(:d, val) = 1'
')',
{'d': domain_prep}
)
for val, entry in cursor:
# print(293, val, entry)
self.enter_step('get_domain_confirm')
if not (val is None or domain_prep.startswith(val)):
# print(297)
continue
def get_domain(self, domain_str: str) -> typing.Iterable[TypedPath]:
self.enter_step('get_domain_pack')
domain = self.pack_domain(domain_str)
self.enter_step('get_domain_brws')
dic = self.domtree
depth = 0
for part in domain:
if dic.match_zone:
self.enter_step('get_domain_yield')
yield entry
def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domainiz_prepare')
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domainiz_select')
cursor.execute(
'SELECT val, entry FROM zone '
'WHERE val<=:d '
'ORDER BY val DESC LIMIT 1',
{'d': domain_prep}
)
for val, entry in cursor:
self.enter_step('get_domainiz_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
self.enter_step('get_domainiz_yield')
yield entry
def get_ip4(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4_prepare')
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
yield (PathType.Zone, domain[:depth])
self.enter_step('get_domain_brws')
if part not in dic.children:
return
cursor = self.conn.cursor()
self.enter_step('get_ip4_select')
cursor.execute(
'SELECT entry FROM ip4address '
# 'SELECT null, entry FROM ip4address '
'WHERE val=:a '
'UNION '
# 'SELECT * FROM ('
# 'SELECT val, entry FROM ip4network '
# 'WHERE val<=:a '
# 'AND instr(:a, val) > 0 '
# 'ORDER BY val DESC'
# ')'
'SELECT entry FROM ip4network '
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for entry, in cursor:
# self.enter_step('get_ip4_confirm')
# if not (val is None or val.startswith(address_prep)):
# # PERF startswith but from the end
# continue
dic = dic.children[part]
depth += 1
if dic.match_zone:
self.enter_step('get_domain_yield')
yield (PathType.Zone, domain)
if dic.match_hostname:
self.enter_step('get_domain_yield')
yield (PathType.Hostname, domain)
def get_ip4(self, ip4_str: str) -> typing.Iterable[TypedPath]:
self.enter_step('get_ip4_pack')
ip4 = self.pack_ip4address(ip4_str)
self.enter_step('get_ip4_brws')
dic = self.ip4tree
depth = 0
for part in ip4:
if dic.match:
self.enter_step('get_ip4_yield')
yield entry
def get_ip4_in_network(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4in_prepare')
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
yield (PathType.Ip4, ip4[:depth])
self.enter_step('get_ip4_brws')
next_dic = dic.children[part]
if next_dic is None:
return
cursor = self.conn.cursor()
self.enter_step('get_ip4in_select')
cursor.execute(
'SELECT entry FROM ip4network '
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for entry, in cursor:
self.enter_step('get_ip4in_yield')
yield entry
dic = next_dic
depth += 1
if dic.match:
self.enter_step('get_ip4_yield')
yield (PathType.Ip4, ip4)
def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]:
cursor = self.conn.cursor()
self.enter_step('list_asn_select')
cursor.execute('SELECT val, entry FROM asn')
for val, entry in cursor:
yield f'AS{val}', entry
def list_asn(self) -> typing.Iterable[TypedPath]:
for asn in self.asns:
yield (PathType.Asn, asn)
def _set_generic(self,
table: str,
select_query: str,
insert_query: str,
prep: typing.Dict[str, DbValue],
def set_hostname(self,
hostname_str: str,
updated: int,
is_first_party: bool = False,
source: int = None,
) -> None:
# Since this isn't the bulk of the processing,
# here abstraction > performaces
# Fields based on the source
self.enter_step(f'set_{table}_prepare')
cursor = self.conn.cursor()
if source is None:
first_party = int(is_first_party)
level = 0
else:
self.enter_step(f'set_{table}_source')
cursor.execute(
'SELECT first_party, level FROM rules '
'WHERE id=?',
(source,)
)
first_party, level = cursor.fetchone()
level += 1
self.enter_step(f'set_{table}_select')
cursor.execute(select_query, prep)
rules_prep: typing.Dict[str, DbValue] = {
"source": source,
"updated": updated,
"first_party": first_party,
"level": level,
}
# If the entry already exists
for entry, in cursor: # only one
self.enter_step(f'set_{table}_update')
rules_prep['entry'] = entry
cursor.execute(
'UPDATE rules SET '
'source=:source, updated=:updated, '
'first_party=:first_party, level=:level '
'WHERE id=:entry AND (updated<:updated OR '
'first_party<:first_party OR level<:level)',
rules_prep
)
# Only update if any of the following:
# - the entry is outdataed
# - the entry was not a first_party but this is
# - this is closer to the original rule
is_first_party: bool = None,
source: TypedPath = None) -> None:
self.enter_step('set_hostname_pack')
if is_first_party or source:
raise NotImplementedError
self.enter_step('set_hostname_brws')
hostname = self.pack_domain(hostname_str)
dic = self.domtree
for part in hostname:
if dic.match_zone:
# Refuse to add hostname whose zone is already matching
return
if part not in dic.children:
dic.children[part] = DomainTreeNode()
dic = dic.children[part]
dic.match_hostname = (updated, DebugPath, 0)
# If it does not exist
self.enter_step(f'set_{table}_insert')
cursor.execute(
'INSERT INTO rules '
'(source, updated, first_party, level) '
'VALUES (:source, :updated, :first_party, :level) ',
rules_prep
)
cursor.execute('SELECT id FROM rules WHERE rowid=?',
(cursor.lastrowid,))
for entry, in cursor: # only one
prep['entry'] = entry
cursor.execute(insert_query, prep)
def set_zone(self,
zone_str: str,
updated: int,
is_first_party: bool = None,
source: TypedPath = None) -> None:
self.enter_step('set_zone_pack')
if is_first_party or source:
raise NotImplementedError
zone = self.pack_domain(zone_str)
self.enter_step('set_zone_brws')
dic = self.domtree
for part in zone:
if dic.match_zone:
# Refuse to add zone whose parent zone is already matching
return
assert False
if part not in dic.children:
dic.children[part] = DomainTreeNode()
dic = dic.children[part]
dic.match_zone = (updated, DebugPath, 0)
def set_hostname(self, hostname: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_hostname_prepare')
prep: typing.Dict[str, DbValue] = {
'val': self.pack_hostname(hostname),
}
self._set_generic(
'hostname',
'SELECT entry FROM hostname WHERE val=:val',
'INSERT INTO hostname (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_asn(self,
asn_str: str,
updated: int,
is_first_party: bool = None,
source: TypedPath = None) -> None:
self.enter_step('set_asn_pack')
if is_first_party or source:
# TODO updated
raise NotImplementedError
asn = self.pack_asn(asn_str)
self.enter_step('set_asn_brws')
self.asns.add(asn)
def set_asn(self, asn: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_asn_prepare')
try:
asn_prep = self.pack_asn(asn)
except ValueError:
self.log.error("Invalid asn: %s", asn)
def set_ip4address(self,
ip4address_str: str,
updated: int,
is_first_party: bool = None,
source: TypedPath = None) -> None:
self.enter_step('set_ip4add_pack')
if is_first_party or source:
raise NotImplementedError
self.enter_step('set_ip4add_brws')
ip4address = self.pack_ip4address(ip4address_str)
dic = self.ip4tree
for part in ip4address:
if dic.match:
# Refuse to add ip4address whose network is already matching
return
prep: typing.Dict[str, DbValue] = {
'val': asn_prep,
}
self._set_generic(
'asn',
'SELECT entry FROM asn WHERE val=:val',
'INSERT INTO asn (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
next_dic = dic.children[part]
if next_dic is None:
next_dic = IpTreeNode()
dic.children[part] = next_dic
dic = next_dic
dic.match = (updated, DebugPath, 0)
def set_ip4address(self, ip4address: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_ip4add_prepare')
try:
ip4address_prep = self.pack_ip4address(ip4address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", ip4address)
def set_ip4network(self,
ip4network_str: str,
updated: int,
is_first_party: bool = None,
source: TypedPath = None) -> None:
self.enter_step('set_ip4net_pack')
if is_first_party or source:
raise NotImplementedError
self.enter_step('set_ip4net_brws')
ip4network = self.pack_ip4network(ip4network_str)
dic = self.ip4tree
for part in ip4network:
if dic.match:
# Refuse to add ip4network whose parent network
# is already matching
return
prep: typing.Dict[str, DbValue] = {
'val': ip4address_prep,
}
self._set_generic(
'ip4add',
'SELECT entry FROM ip4address WHERE val=:val',
'INSERT INTO ip4address (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_zone(self, zone: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_zone_prepare')
prep: typing.Dict[str, DbValue] = {
'val': self.pack_zone(zone),
}
self._set_generic(
'zone',
'SELECT entry FROM zone WHERE val=:val',
'INSERT INTO zone (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_ip4network(self, ip4network: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_ip4net_prepare')
try:
ip4network_prep = self.pack_ip4network(ip4network)
except (ValueError, IndexError):
self.log.error("Invalid ip4network: %s", ip4network)
return
prep: typing.Dict[str, DbValue] = {
'mini': ip4network_prep[0],
'maxi': ip4network_prep[1],
}
self._set_generic(
'ip4net',
'SELECT entry FROM ip4network WHERE mini=:mini AND maxi=:maxi',
'INSERT INTO ip4network (mini, maxi, entry) '
'VALUES (:mini, :maxi, :entry)',
prep,
*args, **kwargs
)
if __name__ == '__main__':
# Parsing arguments
parser = argparse.ArgumentParser(
description="Database operations")
parser.add_argument(
'-i', '--initialize', action='store_true',
help="Reconstruct the whole database")
parser.add_argument(
'-p', '--prune', action='store_true',
help="Remove old entries from database")
parser.add_argument(
'-b', '--prune-base', action='store_true',
help="TODO")
parser.add_argument(
'-s', '--prune-before', type=int,
default=(int(time.time()) - 60*60*24*31*6),
help="TODO")
parser.add_argument(
'-r', '--references', action='store_true',
help="Update the reference count")
args = parser.parse_args()
DB = Database(write=True)
if args.initialize:
DB.initialize()
if args.prune:
DB.prune(before=args.prune_before, base_only=args.prune_base)
if args.references:
DB.update_references()
DB.close()
next_dic = dic.children[part]
if next_dic is None:
next_dic = IpTreeNode()
dic.children[part] = next_dic
dic = next_dic
dic.match = (updated, DebugPath, 0)

View File

@ -1,59 +0,0 @@
-- Remember to increment DB_VERSION
-- in database.py on changes to this file
CREATE TABLE rules (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source INTEGER, -- The rule this one is based on
updated INTEGER, -- If the row was updated during last data import (0: No, 1: Yes)
first_party INTEGER, -- 1: this blocks a first party for sure, 0: maybe
refs INTEGER, -- Number of entries issued from this one
level INTEGER, -- Level of recursion to the root source rule (used for source priority)
FOREIGN KEY (source) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX rules_source ON rules (source); -- for references recounting
CREATE INDEX rules_updated ON rules (updated); -- for pruning
CREATE INDEX rules_level_firstparty ON rules (level, first_party); -- for counting rules
CREATE TABLE asn (
val INTEGER PRIMARY KEY,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX asn_entry ON asn (entry); -- for explainations
CREATE TABLE hostname (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for consistency with zone)
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX hostname_entry ON hostname (entry); -- for explainations
CREATE TABLE zone (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for easier matching)
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX zone_entry ON zone (entry); -- for explainations
CREATE TABLE ip4address (
val INTEGER PRIMARY KEY,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX ip4address_entry ON ip4address (entry); -- for explainations
CREATE TABLE ip4network (
-- val TEXT PRIMARY KEY,
mini INTEGER,
maxi INTEGER,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX ip4network_minmax ON ip4network (mini, maxi);
CREATE INDEX ip4network_entry ON ip4network (entry); -- for explainations
-- Store various things
CREATE TABLE meta (
key TEXT PRIMARY KEY,
value integer
);

View File

@ -45,5 +45,3 @@ if __name__ == '__main__':
explain=args.explain,
):
print(domain, file=args.output)
DB.close()

View File

@ -31,23 +31,25 @@ if __name__ == '__main__':
args = parser.parse_args()
DB = database.Database()
DBW = database.Database(write=True)
for asn, entry in DB.list_asn():
for path in DB.list_asn():
ptype, asn = path
assert ptype == database.PathType.Asn
assert isinstance(asn, int)
asn_str = database.Database.unpack_asn(asn)
DB.enter_step('asn_get_ranges')
for prefix in get_ranges(asn):
for prefix in get_ranges(asn_str):
parsed_prefix: IPNetwork = ipaddress.ip_network(prefix)
if parsed_prefix.version == 4:
DBW.set_ip4network(
DB.set_ip4network(
prefix,
source=entry,
# source=path,
updated=int(time.time())
)
log.info('Added %s from %s (id=%s)', prefix, asn, entry)
log.info('Added %s from %s (source=%s)', prefix, asn, path)
elif parsed_prefix.version == 6:
log.warning('Unimplemented prefix version: %s', prefix)
else:
log.error('Unknown prefix version: %s', prefix)
DB.close()
DBW.close()
DB.save()

View File

@ -6,126 +6,52 @@ import json
import logging
import sys
import typing
import multiprocessing
import enum
RecordType = enum.Enum('RecordType', 'A AAAA CNAME PTR')
Record = typing.Tuple[RecordType, int, str, str]
# select, confirm, write
# select, write
FUNCTION_MAP: typing.Any = {
RecordType.A: (
database.Database.get_ip4,
database.Database.get_domain_in_zone,
database.Database.set_hostname,
),
RecordType.CNAME: (
database.Database.get_domain,
database.Database.get_domain_in_zone,
database.Database.set_hostname,
),
RecordType.PTR: (
database.Database.get_domain,
database.Database.get_ip4_in_network,
database.Database.set_ip4address,
),
}
class Reader(multiprocessing.Process):
def __init__(self,
recs_queue: multiprocessing.Queue,
write_queue: multiprocessing.Queue,
index: int = 0):
super(Reader, self).__init__()
self.log = logging.getLogger(f'rd{index:03d}')
self.recs_queue = recs_queue
self.write_queue = write_queue
self.index = index
def run(self) -> None:
self.db = database.Database(write=False)
self.db.log = logging.getLogger(f'db{self.index:03d}')
self.db.enter_step('line_wait')
block: typing.List[str]
try:
for block in iter(self.recs_queue.get, None):
record: Record
for record in block:
# print(55, record)
dtype, updated, name, value = record
self.db.enter_step('feed_switch')
select, confirm, write = FUNCTION_MAP[dtype]
for rule in select(self.db, value):
# print(60, rule, list(confirm(self.db, name)))
if not any(confirm(self.db, name)):
# print(62, write, name, updated, rule)
self.db.enter_step('wait_put')
self.write_queue.put((write, name, updated, rule))
self.db.enter_step('line_wait')
except KeyboardInterrupt:
self.log.error('Interrupted')
self.db.enter_step('end')
self.db.close()
class Writer(multiprocessing.Process):
def __init__(self,
write_queue: multiprocessing.Queue,
):
super(Writer, self).__init__()
self.log = logging.getLogger(f'wr ')
self.write_queue = write_queue
def run(self) -> None:
self.db = database.Database(write=True)
self.db.log = logging.getLogger(f'dbw ')
self.db.enter_step('line_wait')
block: typing.List[str]
try:
fun: typing.Callable
name: str
updated: int
source: int
for fun, name, updated, source in iter(self.write_queue.get, None):
self.db.enter_step('exec')
fun(self.db, name, updated, source=source)
self.db.enter_step('line_wait')
except KeyboardInterrupt:
self.log.error('Interrupted')
self.db.enter_step('end')
self.db.close()
class Parser():
def __init__(self,
buf: typing.Any,
recs_queue: multiprocessing.Queue,
block_size: int,
):
super(Parser, self).__init__()
def __init__(self, buf: typing.Any) -> None:
self.buf = buf
self.log = logging.getLogger('pr ')
self.recs_queue = recs_queue
self.block: typing.List[Record] = list()
self.block_size = block_size
self.db = database.Database() # Just for timing
self.db.log = logging.getLogger('pr ')
self.log = logging.getLogger('parser')
self.db = database.Database()
def end(self) -> None:
self.db.save()
def register(self,
rtype: RecordType,
updated: int,
name: str,
value: str
) -> None:
def register(self, record: Record) -> None:
self.db.enter_step('register')
self.block.append(record)
if len(self.block) >= self.block_size:
self.db.enter_step('put_block')
self.recs_queue.put(self.block)
self.block = list()
def run(self) -> None:
self.consume()
self.recs_queue.put(self.block)
self.db.close()
select, write = FUNCTION_MAP[rtype]
try:
for source in select(self.db, value):
# write(self.db, name, updated, source=source)
write(self.db, name, updated)
except NotImplementedError:
return # DEBUG
def consume(self) -> None:
raise NotImplementedError
@ -146,13 +72,12 @@ class Rapid7Parser(Parser):
data = json.loads(line)
except json.decoder.JSONDecodeError:
continue
record = (
self.register(
Rapid7Parser.TYPES[data['type']],
int(data['timestamp']),
data['name'],
data['value']
)
self.register(record)
class DnsMassParser(Parser):
@ -182,13 +107,12 @@ class DnsMassParser(Parser):
else:
dtype, name_offset, value_offset = \
DnsMassParser.TYPES[split[1]]
record = (
self.register(
dtype,
timestamp,
split[0][:name_offset],
split[2][:value_offset],
)
self.register(record)
self.db.enter_step('parse_dnsmass')
except KeyError:
continue
@ -212,49 +136,12 @@ if __name__ == '__main__':
args_parser.add_argument(
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO")
args_parser.add_argument(
'-j', '--workers', type=int, default=4,
help="TODO")
args_parser.add_argument(
'-b', '--block-size', type=int, default=100,
help="TODO")
args = args_parser.parse_args()
DB = database.Database(write=False) # Not needed, just for timing
DB.log = logging.getLogger('db ')
recs_queue: multiprocessing.Queue = multiprocessing.Queue(
maxsize=10*args.workers)
write_queue: multiprocessing.Queue = multiprocessing.Queue(
maxsize=10*args.workers)
DB.enter_step('proc_create')
readers: typing.List[Reader] = list()
for w in range(args.workers):
readers.append(Reader(recs_queue, write_queue, w))
writer = Writer(write_queue)
parser = PARSERS[args.parser](
args.input, recs_queue, args.block_size)
DB.enter_step('proc_start')
for reader in readers:
reader.start()
writer.start()
parser = PARSERS[args.parser](args.input)
try:
DB.enter_step('parser_run')
parser.run()
DB.enter_step('end_put')
for _ in range(args.workers):
recs_queue.put(None)
write_queue.put(None)
DB.enter_step('proc_join')
for reader in readers:
reader.join()
writer.join()
parser.consume()
except KeyboardInterrupt:
log.error('Interrupted')
pass
parser.end()
DB.close()

View File

@ -28,15 +28,15 @@ if __name__ == '__main__':
help="The input only comes from verified first-party sources")
args = parser.parse_args()
DB = database.Database(write=True)
DB = database.Database()
fun = FUNCTION_MAP[args.type]
for rule in args.input:
fun(DB,
rule.strip(),
is_first_party=args.first_party,
# is_first_party=args.first_party,
updated=int(time.time()),
)
DB.close()
DB.save()

View File

@ -6,11 +6,11 @@ function log() {
log "Importing rules…"
BEFORE="$(date +%s)"
cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone
cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone
cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone
cat rules_ip/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network
cat rules_asn/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py asn
# cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone
# cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone
# cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone
# cat rules_ip/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network
# cat rules_asn/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py asn
cat rules/first-party.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone --first-party
cat rules_ip/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network --first-party
@ -19,4 +19,4 @@ cat rules_asn/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py as
./feed_asn.py
log "Pruning old rules…"
./database.py --prune --prune-before "$BEFORE" --prune-base
./db.py --prune --prune-before "$BEFORE" --prune-base