Put packing in parsing thread

Why did I think this would be a good idea?
- value don't need to be packed most of the time, but we don't know that
early
- packed domain (it's one most of the time) is way larger than its
unpacked counterpart
This commit is contained in:
Geoffrey Frogeye 2019-12-16 10:38:37 +01:00
parent 03a4042238
commit dcf39c9582
Signed by: geoffrey
GPG key ID: D8A7ECA00A8CD3DD
3 changed files with 88 additions and 63 deletions

View file

@ -395,9 +395,7 @@ class Database(Profiler):
) -> str: ) -> str:
raise NotImplementedError raise NotImplementedError
def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]: def get_domain(self, domain: DomainPath) -> typing.Iterable[DomainPath]:
self.enter_step('get_domain_pack')
domain = self.pack_domain(domain_str)
self.enter_step('get_domain_brws') self.enter_step('get_domain_brws')
dic = self.domtree dic = self.domtree
depth = 0 depth = 0
@ -417,9 +415,7 @@ class Database(Profiler):
self.enter_step('get_domain_yield') self.enter_step('get_domain_yield')
yield HostnamePath(domain.parts) yield HostnamePath(domain.parts)
def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]: def get_ip4(self, ip4: Ip4Path) -> typing.Iterable[Path]:
self.enter_step('get_ip4_pack')
ip4 = self.pack_ip4address(ip4_str)
self.enter_step('get_ip4_brws') self.enter_step('get_ip4_brws')
dic = self.ip4tree dic = self.ip4tree
for i in range(31, 31-ip4.prefixlen, -1): for i in range(31, 31-ip4.prefixlen, -1):
@ -443,14 +439,12 @@ class Database(Profiler):
def _set_domain(self, def _set_domain(self,
hostname: bool, hostname: bool,
domain_str: str, domain: DomainPath,
updated: int, updated: int,
is_first_party: bool = None, is_first_party: bool = None,
source: Path = None) -> None: source: Path = None) -> None:
self.enter_step('set_domain_pack')
if is_first_party: if is_first_party:
raise NotImplementedError raise NotImplementedError
domain = self.pack_domain(domain_str)
self.enter_step('set_domain_src') self.enter_step('set_domain_src')
if source is None: if source is None:
level = 0 level = 0
@ -488,7 +482,7 @@ class Database(Profiler):
self._set_domain(False, *args, **kwargs) self._set_domain(False, *args, **kwargs)
def set_asn(self, def set_asn(self,
asn_str: str, asn: AsnPath,
updated: int, updated: int,
is_first_party: bool = None, is_first_party: bool = None,
source: Path = None) -> None: source: Path = None) -> None:
@ -501,23 +495,22 @@ class Database(Profiler):
else: else:
match = self.get_match(source) match = self.get_match(source)
level = match.level + 1 level = match.level + 1
path = self.pack_asn(asn_str) if asn.asn in self.asns:
if path.asn in self.asns: match = self.asns[asn.asn]
match = self.asns[path.asn]
else: else:
match = AsnNode() match = AsnNode()
self.asns[path.asn] = match self.asns[asn.asn] = match
match.set( match.set(
updated, updated,
level, level,
source, source,
) )
def _set_ip4(self, def set_ip4network(self,
ip4: Ip4Path, ip4: Ip4Path,
updated: int, updated: int,
is_first_party: bool = None, is_first_party: bool = None,
source: Path = None) -> None: source: Path = None) -> None:
if is_first_party: if is_first_party:
raise NotImplementedError raise NotImplementedError
self.enter_step('set_ip4_src') self.enter_step('set_ip4_src')
@ -549,17 +542,8 @@ class Database(Profiler):
) )
def set_ip4address(self, def set_ip4address(self,
ip4address_str: str, ip4: Ip4Path,
*args: typing.Any, **kwargs: typing.Any *args: typing.Any, **kwargs: typing.Any
) -> None: ) -> None:
self.enter_step('set_ip4add_pack') assert ip4.prefixlen == 32
ip4 = self.pack_ip4address(ip4address_str) self.set_ip4network(ip4, *args, **kwargs)
self._set_ip4(ip4, *args, **kwargs)
def set_ip4network(self,
ip4network_str: str,
*args: typing.Any, **kwargs: typing.Any
) -> None:
self.enter_step('set_ip4net_pack')
ip4 = self.pack_ip4network(ip4network_str)
self._set_ip4(ip4, *args, **kwargs)

View file

@ -8,21 +8,28 @@ import typing
import multiprocessing import multiprocessing
import enum import enum
Record = typing.Tuple[typing.Callable, typing.Callable, int, str, str] Record = typing.Tuple[typing.Callable,
typing.Callable, int, database.Path, database.Path]
# select, write # select, write, name_packer, value_packer
FUNCTION_MAP: typing.Any = { FUNCTION_MAP: typing.Any = {
'a': ( 'a': (
database.Database.get_ip4, database.Database.get_ip4,
database.Database.set_hostname, database.Database.set_hostname,
database.Database.pack_domain,
database.Database.pack_ip4address,
), ),
'cname': ( 'cname': (
database.Database.get_domain, database.Database.get_domain,
database.Database.set_hostname, database.Database.set_hostname,
database.Database.pack_domain,
database.Database.pack_domain,
), ),
'ptr': ( 'ptr': (
database.Database.get_domain, database.Database.get_domain,
database.Database.set_ip4address, database.Database.set_ip4address,
database.Database.pack_ip4address,
database.Database.pack_domain,
), ),
} }
@ -49,11 +56,8 @@ class Writer(multiprocessing.Process):
select, write, updated, name, value = record select, write, updated, name, value = record
self.db.enter_step('feed_switch') self.db.enter_step('feed_switch')
try: for source in select(self.db, value):
for source in select(self.db, value): write(self.db, name, updated, source=source)
write(self.db, name, updated, source=source)
except ValueError:
self.log.exception("Cannot execute: %s", record)
self.db.enter_step('block_wait') self.db.enter_step('block_wait')
@ -76,8 +80,33 @@ class Parser():
self.prof = database.Profiler() self.prof = database.Profiler()
self.prof.log = logging.getLogger('pr') self.prof.log = logging.getLogger('pr')
def register(self, record: Record) -> None: def register(self,
self.prof.enter_step('register') rtype: str,
timestamp: int,
name_str: str,
value_str: str,
) -> None:
self.prof.enter_step('pack')
try:
select, write, name_packer, value_packer = FUNCTION_MAP[rtype]
except KeyError:
self.log.exception("Unknown record type")
return
try:
name = name_packer(name_str)
except ValueError:
self.log.exception("Cannot parse name ('%s' with %s)",
name_str, name_packer)
return
try:
value = value_packer(value_str)
except ValueError:
self.log.exception("Cannot parse value ('%s' with %s)",
value_str, value_packer)
return
record = (select, write, timestamp, name, value)
self.prof.enter_step('grow_block')
self.block.append(record) self.block.append(record)
if len(self.block) >= self.block_size: if len(self.block) >= self.block_size:
self.prof.enter_step('put_block') self.prof.enter_step('put_block')
@ -96,6 +125,7 @@ class Parser():
class Rapid7Parser(Parser): class Rapid7Parser(Parser):
def consume(self) -> None: def consume(self) -> None:
data = dict() data = dict()
self.prof.enter_step('iowait')
for line in self.buf: for line in self.buf:
self.prof.enter_step('parse_rapid7') self.prof.enter_step('parse_rapid7')
split = line.split('"') split = line.split('"')
@ -106,26 +136,25 @@ class Rapid7Parser(Parser):
val = split[k+2] val = split[k+2]
data[key] = val data[key] = val
select, writer = FUNCTION_MAP[data['type']] self.register(
record = ( data['type'],
select,
writer,
int(data['timestamp']), int(data['timestamp']),
data['name'], data['name'],
data['value'] data['value'],
) )
except IndexError: self.prof.enter_step('iowait')
except KeyError:
# Sometimes JSON records are off the place
self.log.exception("Cannot parse: %s", line) self.log.exception("Cannot parse: %s", line)
self.register(record)
class DnsMassParser(Parser): class DnsMassParser(Parser):
# dnsmass --output Snrql # dnsmass --output Snrql
# --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4 # --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4
TYPES = { TYPES = {
'A': (FUNCTION_MAP['a'][0], FUNCTION_MAP['a'][1], -1, None), 'A': ('a', -1, None),
# 'AAAA': (FUNCTION_MAP['aaaa'][0], FUNCTION_MAP['aaaa'][1], -1, None), # 'AAAA': ('aaaa', -1, None),
'CNAME': (FUNCTION_MAP['cname'][0], FUNCTION_MAP['cname'][1], -1, -1), 'CNAME': ('cname', -1, -1),
} }
def consume(self) -> None: def consume(self) -> None:
@ -144,19 +173,19 @@ class DnsMassParser(Parser):
timestamp = int(split[1]) timestamp = int(split[1])
header = False header = False
else: else:
select, write, name_offset, value_offset = \ rtype, name_offset, value_offset = \
DnsMassParser.TYPES[split[1]] DnsMassParser.TYPES[split[1]]
record = ( self.register(
select, rtype,
write,
timestamp, timestamp,
split[0][:name_offset], split[0][:name_offset],
split[2][:value_offset], split[2][:value_offset],
) )
self.register(record)
self.prof.enter_step('parse_dnsmass') self.prof.enter_step('parse_dnsmass')
except KeyError: except KeyError:
continue # Malformed records are less likely to happen,
# but we may never be sure
self.log.exception("Cannot parse: %s", line)
PARSERS = { PARSERS = {
@ -189,7 +218,7 @@ if __name__ == '__main__':
args = args_parser.parse_args() args = args_parser.parse_args()
recs_queue: multiprocessing.Queue = multiprocessing.Queue( recs_queue: multiprocessing.Queue = multiprocessing.Queue(
maxsize=args.queue_size) maxsize=args.queue_size)
writer = Writer(recs_queue) writer = Writer(recs_queue)
writer.start() writer.start()

View file

@ -4,11 +4,22 @@ import database
import argparse import argparse
import sys import sys
import time import time
import typing
FUNCTION_MAP = { FUNCTION_MAP: typing.Dict[str, typing.Tuple[
'zone': database.Database.set_zone, typing.Callable[[database.Database, database.Path, int], None],
'ip4network': database.Database.set_ip4network, typing.Callable[[str], database.Path],
'asn': database.Database.set_asn, ]] = {
'hostname': (database.Database.set_hostname,
database.Database.pack_domain),
'zone': (database.Database.set_zone,
database.Database.pack_domain),
'asn': (database.Database.set_asn,
database.Database.pack_asn),
'ip4address': (database.Database.set_ip4address,
database.Database.pack_ip4address),
'ip4network': (database.Database.set_ip4network,
database.Database.pack_ip4network),
} }
if __name__ == '__main__': if __name__ == '__main__':
@ -30,11 +41,12 @@ if __name__ == '__main__':
DB = database.Database() DB = database.Database()
fun = FUNCTION_MAP[args.type] fun, packer = FUNCTION_MAP[args.type]
for rule in args.input: for rule in args.input:
packed = packer(rule.strip())
fun(DB, fun(DB,
rule.strip(), packed,
# is_first_party=args.first_party, # is_first_party=args.first_party,
updated=int(time.time()), updated=int(time.time()),
) )