Compare commits

..

15 commits

Author SHA1 Message Date
Geoffrey Frogeye a0e68f0848
Reworked match and node system
For level, and first_party later
Next: add get_match to retrieve level of source and have correct levels

... am I going somewhere with all this?
2019-12-15 23:13:25 +01:00
Geoffrey Frogeye aec8d3f8de
Reworked how paths work
Get those tuples out of my eyes
2019-12-15 22:21:05 +01:00
Geoffrey Frogeye 7af2074c7a
Small optimisation of feed_switch 2019-12-15 17:12:44 +01:00
Geoffrey Frogeye 45325782d2
Multi-processed parser 2019-12-15 17:05:41 +01:00
Geoffrey Frogeye ce52897d30
Smol fixes 2019-12-15 16:48:17 +01:00
Geoffrey Frogeye 954b33b2a6
Slightly better Rapid7 parser 2019-12-15 16:38:01 +01:00
Geoffrey Frogeye d976752797
Store Ip4Path as int instead of List[int] 2019-12-15 16:26:18 +01:00
Geoffrey Frogeye 4d966371b2
Workflow: SQL -> Tree
Welp. All that for this.
2019-12-15 15:56:26 +01:00
Geoffrey Frogeye 040ce4c14e
Typo in source 2019-12-15 01:52:45 +01:00
Geoffrey Frogeye b50c01f740 Merge branch 'master' into newworkflow 2019-12-15 01:30:03 +01:00
Geoffrey Frogeye ddceed3d25
Workflow: Can now import DnsMass output
Well, in a specific format but DnsMass nonetheless
2019-12-15 00:28:08 +01:00
Geoffrey Frogeye 189deeb559
Workflow: Multiprocess
Still trying.
It's better than multithread though.

Merge branch 'newworkflow' into newworkflow_threaded
2019-12-14 17:27:46 +01:00
Geoffrey Frogeye d7c239a6f6 Workflow: Some modifications 2019-12-14 16:04:19 +01:00
Geoffrey Frogeye 231bb83667
Threaded feed_dns
Largely disapointing
2019-12-13 12:36:11 +01:00
Geoffrey Frogeye 12dcafe606
Added alternate source of Eulerian CNAMES
It was requested so.
It should be temporary, once I have a bigger subdomain list
that shouldn't be required.
2019-12-12 19:13:54 +01:00
13 changed files with 726 additions and 899 deletions

5
.gitignore vendored
View file

@ -1,7 +1,4 @@
*.log
*.db
*.db-journal
*.p
nameservers
nameservers.head
*.o
*.so

835
database.py Executable file → Normal file
View file

@ -4,111 +4,115 @@
Utility functions to interact with the database.
"""
import sqlite3
import typing
import time
import os
import logging
import argparse
import coloredlogs
import ipaddress
import math
import pickle
coloredlogs.install(
level='DEBUG',
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
)
DbValue = typing.Union[None, int, float, str, bytes]
Asn = int
Timestamp = int
Level = int
class Database():
VERSION = 5
PATH = "blocking.db"
class Path():
# FP add boolean here
pass
def open(self) -> None:
mode = 'rwc' if self.write else 'ro'
uri = f'file:{self.PATH}?mode={mode}'
self.conn = sqlite3.connect(uri, uri=True)
cursor = self.conn.cursor()
cursor.execute("PRAGMA foreign_keys = ON")
self.conn.create_function("unpack_asn", 1,
self.unpack_asn,
deterministic=True)
self.conn.create_function("unpack_ip4address", 1,
self.unpack_ip4address,
deterministic=True)
self.conn.create_function("unpack_ip4network", 2,
self.unpack_ip4network,
deterministic=True)
self.conn.create_function("unpack_domain", 1,
lambda s: s[:-1][::-1],
deterministic=True)
self.conn.create_function("format_zone", 1,
lambda s: '*' + s[::-1],
deterministic=True)
def get_meta(self, key: str) -> typing.Optional[int]:
cursor = self.conn.cursor()
try:
cursor.execute("SELECT value FROM meta WHERE key=?", (key,))
except sqlite3.OperationalError:
return None
for ver, in cursor:
return ver
return None
class RulePath(Path):
pass
def set_meta(self, key: str, val: int) -> None:
cursor = self.conn.cursor()
cursor.execute("INSERT INTO meta VALUES (?, ?) "
"ON CONFLICT (key) DO "
"UPDATE set value=?",
(key, val, val))
def close(self) -> None:
self.enter_step('close_commit')
self.conn.commit()
self.enter_step('close')
self.conn.close()
self.profile()
class DomainPath(Path):
def __init__(self, path: typing.List[str]):
self.path = path
def initialize(self) -> None:
self.close()
self.enter_step('initialize')
if not self.write:
self.log.error("Cannot initialize in read-only mode.")
raise
os.unlink(self.PATH)
self.open()
self.log.info("Creating database version %d.", self.VERSION)
cursor = self.conn.cursor()
with open("database_schema.sql", 'r') as db_schema:
cursor.executescript(db_schema.read())
self.set_meta('version', self.VERSION)
self.conn.commit()
def __init__(self, write: bool = False) -> None:
self.log = logging.getLogger('db')
class HostnamePath(DomainPath):
pass
class ZonePath(DomainPath):
pass
class AsnPath(Path):
def __init__(self, asn: Asn):
self.asn = asn
class Ip4Path(Path):
def __init__(self, value: int, prefixlen: int):
self.value = value
self.prefixlen = prefixlen
class Match():
def __init__(self) -> None:
self.updated: int = 0
self.level: int = 0
self.source: Path = RulePath()
# FP dupplicate args
def set(self,
updated: int,
level: int,
source: Path,
) -> None:
if updated > self.updated or level > self.level:
self.updated = updated
self.level = level
self.source = source
# FP dupplicate function
def active(self) -> bool:
return self.updated > 0
class AsnNode(Match):
pass
class DomainTreeNode():
def __init__(self) -> None:
self.children: typing.Dict[str, DomainTreeNode] = dict()
self.match_zone = Match()
self.match_hostname = Match()
class IpTreeNode():
def __init__(self) -> None:
self.children: typing.List[typing.Optional[IpTreeNode]] = [None, None]
self.match = Match()
Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode]
NodeCallable = typing.Callable[[Path,
Node,
typing.Optional[typing.Any]],
typing.Any]
class Profiler():
def __init__(self) -> None:
self.log = logging.getLogger('profiler')
self.time_last = time.perf_counter()
self.time_step = 'init'
self.time_dict: typing.Dict[str, float] = dict()
self.step_dict: typing.Dict[str, int] = dict()
self.write = write
self.open()
version = self.get_meta('version')
if version != self.VERSION:
if version is not None:
self.log.warning(
"Outdated database version: %d found, will be rebuilt.",
version)
self.initialize()
def enter_step(self, name: str) -> None:
return
now = time.perf_counter()
try:
self.time_dict[self.time_step] += now - self.time_last
self.step_dict[self.time_step] += 1
self.step_dict[self.time_step] += int(name != self.time_step)
except KeyError:
self.time_dict[self.time_step] = now - self.time_last
self.step_dict[self.time_step] = 1
@ -125,435 +129,334 @@ class Database():
self.log.debug(f"{'total':<20}: "
f"{total:9.2f} s ({1:7.2%})")
@staticmethod
def pack_hostname(hostname: str) -> str:
return hostname[::-1] + '.'
class Database(Profiler):
VERSION = 10
PATH = "blocking.p"
def initialize(self) -> None:
self.log.warning(
"Creating database version: %d ",
Database.VERSION)
self.domtree = DomainTreeNode()
self.asns: typing.Dict[Asn, AsnNode] = dict()
self.ip4tree = IpTreeNode()
def load(self) -> None:
self.enter_step('load')
try:
with open(self.PATH, 'rb') as db_fdsec:
version, data = pickle.load(db_fdsec)
if version == Database.VERSION:
self.domtree, self.asns, self.ip4tree = data
return
self.log.warning(
"Outdated database version found: %d, "
"it will be rebuilt.",
version)
except (TypeError, AttributeError, EOFError):
self.log.error(
"Corrupt (or heavily outdated) database found, "
"it will be rebuilt.")
except FileNotFoundError:
pass
self.initialize()
def save(self) -> None:
self.enter_step('save')
with open(self.PATH, 'wb') as db_fdsec:
data = self.domtree, self.asns, self.ip4tree
pickle.dump((self.VERSION, data), db_fdsec)
self.profile()
def __init__(self) -> None:
Profiler.__init__(self)
self.log = logging.getLogger('db')
self.load()
@staticmethod
def pack_zone(zone: str) -> str:
return Database.pack_hostname(zone)
def pack_domain(domain: str) -> DomainPath:
return DomainPath(domain.split('.')[::-1])
@staticmethod
def pack_asn(asn: str) -> int:
def unpack_domain(domain: DomainPath) -> str:
return '.'.join(domain.path[::-1])
@staticmethod
def pack_asn(asn: str) -> AsnPath:
asn = asn.upper()
if asn.startswith('AS'):
asn = asn[2:]
return int(asn)
return AsnPath(int(asn))
@staticmethod
def unpack_asn(asn: int) -> str:
return f'AS{asn}'
def unpack_asn(asn: AsnPath) -> str:
return f'AS{asn.asn}'
@staticmethod
def pack_ip4address(address: str) -> int:
total = 0
for i, octet in enumerate(address.split('.')):
total += int(octet) << (3-i)*8
return total
# return '{:02x}{:02x}{:02x}{:02x}'.format(
# *[int(c) for c in address.split('.')])
# return base64.b16encode(packed).decode()
# return '{:08b}{:08b}{:08b}{:08b}'.format(
# *[int(c) for c in address.split('.')])
# carg = ctypes.c_wchar_p(address)
# ret = ACCEL.ip4_flat(carg, self.accel_ip4_buf)
# if ret != 0:
# raise ValueError
# return self.accel_ip4_buf.value
# packed = ipaddress.ip_address(address).packed
# return packed
def pack_ip4address(address: str) -> Ip4Path:
addr = 0
for split in address.split('.'):
addr = (addr << 8) + int(split)
return Ip4Path(addr, 32)
@staticmethod
def unpack_ip4address(address: int) -> str:
return '.'.join(str((address >> (i * 8)) & 0xFF)
for i in reversed(range(4)))
def unpack_ip4address(address: Ip4Path) -> str:
addr = address.value
assert address.prefixlen == 32
octets: typing.List[int] = list()
octets = [0] * 4
for o in reversed(range(4)):
octets[o] = addr & 0xFF
addr >>= 8
return '.'.join(map(str, octets))
@staticmethod
def pack_ip4network(network: str) -> typing.Tuple[int, int]:
# def pack_ip4network(network: str) -> str:
net = ipaddress.ip_network(network)
mini = Database.pack_ip4address(net.network_address.exploded)
maxi = Database.pack_ip4address(net.broadcast_address.exploded)
# mini = net.network_address.packed
# maxi = net.broadcast_address.packed
return mini, maxi
# return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen]
def pack_ip4network(network: str) -> Ip4Path:
address, prefixlen_str = network.split('/')
prefixlen = int(prefixlen_str)
addr = Database.pack_ip4address(address)
addr.prefixlen = prefixlen
return addr
@staticmethod
def unpack_ip4network(mini: int, maxi: int) -> str:
addr = Database.unpack_ip4address(mini)
prefixlen = 32-int(math.log2(maxi-mini+1))
return f'{addr}/{prefixlen}'
def unpack_ip4network(network: Ip4Path) -> str:
addr = network.value
octets: typing.List[int] = list()
octets = [0] * 4
for o in reversed(range(4)):
octets[o] = addr & 0xFF
addr >>= 8
return '.'.join(map(str, octets)) + '/' + str(network.prefixlen)
def exec_each_domain(self,
callback: NodeCallable,
arg: typing.Any = None,
_dic: DomainTreeNode = None,
_par: DomainPath = None,
) -> typing.Any:
_dic = _dic or self.domtree
_par = _par or DomainPath([])
yield from callback(_par, _dic, arg)
for part in _dic.children:
dic = _dic.children[part]
yield from self.exec_each_domain(
callback,
arg,
_dic=dic,
_par=DomainPath(_par.path + [part])
)
def exec_each_ip4(self,
callback: NodeCallable,
arg: typing.Any = None,
_dic: IpTreeNode = None,
_par: Ip4Path = None,
) -> typing.Any:
_dic = _dic or self.ip4tree
_par = _par or Ip4Path(0, 0)
callback(_par, _dic, arg)
# 0
dic = _dic.children[0]
if dic:
addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-_par.prefixlen)))
assert addr0 == _par.value
yield from self.exec_each_ip4(
callback,
arg,
_dic=dic,
_par=Ip4Path(addr0, _par.prefixlen+1)
)
# 1
dic = _dic.children[1]
if dic:
addr1 = _par.value | (1 << (32-_par.prefixlen))
yield from self.exec_each_ip4(
callback,
arg,
_dic=dic,
_par=Ip4Path(addr1, _par.prefixlen+1)
)
def exec_each(self,
callback: NodeCallable,
arg: typing.Any = None,
) -> typing.Any:
yield from self.exec_each_domain(callback)
yield from self.exec_each_ip4(callback)
def update_references(self) -> None:
self.enter_step('update_refs')
cursor = self.conn.cursor()
cursor.execute('UPDATE rules AS r SET refs='
'(SELECT count(*) FROM rules '
'WHERE source=r.id)')
raise NotImplementedError
def prune(self, before: int) -> None:
self.enter_step('prune')
cursor = self.conn.cursor()
cursor.execute('DELETE FROM rules WHERE updated<?', (before,))
def prune(self, before: int, base_only: bool = False) -> None:
raise NotImplementedError
def explain(self, entry: int) -> str:
# Format current
string = '???'
cursor = self.conn.cursor()
cursor.execute(
'SELECT unpack_asn(val) FROM asn WHERE entry=:entry '
'UNION '
'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry '
'UNION '
'SELECT format_zone(val) FROM zone WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4network(mini, maxi) '
'FROM ip4network WHERE entry=:entry ',
{"entry": entry}
)
for val, in cursor: # only one
string = str(val)
string += f' #{entry}'
# Add source if any
cursor.execute('SELECT source FROM rules WHERE id=?', (entry,))
for source, in cursor:
if source:
string += f'{self.explain(source)}'
return string
raise NotImplementedError
def export(self,
first_party_only: bool = False,
end_chain_only: bool = False,
explain: bool = False,
) -> typing.Iterable[str]:
selection = 'entry' if explain else 'unpack_domain(val)'
command = f'SELECT {selection} FROM rules ' \
'INNER JOIN hostname ON rules.id = hostname.entry'
restrictions: typing.List[str] = list()
if first_party_only:
restrictions.append('rules.first_party = 1')
if end_chain_only:
restrictions.append('rules.refs = 0')
if restrictions:
command += ' WHERE ' + ' AND '.join(restrictions)
if not explain:
command += ' ORDER BY unpack_domain(val) ASC'
cursor = self.conn.cursor()
cursor.execute(command)
for val, in cursor:
if explain:
yield self.explain(val)
else:
yield val
if first_party_only or end_chain_only or explain:
raise NotImplementedError
def export_cb(path: Path, node: Node, _: typing.Any
) -> typing.Iterable[str]:
assert isinstance(path, DomainPath)
assert isinstance(node, DomainTreeNode)
if node.match_hostname:
a = self.unpack_domain(path)
yield a
yield from self.exec_each_domain(export_cb, None)
def count_rules(self,
first_party_only: bool = False,
) -> str:
counts: typing.List[str] = list()
cursor = self.conn.cursor()
for table in ['asn', 'ip4network', 'ip4address', 'zone', 'hostname']:
command = f'SELECT count(*) FROM rules ' \
f'INNER JOIN {table} ON rules.id = {table}.entry ' \
'WHERE rules.level = 0'
if first_party_only:
command += ' AND first_party=1'
cursor.execute(command)
count, = cursor.fetchone()
if count > 0:
counts.append(f'{table}: {count}')
raise NotImplementedError
return ', '.join(counts)
def get_domain(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domain_prepare')
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domain_select')
cursor.execute(
'SELECT null, entry FROM hostname '
'WHERE val=:d '
'UNION '
'SELECT * FROM ('
'SELECT val, entry FROM zone '
'WHERE val<=:d '
'ORDER BY val DESC LIMIT 1'
')',
{'d': domain_prep}
)
for val, entry in cursor:
self.enter_step('get_domain_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]:
self.enter_step('get_domain_pack')
domain = self.pack_domain(domain_str)
self.enter_step('get_domain_brws')
dic = self.domtree
depth = 0
for part in domain.path:
if dic.match_zone.active():
self.enter_step('get_domain_yield')
yield ZonePath(domain.path[:depth])
self.enter_step('get_domain_brws')
if part not in dic.children:
return
dic = dic.children[part]
depth += 1
if dic.match_zone.active():
self.enter_step('get_domain_yield')
yield entry
yield ZonePath(domain.path)
if dic.match_hostname.active():
self.enter_step('get_domain_yield')
yield HostnamePath(domain.path)
def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domainiz_prepare')
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domainiz_select')
cursor.execute(
'SELECT val, entry FROM zone '
'WHERE val<=:d '
'ORDER BY val DESC LIMIT 1',
{'d': domain_prep}
)
for val, entry in cursor:
self.enter_step('get_domainiz_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
self.enter_step('get_domainiz_yield')
yield entry
def get_ip4(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4_prepare')
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return
cursor = self.conn.cursor()
self.enter_step('get_ip4_select')
cursor.execute(
'SELECT entry FROM ip4address '
# 'SELECT null, entry FROM ip4address '
'WHERE val=:a '
'UNION '
# 'SELECT * FROM ('
# 'SELECT val, entry FROM ip4network '
# 'WHERE val<=:a '
# 'AND instr(:a, val) > 0 '
# 'ORDER BY val DESC'
# ')'
'SELECT entry FROM ip4network '
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for entry, in cursor:
# self.enter_step('get_ip4_confirm')
# if not (val is None or val.startswith(address_prep)):
# # PERF startswith but from the end
# continue
def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]:
self.enter_step('get_ip4_pack')
ip4 = self.pack_ip4address(ip4_str)
self.enter_step('get_ip4_brws')
dic = self.ip4tree
for i in reversed(range(ip4.prefixlen)):
part = (ip4.value >> i) & 0b1
if dic.match.active():
self.enter_step('get_ip4_yield')
yield Ip4Path(ip4.value, 32-i)
self.enter_step('get_ip4_brws')
next_dic = dic.children[part]
if next_dic is None:
return
dic = next_dic
if dic.match.active():
self.enter_step('get_ip4_yield')
yield entry
yield ip4
def get_ip4_in_network(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4in_prepare')
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return
cursor = self.conn.cursor()
self.enter_step('get_ip4in_select')
cursor.execute(
'SELECT entry FROM ip4network '
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for entry, in cursor:
self.enter_step('get_ip4in_yield')
yield entry
def list_asn(self) -> typing.Iterable[AsnPath]:
for asn in self.asns:
yield AsnPath(asn)
def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]:
cursor = self.conn.cursor()
self.enter_step('list_asn_select')
cursor.execute('SELECT val, entry FROM asn')
for val, entry in cursor:
yield f'AS{val}', entry
def _set_generic(self,
table: str,
select_query: str,
insert_query: str,
prep: typing.Dict[str, DbValue],
updated: int,
is_first_party: bool = False,
source: int = None,
) -> None:
# Since this isn't the bulk of the processing,
# here abstraction > performaces
# Fields based on the source
self.enter_step(f'set_{table}_prepare')
cursor = self.conn.cursor()
if source is None:
first_party = int(is_first_party)
level = 0
def _set_domain(self,
hostname: bool,
domain_str: str,
updated: int,
is_first_party: bool = None,
source: Path = None) -> None:
self.enter_step('set_domain_pack')
if is_first_party:
raise NotImplementedError
domain = self.pack_domain(domain_str)
self.enter_step('set_domain_brws')
dic = self.domtree
for part in domain.path:
if dic.match_zone.active():
# Refuse to add domain whose zone is already matching
return
if part not in dic.children:
dic.children[part] = DomainTreeNode()
dic = dic.children[part]
if hostname:
match = dic.match_hostname
else:
self.enter_step(f'set_{table}_source')
cursor.execute(
'SELECT first_party, level FROM rules '
'WHERE id=?',
(source,)
)
first_party, level = cursor.fetchone()
level += 1
self.enter_step(f'set_{table}_select')
cursor.execute(select_query, prep)
rules_prep: typing.Dict[str, DbValue] = {
"source": source,
"updated": updated,
"first_party": first_party,
"level": level,
}
# If the entry already exists
for entry, in cursor: # only one
self.enter_step(f'set_{table}_update')
rules_prep['entry'] = entry
cursor.execute(
'UPDATE rules SET '
'source=:source, updated=:updated, '
'first_party=:first_party, level=:level '
'WHERE id=:entry AND (updated<:updated OR '
'first_party<:first_party OR level<:level)',
rules_prep
)
# Only update if any of the following:
# - the entry is outdataed
# - the entry was not a first_party but this is
# - this is closer to the original rule
return
# If it does not exist
self.enter_step(f'set_{table}_insert')
cursor.execute(
'INSERT INTO rules '
'(source, updated, first_party, level) '
'VALUES (:source, :updated, :first_party, :level) ',
rules_prep
)
cursor.execute('SELECT id FROM rules WHERE rowid=?',
(cursor.lastrowid,))
for entry, in cursor: # only one
prep['entry'] = entry
cursor.execute(insert_query, prep)
return
assert False
def set_hostname(self, hostname: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_hostname_prepare')
prep: typing.Dict[str, DbValue] = {
'val': self.pack_hostname(hostname),
}
self._set_generic(
'hostname',
'SELECT entry FROM hostname WHERE val=:val',
'INSERT INTO hostname (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
match = dic.match_zone
match.set(
updated,
0, # TODO Level
source or RulePath(),
)
def set_asn(self, asn: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_asn_prepare')
try:
asn_prep = self.pack_asn(asn)
except ValueError:
self.log.error("Invalid asn: %s", asn)
return
prep: typing.Dict[str, DbValue] = {
'val': asn_prep,
}
self._set_generic(
'asn',
'SELECT entry FROM asn WHERE val=:val',
'INSERT INTO asn (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
def set_hostname(self,
*args: typing.Any, **kwargs: typing.Any
) -> None:
self._set_domain(True, *args, **kwargs)
def set_zone(self,
*args: typing.Any, **kwargs: typing.Any
) -> None:
self._set_domain(False, *args, **kwargs)
def set_asn(self,
asn_str: str,
updated: int,
is_first_party: bool = None,
source: Path = None) -> None:
self.enter_step('set_asn')
if is_first_party:
raise NotImplementedError
path = self.pack_asn(asn_str)
match = AsnNode()
match.set(
updated,
0,
source or RulePath()
)
self.asns[path.asn] = match
def _set_ip4(self,
ip4: Ip4Path,
updated: int,
is_first_party: bool = None,
source: Path = None) -> None:
if is_first_party:
raise NotImplementedError
dic = self.ip4tree
for i in reversed(range(ip4.prefixlen)):
part = (ip4.value >> i) & 0b1
if dic.match.active():
# Refuse to add ip4* whose network is already matching
return
next_dic = dic.children[part]
if next_dic is None:
next_dic = IpTreeNode()
dic.children[part] = next_dic
dic = next_dic
dic.match.set(
updated,
0, # TODO Level
source or RulePath(),
)
def set_ip4address(self, ip4address: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_ip4add_prepare')
try:
ip4address_prep = self.pack_ip4address(ip4address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", ip4address)
return
prep: typing.Dict[str, DbValue] = {
'val': ip4address_prep,
}
self._set_generic(
'ip4add',
'SELECT entry FROM ip4address WHERE val=:val',
'INSERT INTO ip4address (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_ip4address(self,
ip4address_str: str,
*args: typing.Any, **kwargs: typing.Any
) -> None:
self.enter_step('set_ip4add_pack')
ip4 = self.pack_ip4address(ip4address_str)
self.enter_step('set_ip4add_brws')
self._set_ip4(ip4, *args, **kwargs)
def set_zone(self, zone: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_zone_prepare')
prep: typing.Dict[str, DbValue] = {
'val': self.pack_zone(zone),
}
self._set_generic(
'zone',
'SELECT entry FROM zone WHERE val=:val',
'INSERT INTO zone (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_ip4network(self, ip4network: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_ip4net_prepare')
try:
ip4network_prep = self.pack_ip4network(ip4network)
except (ValueError, IndexError):
self.log.error("Invalid ip4network: %s", ip4network)
return
prep: typing.Dict[str, DbValue] = {
'mini': ip4network_prep[0],
'maxi': ip4network_prep[1],
}
self._set_generic(
'ip4net',
'SELECT entry FROM ip4network WHERE mini=:mini AND maxi=:maxi',
'INSERT INTO ip4network (mini, maxi, entry) '
'VALUES (:mini, :maxi, :entry)',
prep,
*args, **kwargs
)
if __name__ == '__main__':
# Parsing arguments
parser = argparse.ArgumentParser(
description="Database operations")
parser.add_argument(
'-i', '--initialize', action='store_true',
help="Reconstruct the whole database")
parser.add_argument(
'-p', '--prune', action='store_true',
help="Remove old (+6 months) entries from database")
parser.add_argument(
'-r', '--references', action='store_true',
help="Update the reference count")
args = parser.parse_args()
DB = Database(write=True)
if args.initialize:
DB.initialize()
if args.prune:
DB.prune(before=int(time.time()) - 60*60*24*31*6)
if args.references and not args.prune:
DB.update_references()
DB.close()
def set_ip4network(self,
ip4network_str: str,
*args: typing.Any, **kwargs: typing.Any
) -> None:
self.enter_step('set_ip4net_pack')
ip4 = self.pack_ip4network(ip4network_str)
self.enter_step('set_ip4net_brws')
self._set_ip4(ip4, *args, **kwargs)

View file

@ -1,59 +0,0 @@
-- Remember to increment DB_VERSION
-- in database.py on changes to this file
CREATE TABLE rules (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source INTEGER, -- The rule this one is based on
updated INTEGER, -- If the row was updated during last data import (0: No, 1: Yes)
first_party INTEGER, -- 1: this blocks a first party for sure, 0: maybe
refs INTEGER, -- Number of entries issued from this one
level INTEGER, -- Level of recursion to the root source rule (used for source priority)
FOREIGN KEY (source) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX rules_source ON rules (source); -- for references recounting
CREATE INDEX rules_updated ON rules (updated); -- for pruning
CREATE INDEX rules_level_firstparty ON rules (level, first_party); -- for counting rules
CREATE TABLE asn (
val INTEGER PRIMARY KEY,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX asn_entry ON asn (entry); -- for explainations
CREATE TABLE hostname (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for consistency with zone)
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX hostname_entry ON hostname (entry); -- for explainations
CREATE TABLE zone (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for easier matching)
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX zone_entry ON zone (entry); -- for explainations
CREATE TABLE ip4address (
val INTEGER PRIMARY KEY,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX ip4address_entry ON ip4address (entry); -- for explainations
CREATE TABLE ip4network (
-- val TEXT PRIMARY KEY,
mini INTEGER,
maxi INTEGER,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX ip4network_minmax ON ip4network (mini, maxi);
CREATE INDEX ip4network_entry ON ip4network (entry); -- for explainations
-- Store various things
CREATE TABLE meta (
key TEXT PRIMARY KEY,
value integer
);

View file

@ -45,5 +45,3 @@ if __name__ == '__main__':
explain=args.explain,
):
print(domain, file=args.output)
DB.close()

View file

@ -31,23 +31,22 @@ if __name__ == '__main__':
args = parser.parse_args()
DB = database.Database()
DBW = database.Database(write=True)
for asn, entry in DB.list_asn():
for path in DB.list_asn():
asn_str = database.Database.unpack_asn(path)
DB.enter_step('asn_get_ranges')
for prefix in get_ranges(asn):
for prefix in get_ranges(asn_str):
parsed_prefix: IPNetwork = ipaddress.ip_network(prefix)
if parsed_prefix.version == 4:
DBW.set_ip4network(
DB.set_ip4network(
prefix,
source=entry,
source=path,
updated=int(time.time())
)
log.info('Added %s from %s (id=%s)', prefix, asn, entry)
log.info('Added %s from %s (%s)', prefix, asn_str, path)
elif parsed_prefix.version == 6:
log.warning('Unimplemented prefix version: %s', prefix)
else:
log.error('Unknown prefix version: %s', prefix)
DB.close()
DBW.close()
DB.save()

147
feed_dns.old.py Executable file
View file

@ -0,0 +1,147 @@
#!/usr/bin/env python3
import argparse
import database
import logging
import sys
import typing
import enum
RecordType = enum.Enum('RecordType', 'A AAAA CNAME PTR')
Record = typing.Tuple[RecordType, int, str, str]
# select, write
FUNCTION_MAP: typing.Any = {
RecordType.A: (
database.Database.get_ip4,
database.Database.set_hostname,
),
RecordType.CNAME: (
database.Database.get_domain,
database.Database.set_hostname,
),
RecordType.PTR: (
database.Database.get_domain,
database.Database.set_ip4address,
),
}
class Parser():
def __init__(self, buf: typing.Any) -> None:
self.buf = buf
self.log = logging.getLogger('parser')
self.db = database.Database()
def end(self) -> None:
self.db.save()
def register(self,
rtype: RecordType,
updated: int,
name: str,
value: str
) -> None:
self.db.enter_step('register')
select, write = FUNCTION_MAP[rtype]
for source in select(self.db, value):
# write(self.db, name, updated, source=source)
write(self.db, name, updated)
def consume(self) -> None:
raise NotImplementedError
class Rapid7Parser(Parser):
TYPES = {
'a': RecordType.A,
'aaaa': RecordType.AAAA,
'cname': RecordType.CNAME,
'ptr': RecordType.PTR,
}
def consume(self) -> None:
data = dict()
for line in self.buf:
self.db.enter_step('parse_rapid7')
split = line.split('"')
for k in range(1, 14, 4):
key = split[k]
val = split[k+2]
data[key] = val
self.register(
Rapid7Parser.TYPES[data['type']],
int(data['timestamp']),
data['name'],
data['value']
)
class DnsMassParser(Parser):
# dnsmass --output Snrql
# --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4
TYPES = {
'A': (RecordType.A, -1, None),
'AAAA': (RecordType.AAAA, -1, None),
'CNAME': (RecordType.CNAME, -1, -1),
}
def consume(self) -> None:
self.db.enter_step('parse_dnsmass')
timestamp = 0
header = True
for line in self.buf:
line = line[:-1]
if not line:
header = True
continue
split = line.split(' ')
try:
if header:
timestamp = int(split[1])
header = False
else:
dtype, name_offset, value_offset = \
DnsMassParser.TYPES[split[1]]
self.register(
dtype,
timestamp,
split[0][:name_offset],
split[2][:value_offset],
)
self.db.enter_step('parse_dnsmass')
except KeyError:
continue
PARSERS = {
'rapid7': Rapid7Parser,
'dnsmass': DnsMassParser,
}
if __name__ == '__main__':
# Parsing arguments
log = logging.getLogger('feed_dns')
args_parser = argparse.ArgumentParser(
description="TODO")
args_parser.add_argument(
'parser',
choices=PARSERS.keys(),
help="TODO")
args_parser.add_argument(
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO")
args = args_parser.parse_args()
parser = PARSERS[args.parser](args.input)
try:
parser.consume()
except KeyboardInterrupt:
pass
parser.end()

View file

@ -1,64 +1,202 @@
#!/usr/bin/env python3
import database
import argparse
import sys
import database
import logging
import csv
import json
import sys
import typing
import multiprocessing
import enum
Record = typing.Tuple[typing.Callable, typing.Callable, int, str, str]
# select, write
FUNCTION_MAP: typing.Any = {
'a': (
database.Database.get_ip4,
database.Database.set_hostname,
),
'cname': (
database.Database.get_domain,
database.Database.set_hostname,
),
'ptr': (
database.Database.get_domain,
database.Database.set_ip4address,
),
}
class Writer(multiprocessing.Process):
def __init__(self,
recs_queue: multiprocessing.Queue,
index: int = 0):
super(Writer, self).__init__()
self.log = logging.getLogger(f'wr')
self.recs_queue = recs_queue
def run(self) -> None:
self.db = database.Database()
self.db.log = logging.getLogger(f'wr')
self.db.enter_step('block_wait')
block: typing.List[Record]
for block in iter(self.recs_queue.get, None):
record: Record
for record in block:
select, write, updated, name, value = record
self.db.enter_step('feed_switch')
try:
for source in select(self.db, value):
# write(self.db, name, updated, source=source)
write(self.db, name, updated)
except ValueError:
self.log.exception("Cannot execute: %s", record)
self.db.enter_step('block_wait')
self.db.enter_step('end')
self.db.save()
class Parser():
def __init__(self,
buf: typing.Any,
recs_queue: multiprocessing.Queue,
block_size: int,
):
super(Parser, self).__init__()
self.buf = buf
self.log = logging.getLogger('pr')
self.recs_queue = recs_queue
self.block: typing.List[Record] = list()
self.block_size = block_size
self.prof = database.Profiler()
self.prof.log = logging.getLogger('pr')
def register(self, record: Record) -> None:
self.prof.enter_step('register')
self.block.append(record)
if len(self.block) >= self.block_size:
self.prof.enter_step('put_block')
self.recs_queue.put(self.block)
self.block = list()
def run(self) -> None:
self.consume()
self.recs_queue.put(self.block)
self.prof.profile()
def consume(self) -> None:
raise NotImplementedError
class Rapid7Parser(Parser):
def consume(self) -> None:
data = dict()
for line in self.buf:
self.prof.enter_step('parse_rapid7')
split = line.split('"')
try:
for k in range(1, 14, 4):
key = split[k]
val = split[k+2]
data[key] = val
select, writer = FUNCTION_MAP[data['type']]
record = (
select,
writer,
int(data['timestamp']),
data['name'],
data['value']
)
except IndexError:
self.log.exception("Cannot parse: %s", line)
self.register(record)
class DnsMassParser(Parser):
# dnsmass --output Snrql
# --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4
TYPES = {
'A': (FUNCTION_MAP['a'][0], FUNCTION_MAP['a'][1], -1, None),
# 'AAAA': (FUNCTION_MAP['aaaa'][0], FUNCTION_MAP['aaaa'][1], -1, None),
'CNAME': (FUNCTION_MAP['cname'][0], FUNCTION_MAP['cname'][1], -1, -1),
}
def consume(self) -> None:
self.prof.enter_step('parse_dnsmass')
timestamp = 0
header = True
for line in self.buf:
line = line[:-1]
if not line:
header = True
continue
split = line.split(' ')
try:
if header:
timestamp = int(split[1])
header = False
else:
select, write, name_offset, value_offset = \
DnsMassParser.TYPES[split[1]]
record = (
select,
write,
timestamp,
split[0][:name_offset],
split[2][:value_offset],
)
self.register(record)
self.prof.enter_step('parse_dnsmass')
except KeyError:
continue
PARSERS = {
'rapid7': Rapid7Parser,
'dnsmass': DnsMassParser,
}
if __name__ == '__main__':
# Parsing arguments
log = logging.getLogger('feed_dns')
parser = argparse.ArgumentParser(
args_parser = argparse.ArgumentParser(
description="TODO")
parser.add_argument(
# '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer,
args_parser.add_argument(
'parser',
choices=PARSERS.keys(),
help="TODO")
args_parser.add_argument(
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO")
args = parser.parse_args()
args_parser.add_argument(
'-j', '--workers', type=int, default=4,
help="TODO")
args_parser.add_argument(
'-b', '--block-size', type=int, default=100,
help="TODO")
args_parser.add_argument(
'-q', '--queue-size', type=int, default=10,
help="TODO")
args = args_parser.parse_args()
DB = database.Database(write=True)
recs_queue: multiprocessing.Queue = multiprocessing.Queue(
maxsize=args.queue_size)
try:
DB.enter_step('iowait')
for row in csv.reader(args.input):
# for line in args.input:
DB.enter_step('feed_csv_parse')
dtype, timestamp, name, value = row
# DB.enter_step('feed_json_parse')
# data = json.loads(line)
# dtype = data['type'][0]
# # timestamp = data['timestamp']
# name = data['name']
# value = data['value']
writer = Writer(recs_queue)
writer.start()
DB.enter_step('feed_switch')
if dtype == 'a':
for rule in DB.get_ip4(value):
if not list(DB.get_domain_in_zone(name)):
parser = PARSERS[args.parser](args.input, recs_queue, args.block_size)
parser.run()
DB.set_hostname(name, source=rule,
updated=int(timestamp))
# updated=int(data['timestamp']))
elif dtype == 'c':
for rule in DB.get_domain(value):
if not list(DB.get_domain_in_zone(name)):
DB.set_hostname(name, source=rule,
updated=int(timestamp))
# updated=int(data['timestamp']))
elif dtype == 'p':
for rule in DB.get_domain(value):
if not list(DB.get_ip4_in_network(name)):
DB.set_ip4address(name, source=rule,
updated=int(timestamp))
# updated=int(data['timestamp']))
else:
raise NotImplementedError(f'Type: {dtype}')
DB.enter_step('iowait')
except KeyboardInterrupt:
log.warning("Interupted.")
pass
DB.close()
recs_queue.put(None)
writer.join()

View file

@ -28,15 +28,15 @@ if __name__ == '__main__':
help="The input only comes from verified first-party sources")
args = parser.parse_args()
DB = database.Database(write=True)
DB = database.Database()
fun = FUNCTION_MAP[args.type]
for rule in args.input:
fun(DB,
rule.strip(),
is_first_party=args.first_party,
# is_first_party=args.first_party,
updated=int(time.time()),
)
DB.close()
DB.save()

View file

@ -18,7 +18,7 @@ log "Retrieving rules…"
rm -f rules*/*.cache.*
dl https://easylist.to/easylist/easyprivacy.txt rules_adblock/easyprivacy.cache.txt
# From firebog.net Tracking & Telemetry Lists
dl https://v.firebog.net/hosts/Prigent-Ads.txt rules/prigent-ads.cache.list
# dl https://v.firebog.net/hosts/Prigent-Ads.txt rules/prigent-ads.cache.list
# dl https://gitlab.com/quidsup/notrack-blocklists/raw/master/notrack-blocklist.txt rules/notrack-blocklist.cache.list
# False positives: https://github.com/WaLLy3K/wally3k.github.io/issues/73 -> 69.media.tumblr.com chicdn.net
dl https://raw.githubusercontent.com/StevenBlack/hosts/master/data/add.2o7Net/hosts rules_hosts/add2o7.cache.txt
@ -51,3 +51,4 @@ then
else
mv temp/cisco-umbrella_popularity.fresh.list subdomains/cisco-umbrella_popularity.cache.list
fi
dl https://www.orwell1984.today/cname/eulerian.net.txt subdomains/orwell-eulerian-cname-list.cache.list

View file

@ -5,11 +5,12 @@ function log() {
}
log "Importing rules…"
cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone
cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone
cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone
cat rules_ip/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network
cat rules_asn/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py asn
BEFORE="$(date +%s)"
# cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone
# cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone
# cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone
# cat rules_ip/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network
# cat rules_asn/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py asn
cat rules/first-party.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone --first-party
cat rules_ip/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network --first-party
@ -17,3 +18,5 @@ cat rules_asn/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py as
./feed_asn.py
log "Pruning old rules…"
./db.py --prune --prune-before "$BEFORE" --prune-base

View file

@ -1,36 +0,0 @@
#!/usr/bin/env python3
import argparse
import sys
import logging
import json
import csv
if __name__ == '__main__':
# Parsing arguments
log = logging.getLogger('json_to_csv')
parser = argparse.ArgumentParser(
description="TODO")
parser.add_argument(
# '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer,
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO")
parser.add_argument(
# '-i', '--output', type=argparse.FileType('wb'), default=sys.stdout.buffer,
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
help="TODO")
args = parser.parse_args()
writer = csv.writer(args.output)
for line in args.input:
data = json.loads(line)
try:
writer.writerow([
data['type'][0],
data['timestamp'],
data['name'],
data['value']])
except IndexError:
log.error('Could not parse line: %s', line)
pass

View file

@ -9,11 +9,11 @@ function log() {
# TODO Fetch 'em
log "Reading PTR records…"
pv ptr.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py
pv ptr.json.gz | gunzip | ./feed_dns.py
log "Reading A records…"
pv a.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py
pv a.json.gz | gunzip | ./feed_dns.py
log "Reading CNAME records…"
pv cname.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py
pv cname.json.gz | gunzip | ./feed_dns.py
log "Pruning old data…"
./database.py --prune

View file

@ -1,264 +0,0 @@
#!/usr/bin/env python3
"""
From a list of subdomains, output only
the ones resolving to a first-party tracker.
"""
import argparse
import logging
import os
import queue
import sys
import threading
import typing
import time
import coloredlogs
import dns.exception
import dns.resolver
DNS_TIMEOUT = 5.0
NUMBER_TRIES = 5
class Worker(threading.Thread):
"""
Worker process for a DNS resolver.
Will resolve DNS to match first-party subdomains.
"""
def change_nameserver(self) -> None:
"""
Assign a this worker another nameserver from the queue.
"""
server = None
while server is None:
try:
server = self.orchestrator.nameservers_queue.get(block=False)
except queue.Empty:
self.orchestrator.refill_nameservers_queue()
self.log.info("Using nameserver: %s", server)
self.resolver.nameservers = [server]
def __init__(self,
orchestrator: 'Orchestrator',
index: int = 0):
super(Worker, self).__init__()
self.log = logging.getLogger(f'worker{index:03d}')
self.orchestrator = orchestrator
self.resolver = dns.resolver.Resolver()
self.change_nameserver()
def resolve_subdomain(self, subdomain: str) -> typing.Optional[
typing.List[
dns.rrset.RRset
]
]:
"""
Returns the resolution chain of the subdomain to an A record,
including any intermediary CNAME.
The last element is an IP address.
Returns None if the nameserver was unable to satisfy the request.
Returns [] if the requests points to nothing.
"""
self.log.debug("Querying %s", subdomain)
try:
query = self.resolver.query(subdomain, 'A', lifetime=DNS_TIMEOUT)
except dns.resolver.NXDOMAIN:
return []
except dns.resolver.NoAnswer:
return []
except dns.resolver.YXDOMAIN:
self.log.warning("Query name too long for %s", subdomain)
return None
except dns.resolver.NoNameservers:
# NOTE Most of the time this error message means that the domain
# does not exists, but sometimes it means the that the server
# itself is broken. So we count on the retry logic.
self.log.warning("All nameservers broken for %s", subdomain)
return None
except dns.exception.Timeout:
# NOTE Same as above
self.log.warning("Timeout for %s", subdomain)
return None
except dns.name.EmptyLabel:
self.log.warning("Empty label for %s", subdomain)
return None
return query.response.answer
def run(self) -> None:
self.log.info("Started")
subdomain: str
for subdomain in iter(self.orchestrator.subdomains_queue.get, None):
for _ in range(NUMBER_TRIES):
resolved = self.resolve_subdomain(subdomain)
# Retry with another nameserver if error
if resolved is None:
self.change_nameserver()
else:
break
# If it wasn't found after multiple tries
if resolved is None:
self.log.error("Gave up on %s", subdomain)
resolved = []
assert isinstance(resolved, list)
self.orchestrator.results_queue.put(resolved)
self.orchestrator.results_queue.put(None)
self.log.info("Stopped")
class Orchestrator():
"""
Orchestrator of the different Worker threads.
"""
def refill_nameservers_queue(self) -> None:
"""
Re-fill the given nameservers into the nameservers queue.
Done every-time the queue is empty, making it
basically looping and infinite.
"""
# Might be in a race condition but that's probably fine
for nameserver in self.nameservers:
self.nameservers_queue.put(nameserver)
self.log.info("Refilled nameserver queue")
def __init__(self, subdomains: typing.Iterable[str],
nameservers: typing.List[str] = None,
nb_workers: int = 1,
):
self.log = logging.getLogger('orchestrator')
self.subdomains = subdomains
self.nb_workers = nb_workers
# Use interal resolver by default
self.nameservers = nameservers or dns.resolver.Resolver().nameservers
self.subdomains_queue: queue.Queue = queue.Queue(
maxsize=self.nb_workers)
self.results_queue: queue.Queue = queue.Queue()
self.nameservers_queue: queue.Queue = queue.Queue()
self.refill_nameservers_queue()
def fill_subdomain_queue(self) -> None:
"""
Read the subdomains in input and put them into the queue.
Done in a thread so we can both:
- yield the results as they come
- not store all the subdomains at once
"""
self.log.info("Started reading subdomains")
# Send data to workers
for subdomain in self.subdomains:
self.subdomains_queue.put(subdomain)
self.log.info("Finished reading subdomains")
# Send sentinel to each worker
# sentinel = None ~= EOF
for _ in range(self.nb_workers):
self.subdomains_queue.put(None)
@staticmethod
def format_rrset(rrset: dns.rrset.RRset) -> typing.Iterable[str]:
if rrset.rdtype == dns.rdatatype.CNAME:
dtype = 'c'
elif rrset.rdtype == dns.rdatatype.A:
dtype = 'a'
else:
raise NotImplementedError
name = rrset.name.to_text()[:-1]
for item in rrset.items:
value = item.to_text()
if rrset.rdtype == dns.rdatatype.CNAME:
value = value[:-1]
yield f'{dtype},{int(time.time())},{name},{value}\n'
def run(self) -> typing.Iterable[str]:
"""
Yield the results.
"""
# Create workers
self.log.info("Creating workers")
for i in range(self.nb_workers):
Worker(self, i).start()
fill_thread = threading.Thread(target=self.fill_subdomain_queue)
fill_thread.start()
# Wait for one sentinel per worker
# In the meantime output results
for _ in range(self.nb_workers):
resolved: typing.List[dns.rrset.RRset]
for resolved in iter(self.results_queue.get, None):
for rrset in resolved:
yield from self.format_rrset(rrset)
self.log.info("Waiting for reader thread")
fill_thread.join()
self.log.info("Done!")
def main() -> None:
"""
Main function when used directly.
Read the subdomains provided and output it,
the last CNAME resolved and the IP adress it resolves to.
Takes as an input a filename (or nothing, for stdin),
and as an output a filename (or nothing, for stdout).
The input must be a subdomain per line, the output is a TODO
Use the file `nameservers` as the list of nameservers
to use, or else it will use the system defaults.
"""
# Initialization
coloredlogs.install(
level='DEBUG',
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
)
# Parsing arguments
parser = argparse.ArgumentParser(
description="Massively resolves subdomains and store them in a file.")
parser.add_argument(
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="Input file with one subdomain per line")
parser.add_argument(
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
help="Outptut file with DNS chains")
parser.add_argument(
'-n', '--nameservers', default='nameservers',
help="File with one nameserver per line")
parser.add_argument(
'-j', '--workers', type=int, default=512,
help="Number of threads to use")
args = parser.parse_args()
# Cleaning input
iterator = iter(args.input)
iterator = map(str.strip, iterator)
iterator = filter(None, iterator)
# Reading nameservers
servers: typing.List[str] = list()
if os.path.isfile(args.nameservers):
servers = open(args.nameservers).readlines()
servers = list(filter(None, map(str.strip, servers)))
for resolved in Orchestrator(
iterator,
servers,
nb_workers=args.workers
).run():
args.output.write(resolved)
if __name__ == '__main__':
main()