eulaurarien/database.py

600 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Utility functions to interact with the database.
"""
import typing
import time
import logging
import coloredlogs
import pickle
coloredlogs.install(
level='DEBUG',
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
)
Asn = int
Timestamp = int
Level = int
class Path():
# FP add boolean here
pass
class RulePath(Path):
def __str__(self) -> str:
return '(rules)'
class DomainPath(Path):
def __init__(self, parts: typing.List[str]):
self.parts = parts
def __str__(self) -> str:
return '?.' + Database.unpack_domain(self)
class HostnamePath(DomainPath):
def __str__(self) -> str:
return Database.unpack_domain(self)
class ZonePath(DomainPath):
def __str__(self) -> str:
return '*.' + Database.unpack_domain(self)
class AsnPath(Path):
def __init__(self, asn: Asn):
self.asn = asn
def __str__(self) -> str:
return Database.unpack_asn(self)
class Ip4Path(Path):
def __init__(self, value: int, prefixlen: int):
self.value = value
self.prefixlen = prefixlen
def __str__(self) -> str:
return Database.unpack_ip4network(self)
class Match():
def __init__(self) -> None:
self.updated: int = 0
self.level: int = 0
self.source: typing.Optional[Path] = None
self.references: int = 0
# FP dupplicate args
def active(self) -> bool:
return self.updated > 0
class AsnNode(Match):
pass
class DomainTreeNode():
def __init__(self) -> None:
self.children: typing.Dict[str, DomainTreeNode] = dict()
self.match_zone = Match()
self.match_hostname = Match()
class IpTreeNode(Match):
def __init__(self) -> None:
Match.__init__(self)
self.zero: typing.Optional[IpTreeNode] = None
self.one: typing.Optional[IpTreeNode] = None
Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode]
MatchCallable = typing.Callable[[Path,
Match,
typing.Optional[typing.Any]],
typing.Any]
class Profiler():
def __init__(self) -> None:
self.log = logging.getLogger('profiler')
self.time_last = time.perf_counter()
self.time_step = 'init'
self.time_dict: typing.Dict[str, float] = dict()
self.step_dict: typing.Dict[str, int] = dict()
def enter_step(self, name: str) -> None:
now = time.perf_counter()
try:
self.time_dict[self.time_step] += now - self.time_last
self.step_dict[self.time_step] += int(name != self.time_step)
except KeyError:
self.time_dict[self.time_step] = now - self.time_last
self.step_dict[self.time_step] = 1
self.time_step = name
self.time_last = time.perf_counter()
def profile(self) -> None:
self.enter_step('profile')
total = sum(self.time_dict.values())
for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]):
times = self.step_dict[key]
self.log.debug(f"{key:<20}: {times:9d} × {secs/times:5.3e} "
f"= {secs:9.2f} s ({secs/total:7.2%}) ")
self.log.debug(f"{'total':<20}: "
f"{total:9.2f} s ({1:7.2%})")
class Database(Profiler):
VERSION = 14
PATH = "blocking.p"
def initialize(self) -> None:
self.log.warning(
"Creating database version: %d ",
Database.VERSION)
self.domtree = DomainTreeNode()
self.asns: typing.Dict[Asn, AsnNode] = dict()
self.ip4tree = IpTreeNode()
def load(self) -> None:
self.enter_step('load')
try:
with open(self.PATH, 'rb') as db_fdsec:
version, data = pickle.load(db_fdsec)
if version == Database.VERSION:
self.domtree, self.asns, self.ip4tree = data
return
self.log.warning(
"Outdated database version found: %d, "
"it will be rebuilt.",
version)
except (TypeError, AttributeError, EOFError):
self.log.error(
"Corrupt (or heavily outdated) database found, "
"it will be rebuilt.")
except FileNotFoundError:
pass
self.initialize()
def save(self) -> None:
self.enter_step('save')
with open(self.PATH, 'wb') as db_fdsec:
data = self.domtree, self.asns, self.ip4tree
pickle.dump((self.VERSION, data), db_fdsec)
self.profile()
def __init__(self) -> None:
Profiler.__init__(self)
self.log = logging.getLogger('db')
self.load()
@staticmethod
def pack_domain(domain: str) -> DomainPath:
return DomainPath(domain.split('.')[::-1])
@staticmethod
def unpack_domain(domain: DomainPath) -> str:
return '.'.join(domain.parts[::-1])
@staticmethod
def pack_asn(asn: str) -> AsnPath:
asn = asn.upper()
if asn.startswith('AS'):
asn = asn[2:]
return AsnPath(int(asn))
@staticmethod
def unpack_asn(asn: AsnPath) -> str:
return f'AS{asn.asn}'
@staticmethod
def pack_ip4address(address: str) -> Ip4Path:
addr = 0
for split in address.split('.'):
addr = (addr << 8) + int(split)
return Ip4Path(addr, 32)
@staticmethod
def unpack_ip4address(address: Ip4Path) -> str:
addr = address.value
assert address.prefixlen == 32
octets: typing.List[int] = list()
octets = [0] * 4
for o in reversed(range(4)):
octets[o] = addr & 0xFF
addr >>= 8
return '.'.join(map(str, octets))
@staticmethod
def pack_ip4network(network: str) -> Ip4Path:
address, prefixlen_str = network.split('/')
prefixlen = int(prefixlen_str)
addr = Database.pack_ip4address(address)
addr.prefixlen = prefixlen
return addr
@staticmethod
def unpack_ip4network(network: Ip4Path) -> str:
addr = network.value
octets: typing.List[int] = list()
octets = [0] * 4
for o in reversed(range(4)):
octets[o] = addr & 0xFF
addr >>= 8
return '.'.join(map(str, octets)) + '/' + str(network.prefixlen)
def get_match(self, path: Path) -> Match:
if isinstance(path, RulePath):
return Match()
elif isinstance(path, AsnPath):
return self.asns[path.asn]
elif isinstance(path, DomainPath):
dicd = self.domtree
for part in path.parts:
dicd = dicd.children[part]
if isinstance(path, HostnamePath):
return dicd.match_hostname
elif isinstance(path, ZonePath):
return dicd.match_zone
else:
raise ValueError
elif isinstance(path, Ip4Path):
dici = self.ip4tree
for i in range(31, 31-path.prefixlen, -1):
bit = (path.value >> i) & 0b1
dici_next = dici.one if bit else dici.zero
if not dici_next:
raise IndexError
dici = dici_next
return dici
else:
raise ValueError
def exec_each_asn(self,
callback: MatchCallable,
arg: typing.Any = None,
) -> typing.Any:
for asn in self.asns:
match = self.asns[asn]
if match.active():
c = callback(
AsnPath(asn),
match,
arg
)
try:
yield from c
except TypeError: # not iterable
pass
def exec_each_domain(self,
callback: MatchCallable,
arg: typing.Any = None,
_dic: DomainTreeNode = None,
_par: DomainPath = None,
) -> typing.Any:
_dic = _dic or self.domtree
_par = _par or DomainPath([])
if _dic.match_hostname.active():
c = callback(
HostnamePath(_par.parts),
_dic.match_hostname,
arg
)
try:
yield from c
except TypeError: # not iterable
pass
if _dic.match_zone.active():
c = callback(
ZonePath(_par.parts),
_dic.match_zone,
arg
)
try:
yield from c
except TypeError: # not iterable
pass
for part in _dic.children:
dic = _dic.children[part]
yield from self.exec_each_domain(
callback,
arg,
_dic=dic,
_par=DomainPath(_par.parts + [part])
)
def exec_each_ip4(self,
callback: MatchCallable,
arg: typing.Any = None,
_dic: IpTreeNode = None,
_par: Ip4Path = None,
) -> typing.Any:
_dic = _dic or self.ip4tree
_par = _par or Ip4Path(0, 0)
if _dic.active():
c = callback(
_par,
_dic,
arg
)
try:
yield from c
except TypeError: # not iterable
pass
# 0
pref = _par.prefixlen + 1
dic = _dic.zero
if dic:
addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-pref)))
assert addr0 == _par.value
yield from self.exec_each_ip4(
callback,
arg,
_dic=dic,
_par=Ip4Path(addr0, pref)
)
# 1
dic = _dic.one
if dic:
addr1 = _par.value | (1 << (32-pref))
yield from self.exec_each_ip4(
callback,
arg,
_dic=dic,
_par=Ip4Path(addr1, pref)
)
def exec_each(self,
callback: MatchCallable,
arg: typing.Any = None,
) -> typing.Any:
yield from self.exec_each_domain(callback)
yield from self.exec_each_ip4(callback)
yield from self.exec_each_asn(callback)
def update_references(self) -> None:
# Should be correctly calculated normally,
# keeping this just in case
def reset_references_cb(path: Path,
match: Match, _: typing.Any
) -> None:
match.references = 0
for _ in self.exec_each(reset_references_cb, None):
pass
def increment_references_cb(path: Path,
match: Match, _: typing.Any
) -> None:
if match.source:
source = self.get_match(match.source)
source.references += 1
for _ in self.exec_each(increment_references_cb, None):
pass
def prune(self, before: int, base_only: bool = False) -> None:
raise NotImplementedError
def explain(self, path: Path) -> str:
match = self.get_match(path)
string = f'{path}'
if not isinstance(path, RulePath):
string += f' #{match.references}'
if match.source:
string += f'{self.explain(match.source)}'
return string
def export(self,
first_party_only: bool = False,
end_chain_only: bool = False,
explain: bool = False,
) -> typing.Iterable[str]:
if first_party_only:
raise NotImplementedError
def export_cb(path: Path, match: Match, _: typing.Any
) -> typing.Iterable[str]:
assert isinstance(path, DomainPath)
if not isinstance(path, HostnamePath):
return
if end_chain_only and match.references > 0:
return
if explain:
yield self.explain(path)
else:
yield self.unpack_domain(path)
yield from self.exec_each_domain(export_cb, None)
def list_rules(self,
first_party_only: bool = False,
) -> typing.Iterable[str]:
if first_party_only:
raise NotImplementedError
def list_rules_cb(path: Path, match: Match, _: typing.Any
) -> typing.Iterable[str]:
if isinstance(path, ZonePath) \
or (isinstance(path, Ip4Path) and path.prefixlen < 32):
# if match.level == 0:
yield self.explain(path)
yield from self.exec_each(list_rules_cb, None)
def count_rules(self,
first_party_only: bool = False,
) -> str:
raise NotImplementedError
def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]:
self.enter_step('get_domain_pack')
domain = self.pack_domain(domain_str)
self.enter_step('get_domain_brws')
dic = self.domtree
depth = 0
for part in domain.parts:
if dic.match_zone.active():
self.enter_step('get_domain_yield')
yield ZonePath(domain.parts[:depth])
self.enter_step('get_domain_brws')
if part not in dic.children:
return
dic = dic.children[part]
depth += 1
if dic.match_zone.active():
self.enter_step('get_domain_yield')
yield ZonePath(domain.parts)
if dic.match_hostname.active():
self.enter_step('get_domain_yield')
yield HostnamePath(domain.parts)
def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]:
self.enter_step('get_ip4_pack')
ip4 = self.pack_ip4address(ip4_str)
self.enter_step('get_ip4_brws')
dic = self.ip4tree
for i in range(31, 31-ip4.prefixlen, -1):
bit = (ip4.value >> i) & 0b1
if dic.active():
self.enter_step('get_ip4_yield')
a = Ip4Path(ip4.value >> (i+1) << (i+1), 31-i)
yield a
self.enter_step('get_ip4_brws')
next_dic = dic.one if bit else dic.zero
if next_dic is None:
return
dic = next_dic
if dic.active():
self.enter_step('get_ip4_yield')
yield ip4
def set_match(self,
match: Match,
updated: int,
source: Path,
) -> None:
new_source = self.get_match(source)
new_level = new_source.level + 1
if updated > match.updated or new_level > match.level:
if match.source:
old_source = self.get_match(match.source)
old_source.references -= 1
match.updated = updated
match.level = new_level
match.source = source
new_source.references += 1
# FP dupplicate function
def _set_domain(self,
hostname: bool,
domain_str: str,
updated: int,
is_first_party: bool = None,
source: Path = None) -> None:
self.enter_step('set_domain_pack')
if is_first_party:
raise NotImplementedError
domain = self.pack_domain(domain_str)
self.enter_step('set_domain_brws')
dic = self.domtree
for part in domain.parts:
if part not in dic.children:
dic.children[part] = DomainTreeNode()
dic = dic.children[part]
if dic.match_zone.active():
# Refuse to add domain whose zone is already matching
return
if hostname:
match = dic.match_hostname
else:
match = dic.match_zone
self.set_match(
match,
updated,
source or RulePath(),
)
def set_hostname(self,
*args: typing.Any, **kwargs: typing.Any
) -> None:
self._set_domain(True, *args, **kwargs)
def set_zone(self,
*args: typing.Any, **kwargs: typing.Any
) -> None:
self._set_domain(False, *args, **kwargs)
def set_asn(self,
asn_str: str,
updated: int,
is_first_party: bool = None,
source: Path = None) -> None:
self.enter_step('set_asn')
if is_first_party:
raise NotImplementedError
path = self.pack_asn(asn_str)
if path.asn in self.asns:
match = self.asns[path.asn]
else:
match = AsnNode()
self.asns[path.asn] = match
self.set_match(
match,
updated,
source or RulePath(),
)
def _set_ip4(self,
ip4: Ip4Path,
updated: int,
is_first_party: bool = None,
source: Path = None) -> None:
if is_first_party:
raise NotImplementedError
self.enter_step('set_ip4_brws')
dic = self.ip4tree
for i in range(31, 31-ip4.prefixlen, -1):
bit = (ip4.value >> i) & 0b1
next_dic = dic.one if bit else dic.zero
if next_dic is None:
next_dic = IpTreeNode()
if bit:
dic.one = next_dic
else:
dic.zero = next_dic
dic = next_dic
if dic.active():
# Refuse to add ip4* whose network is already matching
return
self.set_match(
dic,
updated,
source or RulePath(),
)
def set_ip4address(self,
ip4address_str: str,
*args: typing.Any, **kwargs: typing.Any
) -> None:
self.enter_step('set_ip4add_pack')
ip4 = self.pack_ip4address(ip4address_str)
self._set_ip4(ip4, *args, **kwargs)
def set_ip4network(self,
ip4network_str: str,
*args: typing.Any, **kwargs: typing.Any
) -> None:
self.enter_step('set_ip4net_pack')
ip4 = self.pack_ip4network(ip4network_str)
self._set_ip4(ip4, *args, **kwargs)