Compare commits

..

15 commits

Author SHA1 Message Date
Geoffrey Frogeye a0e68f0848
Reworked match and node system
For level, and first_party later
Next: add get_match to retrieve level of source and have correct levels

... am I going somewhere with all this?
2019-12-15 23:13:25 +01:00
Geoffrey Frogeye aec8d3f8de
Reworked how paths work
Get those tuples out of my eyes
2019-12-15 22:21:05 +01:00
Geoffrey Frogeye 7af2074c7a
Small optimisation of feed_switch 2019-12-15 17:12:44 +01:00
Geoffrey Frogeye 45325782d2
Multi-processed parser 2019-12-15 17:05:41 +01:00
Geoffrey Frogeye ce52897d30
Smol fixes 2019-12-15 16:48:17 +01:00
Geoffrey Frogeye 954b33b2a6
Slightly better Rapid7 parser 2019-12-15 16:38:01 +01:00
Geoffrey Frogeye d976752797
Store Ip4Path as int instead of List[int] 2019-12-15 16:26:18 +01:00
Geoffrey Frogeye 4d966371b2
Workflow: SQL -> Tree
Welp. All that for this.
2019-12-15 15:56:26 +01:00
Geoffrey Frogeye 040ce4c14e
Typo in source 2019-12-15 01:52:45 +01:00
Geoffrey Frogeye b50c01f740 Merge branch 'master' into newworkflow 2019-12-15 01:30:03 +01:00
Geoffrey Frogeye ddceed3d25
Workflow: Can now import DnsMass output
Well, in a specific format but DnsMass nonetheless
2019-12-15 00:28:08 +01:00
Geoffrey Frogeye 189deeb559
Workflow: Multiprocess
Still trying.
It's better than multithread though.

Merge branch 'newworkflow' into newworkflow_threaded
2019-12-14 17:27:46 +01:00
Geoffrey Frogeye d7c239a6f6 Workflow: Some modifications 2019-12-14 16:04:19 +01:00
Geoffrey Frogeye 231bb83667
Threaded feed_dns
Largely disapointing
2019-12-13 12:36:11 +01:00
Geoffrey Frogeye 12dcafe606
Added alternate source of Eulerian CNAMES
It was requested so.
It should be temporary, once I have a bigger subdomain list
that shouldn't be required.
2019-12-12 19:13:54 +01:00
13 changed files with 726 additions and 899 deletions

5
.gitignore vendored
View file

@ -1,7 +1,4 @@
*.log *.log
*.db *.p
*.db-journal
nameservers nameservers
nameservers.head nameservers.head
*.o
*.so

823
database.py Executable file → Normal file
View file

@ -4,111 +4,115 @@
Utility functions to interact with the database. Utility functions to interact with the database.
""" """
import sqlite3
import typing import typing
import time import time
import os
import logging import logging
import argparse
import coloredlogs import coloredlogs
import ipaddress import pickle
import math
coloredlogs.install( coloredlogs.install(
level='DEBUG', level='DEBUG',
fmt='%(asctime)s %(name)s %(levelname)s %(message)s' fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
) )
DbValue = typing.Union[None, int, float, str, bytes] Asn = int
Timestamp = int
Level = int
class Database(): class Path():
VERSION = 5 # FP add boolean here
PATH = "blocking.db" pass
def open(self) -> None:
mode = 'rwc' if self.write else 'ro'
uri = f'file:{self.PATH}?mode={mode}'
self.conn = sqlite3.connect(uri, uri=True)
cursor = self.conn.cursor()
cursor.execute("PRAGMA foreign_keys = ON")
self.conn.create_function("unpack_asn", 1,
self.unpack_asn,
deterministic=True)
self.conn.create_function("unpack_ip4address", 1,
self.unpack_ip4address,
deterministic=True)
self.conn.create_function("unpack_ip4network", 2,
self.unpack_ip4network,
deterministic=True)
self.conn.create_function("unpack_domain", 1,
lambda s: s[:-1][::-1],
deterministic=True)
self.conn.create_function("format_zone", 1,
lambda s: '*' + s[::-1],
deterministic=True)
def get_meta(self, key: str) -> typing.Optional[int]: class RulePath(Path):
cursor = self.conn.cursor() pass
try:
cursor.execute("SELECT value FROM meta WHERE key=?", (key,))
except sqlite3.OperationalError:
return None
for ver, in cursor:
return ver
return None
def set_meta(self, key: str, val: int) -> None:
cursor = self.conn.cursor()
cursor.execute("INSERT INTO meta VALUES (?, ?) "
"ON CONFLICT (key) DO "
"UPDATE set value=?",
(key, val, val))
def close(self) -> None: class DomainPath(Path):
self.enter_step('close_commit') def __init__(self, path: typing.List[str]):
self.conn.commit() self.path = path
self.enter_step('close')
self.conn.close()
self.profile()
def initialize(self) -> None:
self.close()
self.enter_step('initialize')
if not self.write:
self.log.error("Cannot initialize in read-only mode.")
raise
os.unlink(self.PATH)
self.open()
self.log.info("Creating database version %d.", self.VERSION)
cursor = self.conn.cursor()
with open("database_schema.sql", 'r') as db_schema:
cursor.executescript(db_schema.read())
self.set_meta('version', self.VERSION)
self.conn.commit()
def __init__(self, write: bool = False) -> None: class HostnamePath(DomainPath):
self.log = logging.getLogger('db') pass
class ZonePath(DomainPath):
pass
class AsnPath(Path):
def __init__(self, asn: Asn):
self.asn = asn
class Ip4Path(Path):
def __init__(self, value: int, prefixlen: int):
self.value = value
self.prefixlen = prefixlen
class Match():
def __init__(self) -> None:
self.updated: int = 0
self.level: int = 0
self.source: Path = RulePath()
# FP dupplicate args
def set(self,
updated: int,
level: int,
source: Path,
) -> None:
if updated > self.updated or level > self.level:
self.updated = updated
self.level = level
self.source = source
# FP dupplicate function
def active(self) -> bool:
return self.updated > 0
class AsnNode(Match):
pass
class DomainTreeNode():
def __init__(self) -> None:
self.children: typing.Dict[str, DomainTreeNode] = dict()
self.match_zone = Match()
self.match_hostname = Match()
class IpTreeNode():
def __init__(self) -> None:
self.children: typing.List[typing.Optional[IpTreeNode]] = [None, None]
self.match = Match()
Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode]
NodeCallable = typing.Callable[[Path,
Node,
typing.Optional[typing.Any]],
typing.Any]
class Profiler():
def __init__(self) -> None:
self.log = logging.getLogger('profiler')
self.time_last = time.perf_counter() self.time_last = time.perf_counter()
self.time_step = 'init' self.time_step = 'init'
self.time_dict: typing.Dict[str, float] = dict() self.time_dict: typing.Dict[str, float] = dict()
self.step_dict: typing.Dict[str, int] = dict() self.step_dict: typing.Dict[str, int] = dict()
self.write = write
self.open()
version = self.get_meta('version')
if version != self.VERSION:
if version is not None:
self.log.warning(
"Outdated database version: %d found, will be rebuilt.",
version)
self.initialize()
def enter_step(self, name: str) -> None: def enter_step(self, name: str) -> None:
return
now = time.perf_counter() now = time.perf_counter()
try: try:
self.time_dict[self.time_step] += now - self.time_last self.time_dict[self.time_step] += now - self.time_last
self.step_dict[self.time_step] += 1 self.step_dict[self.time_step] += int(name != self.time_step)
except KeyError: except KeyError:
self.time_dict[self.time_step] = now - self.time_last self.time_dict[self.time_step] = now - self.time_last
self.step_dict[self.time_step] = 1 self.step_dict[self.time_step] = 1
@ -125,435 +129,334 @@ class Database():
self.log.debug(f"{'total':<20}: " self.log.debug(f"{'total':<20}: "
f"{total:9.2f} s ({1:7.2%})") f"{total:9.2f} s ({1:7.2%})")
@staticmethod
def pack_hostname(hostname: str) -> str: class Database(Profiler):
return hostname[::-1] + '.' VERSION = 10
PATH = "blocking.p"
def initialize(self) -> None:
self.log.warning(
"Creating database version: %d ",
Database.VERSION)
self.domtree = DomainTreeNode()
self.asns: typing.Dict[Asn, AsnNode] = dict()
self.ip4tree = IpTreeNode()
def load(self) -> None:
self.enter_step('load')
try:
with open(self.PATH, 'rb') as db_fdsec:
version, data = pickle.load(db_fdsec)
if version == Database.VERSION:
self.domtree, self.asns, self.ip4tree = data
return
self.log.warning(
"Outdated database version found: %d, "
"it will be rebuilt.",
version)
except (TypeError, AttributeError, EOFError):
self.log.error(
"Corrupt (or heavily outdated) database found, "
"it will be rebuilt.")
except FileNotFoundError:
pass
self.initialize()
def save(self) -> None:
self.enter_step('save')
with open(self.PATH, 'wb') as db_fdsec:
data = self.domtree, self.asns, self.ip4tree
pickle.dump((self.VERSION, data), db_fdsec)
self.profile()
def __init__(self) -> None:
Profiler.__init__(self)
self.log = logging.getLogger('db')
self.load()
@staticmethod @staticmethod
def pack_zone(zone: str) -> str: def pack_domain(domain: str) -> DomainPath:
return Database.pack_hostname(zone) return DomainPath(domain.split('.')[::-1])
@staticmethod @staticmethod
def pack_asn(asn: str) -> int: def unpack_domain(domain: DomainPath) -> str:
return '.'.join(domain.path[::-1])
@staticmethod
def pack_asn(asn: str) -> AsnPath:
asn = asn.upper() asn = asn.upper()
if asn.startswith('AS'): if asn.startswith('AS'):
asn = asn[2:] asn = asn[2:]
return int(asn) return AsnPath(int(asn))
@staticmethod @staticmethod
def unpack_asn(asn: int) -> str: def unpack_asn(asn: AsnPath) -> str:
return f'AS{asn}' return f'AS{asn.asn}'
@staticmethod @staticmethod
def pack_ip4address(address: str) -> int: def pack_ip4address(address: str) -> Ip4Path:
total = 0 addr = 0
for i, octet in enumerate(address.split('.')): for split in address.split('.'):
total += int(octet) << (3-i)*8 addr = (addr << 8) + int(split)
return total return Ip4Path(addr, 32)
# return '{:02x}{:02x}{:02x}{:02x}'.format(
# *[int(c) for c in address.split('.')])
# return base64.b16encode(packed).decode()
# return '{:08b}{:08b}{:08b}{:08b}'.format(
# *[int(c) for c in address.split('.')])
# carg = ctypes.c_wchar_p(address)
# ret = ACCEL.ip4_flat(carg, self.accel_ip4_buf)
# if ret != 0:
# raise ValueError
# return self.accel_ip4_buf.value
# packed = ipaddress.ip_address(address).packed
# return packed
@staticmethod @staticmethod
def unpack_ip4address(address: int) -> str: def unpack_ip4address(address: Ip4Path) -> str:
return '.'.join(str((address >> (i * 8)) & 0xFF) addr = address.value
for i in reversed(range(4))) assert address.prefixlen == 32
octets: typing.List[int] = list()
octets = [0] * 4
for o in reversed(range(4)):
octets[o] = addr & 0xFF
addr >>= 8
return '.'.join(map(str, octets))
@staticmethod @staticmethod
def pack_ip4network(network: str) -> typing.Tuple[int, int]: def pack_ip4network(network: str) -> Ip4Path:
# def pack_ip4network(network: str) -> str: address, prefixlen_str = network.split('/')
net = ipaddress.ip_network(network) prefixlen = int(prefixlen_str)
mini = Database.pack_ip4address(net.network_address.exploded) addr = Database.pack_ip4address(address)
maxi = Database.pack_ip4address(net.broadcast_address.exploded) addr.prefixlen = prefixlen
# mini = net.network_address.packed return addr
# maxi = net.broadcast_address.packed
return mini, maxi
# return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen]
@staticmethod @staticmethod
def unpack_ip4network(mini: int, maxi: int) -> str: def unpack_ip4network(network: Ip4Path) -> str:
addr = Database.unpack_ip4address(mini) addr = network.value
prefixlen = 32-int(math.log2(maxi-mini+1)) octets: typing.List[int] = list()
return f'{addr}/{prefixlen}' octets = [0] * 4
for o in reversed(range(4)):
octets[o] = addr & 0xFF
addr >>= 8
return '.'.join(map(str, octets)) + '/' + str(network.prefixlen)
def exec_each_domain(self,
callback: NodeCallable,
arg: typing.Any = None,
_dic: DomainTreeNode = None,
_par: DomainPath = None,
) -> typing.Any:
_dic = _dic or self.domtree
_par = _par or DomainPath([])
yield from callback(_par, _dic, arg)
for part in _dic.children:
dic = _dic.children[part]
yield from self.exec_each_domain(
callback,
arg,
_dic=dic,
_par=DomainPath(_par.path + [part])
)
def exec_each_ip4(self,
callback: NodeCallable,
arg: typing.Any = None,
_dic: IpTreeNode = None,
_par: Ip4Path = None,
) -> typing.Any:
_dic = _dic or self.ip4tree
_par = _par or Ip4Path(0, 0)
callback(_par, _dic, arg)
# 0
dic = _dic.children[0]
if dic:
addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-_par.prefixlen)))
assert addr0 == _par.value
yield from self.exec_each_ip4(
callback,
arg,
_dic=dic,
_par=Ip4Path(addr0, _par.prefixlen+1)
)
# 1
dic = _dic.children[1]
if dic:
addr1 = _par.value | (1 << (32-_par.prefixlen))
yield from self.exec_each_ip4(
callback,
arg,
_dic=dic,
_par=Ip4Path(addr1, _par.prefixlen+1)
)
def exec_each(self,
callback: NodeCallable,
arg: typing.Any = None,
) -> typing.Any:
yield from self.exec_each_domain(callback)
yield from self.exec_each_ip4(callback)
def update_references(self) -> None: def update_references(self) -> None:
self.enter_step('update_refs') raise NotImplementedError
cursor = self.conn.cursor()
cursor.execute('UPDATE rules AS r SET refs='
'(SELECT count(*) FROM rules '
'WHERE source=r.id)')
def prune(self, before: int) -> None: def prune(self, before: int, base_only: bool = False) -> None:
self.enter_step('prune') raise NotImplementedError
cursor = self.conn.cursor()
cursor.execute('DELETE FROM rules WHERE updated<?', (before,))
def explain(self, entry: int) -> str: def explain(self, entry: int) -> str:
# Format current raise NotImplementedError
string = '???'
cursor = self.conn.cursor()
cursor.execute(
'SELECT unpack_asn(val) FROM asn WHERE entry=:entry '
'UNION '
'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry '
'UNION '
'SELECT format_zone(val) FROM zone WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4network(mini, maxi) '
'FROM ip4network WHERE entry=:entry ',
{"entry": entry}
)
for val, in cursor: # only one
string = str(val)
string += f' #{entry}'
# Add source if any
cursor.execute('SELECT source FROM rules WHERE id=?', (entry,))
for source, in cursor:
if source:
string += f'{self.explain(source)}'
return string
def export(self, def export(self,
first_party_only: bool = False, first_party_only: bool = False,
end_chain_only: bool = False, end_chain_only: bool = False,
explain: bool = False, explain: bool = False,
) -> typing.Iterable[str]: ) -> typing.Iterable[str]:
selection = 'entry' if explain else 'unpack_domain(val)' if first_party_only or end_chain_only or explain:
command = f'SELECT {selection} FROM rules ' \ raise NotImplementedError
'INNER JOIN hostname ON rules.id = hostname.entry'
restrictions: typing.List[str] = list() def export_cb(path: Path, node: Node, _: typing.Any
if first_party_only: ) -> typing.Iterable[str]:
restrictions.append('rules.first_party = 1') assert isinstance(path, DomainPath)
if end_chain_only: assert isinstance(node, DomainTreeNode)
restrictions.append('rules.refs = 0') if node.match_hostname:
if restrictions: a = self.unpack_domain(path)
command += ' WHERE ' + ' AND '.join(restrictions) yield a
if not explain:
command += ' ORDER BY unpack_domain(val) ASC' yield from self.exec_each_domain(export_cb, None)
cursor = self.conn.cursor()
cursor.execute(command)
for val, in cursor:
if explain:
yield self.explain(val)
else:
yield val
def count_rules(self, def count_rules(self,
first_party_only: bool = False, first_party_only: bool = False,
) -> str: ) -> str:
counts: typing.List[str] = list() raise NotImplementedError
cursor = self.conn.cursor()
for table in ['asn', 'ip4network', 'ip4address', 'zone', 'hostname']:
command = f'SELECT count(*) FROM rules ' \
f'INNER JOIN {table} ON rules.id = {table}.entry ' \
'WHERE rules.level = 0'
if first_party_only:
command += ' AND first_party=1'
cursor.execute(command)
count, = cursor.fetchone()
if count > 0:
counts.append(f'{table}: {count}')
return ', '.join(counts) def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]:
self.enter_step('get_domain_pack')
def get_domain(self, domain: str) -> typing.Iterable[int]: domain = self.pack_domain(domain_str)
self.enter_step('get_domain_prepare') self.enter_step('get_domain_brws')
domain_prep = self.pack_hostname(domain) dic = self.domtree
cursor = self.conn.cursor() depth = 0
self.enter_step('get_domain_select') for part in domain.path:
cursor.execute( if dic.match_zone.active():
'SELECT null, entry FROM hostname '
'WHERE val=:d '
'UNION '
'SELECT * FROM ('
'SELECT val, entry FROM zone '
'WHERE val<=:d '
'ORDER BY val DESC LIMIT 1'
')',
{'d': domain_prep}
)
for val, entry in cursor:
self.enter_step('get_domain_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
self.enter_step('get_domain_yield') self.enter_step('get_domain_yield')
yield entry yield ZonePath(domain.path[:depth])
self.enter_step('get_domain_brws')
def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]: if part not in dic.children:
self.enter_step('get_domainiz_prepare')
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domainiz_select')
cursor.execute(
'SELECT val, entry FROM zone '
'WHERE val<=:d '
'ORDER BY val DESC LIMIT 1',
{'d': domain_prep}
)
for val, entry in cursor:
self.enter_step('get_domainiz_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
self.enter_step('get_domainiz_yield')
yield entry
def get_ip4(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4_prepare')
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return return
cursor = self.conn.cursor() dic = dic.children[part]
self.enter_step('get_ip4_select') depth += 1
cursor.execute( if dic.match_zone.active():
'SELECT entry FROM ip4address ' self.enter_step('get_domain_yield')
# 'SELECT null, entry FROM ip4address ' yield ZonePath(domain.path)
'WHERE val=:a ' if dic.match_hostname.active():
'UNION ' self.enter_step('get_domain_yield')
# 'SELECT * FROM (' yield HostnamePath(domain.path)
# 'SELECT val, entry FROM ip4network '
# 'WHERE val<=:a ' def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]:
# 'AND instr(:a, val) > 0 ' self.enter_step('get_ip4_pack')
# 'ORDER BY val DESC' ip4 = self.pack_ip4address(ip4_str)
# ')' self.enter_step('get_ip4_brws')
'SELECT entry FROM ip4network ' dic = self.ip4tree
'WHERE :a BETWEEN mini AND maxi ', for i in reversed(range(ip4.prefixlen)):
{'a': address_prep} part = (ip4.value >> i) & 0b1
) if dic.match.active():
for entry, in cursor:
# self.enter_step('get_ip4_confirm')
# if not (val is None or val.startswith(address_prep)):
# # PERF startswith but from the end
# continue
self.enter_step('get_ip4_yield') self.enter_step('get_ip4_yield')
yield entry yield Ip4Path(ip4.value, 32-i)
self.enter_step('get_ip4_brws')
def get_ip4_in_network(self, address: str) -> typing.Iterable[int]: next_dic = dic.children[part]
self.enter_step('get_ip4in_prepare') if next_dic is None:
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return return
cursor = self.conn.cursor() dic = next_dic
self.enter_step('get_ip4in_select') if dic.match.active():
cursor.execute( self.enter_step('get_ip4_yield')
'SELECT entry FROM ip4network ' yield ip4
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for entry, in cursor:
self.enter_step('get_ip4in_yield')
yield entry
def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]: def list_asn(self) -> typing.Iterable[AsnPath]:
cursor = self.conn.cursor() for asn in self.asns:
self.enter_step('list_asn_select') yield AsnPath(asn)
cursor.execute('SELECT val, entry FROM asn')
for val, entry in cursor:
yield f'AS{val}', entry
def _set_generic(self, def _set_domain(self,
table: str, hostname: bool,
select_query: str, domain_str: str,
insert_query: str,
prep: typing.Dict[str, DbValue],
updated: int, updated: int,
is_first_party: bool = False, is_first_party: bool = None,
source: int = None, source: Path = None) -> None:
) -> None: self.enter_step('set_domain_pack')
# Since this isn't the bulk of the processing, if is_first_party:
# here abstraction > performaces raise NotImplementedError
domain = self.pack_domain(domain_str)
# Fields based on the source self.enter_step('set_domain_brws')
self.enter_step(f'set_{table}_prepare') dic = self.domtree
cursor = self.conn.cursor() for part in domain.path:
if source is None: if dic.match_zone.active():
first_party = int(is_first_party) # Refuse to add domain whose zone is already matching
level = 0 return
if part not in dic.children:
dic.children[part] = DomainTreeNode()
dic = dic.children[part]
if hostname:
match = dic.match_hostname
else: else:
self.enter_step(f'set_{table}_source') match = dic.match_zone
cursor.execute( match.set(
'SELECT first_party, level FROM rules ' updated,
'WHERE id=?', 0, # TODO Level
(source,) source or RulePath(),
) )
first_party, level = cursor.fetchone()
level += 1
self.enter_step(f'set_{table}_select') def set_hostname(self,
cursor.execute(select_query, prep) *args: typing.Any, **kwargs: typing.Any
) -> None:
self._set_domain(True, *args, **kwargs)
rules_prep: typing.Dict[str, DbValue] = { def set_zone(self,
"source": source, *args: typing.Any, **kwargs: typing.Any
"updated": updated, ) -> None:
"first_party": first_party, self._set_domain(False, *args, **kwargs)
"level": level,
}
# If the entry already exists def set_asn(self,
for entry, in cursor: # only one asn_str: str,
self.enter_step(f'set_{table}_update') updated: int,
rules_prep['entry'] = entry is_first_party: bool = None,
cursor.execute( source: Path = None) -> None:
'UPDATE rules SET ' self.enter_step('set_asn')
'source=:source, updated=:updated, ' if is_first_party:
'first_party=:first_party, level=:level ' raise NotImplementedError
'WHERE id=:entry AND (updated<:updated OR ' path = self.pack_asn(asn_str)
'first_party<:first_party OR level<:level)', match = AsnNode()
rules_prep match.set(
updated,
0,
source or RulePath()
) )
# Only update if any of the following: self.asns[path.asn] = match
# - the entry is outdataed
# - the entry was not a first_party but this is def _set_ip4(self,
# - this is closer to the original rule ip4: Ip4Path,
updated: int,
is_first_party: bool = None,
source: Path = None) -> None:
if is_first_party:
raise NotImplementedError
dic = self.ip4tree
for i in reversed(range(ip4.prefixlen)):
part = (ip4.value >> i) & 0b1
if dic.match.active():
# Refuse to add ip4* whose network is already matching
return return
next_dic = dic.children[part]
# If it does not exist if next_dic is None:
next_dic = IpTreeNode()
self.enter_step(f'set_{table}_insert') dic.children[part] = next_dic
cursor.execute( dic = next_dic
'INSERT INTO rules ' dic.match.set(
'(source, updated, first_party, level) ' updated,
'VALUES (:source, :updated, :first_party, :level) ', 0, # TODO Level
rules_prep source or RulePath(),
)
cursor.execute('SELECT id FROM rules WHERE rowid=?',
(cursor.lastrowid,))
for entry, in cursor: # only one
prep['entry'] = entry
cursor.execute(insert_query, prep)
return
assert False
def set_hostname(self, hostname: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_hostname_prepare')
prep: typing.Dict[str, DbValue] = {
'val': self.pack_hostname(hostname),
}
self._set_generic(
'hostname',
'SELECT entry FROM hostname WHERE val=:val',
'INSERT INTO hostname (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
) )
def set_asn(self, asn: str, def set_ip4address(self,
*args: typing.Any, **kwargs: typing.Any) -> None: ip4address_str: str,
self.enter_step('set_asn_prepare') *args: typing.Any, **kwargs: typing.Any
try: ) -> None:
asn_prep = self.pack_asn(asn) self.enter_step('set_ip4add_pack')
except ValueError: ip4 = self.pack_ip4address(ip4address_str)
self.log.error("Invalid asn: %s", asn) self.enter_step('set_ip4add_brws')
return self._set_ip4(ip4, *args, **kwargs)
prep: typing.Dict[str, DbValue] = {
'val': asn_prep,
}
self._set_generic(
'asn',
'SELECT entry FROM asn WHERE val=:val',
'INSERT INTO asn (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_ip4address(self, ip4address: str, def set_ip4network(self,
*args: typing.Any, **kwargs: typing.Any) -> None: ip4network_str: str,
self.enter_step('set_ip4add_prepare') *args: typing.Any, **kwargs: typing.Any
try: ) -> None:
ip4address_prep = self.pack_ip4address(ip4address) self.enter_step('set_ip4net_pack')
except (ValueError, IndexError): ip4 = self.pack_ip4network(ip4network_str)
self.log.error("Invalid ip4address: %s", ip4address) self.enter_step('set_ip4net_brws')
return self._set_ip4(ip4, *args, **kwargs)
prep: typing.Dict[str, DbValue] = {
'val': ip4address_prep,
}
self._set_generic(
'ip4add',
'SELECT entry FROM ip4address WHERE val=:val',
'INSERT INTO ip4address (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_zone(self, zone: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_zone_prepare')
prep: typing.Dict[str, DbValue] = {
'val': self.pack_zone(zone),
}
self._set_generic(
'zone',
'SELECT entry FROM zone WHERE val=:val',
'INSERT INTO zone (val, entry) '
'VALUES (:val, :entry)',
prep,
*args, **kwargs
)
def set_ip4network(self, ip4network: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_ip4net_prepare')
try:
ip4network_prep = self.pack_ip4network(ip4network)
except (ValueError, IndexError):
self.log.error("Invalid ip4network: %s", ip4network)
return
prep: typing.Dict[str, DbValue] = {
'mini': ip4network_prep[0],
'maxi': ip4network_prep[1],
}
self._set_generic(
'ip4net',
'SELECT entry FROM ip4network WHERE mini=:mini AND maxi=:maxi',
'INSERT INTO ip4network (mini, maxi, entry) '
'VALUES (:mini, :maxi, :entry)',
prep,
*args, **kwargs
)
if __name__ == '__main__':
# Parsing arguments
parser = argparse.ArgumentParser(
description="Database operations")
parser.add_argument(
'-i', '--initialize', action='store_true',
help="Reconstruct the whole database")
parser.add_argument(
'-p', '--prune', action='store_true',
help="Remove old (+6 months) entries from database")
parser.add_argument(
'-r', '--references', action='store_true',
help="Update the reference count")
args = parser.parse_args()
DB = Database(write=True)
if args.initialize:
DB.initialize()
if args.prune:
DB.prune(before=int(time.time()) - 60*60*24*31*6)
if args.references and not args.prune:
DB.update_references()
DB.close()

View file

@ -1,59 +0,0 @@
-- Remember to increment DB_VERSION
-- in database.py on changes to this file
CREATE TABLE rules (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source INTEGER, -- The rule this one is based on
updated INTEGER, -- If the row was updated during last data import (0: No, 1: Yes)
first_party INTEGER, -- 1: this blocks a first party for sure, 0: maybe
refs INTEGER, -- Number of entries issued from this one
level INTEGER, -- Level of recursion to the root source rule (used for source priority)
FOREIGN KEY (source) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX rules_source ON rules (source); -- for references recounting
CREATE INDEX rules_updated ON rules (updated); -- for pruning
CREATE INDEX rules_level_firstparty ON rules (level, first_party); -- for counting rules
CREATE TABLE asn (
val INTEGER PRIMARY KEY,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX asn_entry ON asn (entry); -- for explainations
CREATE TABLE hostname (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for consistency with zone)
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX hostname_entry ON hostname (entry); -- for explainations
CREATE TABLE zone (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for easier matching)
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX zone_entry ON zone (entry); -- for explainations
CREATE TABLE ip4address (
val INTEGER PRIMARY KEY,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX ip4address_entry ON ip4address (entry); -- for explainations
CREATE TABLE ip4network (
-- val TEXT PRIMARY KEY,
mini INTEGER,
maxi INTEGER,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX ip4network_minmax ON ip4network (mini, maxi);
CREATE INDEX ip4network_entry ON ip4network (entry); -- for explainations
-- Store various things
CREATE TABLE meta (
key TEXT PRIMARY KEY,
value integer
);

View file

@ -45,5 +45,3 @@ if __name__ == '__main__':
explain=args.explain, explain=args.explain,
): ):
print(domain, file=args.output) print(domain, file=args.output)
DB.close()

View file

@ -31,23 +31,22 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
DB = database.Database() DB = database.Database()
DBW = database.Database(write=True)
for asn, entry in DB.list_asn(): for path in DB.list_asn():
asn_str = database.Database.unpack_asn(path)
DB.enter_step('asn_get_ranges') DB.enter_step('asn_get_ranges')
for prefix in get_ranges(asn): for prefix in get_ranges(asn_str):
parsed_prefix: IPNetwork = ipaddress.ip_network(prefix) parsed_prefix: IPNetwork = ipaddress.ip_network(prefix)
if parsed_prefix.version == 4: if parsed_prefix.version == 4:
DBW.set_ip4network( DB.set_ip4network(
prefix, prefix,
source=entry, source=path,
updated=int(time.time()) updated=int(time.time())
) )
log.info('Added %s from %s (id=%s)', prefix, asn, entry) log.info('Added %s from %s (%s)', prefix, asn_str, path)
elif parsed_prefix.version == 6: elif parsed_prefix.version == 6:
log.warning('Unimplemented prefix version: %s', prefix) log.warning('Unimplemented prefix version: %s', prefix)
else: else:
log.error('Unknown prefix version: %s', prefix) log.error('Unknown prefix version: %s', prefix)
DB.close() DB.save()
DBW.close()

147
feed_dns.old.py Executable file
View file

@ -0,0 +1,147 @@
#!/usr/bin/env python3
import argparse
import database
import logging
import sys
import typing
import enum
RecordType = enum.Enum('RecordType', 'A AAAA CNAME PTR')
Record = typing.Tuple[RecordType, int, str, str]
# select, write
FUNCTION_MAP: typing.Any = {
RecordType.A: (
database.Database.get_ip4,
database.Database.set_hostname,
),
RecordType.CNAME: (
database.Database.get_domain,
database.Database.set_hostname,
),
RecordType.PTR: (
database.Database.get_domain,
database.Database.set_ip4address,
),
}
class Parser():
def __init__(self, buf: typing.Any) -> None:
self.buf = buf
self.log = logging.getLogger('parser')
self.db = database.Database()
def end(self) -> None:
self.db.save()
def register(self,
rtype: RecordType,
updated: int,
name: str,
value: str
) -> None:
self.db.enter_step('register')
select, write = FUNCTION_MAP[rtype]
for source in select(self.db, value):
# write(self.db, name, updated, source=source)
write(self.db, name, updated)
def consume(self) -> None:
raise NotImplementedError
class Rapid7Parser(Parser):
TYPES = {
'a': RecordType.A,
'aaaa': RecordType.AAAA,
'cname': RecordType.CNAME,
'ptr': RecordType.PTR,
}
def consume(self) -> None:
data = dict()
for line in self.buf:
self.db.enter_step('parse_rapid7')
split = line.split('"')
for k in range(1, 14, 4):
key = split[k]
val = split[k+2]
data[key] = val
self.register(
Rapid7Parser.TYPES[data['type']],
int(data['timestamp']),
data['name'],
data['value']
)
class DnsMassParser(Parser):
# dnsmass --output Snrql
# --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4
TYPES = {
'A': (RecordType.A, -1, None),
'AAAA': (RecordType.AAAA, -1, None),
'CNAME': (RecordType.CNAME, -1, -1),
}
def consume(self) -> None:
self.db.enter_step('parse_dnsmass')
timestamp = 0
header = True
for line in self.buf:
line = line[:-1]
if not line:
header = True
continue
split = line.split(' ')
try:
if header:
timestamp = int(split[1])
header = False
else:
dtype, name_offset, value_offset = \
DnsMassParser.TYPES[split[1]]
self.register(
dtype,
timestamp,
split[0][:name_offset],
split[2][:value_offset],
)
self.db.enter_step('parse_dnsmass')
except KeyError:
continue
PARSERS = {
'rapid7': Rapid7Parser,
'dnsmass': DnsMassParser,
}
if __name__ == '__main__':
# Parsing arguments
log = logging.getLogger('feed_dns')
args_parser = argparse.ArgumentParser(
description="TODO")
args_parser.add_argument(
'parser',
choices=PARSERS.keys(),
help="TODO")
args_parser.add_argument(
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO")
args = args_parser.parse_args()
parser = PARSERS[args.parser](args.input)
try:
parser.consume()
except KeyboardInterrupt:
pass
parser.end()

View file

@ -1,64 +1,202 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import database
import argparse import argparse
import sys import database
import logging import logging
import csv import sys
import json import typing
import multiprocessing
import enum
Record = typing.Tuple[typing.Callable, typing.Callable, int, str, str]
# select, write
FUNCTION_MAP: typing.Any = {
'a': (
database.Database.get_ip4,
database.Database.set_hostname,
),
'cname': (
database.Database.get_domain,
database.Database.set_hostname,
),
'ptr': (
database.Database.get_domain,
database.Database.set_ip4address,
),
}
class Writer(multiprocessing.Process):
def __init__(self,
recs_queue: multiprocessing.Queue,
index: int = 0):
super(Writer, self).__init__()
self.log = logging.getLogger(f'wr')
self.recs_queue = recs_queue
def run(self) -> None:
self.db = database.Database()
self.db.log = logging.getLogger(f'wr')
self.db.enter_step('block_wait')
block: typing.List[Record]
for block in iter(self.recs_queue.get, None):
record: Record
for record in block:
select, write, updated, name, value = record
self.db.enter_step('feed_switch')
try:
for source in select(self.db, value):
# write(self.db, name, updated, source=source)
write(self.db, name, updated)
except ValueError:
self.log.exception("Cannot execute: %s", record)
self.db.enter_step('block_wait')
self.db.enter_step('end')
self.db.save()
class Parser():
def __init__(self,
buf: typing.Any,
recs_queue: multiprocessing.Queue,
block_size: int,
):
super(Parser, self).__init__()
self.buf = buf
self.log = logging.getLogger('pr')
self.recs_queue = recs_queue
self.block: typing.List[Record] = list()
self.block_size = block_size
self.prof = database.Profiler()
self.prof.log = logging.getLogger('pr')
def register(self, record: Record) -> None:
self.prof.enter_step('register')
self.block.append(record)
if len(self.block) >= self.block_size:
self.prof.enter_step('put_block')
self.recs_queue.put(self.block)
self.block = list()
def run(self) -> None:
self.consume()
self.recs_queue.put(self.block)
self.prof.profile()
def consume(self) -> None:
raise NotImplementedError
class Rapid7Parser(Parser):
def consume(self) -> None:
data = dict()
for line in self.buf:
self.prof.enter_step('parse_rapid7')
split = line.split('"')
try:
for k in range(1, 14, 4):
key = split[k]
val = split[k+2]
data[key] = val
select, writer = FUNCTION_MAP[data['type']]
record = (
select,
writer,
int(data['timestamp']),
data['name'],
data['value']
)
except IndexError:
self.log.exception("Cannot parse: %s", line)
self.register(record)
class DnsMassParser(Parser):
# dnsmass --output Snrql
# --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4
TYPES = {
'A': (FUNCTION_MAP['a'][0], FUNCTION_MAP['a'][1], -1, None),
# 'AAAA': (FUNCTION_MAP['aaaa'][0], FUNCTION_MAP['aaaa'][1], -1, None),
'CNAME': (FUNCTION_MAP['cname'][0], FUNCTION_MAP['cname'][1], -1, -1),
}
def consume(self) -> None:
self.prof.enter_step('parse_dnsmass')
timestamp = 0
header = True
for line in self.buf:
line = line[:-1]
if not line:
header = True
continue
split = line.split(' ')
try:
if header:
timestamp = int(split[1])
header = False
else:
select, write, name_offset, value_offset = \
DnsMassParser.TYPES[split[1]]
record = (
select,
write,
timestamp,
split[0][:name_offset],
split[2][:value_offset],
)
self.register(record)
self.prof.enter_step('parse_dnsmass')
except KeyError:
continue
PARSERS = {
'rapid7': Rapid7Parser,
'dnsmass': DnsMassParser,
}
if __name__ == '__main__': if __name__ == '__main__':
# Parsing arguments # Parsing arguments
log = logging.getLogger('feed_dns') log = logging.getLogger('feed_dns')
parser = argparse.ArgumentParser( args_parser = argparse.ArgumentParser(
description="TODO") description="TODO")
parser.add_argument( args_parser.add_argument(
# '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer, 'parser',
choices=PARSERS.keys(),
help="TODO")
args_parser.add_argument(
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin, '-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO") help="TODO")
args = parser.parse_args() args_parser.add_argument(
'-j', '--workers', type=int, default=4,
help="TODO")
args_parser.add_argument(
'-b', '--block-size', type=int, default=100,
help="TODO")
args_parser.add_argument(
'-q', '--queue-size', type=int, default=10,
help="TODO")
args = args_parser.parse_args()
DB = database.Database(write=True) recs_queue: multiprocessing.Queue = multiprocessing.Queue(
maxsize=args.queue_size)
try: writer = Writer(recs_queue)
DB.enter_step('iowait') writer.start()
for row in csv.reader(args.input):
# for line in args.input:
DB.enter_step('feed_csv_parse')
dtype, timestamp, name, value = row
# DB.enter_step('feed_json_parse')
# data = json.loads(line)
# dtype = data['type'][0]
# # timestamp = data['timestamp']
# name = data['name']
# value = data['value']
DB.enter_step('feed_switch') parser = PARSERS[args.parser](args.input, recs_queue, args.block_size)
if dtype == 'a': parser.run()
for rule in DB.get_ip4(value):
if not list(DB.get_domain_in_zone(name)):
DB.set_hostname(name, source=rule, recs_queue.put(None)
updated=int(timestamp)) writer.join()
# updated=int(data['timestamp']))
elif dtype == 'c':
for rule in DB.get_domain(value):
if not list(DB.get_domain_in_zone(name)):
DB.set_hostname(name, source=rule,
updated=int(timestamp))
# updated=int(data['timestamp']))
elif dtype == 'p':
for rule in DB.get_domain(value):
if not list(DB.get_ip4_in_network(name)):
DB.set_ip4address(name, source=rule,
updated=int(timestamp))
# updated=int(data['timestamp']))
else:
raise NotImplementedError(f'Type: {dtype}')
DB.enter_step('iowait')
except KeyboardInterrupt:
log.warning("Interupted.")
pass
DB.close()

View file

@ -28,15 +28,15 @@ if __name__ == '__main__':
help="The input only comes from verified first-party sources") help="The input only comes from verified first-party sources")
args = parser.parse_args() args = parser.parse_args()
DB = database.Database(write=True) DB = database.Database()
fun = FUNCTION_MAP[args.type] fun = FUNCTION_MAP[args.type]
for rule in args.input: for rule in args.input:
fun(DB, fun(DB,
rule.strip(), rule.strip(),
is_first_party=args.first_party, # is_first_party=args.first_party,
updated=int(time.time()), updated=int(time.time()),
) )
DB.close() DB.save()

View file

@ -18,7 +18,7 @@ log "Retrieving rules…"
rm -f rules*/*.cache.* rm -f rules*/*.cache.*
dl https://easylist.to/easylist/easyprivacy.txt rules_adblock/easyprivacy.cache.txt dl https://easylist.to/easylist/easyprivacy.txt rules_adblock/easyprivacy.cache.txt
# From firebog.net Tracking & Telemetry Lists # From firebog.net Tracking & Telemetry Lists
dl https://v.firebog.net/hosts/Prigent-Ads.txt rules/prigent-ads.cache.list # dl https://v.firebog.net/hosts/Prigent-Ads.txt rules/prigent-ads.cache.list
# dl https://gitlab.com/quidsup/notrack-blocklists/raw/master/notrack-blocklist.txt rules/notrack-blocklist.cache.list # dl https://gitlab.com/quidsup/notrack-blocklists/raw/master/notrack-blocklist.txt rules/notrack-blocklist.cache.list
# False positives: https://github.com/WaLLy3K/wally3k.github.io/issues/73 -> 69.media.tumblr.com chicdn.net # False positives: https://github.com/WaLLy3K/wally3k.github.io/issues/73 -> 69.media.tumblr.com chicdn.net
dl https://raw.githubusercontent.com/StevenBlack/hosts/master/data/add.2o7Net/hosts rules_hosts/add2o7.cache.txt dl https://raw.githubusercontent.com/StevenBlack/hosts/master/data/add.2o7Net/hosts rules_hosts/add2o7.cache.txt
@ -51,3 +51,4 @@ then
else else
mv temp/cisco-umbrella_popularity.fresh.list subdomains/cisco-umbrella_popularity.cache.list mv temp/cisco-umbrella_popularity.fresh.list subdomains/cisco-umbrella_popularity.cache.list
fi fi
dl https://www.orwell1984.today/cname/eulerian.net.txt subdomains/orwell-eulerian-cname-list.cache.list

View file

@ -5,11 +5,12 @@ function log() {
} }
log "Importing rules…" log "Importing rules…"
cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone BEFORE="$(date +%s)"
cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone # cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone
cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone # cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone
cat rules_ip/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network # cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone
cat rules_asn/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py asn # cat rules_ip/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network
# cat rules_asn/*.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py asn
cat rules/first-party.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone --first-party cat rules/first-party.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone --first-party
cat rules_ip/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network --first-party cat rules_ip/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py ip4network --first-party
@ -17,3 +18,5 @@ cat rules_asn/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py as
./feed_asn.py ./feed_asn.py
log "Pruning old rules…"
./db.py --prune --prune-before "$BEFORE" --prune-base

View file

@ -1,36 +0,0 @@
#!/usr/bin/env python3
import argparse
import sys
import logging
import json
import csv
if __name__ == '__main__':
# Parsing arguments
log = logging.getLogger('json_to_csv')
parser = argparse.ArgumentParser(
description="TODO")
parser.add_argument(
# '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer,
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO")
parser.add_argument(
# '-i', '--output', type=argparse.FileType('wb'), default=sys.stdout.buffer,
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
help="TODO")
args = parser.parse_args()
writer = csv.writer(args.output)
for line in args.input:
data = json.loads(line)
try:
writer.writerow([
data['type'][0],
data['timestamp'],
data['name'],
data['value']])
except IndexError:
log.error('Could not parse line: %s', line)
pass

View file

@ -9,11 +9,11 @@ function log() {
# TODO Fetch 'em # TODO Fetch 'em
log "Reading PTR records…" log "Reading PTR records…"
pv ptr.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py pv ptr.json.gz | gunzip | ./feed_dns.py
log "Reading A records…" log "Reading A records…"
pv a.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py pv a.json.gz | gunzip | ./feed_dns.py
log "Reading CNAME records…" log "Reading CNAME records…"
pv cname.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py pv cname.json.gz | gunzip | ./feed_dns.py
log "Pruning old data…" log "Pruning old data…"
./database.py --prune ./database.py --prune

View file

@ -1,264 +0,0 @@
#!/usr/bin/env python3
"""
From a list of subdomains, output only
the ones resolving to a first-party tracker.
"""
import argparse
import logging
import os
import queue
import sys
import threading
import typing
import time
import coloredlogs
import dns.exception
import dns.resolver
DNS_TIMEOUT = 5.0
NUMBER_TRIES = 5
class Worker(threading.Thread):
"""
Worker process for a DNS resolver.
Will resolve DNS to match first-party subdomains.
"""
def change_nameserver(self) -> None:
"""
Assign a this worker another nameserver from the queue.
"""
server = None
while server is None:
try:
server = self.orchestrator.nameservers_queue.get(block=False)
except queue.Empty:
self.orchestrator.refill_nameservers_queue()
self.log.info("Using nameserver: %s", server)
self.resolver.nameservers = [server]
def __init__(self,
orchestrator: 'Orchestrator',
index: int = 0):
super(Worker, self).__init__()
self.log = logging.getLogger(f'worker{index:03d}')
self.orchestrator = orchestrator
self.resolver = dns.resolver.Resolver()
self.change_nameserver()
def resolve_subdomain(self, subdomain: str) -> typing.Optional[
typing.List[
dns.rrset.RRset
]
]:
"""
Returns the resolution chain of the subdomain to an A record,
including any intermediary CNAME.
The last element is an IP address.
Returns None if the nameserver was unable to satisfy the request.
Returns [] if the requests points to nothing.
"""
self.log.debug("Querying %s", subdomain)
try:
query = self.resolver.query(subdomain, 'A', lifetime=DNS_TIMEOUT)
except dns.resolver.NXDOMAIN:
return []
except dns.resolver.NoAnswer:
return []
except dns.resolver.YXDOMAIN:
self.log.warning("Query name too long for %s", subdomain)
return None
except dns.resolver.NoNameservers:
# NOTE Most of the time this error message means that the domain
# does not exists, but sometimes it means the that the server
# itself is broken. So we count on the retry logic.
self.log.warning("All nameservers broken for %s", subdomain)
return None
except dns.exception.Timeout:
# NOTE Same as above
self.log.warning("Timeout for %s", subdomain)
return None
except dns.name.EmptyLabel:
self.log.warning("Empty label for %s", subdomain)
return None
return query.response.answer
def run(self) -> None:
self.log.info("Started")
subdomain: str
for subdomain in iter(self.orchestrator.subdomains_queue.get, None):
for _ in range(NUMBER_TRIES):
resolved = self.resolve_subdomain(subdomain)
# Retry with another nameserver if error
if resolved is None:
self.change_nameserver()
else:
break
# If it wasn't found after multiple tries
if resolved is None:
self.log.error("Gave up on %s", subdomain)
resolved = []
assert isinstance(resolved, list)
self.orchestrator.results_queue.put(resolved)
self.orchestrator.results_queue.put(None)
self.log.info("Stopped")
class Orchestrator():
"""
Orchestrator of the different Worker threads.
"""
def refill_nameservers_queue(self) -> None:
"""
Re-fill the given nameservers into the nameservers queue.
Done every-time the queue is empty, making it
basically looping and infinite.
"""
# Might be in a race condition but that's probably fine
for nameserver in self.nameservers:
self.nameservers_queue.put(nameserver)
self.log.info("Refilled nameserver queue")
def __init__(self, subdomains: typing.Iterable[str],
nameservers: typing.List[str] = None,
nb_workers: int = 1,
):
self.log = logging.getLogger('orchestrator')
self.subdomains = subdomains
self.nb_workers = nb_workers
# Use interal resolver by default
self.nameservers = nameservers or dns.resolver.Resolver().nameservers
self.subdomains_queue: queue.Queue = queue.Queue(
maxsize=self.nb_workers)
self.results_queue: queue.Queue = queue.Queue()
self.nameservers_queue: queue.Queue = queue.Queue()
self.refill_nameservers_queue()
def fill_subdomain_queue(self) -> None:
"""
Read the subdomains in input and put them into the queue.
Done in a thread so we can both:
- yield the results as they come
- not store all the subdomains at once
"""
self.log.info("Started reading subdomains")
# Send data to workers
for subdomain in self.subdomains:
self.subdomains_queue.put(subdomain)
self.log.info("Finished reading subdomains")
# Send sentinel to each worker
# sentinel = None ~= EOF
for _ in range(self.nb_workers):
self.subdomains_queue.put(None)
@staticmethod
def format_rrset(rrset: dns.rrset.RRset) -> typing.Iterable[str]:
if rrset.rdtype == dns.rdatatype.CNAME:
dtype = 'c'
elif rrset.rdtype == dns.rdatatype.A:
dtype = 'a'
else:
raise NotImplementedError
name = rrset.name.to_text()[:-1]
for item in rrset.items:
value = item.to_text()
if rrset.rdtype == dns.rdatatype.CNAME:
value = value[:-1]
yield f'{dtype},{int(time.time())},{name},{value}\n'
def run(self) -> typing.Iterable[str]:
"""
Yield the results.
"""
# Create workers
self.log.info("Creating workers")
for i in range(self.nb_workers):
Worker(self, i).start()
fill_thread = threading.Thread(target=self.fill_subdomain_queue)
fill_thread.start()
# Wait for one sentinel per worker
# In the meantime output results
for _ in range(self.nb_workers):
resolved: typing.List[dns.rrset.RRset]
for resolved in iter(self.results_queue.get, None):
for rrset in resolved:
yield from self.format_rrset(rrset)
self.log.info("Waiting for reader thread")
fill_thread.join()
self.log.info("Done!")
def main() -> None:
"""
Main function when used directly.
Read the subdomains provided and output it,
the last CNAME resolved and the IP adress it resolves to.
Takes as an input a filename (or nothing, for stdin),
and as an output a filename (or nothing, for stdout).
The input must be a subdomain per line, the output is a TODO
Use the file `nameservers` as the list of nameservers
to use, or else it will use the system defaults.
"""
# Initialization
coloredlogs.install(
level='DEBUG',
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
)
# Parsing arguments
parser = argparse.ArgumentParser(
description="Massively resolves subdomains and store them in a file.")
parser.add_argument(
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="Input file with one subdomain per line")
parser.add_argument(
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
help="Outptut file with DNS chains")
parser.add_argument(
'-n', '--nameservers', default='nameservers',
help="File with one nameserver per line")
parser.add_argument(
'-j', '--workers', type=int, default=512,
help="Number of threads to use")
args = parser.parse_args()
# Cleaning input
iterator = iter(args.input)
iterator = map(str.strip, iterator)
iterator = filter(None, iterator)
# Reading nameservers
servers: typing.List[str] = list()
if os.path.isfile(args.nameservers):
servers = open(args.nameservers).readlines()
servers = list(filter(None, map(str.strip, servers)))
for resolved in Orchestrator(
iterator,
servers,
nb_workers=args.workers
).run():
args.output.write(resolved)
if __name__ == '__main__':
main()