eulaurarien/database.py

800 行
25 KiB
Python

此檔案含有易混淆的 Unicode 字元!

此檔案含有易混淆的 Unicode 字元,這些字元的處理方式可能和下面呈現的不同。若您是有意且合理的使用,您可以放心地忽略此警告。使用 Escape 按鈕標記這些字元。

#!/usr/bin/env python3
"""
Utility functions to interact with the database.
"""
import typing
import time
import logging
import coloredlogs
import pickle
import numpy
import math
import os
TLD_LIST: typing.Set[str] = set()
coloredlogs.install(level="DEBUG", fmt="%(asctime)s %(name)s %(levelname)s %(message)s")
Asn = int
Timestamp = int
Level = int
class Path:
pass
class RulePath(Path):
def __str__(self) -> str:
return "(rule)"
class RuleFirstPath(RulePath):
def __str__(self) -> str:
return "(first-party rule)"
class RuleMultiPath(RulePath):
def __str__(self) -> str:
return "(multi-party rule)"
class DomainPath(Path):
def __init__(self, parts: typing.List[str]):
self.parts = parts
def __str__(self) -> str:
return "?." + Database.unpack_domain(self)
class HostnamePath(DomainPath):
def __str__(self) -> str:
return Database.unpack_domain(self)
class ZonePath(DomainPath):
def __str__(self) -> str:
return "*." + Database.unpack_domain(self)
class AsnPath(Path):
def __init__(self, asn: Asn):
self.asn = asn
def __str__(self) -> str:
return Database.unpack_asn(self)
class Ip4Path(Path):
def __init__(self, value: int, prefixlen: int):
self.value = value
self.prefixlen = prefixlen
def __str__(self) -> str:
return Database.unpack_ip4network(self)
class Match:
def __init__(self) -> None:
self.source: typing.Optional[Path] = None
self.updated: int = 0
self.dupplicate: bool = False
# Cache
self.level: int = 0
self.first_party: bool = False
self.references: int = 0
def active(self, first_party: bool = None) -> bool:
if self.updated == 0 or (first_party and not self.first_party):
return False
return True
def disable(self) -> None:
self.updated = 0
class AsnNode(Match):
def __init__(self) -> None:
Match.__init__(self)
self.name = ""
class DomainTreeNode:
def __init__(self) -> None:
self.children: typing.Dict[str, DomainTreeNode] = dict()
self.match_zone = Match()
self.match_hostname = Match()
class IpTreeNode(Match):
def __init__(self) -> None:
Match.__init__(self)
self.zero: typing.Optional[IpTreeNode] = None
self.one: typing.Optional[IpTreeNode] = None
Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode]
MatchCallable = typing.Callable[[Path, Match], typing.Any]
class Profiler:
def __init__(self) -> None:
do_profile = int(os.environ.get("PROFILE", "0"))
if do_profile:
self.log = logging.getLogger("profiler")
self.time_last = time.perf_counter()
self.time_step = "init"
self.time_dict: typing.Dict[str, float] = dict()
self.step_dict: typing.Dict[str, int] = dict()
self.enter_step = self.enter_step_real
self.profile = self.profile_real
else:
self.enter_step = self.enter_step_dummy
self.profile = self.profile_dummy
def enter_step_dummy(self, name: str) -> None:
return
def enter_step_real(self, name: str) -> None:
now = time.perf_counter()
try:
self.time_dict[self.time_step] += now - self.time_last
self.step_dict[self.time_step] += int(name != self.time_step)
except KeyError:
self.time_dict[self.time_step] = now - self.time_last
self.step_dict[self.time_step] = 1
self.time_step = name
self.time_last = time.perf_counter()
def profile_dummy(self) -> None:
return
def profile_real(self) -> None:
self.enter_step("profile")
total = sum(self.time_dict.values())
for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]):
times = self.step_dict[key]
self.log.debug(
f"{key:<20}: {times:9d} × {secs/times:5.3e} "
f"= {secs:9.2f} s ({secs/total:7.2%}) "
)
self.log.debug(
f"{'total':<20}: " f"{total:9.2f} s ({1:7.2%})"
)
class Database(Profiler):
VERSION = 18
PATH = "blocking.p"
def initialize(self) -> None:
self.log.warning("Creating database version: %d ", Database.VERSION)
# Dummy match objects that everything refer to
self.rules: typing.List[Match] = list()
for first_party in (False, True):
m = Match()
m.updated = 1
m.level = 0
m.first_party = first_party
self.rules.append(m)
self.domtree = DomainTreeNode()
self.asns: typing.Dict[Asn, AsnNode] = dict()
self.ip4tree = IpTreeNode()
def load(self) -> None:
self.enter_step("load")
try:
with open(self.PATH, "rb") as db_fdsec:
version, data = pickle.load(db_fdsec)
if version == Database.VERSION:
self.rules, self.domtree, self.asns, self.ip4tree = data
return
self.log.warning(
"Outdated database version found: %d, " "it will be rebuilt.",
version,
)
except (TypeError, AttributeError, EOFError):
self.log.error(
"Corrupt (or heavily outdated) database found, " "it will be rebuilt."
)
except FileNotFoundError:
pass
self.initialize()
def save(self) -> None:
self.enter_step("save")
with open(self.PATH, "wb") as db_fdsec:
data = self.rules, self.domtree, self.asns, self.ip4tree
pickle.dump((self.VERSION, data), db_fdsec)
self.profile()
def __init__(self) -> None:
Profiler.__init__(self)
self.log = logging.getLogger("db")
self.load()
self.ip4cache_shift: int = 32
self.ip4cache = numpy.ones(1)
def _set_ip4cache(self, path: Path, _: Match) -> None:
assert isinstance(path, Ip4Path)
self.enter_step("set_ip4cache")
mini = path.value >> self.ip4cache_shift
maxi = (path.value + 2 ** (32 - path.prefixlen)) >> self.ip4cache_shift
if mini == maxi:
self.ip4cache[mini] = True
else:
self.ip4cache[mini:maxi] = True
def fill_ip4cache(self, max_size: int = 512 * 1024 ** 2) -> None:
"""
Size in bytes
"""
if max_size > 2 ** 32 / 8:
self.log.warning(
"Allocating more than 512 MiB of RAM for "
"the Ip4 cache is not necessary."
)
max_cache_width = int(math.log2(max(1, max_size * 8)))
allocated = False
cache_width = min(32, max_cache_width)
while not allocated:
cache_size = 2 ** cache_width
try:
self.ip4cache = numpy.zeros(cache_size, dtype=bool)
except MemoryError:
self.log.exception("Could not allocate cache. Retrying a smaller one.")
cache_width -= 1
continue
allocated = True
self.ip4cache_shift = 32 - cache_width
for _ in self.exec_each_ip4(self._set_ip4cache):
pass
@staticmethod
def populate_tld_list() -> None:
with open("temp/all_tld.list", "r") as tld_fdesc:
for tld in tld_fdesc:
tld = tld.strip()
TLD_LIST.add(tld)
@staticmethod
def validate_domain(path: str) -> bool:
if len(path) > 255:
return False
splits = path.split(".")
if not TLD_LIST:
Database.populate_tld_list()
if splits[-1] not in TLD_LIST:
return False
for split in splits:
if not 1 <= len(split) <= 63:
return False
return True
@staticmethod
def pack_domain(domain: str) -> DomainPath:
return DomainPath(domain.split(".")[::-1])
@staticmethod
def unpack_domain(domain: DomainPath) -> str:
return ".".join(domain.parts[::-1])
@staticmethod
def pack_asn(asn: str) -> AsnPath:
asn = asn.upper()
if asn.startswith("AS"):
asn = asn[2:]
return AsnPath(int(asn))
@staticmethod
def unpack_asn(asn: AsnPath) -> str:
return f"AS{asn.asn}"
@staticmethod
def validate_ip4address(path: str) -> bool:
splits = path.split(".")
if len(splits) != 4:
return False
for split in splits:
try:
if not 0 <= int(split) <= 255:
return False
except ValueError:
return False
return True
@staticmethod
def pack_ip4address_low(address: str) -> int:
addr = 0
for split in address.split("."):
octet = int(split)
addr = (addr << 8) + octet
return addr
@staticmethod
def pack_ip4address(address: str) -> Ip4Path:
return Ip4Path(Database.pack_ip4address_low(address), 32)
@staticmethod
def unpack_ip4address(address: Ip4Path) -> str:
addr = address.value
assert address.prefixlen == 32
octets: typing.List[int] = list()
octets = [0] * 4
for o in reversed(range(4)):
octets[o] = addr & 0xFF
addr >>= 8
return ".".join(map(str, octets))
@staticmethod
def validate_ip4network(path: str) -> bool:
# A bit generous but ok for our usage
splits = path.split("/")
if len(splits) != 2:
return False
if not Database.validate_ip4address(splits[0]):
return False
try:
if not 0 <= int(splits[1]) <= 32:
return False
except ValueError:
return False
return True
@staticmethod
def pack_ip4network(network: str) -> Ip4Path:
address, prefixlen_str = network.split("/")
prefixlen = int(prefixlen_str)
addr = Database.pack_ip4address(address)
addr.prefixlen = prefixlen
return addr
@staticmethod
def unpack_ip4network(network: Ip4Path) -> str:
addr = network.value
octets: typing.List[int] = list()
octets = [0] * 4
for o in reversed(range(4)):
octets[o] = addr & 0xFF
addr >>= 8
return ".".join(map(str, octets)) + "/" + str(network.prefixlen)
def get_match(self, path: Path) -> Match:
if isinstance(path, RuleMultiPath):
return self.rules[0]
elif isinstance(path, RuleFirstPath):
return self.rules[1]
elif isinstance(path, AsnPath):
return self.asns[path.asn]
elif isinstance(path, DomainPath):
dicd = self.domtree
for part in path.parts:
dicd = dicd.children[part]
if isinstance(path, HostnamePath):
return dicd.match_hostname
elif isinstance(path, ZonePath):
return dicd.match_zone
else:
raise ValueError
elif isinstance(path, Ip4Path):
dici = self.ip4tree
for i in range(31, 31 - path.prefixlen, -1):
bit = (path.value >> i) & 0b1
dici_next = dici.one if bit else dici.zero
if not dici_next:
raise IndexError
dici = dici_next
return dici
else:
raise ValueError
def exec_each_asn(
self,
callback: MatchCallable,
) -> typing.Any:
for asn in self.asns:
match = self.asns[asn]
if match.active():
c = callback(
AsnPath(asn),
match,
)
try:
yield from c
except TypeError: # not iterable
pass
def exec_each_domain(
self,
callback: MatchCallable,
_dic: DomainTreeNode = None,
_par: DomainPath = None,
) -> typing.Any:
_dic = _dic or self.domtree
_par = _par or DomainPath([])
if _dic.match_hostname.active():
c = callback(
HostnamePath(_par.parts),
_dic.match_hostname,
)
try:
yield from c
except TypeError: # not iterable
pass
if _dic.match_zone.active():
c = callback(
ZonePath(_par.parts),
_dic.match_zone,
)
try:
yield from c
except TypeError: # not iterable
pass
for part in _dic.children:
dic = _dic.children[part]
yield from self.exec_each_domain(
callback, _dic=dic, _par=DomainPath(_par.parts + [part])
)
def exec_each_ip4(
self,
callback: MatchCallable,
_dic: IpTreeNode = None,
_par: Ip4Path = None,
) -> typing.Any:
_dic = _dic or self.ip4tree
_par = _par or Ip4Path(0, 0)
if _dic.active():
c = callback(
_par,
_dic,
)
try:
yield from c
except TypeError: # not iterable
pass
# 0
pref = _par.prefixlen + 1
dic = _dic.zero
if dic:
# addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-pref)))
# assert addr0 == _par.value
addr0 = _par.value
yield from self.exec_each_ip4(callback, _dic=dic, _par=Ip4Path(addr0, pref))
# 1
dic = _dic.one
if dic:
addr1 = _par.value | (1 << (32 - pref))
# assert addr1 != _par.value
yield from self.exec_each_ip4(callback, _dic=dic, _par=Ip4Path(addr1, pref))
def exec_each(
self,
callback: MatchCallable,
) -> typing.Any:
yield from self.exec_each_domain(callback)
yield from self.exec_each_ip4(callback)
yield from self.exec_each_asn(callback)
def update_references(self) -> None:
# Should be correctly calculated normally,
# keeping this just in case
def reset_references_cb(path: Path, match: Match) -> None:
match.references = 0
for _ in self.exec_each(reset_references_cb):
pass
def increment_references_cb(path: Path, match: Match) -> None:
if match.source:
source = self.get_match(match.source)
source.references += 1
for _ in self.exec_each(increment_references_cb):
pass
def _clean_deps(self) -> None:
# Disable the matches that depends on the targeted
# matches until all disabled matches reference count = 0
did_something = True
def clean_deps_cb(path: Path, match: Match) -> None:
nonlocal did_something
if not match.source:
return
source = self.get_match(match.source)
if not source.active():
self._unset_match(match)
elif match.first_party > source.first_party:
match.first_party = source.first_party
else:
return
did_something = True
while did_something:
did_something = False
self.enter_step("pass_clean_deps")
for _ in self.exec_each(clean_deps_cb):
pass
def prune(self, before: int, base_only: bool = False) -> None:
# Disable the matches targeted
def prune_cb(path: Path, match: Match) -> None:
if base_only and match.level > 1:
return
if match.updated > before:
return
self._unset_match(match)
self.log.debug("Print: disabled %s", path)
self.enter_step("pass_prune")
for _ in self.exec_each(prune_cb):
pass
self._clean_deps()
# Remove branches with no match
# TODO
def explain(self, path: Path) -> str:
match = self.get_match(path)
string = str(path)
if isinstance(match, AsnNode):
string += f" ({match.name})"
party_char = "F" if match.first_party else "M"
dup_char = "D" if match.dupplicate else "_"
string += f" {match.level}{party_char}{dup_char}{match.references}"
if match.source:
string += f"{self.explain(match.source)}"
return string
def list_records(
self,
first_party_only: bool = False,
end_chain_only: bool = False,
no_dupplicates: bool = False,
rules_only: bool = False,
hostnames_only: bool = False,
explain: bool = False,
) -> typing.Iterable[str]:
def export_cb(path: Path, match: Match) -> typing.Iterable[str]:
if first_party_only and not match.first_party:
return
if end_chain_only and match.references > 0:
return
if no_dupplicates and match.dupplicate:
return
if rules_only and match.level > 1:
return
if hostnames_only and not isinstance(path, HostnamePath):
return
if explain:
yield self.explain(path)
else:
yield str(path)
yield from self.exec_each(export_cb)
def count_records(
self,
first_party_only: bool = False,
end_chain_only: bool = False,
no_dupplicates: bool = False,
rules_only: bool = False,
hostnames_only: bool = False,
) -> str:
memo: typing.Dict[str, int] = dict()
def count_records_cb(path: Path, match: Match) -> None:
if first_party_only and not match.first_party:
return
if end_chain_only and match.references > 0:
return
if no_dupplicates and match.dupplicate:
return
if rules_only and match.level > 1:
return
if hostnames_only and not isinstance(path, HostnamePath):
return
try:
memo[path.__class__.__name__] += 1
except KeyError:
memo[path.__class__.__name__] = 1
for _ in self.exec_each(count_records_cb):
pass
split: typing.List[str] = list()
for key, value in sorted(memo.items(), key=lambda s: s[0]):
split.append(f"{key[:-4].lower()}s: {value}")
return ", ".join(split)
def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]:
self.enter_step("get_domain_pack")
domain = self.pack_domain(domain_str)
self.enter_step("get_domain_brws")
dic = self.domtree
depth = 0
for part in domain.parts:
if dic.match_zone.active():
self.enter_step("get_domain_yield")
yield ZonePath(domain.parts[:depth])
self.enter_step("get_domain_brws")
if part not in dic.children:
return
dic = dic.children[part]
depth += 1
if dic.match_zone.active():
self.enter_step("get_domain_yield")
yield ZonePath(domain.parts)
if dic.match_hostname.active():
self.enter_step("get_domain_yield")
yield HostnamePath(domain.parts)
def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]:
self.enter_step("get_ip4_pack")
ip4val = self.pack_ip4address_low(ip4_str)
self.enter_step("get_ip4_cache")
if not self.ip4cache[ip4val >> self.ip4cache_shift]:
return
self.enter_step("get_ip4_brws")
dic = self.ip4tree
for i in range(31, -1, -1):
bit = (ip4val >> i) & 0b1
if dic.active():
self.enter_step("get_ip4_yield")
yield Ip4Path(ip4val >> (i + 1) << (i + 1), 31 - i)
self.enter_step("get_ip4_brws")
next_dic = dic.one if bit else dic.zero
if next_dic is None:
return
dic = next_dic
if dic.active():
self.enter_step("get_ip4_yield")
yield Ip4Path(ip4val, 32)
def _unset_match(
self,
match: Match,
) -> None:
match.disable()
if match.source:
source_match = self.get_match(match.source)
source_match.references -= 1
def _set_match(
self,
match: Match,
updated: int,
source: Path,
source_match: Match = None,
dupplicate: bool = False,
) -> None:
# source_match is in parameters because most of the time
# its parent function needs it too,
# so it can pass it to save a traversal
source_match = source_match or self.get_match(source)
new_level = source_match.level + 1
if (
updated > match.updated
or new_level < match.level
or source_match.first_party > match.first_party
):
# NOTE FP and level of matches referencing this one
# won't be updated until run or prune
if match.source:
old_source = self.get_match(match.source)
old_source.references -= 1
match.updated = updated
match.level = new_level
match.first_party = source_match.first_party
match.source = source
source_match.references += 1
match.dupplicate = dupplicate
def _set_domain(
self, hostname: bool, domain_str: str, updated: int, source: Path
) -> None:
self.enter_step("set_domain_val")
if not Database.validate_domain(domain_str):
raise ValueError(f"Invalid domain: {domain_str}")
self.enter_step("set_domain_pack")
domain = self.pack_domain(domain_str)
self.enter_step("set_domain_fp")
source_match = self.get_match(source)
is_first_party = source_match.first_party
self.enter_step("set_domain_brws")
dic = self.domtree
dupplicate = False
for part in domain.parts:
if part not in dic.children:
dic.children[part] = DomainTreeNode()
dic = dic.children[part]
if dic.match_zone.active(is_first_party):
dupplicate = True
if hostname:
match = dic.match_hostname
else:
match = dic.match_zone
self._set_match(
match,
updated,
source,
source_match=source_match,
dupplicate=dupplicate,
)
def set_hostname(self, *args: typing.Any, **kwargs: typing.Any) -> None:
self._set_domain(True, *args, **kwargs)
def set_zone(self, *args: typing.Any, **kwargs: typing.Any) -> None:
self._set_domain(False, *args, **kwargs)
def set_asn(self, asn_str: str, updated: int, source: Path) -> None:
self.enter_step("set_asn")
path = self.pack_asn(asn_str)
if path.asn in self.asns:
match = self.asns[path.asn]
else:
match = AsnNode()
self.asns[path.asn] = match
self._set_match(
match,
updated,
source,
)
def _set_ip4(self, ip4: Ip4Path, updated: int, source: Path) -> None:
self.enter_step("set_ip4_fp")
source_match = self.get_match(source)
is_first_party = source_match.first_party
self.enter_step("set_ip4_brws")
dic = self.ip4tree
dupplicate = False
for i in range(31, 31 - ip4.prefixlen, -1):
bit = (ip4.value >> i) & 0b1
next_dic = dic.one if bit else dic.zero
if next_dic is None:
next_dic = IpTreeNode()
if bit:
dic.one = next_dic
else:
dic.zero = next_dic
dic = next_dic
if dic.active(is_first_party):
dupplicate = True
self._set_match(
dic,
updated,
source,
source_match=source_match,
dupplicate=dupplicate,
)
self._set_ip4cache(ip4, dic)
def set_ip4address(
self, ip4address_str: str, *args: typing.Any, **kwargs: typing.Any
) -> None:
self.enter_step("set_ip4add_val")
if not Database.validate_ip4address(ip4address_str):
raise ValueError(f"Invalid ip4address: {ip4address_str}")
self.enter_step("set_ip4add_pack")
ip4 = self.pack_ip4address(ip4address_str)
self._set_ip4(ip4, *args, **kwargs)
def set_ip4network(
self, ip4network_str: str, *args: typing.Any, **kwargs: typing.Any
) -> None:
self.enter_step("set_ip4net_val")
if not Database.validate_ip4network(ip4network_str):
raise ValueError(f"Invalid ip4network: {ip4network_str}")
self.enter_step("set_ip4net_pack")
ip4 = self.pack_ip4network(ip4network_str)
self._set_ip4(ip4, *args, **kwargs)