2019-12-09 08:12:48 +01:00
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
"""
|
|
|
|
|
Utility functions to interact with the database.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import typing
|
|
|
|
|
import time
|
|
|
|
|
import logging
|
|
|
|
|
import coloredlogs
|
2019-12-15 15:56:26 +01:00
|
|
|
|
import pickle
|
|
|
|
|
import enum
|
2019-12-09 08:12:48 +01:00
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
coloredlogs.install(
|
|
|
|
|
level='DEBUG',
|
|
|
|
|
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
|
|
|
|
|
)
|
|
|
|
|
|
2019-12-15 15:56:26 +01:00
|
|
|
|
PathType = enum.Enum('PathType', 'Rule Hostname Zone Asn Ip4 Ip6')
|
|
|
|
|
RulePath = typing.Union[None]
|
|
|
|
|
Asn = int
|
|
|
|
|
DomainPath = typing.List[str]
|
2019-12-15 16:26:18 +01:00
|
|
|
|
Ip4Path = typing.Tuple[int, int] # value, prefixlen
|
2019-12-15 15:56:26 +01:00
|
|
|
|
Ip6Path = typing.List[int]
|
|
|
|
|
Path = typing.Union[RulePath, DomainPath, Asn, Ip4Path, Ip6Path]
|
|
|
|
|
TypedPath = typing.Tuple[PathType, Path]
|
|
|
|
|
Timestamp = int
|
|
|
|
|
Level = int
|
|
|
|
|
Match = typing.Tuple[Timestamp, TypedPath, Level]
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-15 15:56:26 +01:00
|
|
|
|
DebugPath = (PathType.Rule, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DomainTreeNode():
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
self.children: typing.Dict[str, DomainTreeNode] = dict()
|
|
|
|
|
self.match_zone: typing.Optional[Match] = None
|
|
|
|
|
self.match_hostname: typing.Optional[Match] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IpTreeNode():
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
self.children: typing.List[typing.Optional[IpTreeNode]] = [None, None]
|
|
|
|
|
self.match: typing.Optional[Match] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Profiler():
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
self.log = logging.getLogger('profiler')
|
2019-12-13 00:11:21 +01:00
|
|
|
|
self.time_last = time.perf_counter()
|
|
|
|
|
self.time_step = 'init'
|
|
|
|
|
self.time_dict: typing.Dict[str, float] = dict()
|
|
|
|
|
self.step_dict: typing.Dict[str, int] = dict()
|
|
|
|
|
|
|
|
|
|
def enter_step(self, name: str) -> None:
|
|
|
|
|
now = time.perf_counter()
|
|
|
|
|
try:
|
|
|
|
|
self.time_dict[self.time_step] += now - self.time_last
|
2019-12-15 15:56:26 +01:00
|
|
|
|
self.step_dict[self.time_step] += int(name != self.time_step)
|
2019-12-13 00:11:21 +01:00
|
|
|
|
except KeyError:
|
|
|
|
|
self.time_dict[self.time_step] = now - self.time_last
|
|
|
|
|
self.step_dict[self.time_step] = 1
|
|
|
|
|
self.time_step = name
|
|
|
|
|
self.time_last = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
def profile(self) -> None:
|
|
|
|
|
self.enter_step('profile')
|
|
|
|
|
total = sum(self.time_dict.values())
|
|
|
|
|
for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]):
|
|
|
|
|
times = self.step_dict[key]
|
|
|
|
|
self.log.debug(f"{key:<20}: {times:9d} × {secs/times:5.3e} "
|
|
|
|
|
f"= {secs:9.2f} s ({secs/total:7.2%}) ")
|
|
|
|
|
self.log.debug(f"{'total':<20}: "
|
|
|
|
|
f"{total:9.2f} s ({1:7.2%})")
|
|
|
|
|
|
2019-12-15 15:56:26 +01:00
|
|
|
|
|
|
|
|
|
class Database(Profiler):
|
|
|
|
|
VERSION = 8
|
|
|
|
|
PATH = "blocking.p"
|
|
|
|
|
|
|
|
|
|
def initialize(self) -> None:
|
|
|
|
|
self.log.warning(
|
|
|
|
|
"Creating database version: %d ",
|
|
|
|
|
Database.VERSION)
|
|
|
|
|
self.domtree = DomainTreeNode()
|
|
|
|
|
self.asns: typing.Set[Asn] = set()
|
|
|
|
|
self.ip4tree = IpTreeNode()
|
|
|
|
|
|
|
|
|
|
def load(self) -> None:
|
|
|
|
|
self.enter_step('load')
|
|
|
|
|
try:
|
|
|
|
|
with open(self.PATH, 'rb') as db_fdsec:
|
|
|
|
|
version, data = pickle.load(db_fdsec)
|
|
|
|
|
if version == Database.VERSION:
|
|
|
|
|
self.domtree, self.asns, self.ip4tree = data
|
|
|
|
|
return
|
|
|
|
|
self.log.warning(
|
|
|
|
|
"Outdated database version found: %d, "
|
|
|
|
|
"will be rebuilt.",
|
|
|
|
|
version)
|
|
|
|
|
except (TypeError, AttributeError, EOFError):
|
|
|
|
|
self.log.error(
|
|
|
|
|
"Corrupt database found, "
|
|
|
|
|
"will be rebuilt.")
|
|
|
|
|
except FileNotFoundError:
|
|
|
|
|
pass
|
|
|
|
|
self.initialize()
|
|
|
|
|
|
|
|
|
|
def save(self) -> None:
|
|
|
|
|
self.enter_step('save')
|
|
|
|
|
with open(self.PATH, 'wb') as db_fdsec:
|
|
|
|
|
data = self.domtree, self.asns, self.ip4tree
|
|
|
|
|
pickle.dump((self.VERSION, data), db_fdsec)
|
|
|
|
|
self.profile()
|
|
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
Profiler.__init__(self)
|
|
|
|
|
self.log = logging.getLogger('db')
|
|
|
|
|
self.load()
|
|
|
|
|
|
2019-12-13 18:00:00 +01:00
|
|
|
|
@staticmethod
|
2019-12-15 15:56:26 +01:00
|
|
|
|
def pack_domain(domain: str) -> DomainPath:
|
|
|
|
|
return domain.split('.')[::-1]
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-13 18:00:00 +01:00
|
|
|
|
@staticmethod
|
2019-12-15 15:56:26 +01:00
|
|
|
|
def unpack_domain(domain: DomainPath) -> str:
|
|
|
|
|
return '.'.join(domain[::-1])
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-13 08:23:38 +01:00
|
|
|
|
@staticmethod
|
2019-12-13 18:00:00 +01:00
|
|
|
|
def pack_asn(asn: str) -> int:
|
2019-12-13 08:23:38 +01:00
|
|
|
|
asn = asn.upper()
|
|
|
|
|
if asn.startswith('AS'):
|
|
|
|
|
asn = asn[2:]
|
|
|
|
|
return int(asn)
|
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
@staticmethod
|
2019-12-13 18:00:00 +01:00
|
|
|
|
def unpack_asn(asn: int) -> str:
|
|
|
|
|
return f'AS{asn}'
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
2019-12-15 15:56:26 +01:00
|
|
|
|
def pack_ip4address(address: str) -> Ip4Path:
|
2019-12-15 16:26:18 +01:00
|
|
|
|
addr = 0
|
|
|
|
|
for split in address.split('.'):
|
|
|
|
|
addr = addr << 4 + int(split)
|
|
|
|
|
return (addr, 32)
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-13 18:00:00 +01:00
|
|
|
|
@staticmethod
|
2019-12-15 15:56:26 +01:00
|
|
|
|
def unpack_ip4address(address: Ip4Path) -> str:
|
2019-12-15 16:26:18 +01:00
|
|
|
|
addr, prefixlen = address
|
|
|
|
|
assert prefixlen == 32
|
|
|
|
|
octets: typing.List[int] = list()
|
2019-12-15 15:56:26 +01:00
|
|
|
|
octets = [0] * 4
|
2019-12-15 16:26:18 +01:00
|
|
|
|
for o in reversed(range(4)):
|
|
|
|
|
octets[o] = addr & 0xFF
|
|
|
|
|
addr >>= 8
|
2019-12-15 15:56:26 +01:00
|
|
|
|
return '.'.join(map(str, octets))
|
2019-12-13 18:00:00 +01:00
|
|
|
|
|
|
|
|
|
@staticmethod
|
2019-12-15 15:56:26 +01:00
|
|
|
|
def pack_ip4network(network: str) -> Ip4Path:
|
|
|
|
|
address, prefixlen_str = network.split('/')
|
|
|
|
|
prefixlen = int(prefixlen_str)
|
2019-12-15 16:26:18 +01:00
|
|
|
|
addr, _ = Database.pack_ip4address(address)
|
|
|
|
|
return (addr, prefixlen)
|
2019-12-13 18:00:00 +01:00
|
|
|
|
|
|
|
|
|
@staticmethod
|
2019-12-15 15:56:26 +01:00
|
|
|
|
def unpack_ip4network(network: Ip4Path) -> str:
|
2019-12-15 16:26:18 +01:00
|
|
|
|
address, prefixlen = network
|
|
|
|
|
addr = Database.unpack_ip4address((address, 32))
|
2019-12-13 18:00:00 +01:00
|
|
|
|
return f'{addr}/{prefixlen}'
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
|
|
|
|
def update_references(self) -> None:
|
2019-12-15 15:56:26 +01:00
|
|
|
|
raise NotImplementedError
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-14 16:04:19 +01:00
|
|
|
|
def prune(self, before: int, base_only: bool = False) -> None:
|
2019-12-15 15:56:26 +01:00
|
|
|
|
raise NotImplementedError
|
2019-12-13 18:00:00 +01:00
|
|
|
|
|
|
|
|
|
def explain(self, entry: int) -> str:
|
2019-12-15 15:56:26 +01:00
|
|
|
|
raise NotImplementedError
|
2019-12-13 18:00:00 +01:00
|
|
|
|
|
|
|
|
|
def export(self,
|
|
|
|
|
first_party_only: bool = False,
|
|
|
|
|
end_chain_only: bool = False,
|
|
|
|
|
explain: bool = False,
|
2019-12-15 15:56:26 +01:00
|
|
|
|
_dic: DomainTreeNode = None,
|
|
|
|
|
_par: DomainPath = None,
|
2019-12-13 18:00:00 +01:00
|
|
|
|
) -> typing.Iterable[str]:
|
2019-12-15 15:56:26 +01:00
|
|
|
|
if first_party_only or end_chain_only or explain:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
_dic = _dic or self.domtree
|
|
|
|
|
_par = _par or list()
|
|
|
|
|
if _dic.match_hostname:
|
|
|
|
|
yield self.unpack_domain(_par)
|
|
|
|
|
for part in _dic.children:
|
|
|
|
|
dic = _dic.children[part]
|
|
|
|
|
yield from self.export(_dic=dic,
|
|
|
|
|
_par=_par + [part])
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-13 18:36:08 +01:00
|
|
|
|
def count_rules(self,
|
|
|
|
|
first_party_only: bool = False,
|
|
|
|
|
) -> str:
|
2019-12-15 15:56:26 +01:00
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def get_domain(self, domain_str: str) -> typing.Iterable[TypedPath]:
|
|
|
|
|
self.enter_step('get_domain_pack')
|
|
|
|
|
domain = self.pack_domain(domain_str)
|
|
|
|
|
self.enter_step('get_domain_brws')
|
|
|
|
|
dic = self.domtree
|
|
|
|
|
depth = 0
|
|
|
|
|
for part in domain:
|
|
|
|
|
if dic.match_zone:
|
|
|
|
|
self.enter_step('get_domain_yield')
|
|
|
|
|
yield (PathType.Zone, domain[:depth])
|
|
|
|
|
self.enter_step('get_domain_brws')
|
|
|
|
|
if part not in dic.children:
|
|
|
|
|
return
|
|
|
|
|
dic = dic.children[part]
|
|
|
|
|
depth += 1
|
|
|
|
|
if dic.match_zone:
|
2019-12-13 00:11:21 +01:00
|
|
|
|
self.enter_step('get_domain_yield')
|
2019-12-15 15:56:26 +01:00
|
|
|
|
yield (PathType.Zone, domain)
|
|
|
|
|
if dic.match_hostname:
|
|
|
|
|
self.enter_step('get_domain_yield')
|
|
|
|
|
yield (PathType.Hostname, domain)
|
|
|
|
|
|
|
|
|
|
def get_ip4(self, ip4_str: str) -> typing.Iterable[TypedPath]:
|
|
|
|
|
self.enter_step('get_ip4_pack')
|
2019-12-15 16:26:18 +01:00
|
|
|
|
ip4, prefixlen = self.pack_ip4address(ip4_str)
|
2019-12-15 15:56:26 +01:00
|
|
|
|
self.enter_step('get_ip4_brws')
|
|
|
|
|
dic = self.ip4tree
|
2019-12-15 16:26:18 +01:00
|
|
|
|
for i in reversed(range(prefixlen)):
|
|
|
|
|
part = (ip4 >> i) & 0b1
|
2019-12-15 15:56:26 +01:00
|
|
|
|
if dic.match:
|
|
|
|
|
self.enter_step('get_ip4_yield')
|
2019-12-15 16:26:18 +01:00
|
|
|
|
yield (PathType.Ip4, (ip4, 32-i))
|
2019-12-15 15:56:26 +01:00
|
|
|
|
self.enter_step('get_ip4_brws')
|
|
|
|
|
next_dic = dic.children[part]
|
|
|
|
|
if next_dic is None:
|
|
|
|
|
return
|
|
|
|
|
dic = next_dic
|
|
|
|
|
if dic.match:
|
2019-12-13 00:11:21 +01:00
|
|
|
|
self.enter_step('get_ip4_yield')
|
2019-12-15 15:56:26 +01:00
|
|
|
|
yield (PathType.Ip4, ip4)
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-15 15:56:26 +01:00
|
|
|
|
def list_asn(self) -> typing.Iterable[TypedPath]:
|
|
|
|
|
for asn in self.asns:
|
|
|
|
|
yield (PathType.Asn, asn)
|
|
|
|
|
|
|
|
|
|
def set_hostname(self,
|
|
|
|
|
hostname_str: str,
|
2019-12-13 13:54:00 +01:00
|
|
|
|
updated: int,
|
2019-12-15 15:56:26 +01:00
|
|
|
|
is_first_party: bool = None,
|
|
|
|
|
source: TypedPath = None) -> None:
|
|
|
|
|
self.enter_step('set_hostname_pack')
|
|
|
|
|
if is_first_party or source:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
self.enter_step('set_hostname_brws')
|
|
|
|
|
hostname = self.pack_domain(hostname_str)
|
|
|
|
|
dic = self.domtree
|
|
|
|
|
for part in hostname:
|
|
|
|
|
if dic.match_zone:
|
|
|
|
|
# Refuse to add hostname whose zone is already matching
|
|
|
|
|
return
|
|
|
|
|
if part not in dic.children:
|
|
|
|
|
dic.children[part] = DomainTreeNode()
|
|
|
|
|
dic = dic.children[part]
|
|
|
|
|
dic.match_hostname = (updated, DebugPath, 0)
|
|
|
|
|
|
|
|
|
|
def set_zone(self,
|
|
|
|
|
zone_str: str,
|
|
|
|
|
updated: int,
|
|
|
|
|
is_first_party: bool = None,
|
|
|
|
|
source: TypedPath = None) -> None:
|
|
|
|
|
self.enter_step('set_zone_pack')
|
|
|
|
|
if is_first_party or source:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
zone = self.pack_domain(zone_str)
|
|
|
|
|
self.enter_step('set_zone_brws')
|
|
|
|
|
dic = self.domtree
|
|
|
|
|
for part in zone:
|
|
|
|
|
if dic.match_zone:
|
|
|
|
|
# Refuse to add zone whose parent zone is already matching
|
|
|
|
|
return
|
|
|
|
|
if part not in dic.children:
|
|
|
|
|
dic.children[part] = DomainTreeNode()
|
|
|
|
|
dic = dic.children[part]
|
|
|
|
|
dic.match_zone = (updated, DebugPath, 0)
|
|
|
|
|
|
|
|
|
|
def set_asn(self,
|
|
|
|
|
asn_str: str,
|
|
|
|
|
updated: int,
|
|
|
|
|
is_first_party: bool = None,
|
|
|
|
|
source: TypedPath = None) -> None:
|
|
|
|
|
self.enter_step('set_asn_pack')
|
|
|
|
|
if is_first_party or source:
|
|
|
|
|
# TODO updated
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
asn = self.pack_asn(asn_str)
|
|
|
|
|
self.enter_step('set_asn_brws')
|
|
|
|
|
self.asns.add(asn)
|
|
|
|
|
|
|
|
|
|
def set_ip4address(self,
|
|
|
|
|
ip4address_str: str,
|
|
|
|
|
updated: int,
|
|
|
|
|
is_first_party: bool = None,
|
|
|
|
|
source: TypedPath = None) -> None:
|
|
|
|
|
self.enter_step('set_ip4add_pack')
|
|
|
|
|
if is_first_party or source:
|
|
|
|
|
raise NotImplementedError
|
2019-12-15 16:26:18 +01:00
|
|
|
|
ip4, prefixlen = self.pack_ip4address(ip4address_str)
|
2019-12-15 15:56:26 +01:00
|
|
|
|
self.enter_step('set_ip4add_brws')
|
|
|
|
|
dic = self.ip4tree
|
2019-12-15 16:26:18 +01:00
|
|
|
|
for i in reversed(range(prefixlen)):
|
|
|
|
|
part = (ip4 >> i) & 0b1
|
2019-12-15 15:56:26 +01:00
|
|
|
|
if dic.match:
|
|
|
|
|
# Refuse to add ip4address whose network is already matching
|
|
|
|
|
return
|
|
|
|
|
next_dic = dic.children[part]
|
|
|
|
|
if next_dic is None:
|
|
|
|
|
next_dic = IpTreeNode()
|
|
|
|
|
dic.children[part] = next_dic
|
|
|
|
|
dic = next_dic
|
|
|
|
|
dic.match = (updated, DebugPath, 0)
|
|
|
|
|
|
|
|
|
|
def set_ip4network(self,
|
|
|
|
|
ip4network_str: str,
|
|
|
|
|
updated: int,
|
|
|
|
|
is_first_party: bool = None,
|
|
|
|
|
source: TypedPath = None) -> None:
|
|
|
|
|
self.enter_step('set_ip4net_pack')
|
|
|
|
|
if is_first_party or source:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
self.enter_step('set_ip4net_brws')
|
2019-12-15 16:26:18 +01:00
|
|
|
|
ip4, prefixlen = self.pack_ip4network(ip4network_str)
|
2019-12-15 15:56:26 +01:00
|
|
|
|
dic = self.ip4tree
|
2019-12-15 16:26:18 +01:00
|
|
|
|
for i in reversed(range(prefixlen)):
|
|
|
|
|
part = (ip4 >> i) & 0b1
|
2019-12-15 15:56:26 +01:00
|
|
|
|
if dic.match:
|
|
|
|
|
# Refuse to add ip4network whose parent network
|
|
|
|
|
# is already matching
|
|
|
|
|
return
|
|
|
|
|
next_dic = dic.children[part]
|
|
|
|
|
if next_dic is None:
|
|
|
|
|
next_dic = IpTreeNode()
|
|
|
|
|
dic.children[part] = next_dic
|
|
|
|
|
dic = next_dic
|
|
|
|
|
dic.match = (updated, DebugPath, 0)
|