Reworked how paths work

Get those tuples out of my eyes
This commit is contained in:
Geoffrey Frogeye 2019-12-15 22:21:05 +01:00
parent 7af2074c7a
commit aec8d3f8de
Signed by: geoffrey
GPG key ID: D8A7ECA00A8CD3DD
4 changed files with 354 additions and 114 deletions

View file

@ -16,19 +16,48 @@ coloredlogs.install(
fmt='%(asctime)s %(name)s %(levelname)s %(message)s' fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
) )
PathType = enum.Enum('PathType', 'Rule Hostname Zone Asn Ip4 Ip6')
RulePath = typing.Union[None]
Asn = int Asn = int
DomainPath = typing.List[str]
Ip4Path = typing.Tuple[int, int] # value, prefixlen
Ip6Path = typing.List[int]
Path = typing.Union[RulePath, DomainPath, Asn, Ip4Path, Ip6Path]
TypedPath = typing.Tuple[PathType, Path]
Timestamp = int Timestamp = int
Level = int Level = int
Match = typing.Tuple[Timestamp, TypedPath, Level]
DebugPath = (PathType.Rule, None)
class Path():
pass
class RulePath(Path):
pass
class DomainPath(Path):
def __init__(self, path: typing.List[str]):
self.path = path
class HostnamePath(DomainPath):
pass
class ZonePath(DomainPath):
pass
class AsnPath(Path):
def __init__(self, asn: Asn):
self.asn = asn
class Ip4Path(Path):
def __init__(self, value: int, prefixlen: int):
self.value = value
self.prefixlen = prefixlen
Match = typing.Tuple[Timestamp, Path, Level]
# class AsnNode():
# def __init__(self, asn: int) -> None:
# self.asn = asn
class DomainTreeNode(): class DomainTreeNode():
@ -44,6 +73,13 @@ class IpTreeNode():
self.match: typing.Optional[Match] = None self.match: typing.Optional[Match] = None
Node = typing.Union[DomainTreeNode, IpTreeNode, Asn]
NodeCallable = typing.Callable[[Path,
Node,
typing.Optional[typing.Any]],
typing.Any]
class Profiler(): class Profiler():
def __init__(self) -> None: def __init__(self) -> None:
self.log = logging.getLogger('profiler') self.log = logging.getLogger('profiler')
@ -53,6 +89,7 @@ class Profiler():
self.step_dict: typing.Dict[str, int] = dict() self.step_dict: typing.Dict[str, int] = dict()
def enter_step(self, name: str) -> None: def enter_step(self, name: str) -> None:
return
now = time.perf_counter() now = time.perf_counter()
try: try:
self.time_dict[self.time_step] += now - self.time_last self.time_dict[self.time_step] += now - self.time_last
@ -75,7 +112,7 @@ class Profiler():
class Database(Profiler): class Database(Profiler):
VERSION = 8 VERSION = 9
PATH = "blocking.p" PATH = "blocking.p"
def initialize(self) -> None: def initialize(self) -> None:
@ -120,34 +157,34 @@ class Database(Profiler):
@staticmethod @staticmethod
def pack_domain(domain: str) -> DomainPath: def pack_domain(domain: str) -> DomainPath:
return domain.split('.')[::-1] return DomainPath(domain.split('.')[::-1])
@staticmethod @staticmethod
def unpack_domain(domain: DomainPath) -> str: def unpack_domain(domain: DomainPath) -> str:
return '.'.join(domain[::-1]) return '.'.join(domain.path[::-1])
@staticmethod @staticmethod
def pack_asn(asn: str) -> int: def pack_asn(asn: str) -> AsnPath:
asn = asn.upper() asn = asn.upper()
if asn.startswith('AS'): if asn.startswith('AS'):
asn = asn[2:] asn = asn[2:]
return int(asn) return AsnPath(int(asn))
@staticmethod @staticmethod
def unpack_asn(asn: int) -> str: def unpack_asn(asn: AsnPath) -> str:
return f'AS{asn}' return f'AS{asn.asn}'
@staticmethod @staticmethod
def pack_ip4address(address: str) -> Ip4Path: def pack_ip4address(address: str) -> Ip4Path:
addr = 0 addr = 0
for split in address.split('.'): for split in address.split('.'):
addr = (addr << 8) + int(split) addr = (addr << 8) + int(split)
return (addr, 32) return Ip4Path(addr, 32)
@staticmethod @staticmethod
def unpack_ip4address(address: Ip4Path) -> str: def unpack_ip4address(address: Ip4Path) -> str:
addr, prefixlen = address addr = address.value
assert prefixlen == 32 assert address.prefixlen == 32
octets: typing.List[int] = list() octets: typing.List[int] = list()
octets = [0] * 4 octets = [0] * 4
for o in reversed(range(4)): for o in reversed(range(4)):
@ -159,14 +196,76 @@ class Database(Profiler):
def pack_ip4network(network: str) -> Ip4Path: def pack_ip4network(network: str) -> Ip4Path:
address, prefixlen_str = network.split('/') address, prefixlen_str = network.split('/')
prefixlen = int(prefixlen_str) prefixlen = int(prefixlen_str)
addr, _ = Database.pack_ip4address(address) addr = Database.pack_ip4address(address)
return (addr, prefixlen) addr.prefixlen = prefixlen
return addr
@staticmethod @staticmethod
def unpack_ip4network(network: Ip4Path) -> str: def unpack_ip4network(network: Ip4Path) -> str:
address, prefixlen = network addr = network.value
addr = Database.unpack_ip4address((address, 32)) octets: typing.List[int] = list()
return f'{addr}/{prefixlen}' octets = [0] * 4
for o in reversed(range(4)):
octets[o] = addr & 0xFF
addr >>= 8
return '.'.join(map(str, octets)) + '/' + str(network.prefixlen)
def exec_each_domain(self,
callback: NodeCallable,
arg: typing.Any = None,
_dic: DomainTreeNode = None,
_par: DomainPath = None,
) -> typing.Any:
_dic = _dic or self.domtree
_par = _par or DomainPath([])
yield from callback(_par, _dic, arg)
for part in _dic.children:
dic = _dic.children[part]
yield from self.exec_each_domain(
callback,
arg,
_dic=dic,
_par=DomainPath(_par.path + [part])
)
def exec_each_ip4(self,
callback: NodeCallable,
arg: typing.Any = None,
_dic: IpTreeNode = None,
_par: Ip4Path = None,
) -> typing.Any:
_dic = _dic or self.ip4tree
_par = _par or Ip4Path(0, 0)
callback(_par, _dic, arg)
# 0
dic = _dic.children[0]
if dic:
addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-_par.prefixlen)))
assert addr0 == _par.value
yield from self.exec_each_ip4(
callback,
arg,
_dic=dic,
_par=Ip4Path(addr0, _par.prefixlen+1)
)
# 1
dic = _dic.children[1]
if dic:
addr1 = _par.value | (1 << (32-_par.prefixlen))
yield from self.exec_each_ip4(
callback,
arg,
_dic=dic,
_par=Ip4Path(addr1, _par.prefixlen+1)
)
def exec_each(self,
callback: NodeCallable,
arg: typing.Any = None,
) -> typing.Any:
yield from self.exec_each_domain(callback)
yield from self.exec_each_ip4(callback)
def update_references(self) -> None: def update_references(self) -> None:
raise NotImplementedError raise NotImplementedError
@ -181,35 +280,35 @@ class Database(Profiler):
first_party_only: bool = False, first_party_only: bool = False,
end_chain_only: bool = False, end_chain_only: bool = False,
explain: bool = False, explain: bool = False,
_dic: DomainTreeNode = None,
_par: DomainPath = None,
) -> typing.Iterable[str]: ) -> typing.Iterable[str]:
if first_party_only or end_chain_only or explain: if first_party_only or end_chain_only or explain:
raise NotImplementedError raise NotImplementedError
_dic = _dic or self.domtree
_par = _par or list() def export_cb(path: Path, node: Node, _: typing.Any
if _dic.match_hostname: ) -> typing.Iterable[str]:
yield self.unpack_domain(_par) assert isinstance(path, DomainPath)
for part in _dic.children: assert isinstance(node, DomainTreeNode)
dic = _dic.children[part] if node.match_hostname:
yield from self.export(_dic=dic, a = self.unpack_domain(path)
_par=_par + [part]) yield a
yield from self.exec_each_domain(export_cb, None)
def count_rules(self, def count_rules(self,
first_party_only: bool = False, first_party_only: bool = False,
) -> str: ) -> str:
raise NotImplementedError raise NotImplementedError
def get_domain(self, domain_str: str) -> typing.Iterable[TypedPath]: def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]:
self.enter_step('get_domain_pack') self.enter_step('get_domain_pack')
domain = self.pack_domain(domain_str) domain = self.pack_domain(domain_str)
self.enter_step('get_domain_brws') self.enter_step('get_domain_brws')
dic = self.domtree dic = self.domtree
depth = 0 depth = 0
for part in domain: for part in domain.path:
if dic.match_zone: if dic.match_zone:
self.enter_step('get_domain_yield') self.enter_step('get_domain_yield')
yield (PathType.Zone, domain[:depth]) yield ZonePath(domain.path[:depth])
self.enter_step('get_domain_brws') self.enter_step('get_domain_brws')
if part not in dic.children: if part not in dic.children:
return return
@ -217,21 +316,21 @@ class Database(Profiler):
depth += 1 depth += 1
if dic.match_zone: if dic.match_zone:
self.enter_step('get_domain_yield') self.enter_step('get_domain_yield')
yield (PathType.Zone, domain) yield ZonePath(domain.path)
if dic.match_hostname: if dic.match_hostname:
self.enter_step('get_domain_yield') self.enter_step('get_domain_yield')
yield (PathType.Hostname, domain) yield HostnamePath(domain.path)
def get_ip4(self, ip4_str: str) -> typing.Iterable[TypedPath]: def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]:
self.enter_step('get_ip4_pack') self.enter_step('get_ip4_pack')
ip4, prefixlen = self.pack_ip4address(ip4_str) ip4 = self.pack_ip4address(ip4_str)
self.enter_step('get_ip4_brws') self.enter_step('get_ip4_brws')
dic = self.ip4tree dic = self.ip4tree
for i in reversed(range(prefixlen)): for i in reversed(range(ip4.prefixlen)):
part = (ip4 >> i) & 0b1 part = (ip4.value >> i) & 0b1
if dic.match: if dic.match:
self.enter_step('get_ip4_yield') self.enter_step('get_ip4_yield')
yield (PathType.Ip4, (ip4, 32-i)) yield Ip4Path(ip4.value, 32-i)
self.enter_step('get_ip4_brws') self.enter_step('get_ip4_brws')
next_dic = dic.children[part] next_dic = dic.children[part]
if next_dic is None: if next_dic is None:
@ -239,108 +338,99 @@ class Database(Profiler):
dic = next_dic dic = next_dic
if dic.match: if dic.match:
self.enter_step('get_ip4_yield') self.enter_step('get_ip4_yield')
yield (PathType.Ip4, ip4) yield ip4
def list_asn(self) -> typing.Iterable[TypedPath]: def list_asn(self) -> typing.Iterable[AsnPath]:
for asn in self.asns: for asn in self.asns:
yield (PathType.Asn, asn) yield AsnPath(asn)
def set_hostname(self, def set_hostname(self,
hostname_str: str, hostname_str: str,
updated: int, updated: int,
is_first_party: bool = None, is_first_party: bool = None,
source: TypedPath = None) -> None: source: Path = None) -> None:
self.enter_step('set_hostname_pack') self.enter_step('set_hostname_pack')
if is_first_party or source: if is_first_party:
raise NotImplementedError raise NotImplementedError
self.enter_step('set_hostname_brws') self.enter_step('set_hostname_brws')
hostname = self.pack_domain(hostname_str) hostname = self.pack_domain(hostname_str)
dic = self.domtree dic = self.domtree
for part in hostname: for part in hostname.path:
if dic.match_zone: if dic.match_zone:
# Refuse to add hostname whose zone is already matching # Refuse to add hostname whose zone is already matching
return return
if part not in dic.children: if part not in dic.children:
dic.children[part] = DomainTreeNode() dic.children[part] = DomainTreeNode()
dic = dic.children[part] dic = dic.children[part]
dic.match_hostname = (updated, DebugPath, 0) dic.match_hostname = (updated, source or RulePath(), 0)
def set_zone(self, def set_zone(self,
zone_str: str, zone_str: str,
updated: int, updated: int,
is_first_party: bool = None, is_first_party: bool = None,
source: TypedPath = None) -> None: source: Path = None) -> None:
self.enter_step('set_zone_pack') self.enter_step('set_zone_pack')
if is_first_party or source: if is_first_party:
raise NotImplementedError raise NotImplementedError
zone = self.pack_domain(zone_str) zone = self.pack_domain(zone_str)
self.enter_step('set_zone_brws') self.enter_step('set_zone_brws')
dic = self.domtree dic = self.domtree
for part in zone: for part in zone.path:
if dic.match_zone: if dic.match_zone:
# Refuse to add zone whose parent zone is already matching # Refuse to add zone whose parent zone is already matching
return return
if part not in dic.children: if part not in dic.children:
dic.children[part] = DomainTreeNode() dic.children[part] = DomainTreeNode()
dic = dic.children[part] dic = dic.children[part]
dic.match_zone = (updated, DebugPath, 0) dic.match_zone = (updated, source or RulePath(), 0)
def set_asn(self, def set_asn(self,
asn_str: str, asn_str: str,
updated: int, updated: int,
is_first_party: bool = None, is_first_party: bool = None,
source: TypedPath = None) -> None: source: Path = None) -> None:
self.enter_step('set_asn_pack') self.enter_step('set_asn_pack')
if is_first_party or source: if is_first_party or source:
# TODO updated # TODO updated
raise NotImplementedError raise NotImplementedError
asn = self.pack_asn(asn_str) asn = self.pack_asn(asn_str)
self.enter_step('set_asn_brws') self.enter_step('set_asn_brws')
self.asns.add(asn) self.asns.add(asn.asn)
def _set_ip4(self,
ip4: Ip4Path,
updated: int,
is_first_party: bool = None,
source: Path = None) -> None:
if is_first_party:
raise NotImplementedError
dic = self.ip4tree
for i in reversed(range(ip4.prefixlen)):
part = (ip4.value >> i) & 0b1
if dic.match:
# Refuse to add ip4* whose network is already matching
return
next_dic = dic.children[part]
if next_dic is None:
next_dic = IpTreeNode()
dic.children[part] = next_dic
dic = next_dic
dic.match = (updated, source or RulePath(), 0)
def set_ip4address(self, def set_ip4address(self,
ip4address_str: str, ip4address_str: str,
updated: int, *args: typing.Any, **kwargs: typing.Any
is_first_party: bool = None, ) -> None:
source: TypedPath = None) -> None:
self.enter_step('set_ip4add_pack') self.enter_step('set_ip4add_pack')
if is_first_party or source: ip4 = self.pack_ip4address(ip4address_str)
raise NotImplementedError
ip4, prefixlen = self.pack_ip4address(ip4address_str)
self.enter_step('set_ip4add_brws') self.enter_step('set_ip4add_brws')
dic = self.ip4tree self._set_ip4(ip4, *args, **kwargs)
for i in reversed(range(prefixlen)):
part = (ip4 >> i) & 0b1
if dic.match:
# Refuse to add ip4address whose network is already matching
return
next_dic = dic.children[part]
if next_dic is None:
next_dic = IpTreeNode()
dic.children[part] = next_dic
dic = next_dic
dic.match = (updated, DebugPath, 0)
def set_ip4network(self, def set_ip4network(self,
ip4network_str: str, ip4network_str: str,
updated: int, *args: typing.Any, **kwargs: typing.Any
is_first_party: bool = None, ) -> None:
source: TypedPath = None) -> None:
self.enter_step('set_ip4net_pack') self.enter_step('set_ip4net_pack')
if is_first_party or source: ip4 = self.pack_ip4network(ip4network_str)
raise NotImplementedError
self.enter_step('set_ip4net_brws') self.enter_step('set_ip4net_brws')
ip4, prefixlen = self.pack_ip4network(ip4network_str) self._set_ip4(ip4, *args, **kwargs)
dic = self.ip4tree
for i in reversed(range(prefixlen)):
part = (ip4 >> i) & 0b1
if dic.match:
# Refuse to add ip4network whose parent network
# is already matching
return
next_dic = dic.children[part]
if next_dic is None:
next_dic = IpTreeNode()
dic.children[part] = next_dic
dic = next_dic
dic.match = (updated, DebugPath, 0)

View file

@ -33,10 +33,7 @@ if __name__ == '__main__':
DB = database.Database() DB = database.Database()
for path in DB.list_asn(): for path in DB.list_asn():
ptype, asn = path asn_str = database.Database.unpack_asn(path)
assert ptype == database.PathType.Asn
assert isinstance(asn, int)
asn_str = database.Database.unpack_asn(asn)
DB.enter_step('asn_get_ranges') DB.enter_step('asn_get_ranges')
for prefix in get_ranges(asn_str): for prefix in get_ranges(asn_str):
parsed_prefix: IPNetwork = ipaddress.ip_network(prefix) parsed_prefix: IPNetwork = ipaddress.ip_network(prefix)
@ -46,7 +43,7 @@ if __name__ == '__main__':
# source=path, # source=path,
updated=int(time.time()) updated=int(time.time())
) )
log.info('Added %s from %s (source=%s)', prefix, asn, path) log.info('Added %s from %s (%s)', prefix, asn_str, path)
elif parsed_prefix.version == 6: elif parsed_prefix.version == 6:
log.warning('Unimplemented prefix version: %s', prefix) log.warning('Unimplemented prefix version: %s', prefix)
else: else:

147
feed_dns.old.py Executable file
View file

@ -0,0 +1,147 @@
#!/usr/bin/env python3
import argparse
import database
import logging
import sys
import typing
import enum
RecordType = enum.Enum('RecordType', 'A AAAA CNAME PTR')
Record = typing.Tuple[RecordType, int, str, str]
# select, write
FUNCTION_MAP: typing.Any = {
RecordType.A: (
database.Database.get_ip4,
database.Database.set_hostname,
),
RecordType.CNAME: (
database.Database.get_domain,
database.Database.set_hostname,
),
RecordType.PTR: (
database.Database.get_domain,
database.Database.set_ip4address,
),
}
class Parser():
def __init__(self, buf: typing.Any) -> None:
self.buf = buf
self.log = logging.getLogger('parser')
self.db = database.Database()
def end(self) -> None:
self.db.save()
def register(self,
rtype: RecordType,
updated: int,
name: str,
value: str
) -> None:
self.db.enter_step('register')
select, write = FUNCTION_MAP[rtype]
for source in select(self.db, value):
# write(self.db, name, updated, source=source)
write(self.db, name, updated)
def consume(self) -> None:
raise NotImplementedError
class Rapid7Parser(Parser):
TYPES = {
'a': RecordType.A,
'aaaa': RecordType.AAAA,
'cname': RecordType.CNAME,
'ptr': RecordType.PTR,
}
def consume(self) -> None:
data = dict()
for line in self.buf:
self.db.enter_step('parse_rapid7')
split = line.split('"')
for k in range(1, 14, 4):
key = split[k]
val = split[k+2]
data[key] = val
self.register(
Rapid7Parser.TYPES[data['type']],
int(data['timestamp']),
data['name'],
data['value']
)
class DnsMassParser(Parser):
# dnsmass --output Snrql
# --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4
TYPES = {
'A': (RecordType.A, -1, None),
'AAAA': (RecordType.AAAA, -1, None),
'CNAME': (RecordType.CNAME, -1, -1),
}
def consume(self) -> None:
self.db.enter_step('parse_dnsmass')
timestamp = 0
header = True
for line in self.buf:
line = line[:-1]
if not line:
header = True
continue
split = line.split(' ')
try:
if header:
timestamp = int(split[1])
header = False
else:
dtype, name_offset, value_offset = \
DnsMassParser.TYPES[split[1]]
self.register(
dtype,
timestamp,
split[0][:name_offset],
split[2][:value_offset],
)
self.db.enter_step('parse_dnsmass')
except KeyError:
continue
PARSERS = {
'rapid7': Rapid7Parser,
'dnsmass': DnsMassParser,
}
if __name__ == '__main__':
# Parsing arguments
log = logging.getLogger('feed_dns')
args_parser = argparse.ArgumentParser(
description="TODO")
args_parser.add_argument(
'parser',
choices=PARSERS.keys(),
help="TODO")
args_parser.add_argument(
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO")
args = args_parser.parse_args()
parser = PARSERS[args.parser](args.input)
try:
parser.consume()
except KeyboardInterrupt:
pass
parser.end()

View file

@ -49,9 +49,12 @@ class Writer(multiprocessing.Process):
select, write, updated, name, value = record select, write, updated, name, value = record
self.db.enter_step('feed_switch') self.db.enter_step('feed_switch')
try:
for source in select(self.db, value): for source in select(self.db, value):
# write(self.db, name, updated, source=source) # write(self.db, name, updated, source=source)
write(self.db, name, updated) write(self.db, name, updated)
except ValueError:
self.log.exception("Cannot execute: %s", record)
self.db.enter_step('block_wait') self.db.enter_step('block_wait')
@ -98,6 +101,7 @@ class Rapid7Parser(Parser):
self.prof.enter_step('parse_rapid7') self.prof.enter_step('parse_rapid7')
split = line.split('"') split = line.split('"')
try:
for k in range(1, 14, 4): for k in range(1, 14, 4):
key = split[k] key = split[k]
val = split[k+2] val = split[k+2]
@ -111,6 +115,8 @@ class Rapid7Parser(Parser):
data['name'], data['name'],
data['value'] data['value']
) )
except IndexError:
self.log.exception("Cannot parse: %s", line)
self.register(record) self.register(record)