2019-12-09 08:12:48 +01:00
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
"""
|
|
|
|
|
Utility functions to interact with the database.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import typing
|
|
|
|
|
import time
|
|
|
|
|
import logging
|
|
|
|
|
import coloredlogs
|
2019-12-15 15:56:26 +01:00
|
|
|
|
import pickle
|
2019-12-09 08:12:48 +01:00
|
|
|
|
|
2019-12-17 19:53:05 +01:00
|
|
|
|
TLD_LIST: typing.Set[str] = set()
|
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
coloredlogs.install(
|
|
|
|
|
level='DEBUG',
|
|
|
|
|
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
|
|
|
|
|
)
|
|
|
|
|
|
2019-12-15 15:56:26 +01:00
|
|
|
|
Asn = int
|
|
|
|
|
Timestamp = int
|
|
|
|
|
Level = int
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-15 22:21:05 +01:00
|
|
|
|
|
|
|
|
|
class Path():
|
2019-12-15 23:13:25 +01:00
|
|
|
|
# FP add boolean here
|
2019-12-15 22:21:05 +01:00
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RulePath(Path):
|
2019-12-16 09:31:29 +01:00
|
|
|
|
def __str__(self) -> str:
|
2019-12-16 19:07:35 +01:00
|
|
|
|
return '(rule)'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RuleFirstPath(RulePath):
|
|
|
|
|
def __str__(self) -> str:
|
|
|
|
|
return '(first-party rule)'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RuleMultiPath(RulePath):
|
|
|
|
|
def __str__(self) -> str:
|
|
|
|
|
return '(multi-party rule)'
|
2019-12-15 22:21:05 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DomainPath(Path):
|
2019-12-16 09:31:29 +01:00
|
|
|
|
def __init__(self, parts: typing.List[str]):
|
|
|
|
|
self.parts = parts
|
|
|
|
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
|
|
return '?.' + Database.unpack_domain(self)
|
2019-12-15 22:21:05 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HostnamePath(DomainPath):
|
2019-12-16 09:31:29 +01:00
|
|
|
|
def __str__(self) -> str:
|
|
|
|
|
return Database.unpack_domain(self)
|
2019-12-15 22:21:05 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ZonePath(DomainPath):
|
2019-12-16 09:31:29 +01:00
|
|
|
|
def __str__(self) -> str:
|
|
|
|
|
return '*.' + Database.unpack_domain(self)
|
2019-12-15 22:21:05 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AsnPath(Path):
|
|
|
|
|
def __init__(self, asn: Asn):
|
|
|
|
|
self.asn = asn
|
|
|
|
|
|
2019-12-16 09:31:29 +01:00
|
|
|
|
def __str__(self) -> str:
|
|
|
|
|
return Database.unpack_asn(self)
|
|
|
|
|
|
2019-12-15 22:21:05 +01:00
|
|
|
|
|
|
|
|
|
class Ip4Path(Path):
|
|
|
|
|
def __init__(self, value: int, prefixlen: int):
|
|
|
|
|
self.value = value
|
|
|
|
|
self.prefixlen = prefixlen
|
|
|
|
|
|
2019-12-16 09:31:29 +01:00
|
|
|
|
def __str__(self) -> str:
|
|
|
|
|
return Database.unpack_ip4network(self)
|
|
|
|
|
|
2019-12-15 22:21:05 +01:00
|
|
|
|
|
2019-12-15 23:13:25 +01:00
|
|
|
|
class Match():
|
|
|
|
|
def __init__(self) -> None:
|
2019-12-16 19:07:35 +01:00
|
|
|
|
self.source: typing.Optional[Path] = None
|
2019-12-15 23:13:25 +01:00
|
|
|
|
self.updated: int = 0
|
2019-12-17 14:09:06 +01:00
|
|
|
|
self.dupplicate: bool = False
|
2019-12-16 19:07:35 +01:00
|
|
|
|
|
|
|
|
|
# Cache
|
2019-12-15 23:13:25 +01:00
|
|
|
|
self.level: int = 0
|
2019-12-16 19:07:35 +01:00
|
|
|
|
self.first_party: bool = False
|
2019-12-16 14:18:03 +01:00
|
|
|
|
self.references: int = 0
|
2019-12-15 23:13:25 +01:00
|
|
|
|
|
2019-12-16 19:07:35 +01:00
|
|
|
|
def active(self, first_party: bool = None) -> bool:
|
|
|
|
|
if self.updated == 0 or (first_party and not self.first_party):
|
|
|
|
|
return False
|
|
|
|
|
return True
|
2019-12-15 23:13:25 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AsnNode(Match):
|
2019-12-17 13:29:02 +01:00
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
Match.__init__(self)
|
|
|
|
|
self.name = ''
|
2019-12-15 15:56:26 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DomainTreeNode():
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
self.children: typing.Dict[str, DomainTreeNode] = dict()
|
2019-12-15 23:13:25 +01:00
|
|
|
|
self.match_zone = Match()
|
|
|
|
|
self.match_hostname = Match()
|
2019-12-15 15:56:26 +01:00
|
|
|
|
|
|
|
|
|
|
2019-12-16 09:31:29 +01:00
|
|
|
|
class IpTreeNode(Match):
|
2019-12-15 15:56:26 +01:00
|
|
|
|
def __init__(self) -> None:
|
2019-12-16 09:31:29 +01:00
|
|
|
|
Match.__init__(self)
|
2019-12-16 06:54:18 +01:00
|
|
|
|
self.zero: typing.Optional[IpTreeNode] = None
|
|
|
|
|
self.one: typing.Optional[IpTreeNode] = None
|
2019-12-15 15:56:26 +01:00
|
|
|
|
|
|
|
|
|
|
2019-12-15 23:13:25 +01:00
|
|
|
|
Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode]
|
2019-12-16 09:31:29 +01:00
|
|
|
|
MatchCallable = typing.Callable[[Path,
|
2019-12-17 13:29:02 +01:00
|
|
|
|
Match],
|
2019-12-16 09:31:29 +01:00
|
|
|
|
typing.Any]
|
2019-12-15 22:21:05 +01:00
|
|
|
|
|
|
|
|
|
|
2019-12-15 15:56:26 +01:00
|
|
|
|
class Profiler():
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
self.log = logging.getLogger('profiler')
|
2019-12-13 00:11:21 +01:00
|
|
|
|
self.time_last = time.perf_counter()
|
|
|
|
|
self.time_step = 'init'
|
|
|
|
|
self.time_dict: typing.Dict[str, float] = dict()
|
|
|
|
|
self.step_dict: typing.Dict[str, int] = dict()
|
|
|
|
|
|
|
|
|
|
def enter_step(self, name: str) -> None:
|
|
|
|
|
now = time.perf_counter()
|
|
|
|
|
try:
|
|
|
|
|
self.time_dict[self.time_step] += now - self.time_last
|
2019-12-15 15:56:26 +01:00
|
|
|
|
self.step_dict[self.time_step] += int(name != self.time_step)
|
2019-12-13 00:11:21 +01:00
|
|
|
|
except KeyError:
|
|
|
|
|
self.time_dict[self.time_step] = now - self.time_last
|
|
|
|
|
self.step_dict[self.time_step] = 1
|
|
|
|
|
self.time_step = name
|
|
|
|
|
self.time_last = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
def profile(self) -> None:
|
|
|
|
|
self.enter_step('profile')
|
|
|
|
|
total = sum(self.time_dict.values())
|
|
|
|
|
for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]):
|
|
|
|
|
times = self.step_dict[key]
|
|
|
|
|
self.log.debug(f"{key:<20}: {times:9d} × {secs/times:5.3e} "
|
|
|
|
|
f"= {secs:9.2f} s ({secs/total:7.2%}) ")
|
|
|
|
|
self.log.debug(f"{'total':<20}: "
|
|
|
|
|
f"{total:9.2f} s ({1:7.2%})")
|
|
|
|
|
|
2019-12-15 15:56:26 +01:00
|
|
|
|
|
|
|
|
|
class Database(Profiler):
|
2019-12-17 14:09:06 +01:00
|
|
|
|
VERSION = 18
|
2019-12-15 15:56:26 +01:00
|
|
|
|
PATH = "blocking.p"
|
|
|
|
|
|
|
|
|
|
def initialize(self) -> None:
|
|
|
|
|
self.log.warning(
|
|
|
|
|
"Creating database version: %d ",
|
|
|
|
|
Database.VERSION)
|
2019-12-16 19:07:35 +01:00
|
|
|
|
# Dummy match objects that everything refer to
|
|
|
|
|
self.rules: typing.List[Match] = list()
|
|
|
|
|
for first_party in (False, True):
|
|
|
|
|
m = Match()
|
|
|
|
|
m.updated = 1
|
|
|
|
|
m.level = 0
|
|
|
|
|
m.first_party = first_party
|
|
|
|
|
self.rules.append(m)
|
2019-12-15 15:56:26 +01:00
|
|
|
|
self.domtree = DomainTreeNode()
|
2019-12-15 23:13:25 +01:00
|
|
|
|
self.asns: typing.Dict[Asn, AsnNode] = dict()
|
2019-12-15 15:56:26 +01:00
|
|
|
|
self.ip4tree = IpTreeNode()
|
|
|
|
|
|
|
|
|
|
def load(self) -> None:
|
|
|
|
|
self.enter_step('load')
|
|
|
|
|
try:
|
|
|
|
|
with open(self.PATH, 'rb') as db_fdsec:
|
|
|
|
|
version, data = pickle.load(db_fdsec)
|
|
|
|
|
if version == Database.VERSION:
|
2019-12-16 19:07:35 +01:00
|
|
|
|
self.rules, self.domtree, self.asns, self.ip4tree = data
|
2019-12-15 15:56:26 +01:00
|
|
|
|
return
|
|
|
|
|
self.log.warning(
|
|
|
|
|
"Outdated database version found: %d, "
|
2019-12-15 23:13:25 +01:00
|
|
|
|
"it will be rebuilt.",
|
2019-12-15 15:56:26 +01:00
|
|
|
|
version)
|
|
|
|
|
except (TypeError, AttributeError, EOFError):
|
|
|
|
|
self.log.error(
|
2019-12-15 23:13:25 +01:00
|
|
|
|
"Corrupt (or heavily outdated) database found, "
|
|
|
|
|
"it will be rebuilt.")
|
2019-12-15 15:56:26 +01:00
|
|
|
|
except FileNotFoundError:
|
|
|
|
|
pass
|
|
|
|
|
self.initialize()
|
|
|
|
|
|
|
|
|
|
def save(self) -> None:
|
|
|
|
|
self.enter_step('save')
|
|
|
|
|
with open(self.PATH, 'wb') as db_fdsec:
|
2019-12-16 19:07:35 +01:00
|
|
|
|
data = self.rules, self.domtree, self.asns, self.ip4tree
|
2019-12-15 15:56:26 +01:00
|
|
|
|
pickle.dump((self.VERSION, data), db_fdsec)
|
|
|
|
|
self.profile()
|
|
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
Profiler.__init__(self)
|
|
|
|
|
self.log = logging.getLogger('db')
|
|
|
|
|
self.load()
|
|
|
|
|
|
2019-12-17 19:53:05 +01:00
|
|
|
|
@staticmethod
|
|
|
|
|
def populate_tld_list() -> None:
|
|
|
|
|
with open('temp/all_tld.list', 'r') as tld_fdesc:
|
|
|
|
|
for tld in tld_fdesc:
|
|
|
|
|
tld = tld.strip()
|
|
|
|
|
TLD_LIST.add(tld)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def validate_domain(path: str) -> bool:
|
|
|
|
|
if len(path) > 255:
|
|
|
|
|
return False
|
|
|
|
|
splits = path.split('.')
|
|
|
|
|
if not TLD_LIST:
|
|
|
|
|
Database.populate_tld_list()
|
|
|
|
|
if splits[0] not in TLD_LIST:
|
|
|
|
|
return False
|
|
|
|
|
for split in splits:
|
|
|
|
|
if not 1 <= len(split) <= 63:
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
2019-12-13 18:00:00 +01:00
|
|
|
|
@staticmethod
|
2019-12-15 15:56:26 +01:00
|
|
|
|
def pack_domain(domain: str) -> DomainPath:
|
2019-12-15 22:21:05 +01:00
|
|
|
|
return DomainPath(domain.split('.')[::-1])
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-13 18:00:00 +01:00
|
|
|
|
@staticmethod
|
2019-12-15 15:56:26 +01:00
|
|
|
|
def unpack_domain(domain: DomainPath) -> str:
|
2019-12-16 09:31:29 +01:00
|
|
|
|
return '.'.join(domain.parts[::-1])
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-13 08:23:38 +01:00
|
|
|
|
@staticmethod
|
2019-12-15 22:21:05 +01:00
|
|
|
|
def pack_asn(asn: str) -> AsnPath:
|
2019-12-13 08:23:38 +01:00
|
|
|
|
asn = asn.upper()
|
|
|
|
|
if asn.startswith('AS'):
|
|
|
|
|
asn = asn[2:]
|
2019-12-15 22:21:05 +01:00
|
|
|
|
return AsnPath(int(asn))
|
2019-12-13 08:23:38 +01:00
|
|
|
|
|
2019-12-13 00:11:21 +01:00
|
|
|
|
@staticmethod
|
2019-12-15 22:21:05 +01:00
|
|
|
|
def unpack_asn(asn: AsnPath) -> str:
|
|
|
|
|
return f'AS{asn.asn}'
|
2019-12-13 18:00:00 +01:00
|
|
|
|
|
2019-12-17 19:53:05 +01:00
|
|
|
|
@staticmethod
|
|
|
|
|
def validate_ip4address(path: str) -> bool:
|
|
|
|
|
splits = path.split('.')
|
|
|
|
|
if len(splits) != 4:
|
|
|
|
|
return False
|
|
|
|
|
for split in splits:
|
|
|
|
|
try:
|
|
|
|
|
if not 0 <= int(split) <= 255:
|
|
|
|
|
return False
|
|
|
|
|
except ValueError:
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
2019-12-13 18:00:00 +01:00
|
|
|
|
@staticmethod
|
2019-12-15 15:56:26 +01:00
|
|
|
|
def pack_ip4address(address: str) -> Ip4Path:
|
2019-12-15 16:26:18 +01:00
|
|
|
|
addr = 0
|
|
|
|
|
for split in address.split('.'):
|
2019-12-15 16:48:17 +01:00
|
|
|
|
addr = (addr << 8) + int(split)
|
2019-12-15 22:21:05 +01:00
|
|
|
|
return Ip4Path(addr, 32)
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-13 18:00:00 +01:00
|
|
|
|
@staticmethod
|
2019-12-15 15:56:26 +01:00
|
|
|
|
def unpack_ip4address(address: Ip4Path) -> str:
|
2019-12-15 22:21:05 +01:00
|
|
|
|
addr = address.value
|
|
|
|
|
assert address.prefixlen == 32
|
2019-12-15 16:26:18 +01:00
|
|
|
|
octets: typing.List[int] = list()
|
2019-12-15 15:56:26 +01:00
|
|
|
|
octets = [0] * 4
|
2019-12-15 16:26:18 +01:00
|
|
|
|
for o in reversed(range(4)):
|
|
|
|
|
octets[o] = addr & 0xFF
|
|
|
|
|
addr >>= 8
|
2019-12-15 15:56:26 +01:00
|
|
|
|
return '.'.join(map(str, octets))
|
2019-12-13 18:00:00 +01:00
|
|
|
|
|
2019-12-17 19:53:05 +01:00
|
|
|
|
@staticmethod
|
|
|
|
|
def validate_ip4network(path: str) -> bool:
|
|
|
|
|
# A bit generous but ok for our usage
|
|
|
|
|
splits = path.split('/')
|
|
|
|
|
if len(splits) != 2:
|
|
|
|
|
return False
|
|
|
|
|
if not Database.validate_ip4address(splits[0]):
|
|
|
|
|
return False
|
|
|
|
|
try:
|
|
|
|
|
if not 0 <= int(splits[1]) <= 32:
|
|
|
|
|
return False
|
|
|
|
|
except ValueError:
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
2019-12-13 18:00:00 +01:00
|
|
|
|
@staticmethod
|
2019-12-15 15:56:26 +01:00
|
|
|
|
def pack_ip4network(network: str) -> Ip4Path:
|
|
|
|
|
address, prefixlen_str = network.split('/')
|
|
|
|
|
prefixlen = int(prefixlen_str)
|
2019-12-15 22:21:05 +01:00
|
|
|
|
addr = Database.pack_ip4address(address)
|
|
|
|
|
addr.prefixlen = prefixlen
|
|
|
|
|
return addr
|
2019-12-13 18:00:00 +01:00
|
|
|
|
|
|
|
|
|
@staticmethod
|
2019-12-15 15:56:26 +01:00
|
|
|
|
def unpack_ip4network(network: Ip4Path) -> str:
|
2019-12-15 22:21:05 +01:00
|
|
|
|
addr = network.value
|
|
|
|
|
octets: typing.List[int] = list()
|
|
|
|
|
octets = [0] * 4
|
|
|
|
|
for o in reversed(range(4)):
|
|
|
|
|
octets[o] = addr & 0xFF
|
|
|
|
|
addr >>= 8
|
|
|
|
|
return '.'.join(map(str, octets)) + '/' + str(network.prefixlen)
|
|
|
|
|
|
2019-12-16 09:31:29 +01:00
|
|
|
|
def get_match(self, path: Path) -> Match:
|
2019-12-16 19:07:35 +01:00
|
|
|
|
if isinstance(path, RuleMultiPath):
|
|
|
|
|
return self.rules[0]
|
|
|
|
|
elif isinstance(path, RuleFirstPath):
|
|
|
|
|
return self.rules[1]
|
2019-12-16 09:31:29 +01:00
|
|
|
|
elif isinstance(path, AsnPath):
|
|
|
|
|
return self.asns[path.asn]
|
|
|
|
|
elif isinstance(path, DomainPath):
|
|
|
|
|
dicd = self.domtree
|
|
|
|
|
for part in path.parts:
|
|
|
|
|
dicd = dicd.children[part]
|
|
|
|
|
if isinstance(path, HostnamePath):
|
|
|
|
|
return dicd.match_hostname
|
|
|
|
|
elif isinstance(path, ZonePath):
|
|
|
|
|
return dicd.match_zone
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError
|
|
|
|
|
elif isinstance(path, Ip4Path):
|
|
|
|
|
dici = self.ip4tree
|
|
|
|
|
for i in range(31, 31-path.prefixlen, -1):
|
|
|
|
|
bit = (path.value >> i) & 0b1
|
|
|
|
|
dici_next = dici.one if bit else dici.zero
|
|
|
|
|
if not dici_next:
|
|
|
|
|
raise IndexError
|
|
|
|
|
dici = dici_next
|
|
|
|
|
return dici
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError
|
|
|
|
|
|
2019-12-16 14:18:03 +01:00
|
|
|
|
def exec_each_asn(self,
|
|
|
|
|
callback: MatchCallable,
|
|
|
|
|
) -> typing.Any:
|
|
|
|
|
for asn in self.asns:
|
|
|
|
|
match = self.asns[asn]
|
|
|
|
|
if match.active():
|
|
|
|
|
c = callback(
|
|
|
|
|
AsnPath(asn),
|
|
|
|
|
match,
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
yield from c
|
|
|
|
|
except TypeError: # not iterable
|
|
|
|
|
pass
|
|
|
|
|
|
2019-12-15 22:21:05 +01:00
|
|
|
|
def exec_each_domain(self,
|
2019-12-16 09:31:29 +01:00
|
|
|
|
callback: MatchCallable,
|
2019-12-15 22:21:05 +01:00
|
|
|
|
_dic: DomainTreeNode = None,
|
|
|
|
|
_par: DomainPath = None,
|
|
|
|
|
) -> typing.Any:
|
|
|
|
|
_dic = _dic or self.domtree
|
|
|
|
|
_par = _par or DomainPath([])
|
2019-12-16 09:31:29 +01:00
|
|
|
|
if _dic.match_hostname.active():
|
2019-12-16 14:18:03 +01:00
|
|
|
|
c = callback(
|
2019-12-16 09:31:29 +01:00
|
|
|
|
HostnamePath(_par.parts),
|
|
|
|
|
_dic.match_hostname,
|
|
|
|
|
)
|
2019-12-16 14:18:03 +01:00
|
|
|
|
try:
|
|
|
|
|
yield from c
|
|
|
|
|
except TypeError: # not iterable
|
|
|
|
|
pass
|
2019-12-16 09:31:29 +01:00
|
|
|
|
if _dic.match_zone.active():
|
2019-12-16 14:18:03 +01:00
|
|
|
|
c = callback(
|
2019-12-16 09:31:29 +01:00
|
|
|
|
ZonePath(_par.parts),
|
|
|
|
|
_dic.match_zone,
|
|
|
|
|
)
|
2019-12-16 14:18:03 +01:00
|
|
|
|
try:
|
|
|
|
|
yield from c
|
|
|
|
|
except TypeError: # not iterable
|
|
|
|
|
pass
|
2019-12-15 22:21:05 +01:00
|
|
|
|
for part in _dic.children:
|
|
|
|
|
dic = _dic.children[part]
|
|
|
|
|
yield from self.exec_each_domain(
|
|
|
|
|
callback,
|
|
|
|
|
_dic=dic,
|
2019-12-16 09:31:29 +01:00
|
|
|
|
_par=DomainPath(_par.parts + [part])
|
2019-12-15 22:21:05 +01:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def exec_each_ip4(self,
|
2019-12-16 09:31:29 +01:00
|
|
|
|
callback: MatchCallable,
|
2019-12-15 22:21:05 +01:00
|
|
|
|
_dic: IpTreeNode = None,
|
|
|
|
|
_par: Ip4Path = None,
|
|
|
|
|
) -> typing.Any:
|
|
|
|
|
_dic = _dic or self.ip4tree
|
|
|
|
|
_par = _par or Ip4Path(0, 0)
|
2019-12-16 09:31:29 +01:00
|
|
|
|
if _dic.active():
|
2019-12-16 14:18:03 +01:00
|
|
|
|
c = callback(
|
2019-12-16 09:31:29 +01:00
|
|
|
|
_par,
|
|
|
|
|
_dic,
|
|
|
|
|
)
|
2019-12-16 14:18:03 +01:00
|
|
|
|
try:
|
|
|
|
|
yield from c
|
|
|
|
|
except TypeError: # not iterable
|
|
|
|
|
pass
|
2019-12-15 22:21:05 +01:00
|
|
|
|
|
|
|
|
|
# 0
|
2019-12-16 09:31:29 +01:00
|
|
|
|
pref = _par.prefixlen + 1
|
2019-12-16 06:54:18 +01:00
|
|
|
|
dic = _dic.zero
|
2019-12-15 22:21:05 +01:00
|
|
|
|
if dic:
|
2019-12-16 09:31:29 +01:00
|
|
|
|
addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-pref)))
|
2019-12-15 22:21:05 +01:00
|
|
|
|
assert addr0 == _par.value
|
|
|
|
|
yield from self.exec_each_ip4(
|
|
|
|
|
callback,
|
|
|
|
|
_dic=dic,
|
2019-12-16 09:31:29 +01:00
|
|
|
|
_par=Ip4Path(addr0, pref)
|
2019-12-15 22:21:05 +01:00
|
|
|
|
)
|
|
|
|
|
# 1
|
2019-12-16 06:54:18 +01:00
|
|
|
|
dic = _dic.one
|
2019-12-15 22:21:05 +01:00
|
|
|
|
if dic:
|
2019-12-16 09:31:29 +01:00
|
|
|
|
addr1 = _par.value | (1 << (32-pref))
|
2019-12-15 22:21:05 +01:00
|
|
|
|
yield from self.exec_each_ip4(
|
|
|
|
|
callback,
|
|
|
|
|
_dic=dic,
|
2019-12-16 09:31:29 +01:00
|
|
|
|
_par=Ip4Path(addr1, pref)
|
2019-12-15 22:21:05 +01:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def exec_each(self,
|
2019-12-16 09:31:29 +01:00
|
|
|
|
callback: MatchCallable,
|
2019-12-15 22:21:05 +01:00
|
|
|
|
) -> typing.Any:
|
|
|
|
|
yield from self.exec_each_domain(callback)
|
|
|
|
|
yield from self.exec_each_ip4(callback)
|
2019-12-16 14:18:03 +01:00
|
|
|
|
yield from self.exec_each_asn(callback)
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
|
|
|
|
def update_references(self) -> None:
|
2019-12-16 14:18:03 +01:00
|
|
|
|
# Should be correctly calculated normally,
|
|
|
|
|
# keeping this just in case
|
|
|
|
|
def reset_references_cb(path: Path,
|
2019-12-17 13:29:02 +01:00
|
|
|
|
match: Match
|
2019-12-16 14:18:03 +01:00
|
|
|
|
) -> None:
|
|
|
|
|
match.references = 0
|
2019-12-17 13:29:02 +01:00
|
|
|
|
for _ in self.exec_each(reset_references_cb):
|
2019-12-16 14:18:03 +01:00
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def increment_references_cb(path: Path,
|
2019-12-17 13:29:02 +01:00
|
|
|
|
match: Match
|
2019-12-16 19:07:35 +01:00
|
|
|
|
) -> None:
|
2019-12-16 14:18:03 +01:00
|
|
|
|
if match.source:
|
|
|
|
|
source = self.get_match(match.source)
|
|
|
|
|
source.references += 1
|
2019-12-17 13:29:02 +01:00
|
|
|
|
for _ in self.exec_each(increment_references_cb):
|
2019-12-16 14:18:03 +01:00
|
|
|
|
pass
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-14 16:04:19 +01:00
|
|
|
|
def prune(self, before: int, base_only: bool = False) -> None:
|
2019-12-15 15:56:26 +01:00
|
|
|
|
raise NotImplementedError
|
2019-12-13 18:00:00 +01:00
|
|
|
|
|
2019-12-16 09:31:29 +01:00
|
|
|
|
def explain(self, path: Path) -> str:
|
|
|
|
|
match = self.get_match(path)
|
2019-12-17 13:29:02 +01:00
|
|
|
|
if isinstance(match, AsnNode):
|
|
|
|
|
string = f'{path} ({match.name}) #{match.references}'
|
|
|
|
|
else:
|
|
|
|
|
string = f'{path} #{match.references}'
|
2019-12-16 09:31:29 +01:00
|
|
|
|
if match.source:
|
|
|
|
|
string += f' ← {self.explain(match.source)}'
|
|
|
|
|
return string
|
2019-12-13 18:00:00 +01:00
|
|
|
|
|
|
|
|
|
def export(self,
|
|
|
|
|
first_party_only: bool = False,
|
|
|
|
|
end_chain_only: bool = False,
|
2019-12-17 14:09:06 +01:00
|
|
|
|
no_dupplicates: bool = False,
|
2019-12-13 18:00:00 +01:00
|
|
|
|
explain: bool = False,
|
|
|
|
|
) -> typing.Iterable[str]:
|
2019-12-15 22:21:05 +01:00
|
|
|
|
|
2019-12-17 13:29:02 +01:00
|
|
|
|
def export_cb(path: Path, match: Match
|
2019-12-15 22:21:05 +01:00
|
|
|
|
) -> typing.Iterable[str]:
|
|
|
|
|
assert isinstance(path, DomainPath)
|
2019-12-16 14:18:03 +01:00
|
|
|
|
if not isinstance(path, HostnamePath):
|
|
|
|
|
return
|
2019-12-16 19:07:35 +01:00
|
|
|
|
if first_party_only and not match.first_party:
|
|
|
|
|
return
|
2019-12-16 14:18:03 +01:00
|
|
|
|
if end_chain_only and match.references > 0:
|
|
|
|
|
return
|
2019-12-17 14:09:06 +01:00
|
|
|
|
if no_dupplicates and match.dupplicate:
|
|
|
|
|
return
|
2019-12-16 14:18:03 +01:00
|
|
|
|
if explain:
|
|
|
|
|
yield self.explain(path)
|
|
|
|
|
else:
|
|
|
|
|
yield self.unpack_domain(path)
|
2019-12-15 22:21:05 +01:00
|
|
|
|
|
2019-12-17 13:29:02 +01:00
|
|
|
|
yield from self.exec_each_domain(export_cb)
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-16 09:31:29 +01:00
|
|
|
|
def list_rules(self,
|
|
|
|
|
first_party_only: bool = False,
|
|
|
|
|
) -> typing.Iterable[str]:
|
|
|
|
|
|
2019-12-17 13:29:02 +01:00
|
|
|
|
def list_rules_cb(path: Path, match: Match
|
2019-12-16 09:31:29 +01:00
|
|
|
|
) -> typing.Iterable[str]:
|
2019-12-16 19:07:35 +01:00
|
|
|
|
if first_party_only and not match.first_party:
|
|
|
|
|
return
|
2019-12-16 09:31:29 +01:00
|
|
|
|
if isinstance(path, ZonePath) \
|
|
|
|
|
or (isinstance(path, Ip4Path) and path.prefixlen < 32):
|
2019-12-17 13:29:02 +01:00
|
|
|
|
# if match.level == 1:
|
|
|
|
|
# It should be the latter condition but it is more
|
|
|
|
|
# useful when using the former
|
2019-12-16 09:31:29 +01:00
|
|
|
|
yield self.explain(path)
|
|
|
|
|
|
2019-12-17 13:29:02 +01:00
|
|
|
|
yield from self.exec_each(list_rules_cb)
|
2019-12-16 09:31:29 +01:00
|
|
|
|
|
2019-12-17 13:29:02 +01:00
|
|
|
|
def count_records(self,
|
2019-12-17 14:09:06 +01:00
|
|
|
|
first_party_only: bool = False,
|
|
|
|
|
rules_only: bool = False,
|
|
|
|
|
no_dupplicates: bool = False,
|
|
|
|
|
) -> str:
|
2019-12-17 13:29:02 +01:00
|
|
|
|
memo: typing.Dict[str, int] = dict()
|
|
|
|
|
|
|
|
|
|
def count_records_cb(path: Path, match: Match) -> None:
|
|
|
|
|
if first_party_only and not match.first_party:
|
|
|
|
|
return
|
|
|
|
|
if rules_only and match.level > 1:
|
|
|
|
|
return
|
2019-12-17 14:09:06 +01:00
|
|
|
|
if no_dupplicates and match.dupplicate:
|
|
|
|
|
return
|
2019-12-17 13:29:02 +01:00
|
|
|
|
try:
|
|
|
|
|
memo[path.__class__.__name__] += 1
|
|
|
|
|
except KeyError:
|
|
|
|
|
memo[path.__class__.__name__] = 1
|
|
|
|
|
|
|
|
|
|
for _ in self.exec_each(count_records_cb):
|
|
|
|
|
pass
|
|
|
|
|
split: typing.List[str] = list()
|
|
|
|
|
for key, value in sorted(memo.items(), key=lambda s: s[0]):
|
|
|
|
|
split.append(f'{key[:-4]}: {value}')
|
|
|
|
|
return ', '.join(split)
|
2019-12-15 15:56:26 +01:00
|
|
|
|
|
2019-12-15 22:21:05 +01:00
|
|
|
|
def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]:
|
2019-12-15 15:56:26 +01:00
|
|
|
|
self.enter_step('get_domain_pack')
|
|
|
|
|
domain = self.pack_domain(domain_str)
|
|
|
|
|
self.enter_step('get_domain_brws')
|
|
|
|
|
dic = self.domtree
|
|
|
|
|
depth = 0
|
2019-12-16 09:31:29 +01:00
|
|
|
|
for part in domain.parts:
|
2019-12-15 23:13:25 +01:00
|
|
|
|
if dic.match_zone.active():
|
2019-12-15 15:56:26 +01:00
|
|
|
|
self.enter_step('get_domain_yield')
|
2019-12-16 09:31:29 +01:00
|
|
|
|
yield ZonePath(domain.parts[:depth])
|
2019-12-15 15:56:26 +01:00
|
|
|
|
self.enter_step('get_domain_brws')
|
|
|
|
|
if part not in dic.children:
|
|
|
|
|
return
|
|
|
|
|
dic = dic.children[part]
|
|
|
|
|
depth += 1
|
2019-12-15 23:13:25 +01:00
|
|
|
|
if dic.match_zone.active():
|
2019-12-13 00:11:21 +01:00
|
|
|
|
self.enter_step('get_domain_yield')
|
2019-12-16 09:31:29 +01:00
|
|
|
|
yield ZonePath(domain.parts)
|
2019-12-15 23:13:25 +01:00
|
|
|
|
if dic.match_hostname.active():
|
2019-12-15 15:56:26 +01:00
|
|
|
|
self.enter_step('get_domain_yield')
|
2019-12-16 09:31:29 +01:00
|
|
|
|
yield HostnamePath(domain.parts)
|
2019-12-15 15:56:26 +01:00
|
|
|
|
|
2019-12-15 22:21:05 +01:00
|
|
|
|
def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]:
|
2019-12-15 15:56:26 +01:00
|
|
|
|
self.enter_step('get_ip4_pack')
|
2019-12-15 22:21:05 +01:00
|
|
|
|
ip4 = self.pack_ip4address(ip4_str)
|
2019-12-15 15:56:26 +01:00
|
|
|
|
self.enter_step('get_ip4_brws')
|
|
|
|
|
dic = self.ip4tree
|
2019-12-16 09:31:29 +01:00
|
|
|
|
for i in range(31, 31-ip4.prefixlen, -1):
|
|
|
|
|
bit = (ip4.value >> i) & 0b1
|
|
|
|
|
if dic.active():
|
2019-12-15 15:56:26 +01:00
|
|
|
|
self.enter_step('get_ip4_yield')
|
2019-12-16 19:07:35 +01:00
|
|
|
|
yield Ip4Path(ip4.value >> (i+1) << (i+1), 31-i)
|
2019-12-16 09:31:29 +01:00
|
|
|
|
self.enter_step('get_ip4_brws')
|
|
|
|
|
next_dic = dic.one if bit else dic.zero
|
2019-12-15 15:56:26 +01:00
|
|
|
|
if next_dic is None:
|
|
|
|
|
return
|
|
|
|
|
dic = next_dic
|
2019-12-16 09:31:29 +01:00
|
|
|
|
if dic.active():
|
2019-12-13 00:11:21 +01:00
|
|
|
|
self.enter_step('get_ip4_yield')
|
2019-12-15 22:21:05 +01:00
|
|
|
|
yield ip4
|
2019-12-13 00:11:21 +01:00
|
|
|
|
|
2019-12-16 19:07:35 +01:00
|
|
|
|
def _set_match(self,
|
|
|
|
|
match: Match,
|
|
|
|
|
updated: int,
|
|
|
|
|
source: Path,
|
|
|
|
|
source_match: Match = None,
|
2019-12-17 14:09:06 +01:00
|
|
|
|
dupplicate: bool = False,
|
2019-12-16 19:07:35 +01:00
|
|
|
|
) -> None:
|
|
|
|
|
# source_match is in parameters because most of the time
|
|
|
|
|
# its parent function needs it too,
|
|
|
|
|
# so it can pass it to save a traversal
|
|
|
|
|
source_match = source_match or self.get_match(source)
|
|
|
|
|
new_level = source_match.level + 1
|
|
|
|
|
if updated > match.updated or new_level < match.level \
|
|
|
|
|
or source_match.first_party > match.first_party:
|
|
|
|
|
# NOTE FP and level of matches referencing this one
|
|
|
|
|
# won't be updated until run or prune
|
2019-12-16 14:18:03 +01:00
|
|
|
|
if match.source:
|
|
|
|
|
old_source = self.get_match(match.source)
|
|
|
|
|
old_source.references -= 1
|
|
|
|
|
match.updated = updated
|
|
|
|
|
match.level = new_level
|
2019-12-16 19:07:35 +01:00
|
|
|
|
match.first_party = source_match.first_party
|
2019-12-16 14:18:03 +01:00
|
|
|
|
match.source = source
|
2019-12-16 19:07:35 +01:00
|
|
|
|
source_match.references += 1
|
2019-12-17 14:09:06 +01:00
|
|
|
|
match.dupplicate = dupplicate
|
2019-12-15 15:56:26 +01:00
|
|
|
|
|
2019-12-15 23:13:25 +01:00
|
|
|
|
def _set_domain(self,
|
|
|
|
|
hostname: bool,
|
|
|
|
|
domain_str: str,
|
|
|
|
|
updated: int,
|
2019-12-16 19:07:35 +01:00
|
|
|
|
source: Path) -> None:
|
2019-12-17 19:53:05 +01:00
|
|
|
|
self.enter_step('set_domain_val')
|
|
|
|
|
if not Database.validate_domain(domain_str):
|
|
|
|
|
raise ValueError(f"Invalid domain: {domain_str}")
|
2019-12-15 23:13:25 +01:00
|
|
|
|
self.enter_step('set_domain_pack')
|
|
|
|
|
domain = self.pack_domain(domain_str)
|
2019-12-16 19:07:35 +01:00
|
|
|
|
self.enter_step('set_domain_fp')
|
|
|
|
|
source_match = self.get_match(source)
|
|
|
|
|
is_first_party = source_match.first_party
|
2019-12-15 23:13:25 +01:00
|
|
|
|
self.enter_step('set_domain_brws')
|
2019-12-15 15:56:26 +01:00
|
|
|
|
dic = self.domtree
|
2019-12-17 14:09:06 +01:00
|
|
|
|
dupplicate = False
|
2019-12-16 09:31:29 +01:00
|
|
|
|
for part in domain.parts:
|
2019-12-15 15:56:26 +01:00
|
|
|
|
if part not in dic.children:
|
|
|
|
|
dic.children[part] = DomainTreeNode()
|
|
|
|
|
dic = dic.children[part]
|
2019-12-16 19:07:35 +01:00
|
|
|
|
if dic.match_zone.active(is_first_party):
|
2019-12-17 14:09:06 +01:00
|
|
|
|
dupplicate = True
|
2019-12-15 23:13:25 +01:00
|
|
|
|
if hostname:
|
|
|
|
|
match = dic.match_hostname
|
|
|
|
|
else:
|
|
|
|
|
match = dic.match_zone
|
2019-12-16 19:07:35 +01:00
|
|
|
|
self._set_match(
|
2019-12-16 14:18:03 +01:00
|
|
|
|
match,
|
2019-12-15 23:13:25 +01:00
|
|
|
|
updated,
|
2019-12-16 19:07:35 +01:00
|
|
|
|
source,
|
|
|
|
|
source_match=source_match,
|
2019-12-17 14:09:06 +01:00
|
|
|
|
dupplicate=dupplicate,
|
2019-12-15 23:13:25 +01:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def set_hostname(self,
|
|
|
|
|
*args: typing.Any, **kwargs: typing.Any
|
|
|
|
|
) -> None:
|
|
|
|
|
self._set_domain(True, *args, **kwargs)
|
2019-12-15 15:56:26 +01:00
|
|
|
|
|
|
|
|
|
def set_zone(self,
|
2019-12-15 23:13:25 +01:00
|
|
|
|
*args: typing.Any, **kwargs: typing.Any
|
|
|
|
|
) -> None:
|
|
|
|
|
self._set_domain(False, *args, **kwargs)
|
2019-12-15 15:56:26 +01:00
|
|
|
|
|
|
|
|
|
def set_asn(self,
|
|
|
|
|
asn_str: str,
|
|
|
|
|
updated: int,
|
2019-12-16 19:07:35 +01:00
|
|
|
|
source: Path) -> None:
|
2019-12-15 23:13:25 +01:00
|
|
|
|
self.enter_step('set_asn')
|
|
|
|
|
path = self.pack_asn(asn_str)
|
2019-12-16 09:31:29 +01:00
|
|
|
|
if path.asn in self.asns:
|
|
|
|
|
match = self.asns[path.asn]
|
|
|
|
|
else:
|
|
|
|
|
match = AsnNode()
|
|
|
|
|
self.asns[path.asn] = match
|
2019-12-16 19:07:35 +01:00
|
|
|
|
self._set_match(
|
2019-12-16 14:18:03 +01:00
|
|
|
|
match,
|
2019-12-16 09:31:29 +01:00
|
|
|
|
updated,
|
2019-12-16 19:07:35 +01:00
|
|
|
|
source,
|
2019-12-15 23:13:25 +01:00
|
|
|
|
)
|
2019-12-15 15:56:26 +01:00
|
|
|
|
|
2019-12-15 22:21:05 +01:00
|
|
|
|
def _set_ip4(self,
|
|
|
|
|
ip4: Ip4Path,
|
|
|
|
|
updated: int,
|
2019-12-16 19:07:35 +01:00
|
|
|
|
source: Path) -> None:
|
|
|
|
|
self.enter_step('set_ip4_fp')
|
|
|
|
|
source_match = self.get_match(source)
|
|
|
|
|
is_first_party = source_match.first_party
|
2019-12-16 09:31:29 +01:00
|
|
|
|
self.enter_step('set_ip4_brws')
|
2019-12-15 15:56:26 +01:00
|
|
|
|
dic = self.ip4tree
|
2019-12-17 14:09:06 +01:00
|
|
|
|
dupplicate = False
|
2019-12-16 09:31:29 +01:00
|
|
|
|
for i in range(31, 31-ip4.prefixlen, -1):
|
|
|
|
|
bit = (ip4.value >> i) & 0b1
|
|
|
|
|
next_dic = dic.one if bit else dic.zero
|
2019-12-15 15:56:26 +01:00
|
|
|
|
if next_dic is None:
|
|
|
|
|
next_dic = IpTreeNode()
|
2019-12-16 09:31:29 +01:00
|
|
|
|
if bit:
|
2019-12-16 06:54:18 +01:00
|
|
|
|
dic.one = next_dic
|
|
|
|
|
else:
|
|
|
|
|
dic.zero = next_dic
|
2019-12-15 15:56:26 +01:00
|
|
|
|
dic = next_dic
|
2019-12-16 19:07:35 +01:00
|
|
|
|
if dic.active(is_first_party):
|
2019-12-17 14:09:06 +01:00
|
|
|
|
dupplicate = True
|
2019-12-16 19:07:35 +01:00
|
|
|
|
self._set_match(
|
2019-12-16 14:18:03 +01:00
|
|
|
|
dic,
|
2019-12-15 23:13:25 +01:00
|
|
|
|
updated,
|
2019-12-16 19:07:35 +01:00
|
|
|
|
source,
|
|
|
|
|
source_match=source_match,
|
2019-12-17 14:09:06 +01:00
|
|
|
|
dupplicate=dupplicate,
|
2019-12-15 23:13:25 +01:00
|
|
|
|
)
|
2019-12-15 22:21:05 +01:00
|
|
|
|
|
|
|
|
|
def set_ip4address(self,
|
|
|
|
|
ip4address_str: str,
|
|
|
|
|
*args: typing.Any, **kwargs: typing.Any
|
|
|
|
|
) -> None:
|
2019-12-17 19:53:05 +01:00
|
|
|
|
self.enter_step('set_ip4add_val')
|
|
|
|
|
if not Database.validate_ip4address(ip4address_str):
|
|
|
|
|
raise ValueError(f"Invalid ip4address: {ip4address_str}")
|
2019-12-15 22:21:05 +01:00
|
|
|
|
self.enter_step('set_ip4add_pack')
|
|
|
|
|
ip4 = self.pack_ip4address(ip4address_str)
|
|
|
|
|
self._set_ip4(ip4, *args, **kwargs)
|
2019-12-15 15:56:26 +01:00
|
|
|
|
|
|
|
|
|
def set_ip4network(self,
|
|
|
|
|
ip4network_str: str,
|
2019-12-15 22:21:05 +01:00
|
|
|
|
*args: typing.Any, **kwargs: typing.Any
|
|
|
|
|
) -> None:
|
2019-12-17 19:53:05 +01:00
|
|
|
|
self.enter_step('set_ip4net_val')
|
|
|
|
|
if not Database.validate_ip4network(ip4network_str):
|
|
|
|
|
raise ValueError(f"Invalid ip4network: {ip4network_str}")
|
2019-12-15 15:56:26 +01:00
|
|
|
|
self.enter_step('set_ip4net_pack')
|
2019-12-15 22:21:05 +01:00
|
|
|
|
ip4 = self.pack_ip4network(ip4network_str)
|
|
|
|
|
self._set_ip4(ip4, *args, **kwargs)
|