eulaurarien/database.py
2019-12-15 16:48:17 +01:00

347 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Utility functions to interact with the database.
"""
import typing
import time
import logging
import coloredlogs
import pickle
import enum
coloredlogs.install(
level='DEBUG',
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
)
PathType = enum.Enum('PathType', 'Rule Hostname Zone Asn Ip4 Ip6')
RulePath = typing.Union[None]
Asn = int
DomainPath = typing.List[str]
Ip4Path = typing.Tuple[int, int] # value, prefixlen
Ip6Path = typing.List[int]
Path = typing.Union[RulePath, DomainPath, Asn, Ip4Path, Ip6Path]
TypedPath = typing.Tuple[PathType, Path]
Timestamp = int
Level = int
Match = typing.Tuple[Timestamp, TypedPath, Level]
DebugPath = (PathType.Rule, None)
class DomainTreeNode():
def __init__(self) -> None:
self.children: typing.Dict[str, DomainTreeNode] = dict()
self.match_zone: typing.Optional[Match] = None
self.match_hostname: typing.Optional[Match] = None
class IpTreeNode():
def __init__(self) -> None:
self.children: typing.List[typing.Optional[IpTreeNode]] = [None, None]
self.match: typing.Optional[Match] = None
class Profiler():
def __init__(self) -> None:
self.log = logging.getLogger('profiler')
self.time_last = time.perf_counter()
self.time_step = 'init'
self.time_dict: typing.Dict[str, float] = dict()
self.step_dict: typing.Dict[str, int] = dict()
def enter_step(self, name: str) -> None:
now = time.perf_counter()
try:
self.time_dict[self.time_step] += now - self.time_last
self.step_dict[self.time_step] += int(name != self.time_step)
except KeyError:
self.time_dict[self.time_step] = now - self.time_last
self.step_dict[self.time_step] = 1
self.time_step = name
self.time_last = time.perf_counter()
def profile(self) -> None:
self.enter_step('profile')
total = sum(self.time_dict.values())
for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]):
times = self.step_dict[key]
self.log.debug(f"{key:<20}: {times:9d} × {secs/times:5.3e} "
f"= {secs:9.2f} s ({secs/total:7.2%}) ")
self.log.debug(f"{'total':<20}: "
f"{total:9.2f} s ({1:7.2%})")
class Database(Profiler):
VERSION = 8
PATH = "blocking.p"
def initialize(self) -> None:
self.log.warning(
"Creating database version: %d ",
Database.VERSION)
self.domtree = DomainTreeNode()
self.asns: typing.Set[Asn] = set()
self.ip4tree = IpTreeNode()
def load(self) -> None:
self.enter_step('load')
try:
with open(self.PATH, 'rb') as db_fdsec:
version, data = pickle.load(db_fdsec)
if version == Database.VERSION:
self.domtree, self.asns, self.ip4tree = data
return
self.log.warning(
"Outdated database version found: %d, "
"will be rebuilt.",
version)
except (TypeError, AttributeError, EOFError):
self.log.error(
"Corrupt database found, "
"will be rebuilt.")
except FileNotFoundError:
pass
self.initialize()
def save(self) -> None:
self.enter_step('save')
with open(self.PATH, 'wb') as db_fdsec:
data = self.domtree, self.asns, self.ip4tree
pickle.dump((self.VERSION, data), db_fdsec)
self.profile()
def __init__(self) -> None:
Profiler.__init__(self)
self.log = logging.getLogger('db')
self.load()
@staticmethod
def pack_domain(domain: str) -> DomainPath:
return domain.split('.')[::-1]
@staticmethod
def unpack_domain(domain: DomainPath) -> str:
return '.'.join(domain[::-1])
@staticmethod
def pack_asn(asn: str) -> int:
asn = asn.upper()
if asn.startswith('AS'):
asn = asn[2:]
return int(asn)
@staticmethod
def unpack_asn(asn: int) -> str:
return f'AS{asn}'
@staticmethod
def pack_ip4address(address: str) -> Ip4Path:
addr = 0
for split in address.split('.'):
addr = (addr << 8) + int(split)
return (addr, 32)
@staticmethod
def unpack_ip4address(address: Ip4Path) -> str:
addr, prefixlen = address
assert prefixlen == 32
octets: typing.List[int] = list()
octets = [0] * 4
for o in reversed(range(4)):
octets[o] = addr & 0xFF
addr >>= 8
return '.'.join(map(str, octets))
@staticmethod
def pack_ip4network(network: str) -> Ip4Path:
address, prefixlen_str = network.split('/')
prefixlen = int(prefixlen_str)
addr, _ = Database.pack_ip4address(address)
return (addr, prefixlen)
@staticmethod
def unpack_ip4network(network: Ip4Path) -> str:
address, prefixlen = network
addr = Database.unpack_ip4address((address, 32))
return f'{addr}/{prefixlen}'
def update_references(self) -> None:
raise NotImplementedError
def prune(self, before: int, base_only: bool = False) -> None:
raise NotImplementedError
def explain(self, entry: int) -> str:
raise NotImplementedError
def export(self,
first_party_only: bool = False,
end_chain_only: bool = False,
explain: bool = False,
_dic: DomainTreeNode = None,
_par: DomainPath = None,
) -> typing.Iterable[str]:
if first_party_only or end_chain_only or explain:
raise NotImplementedError
_dic = _dic or self.domtree
_par = _par or list()
if _dic.match_hostname:
yield self.unpack_domain(_par)
for part in _dic.children:
dic = _dic.children[part]
yield from self.export(_dic=dic,
_par=_par + [part])
def count_rules(self,
first_party_only: bool = False,
) -> str:
raise NotImplementedError
def get_domain(self, domain_str: str) -> typing.Iterable[TypedPath]:
self.enter_step('get_domain_pack')
domain = self.pack_domain(domain_str)
self.enter_step('get_domain_brws')
dic = self.domtree
depth = 0
for part in domain:
if dic.match_zone:
self.enter_step('get_domain_yield')
yield (PathType.Zone, domain[:depth])
self.enter_step('get_domain_brws')
if part not in dic.children:
return
dic = dic.children[part]
depth += 1
if dic.match_zone:
self.enter_step('get_domain_yield')
yield (PathType.Zone, domain)
if dic.match_hostname:
self.enter_step('get_domain_yield')
yield (PathType.Hostname, domain)
def get_ip4(self, ip4_str: str) -> typing.Iterable[TypedPath]:
self.enter_step('get_ip4_pack')
ip4, prefixlen = self.pack_ip4address(ip4_str)
self.enter_step('get_ip4_brws')
dic = self.ip4tree
for i in reversed(range(prefixlen)):
part = (ip4 >> i) & 0b1
if dic.match:
self.enter_step('get_ip4_yield')
yield (PathType.Ip4, (ip4, 32-i))
self.enter_step('get_ip4_brws')
next_dic = dic.children[part]
if next_dic is None:
return
dic = next_dic
if dic.match:
self.enter_step('get_ip4_yield')
yield (PathType.Ip4, ip4)
def list_asn(self) -> typing.Iterable[TypedPath]:
for asn in self.asns:
yield (PathType.Asn, asn)
def set_hostname(self,
hostname_str: str,
updated: int,
is_first_party: bool = None,
source: TypedPath = None) -> None:
self.enter_step('set_hostname_pack')
if is_first_party or source:
raise NotImplementedError
self.enter_step('set_hostname_brws')
hostname = self.pack_domain(hostname_str)
dic = self.domtree
for part in hostname:
if dic.match_zone:
# Refuse to add hostname whose zone is already matching
return
if part not in dic.children:
dic.children[part] = DomainTreeNode()
dic = dic.children[part]
dic.match_hostname = (updated, DebugPath, 0)
def set_zone(self,
zone_str: str,
updated: int,
is_first_party: bool = None,
source: TypedPath = None) -> None:
self.enter_step('set_zone_pack')
if is_first_party or source:
raise NotImplementedError
zone = self.pack_domain(zone_str)
self.enter_step('set_zone_brws')
dic = self.domtree
for part in zone:
if dic.match_zone:
# Refuse to add zone whose parent zone is already matching
return
if part not in dic.children:
dic.children[part] = DomainTreeNode()
dic = dic.children[part]
dic.match_zone = (updated, DebugPath, 0)
def set_asn(self,
asn_str: str,
updated: int,
is_first_party: bool = None,
source: TypedPath = None) -> None:
self.enter_step('set_asn_pack')
if is_first_party or source:
# TODO updated
raise NotImplementedError
asn = self.pack_asn(asn_str)
self.enter_step('set_asn_brws')
self.asns.add(asn)
def set_ip4address(self,
ip4address_str: str,
updated: int,
is_first_party: bool = None,
source: TypedPath = None) -> None:
self.enter_step('set_ip4add_pack')
if is_first_party or source:
raise NotImplementedError
ip4, prefixlen = self.pack_ip4address(ip4address_str)
self.enter_step('set_ip4add_brws')
dic = self.ip4tree
for i in reversed(range(prefixlen)):
part = (ip4 >> i) & 0b1
if dic.match:
# Refuse to add ip4address whose network is already matching
return
next_dic = dic.children[part]
if next_dic is None:
next_dic = IpTreeNode()
dic.children[part] = next_dic
dic = next_dic
dic.match = (updated, DebugPath, 0)
def set_ip4network(self,
ip4network_str: str,
updated: int,
is_first_party: bool = None,
source: TypedPath = None) -> None:
self.enter_step('set_ip4net_pack')
if is_first_party or source:
raise NotImplementedError
self.enter_step('set_ip4net_brws')
ip4, prefixlen = self.pack_ip4network(ip4network_str)
dic = self.ip4tree
for i in reversed(range(prefixlen)):
part = (ip4 >> i) & 0b1
if dic.match:
# Refuse to add ip4network whose parent network
# is already matching
return
next_dic = dic.children[part]
if next_dic is None:
next_dic = IpTreeNode()
dic.children[part] = next_dic
dic = next_dic
dic.match = (updated, DebugPath, 0)