Compare commits

..

7 commits

Author SHA1 Message Date
Geoffrey Frogeye 189deeb559
Workflow: Multiprocess
Still trying.
It's better than multithread though.

Merge branch 'newworkflow' into newworkflow_threaded
2019-12-14 17:27:46 +01:00
Geoffrey Frogeye d7c239a6f6 Workflow: Some modifications 2019-12-14 16:04:19 +01:00
Geoffrey Frogeye 5023b85d7c
Added intermediate representation for DNS datasets
It's just CSV.
The DNS from the datasets are not ordered consistently,
so we need to parse it completly.
It seems that converting to an IR before sending data to ./feed_dns.py
through a pipe is faster than decoding the JSON in ./feed_dns.py.
This will also reduce the storage of the resolved subdomains by
about 15% (compressed).
2019-12-13 21:59:35 +01:00
Geoffrey Frogeye 269b8278b5
Worflow: Fixed rules counts 2019-12-13 18:36:08 +01:00
Geoffrey Frogeye ab7ef609dd
Workflow: Various optimisations and fixes
I forgot to close this one earlier, so:
Closes #7
2019-12-13 18:08:22 +01:00
Geoffrey Frogeye f3eedcba22
Updated now based on timestamp
Did I forget to add feed_asn.py a few commits ago?
Oh well...
2019-12-13 13:54:00 +01:00
Geoffrey Frogeye 8d94b80fd0
Integrated DNS resolving to workflow
Since the bigger datasets are only updated once a month,
this might help for quick updates.
2019-12-13 13:38:23 +01:00
15 changed files with 512 additions and 279 deletions

2
.gitignore vendored
View file

@ -3,5 +3,3 @@
*.db-journal
nameservers
nameservers.head
*.o
*.so

View file

@ -12,6 +12,7 @@ import logging
import argparse
import coloredlogs
import ipaddress
import math
coloredlogs.install(
level='DEBUG',
@ -22,43 +23,47 @@ DbValue = typing.Union[None, int, float, str, bytes]
class Database():
VERSION = 3
VERSION = 5
PATH = "blocking.db"
def open(self) -> None:
mode = 'rwc' if self.write else 'ro'
uri = f'file:{self.PATH}?mode={mode}'
self.conn = sqlite3.connect(uri, uri=True)
self.cursor = self.conn.cursor()
self.execute("PRAGMA foreign_keys = ON")
# self.conn.create_function("prepare_ip4address", 1,
# Database.prepare_ip4address,
# deterministic=True)
cursor = self.conn.cursor()
cursor.execute("PRAGMA foreign_keys = ON")
self.conn.create_function("unpack_asn", 1,
self.unpack_asn,
deterministic=True)
self.conn.create_function("unpack_ip4address", 1,
self.unpack_ip4address,
deterministic=True)
self.conn.create_function("unpack_ip4network", 2,
self.unpack_ip4network,
deterministic=True)
self.conn.create_function("unpack_domain", 1,
lambda s: s[:-1][::-1],
deterministic=True)
def execute(self, cmd: str, args: typing.Union[
typing.Tuple[DbValue, ...],
typing.Dict[str, DbValue]] = None) -> None:
# self.log.debug(cmd)
# self.log.debug(args)
self.cursor.execute(cmd, args or tuple())
self.conn.create_function("format_zone", 1,
lambda s: '*' + s[::-1],
deterministic=True)
def get_meta(self, key: str) -> typing.Optional[int]:
cursor = self.conn.cursor()
try:
self.execute("SELECT value FROM meta WHERE key=?", (key,))
cursor.execute("SELECT value FROM meta WHERE key=?", (key,))
except sqlite3.OperationalError:
return None
for ver, in self.cursor:
for ver, in cursor:
return ver
return None
def set_meta(self, key: str, val: int) -> None:
self.execute("INSERT INTO meta VALUES (?, ?) "
"ON CONFLICT (key) DO "
"UPDATE set value=?",
(key, val, val))
cursor = self.conn.cursor()
cursor.execute("INSERT INTO meta VALUES (?, ?) "
"ON CONFLICT (key) DO "
"UPDATE set value=?",
(key, val, val))
def close(self) -> None:
self.enter_step('close_commit')
@ -76,8 +81,9 @@ class Database():
os.unlink(self.PATH)
self.open()
self.log.info("Creating database version %d.", self.VERSION)
cursor = self.conn.cursor()
with open("database_schema.sql", 'r') as db_schema:
self.cursor.executescript(db_schema.read())
cursor.executescript(db_schema.read())
self.set_meta('version', self.VERSION)
self.conn.commit()
@ -98,13 +104,6 @@ class Database():
version)
self.initialize()
updated = self.get_meta('updated')
if updated is None:
self.execute('SELECT max(updated) FROM rules')
data = self.cursor.fetchone()
updated, = data
self.updated = updated or 1
def enter_step(self, name: str) -> None:
now = time.perf_counter()
try:
@ -126,24 +125,32 @@ class Database():
self.log.debug(f"{'total':<20}: "
f"{total:9.2f} s ({1:7.2%})")
def prepare_hostname(self, hostname: str) -> str:
@staticmethod
def pack_hostname(hostname: str) -> str:
return hostname[::-1] + '.'
def prepare_zone(self, zone: str) -> str:
return self.prepare_hostname(zone)
@staticmethod
def pack_zone(zone: str) -> str:
return Database.pack_hostname(zone)
@staticmethod
def prepare_asn(asn: str) -> int:
def pack_asn(asn: str) -> int:
asn = asn.upper()
if asn.startswith('AS'):
asn = asn[2:]
return int(asn)
@staticmethod
def prepare_ip4address(address: str) -> int:
def unpack_asn(asn: int) -> str:
return f'AS{asn}'
@staticmethod
def pack_ip4address(address: str) -> int:
total = 0
for i, octet in enumerate(address.split('.')):
total += int(octet) << (3-i)*8
if total > 0xFFFFFFFF:
raise ValueError
return total
# return '{:02x}{:02x}{:02x}{:02x}'.format(
# *[int(c) for c in address.split('.')])
@ -158,34 +165,78 @@ class Database():
# packed = ipaddress.ip_address(address).packed
# return packed
def prepare_ip4network(self, network: str) -> typing.Tuple[int, int]:
# def prepare_ip4network(network: str) -> str:
@staticmethod
def unpack_ip4address(address: int) -> str:
return '.'.join(str((address >> (i * 8)) & 0xFF)
for i in reversed(range(4)))
@staticmethod
def pack_ip4network(network: str) -> typing.Tuple[int, int]:
# def pack_ip4network(network: str) -> str:
net = ipaddress.ip_network(network)
mini = self.prepare_ip4address(net.network_address.exploded)
maxi = self.prepare_ip4address(net.broadcast_address.exploded)
mini = Database.pack_ip4address(net.network_address.exploded)
maxi = Database.pack_ip4address(net.broadcast_address.exploded)
# mini = net.network_address.packed
# maxi = net.broadcast_address.packed
return mini, maxi
# return Database.prepare_ip4address(net.network_address.exploded)[:net.prefixlen]
# return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen]
def expire(self) -> None:
self.enter_step('expire')
self.updated += 1
self.set_meta('updated', self.updated)
@staticmethod
def unpack_ip4network(mini: int, maxi: int) -> str:
addr = Database.unpack_ip4address(mini)
prefixlen = 32-int(math.log2(maxi-mini+1))
return f'{addr}/{prefixlen}'
def update_references(self) -> None:
self.enter_step('update_refs')
self.execute('UPDATE rules AS r SET refs='
'(SELECT count(*) FROM rules '
'WHERE source=r.id)')
cursor = self.conn.cursor()
cursor.execute('UPDATE rules AS r SET refs='
'(SELECT count(*) FROM rules '
'WHERE source=r.id)')
def prune(self) -> None:
def prune(self, before: int, base_only: bool = False) -> None:
self.enter_step('prune')
self.execute('DELETE FROM rules WHERE updated<?', (self.updated,))
cursor = self.conn.cursor()
cmd = 'DELETE FROM rules WHERE updated<?'
if base_only:
cmd += ' AND level=0'
cursor.execute(cmd, (before,))
def export(self, first_party_only: bool = False,
end_chain_only: bool = False) -> typing.Iterable[str]:
command = 'SELECT unpack_domain(val) FROM rules ' \
def explain(self, entry: int) -> str:
# Format current
string = '???'
cursor = self.conn.cursor()
cursor.execute(
'SELECT unpack_asn(val) FROM asn WHERE entry=:entry '
'UNION '
'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry '
'UNION '
'SELECT format_zone(val) FROM zone WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4network(mini, maxi) '
'FROM ip4network WHERE entry=:entry ',
{"entry": entry}
)
for val, in cursor: # only one
string = str(val)
string += f' #{entry}'
# Add source if any
cursor.execute('SELECT source FROM rules WHERE id=?', (entry,))
for source, in cursor:
if source:
string += f'{self.explain(source)}'
return string
def export(self,
first_party_only: bool = False,
end_chain_only: bool = False,
explain: bool = False,
) -> typing.Iterable[str]:
selection = 'entry' if explain else 'unpack_domain(val)'
command = f'SELECT {selection} FROM rules ' \
'INNER JOIN hostname ON rules.id = hostname.entry'
restrictions: typing.List[str] = list()
if first_party_only:
@ -194,16 +245,40 @@ class Database():
restrictions.append('rules.refs = 0')
if restrictions:
command += ' WHERE ' + ' AND '.join(restrictions)
command += ' ORDER BY unpack_domain(val) ASC'
self.execute(command)
for val, in self.cursor:
yield val
if not explain:
command += ' ORDER BY unpack_domain(val) ASC'
cursor = self.conn.cursor()
cursor.execute(command)
for val, in cursor:
if explain:
yield self.explain(val)
else:
yield val
def count_rules(self,
first_party_only: bool = False,
) -> str:
counts: typing.List[str] = list()
cursor = self.conn.cursor()
for table in ['asn', 'ip4network', 'ip4address', 'zone', 'hostname']:
command = f'SELECT count(*) FROM rules ' \
f'INNER JOIN {table} ON rules.id = {table}.entry ' \
'WHERE rules.level = 0'
if first_party_only:
command += ' AND first_party=1'
cursor.execute(command)
count, = cursor.fetchone()
if count > 0:
counts.append(f'{table}: {count}')
return ', '.join(counts)
def get_domain(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domain_prepare')
domain_prep = self.prepare_hostname(domain)
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domain_select')
self.execute(
cursor.execute(
'SELECT null, entry FROM hostname '
'WHERE val=:d '
'UNION '
@ -214,22 +289,41 @@ class Database():
')',
{'d': domain_prep}
)
for val, entry in self.cursor:
for val, entry in cursor:
self.enter_step('get_domain_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
self.enter_step('get_domain_yield')
yield entry
def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domainiz_prepare')
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domainiz_select')
cursor.execute(
'SELECT val, entry FROM zone '
'WHERE val<=:d '
'ORDER BY val DESC LIMIT 1',
{'d': domain_prep}
)
for val, entry in cursor:
self.enter_step('get_domainiz_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
self.enter_step('get_domainiz_yield')
yield entry
def get_ip4(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4_prepare')
try:
address_prep = self.prepare_ip4address(address)
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return
cursor = self.conn.cursor()
self.enter_step('get_ip4_select')
self.execute(
cursor.execute(
'SELECT entry FROM ip4address '
# 'SELECT null, entry FROM ip4address '
'WHERE val=:a '
@ -244,7 +338,7 @@ class Database():
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for val, entry in self.cursor:
for entry, in cursor:
# self.enter_step('get_ip4_confirm')
# if not (val is None or val.startswith(address_prep)):
# # PERF startswith but from the end
@ -252,11 +346,29 @@ class Database():
self.enter_step('get_ip4_yield')
yield entry
def get_ip4_in_network(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4in_prepare')
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return
cursor = self.conn.cursor()
self.enter_step('get_ip4in_select')
cursor.execute(
'SELECT entry FROM ip4network '
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for entry, in cursor:
self.enter_step('get_ip4in_yield')
yield entry
def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]:
cursor = self.conn.cursor()
self.enter_step('list_asn_select')
self.enter_step('get_domain_select')
self.execute('SELECT val, entry FROM asn')
for val, entry in self.cursor:
cursor.execute('SELECT val, entry FROM asn')
for val, entry in cursor:
yield f'AS{val}', entry
def _set_generic(self,
@ -264,6 +376,7 @@ class Database():
select_query: str,
insert_query: str,
prep: typing.Dict[str, DbValue],
updated: int,
is_first_party: bool = False,
source: int = None,
) -> None:
@ -271,34 +384,36 @@ class Database():
# here abstraction > performaces
# Fields based on the source
self.enter_step(f'set_{table}_prepare')
cursor = self.conn.cursor()
if source is None:
first_party = int(is_first_party)
level = 0
else:
self.enter_step(f'set_{table}_source')
self.execute(
cursor.execute(
'SELECT first_party, level FROM rules '
'WHERE id=?',
(source,)
)
first_party, level = self.cursor.fetchone()
first_party, level = cursor.fetchone()
level += 1
self.enter_step(f'set_{table}_select')
self.execute(select_query, prep)
cursor.execute(select_query, prep)
rules_prep = {
rules_prep: typing.Dict[str, DbValue] = {
"source": source,
"updated": self.updated,
"updated": updated,
"first_party": first_party,
"level": level,
}
# If the entry already exists
for entry, in self.cursor: # only one
for entry, in cursor: # only one
self.enter_step(f'set_{table}_update')
rules_prep['entry'] = entry
self.execute(
cursor.execute(
'UPDATE rules SET '
'source=:source, updated=:updated, '
'first_party=:first_party, level=:level '
@ -314,23 +429,18 @@ class Database():
# If it does not exist
if source is not None:
self.enter_step(f'set_{table}_incsrc')
self.execute('UPDATE rules SET refs = refs + 1 WHERE id=?',
(source,))
self.enter_step(f'set_{table}_insert')
self.execute(
cursor.execute(
'INSERT INTO rules '
'(source, updated, first_party, refs, level) '
'VALUES (:source, :updated, :first_party, 0, :level) ',
'(source, updated, first_party, level) '
'VALUES (:source, :updated, :first_party, :level) ',
rules_prep
)
self.execute('SELECT id FROM rules WHERE rowid=?',
(self.cursor.lastrowid,))
for entry, in self.cursor: # only one
cursor.execute('SELECT id FROM rules WHERE rowid=?',
(cursor.lastrowid,))
for entry, in cursor: # only one
prep['entry'] = entry
self.execute(insert_query, prep)
cursor.execute(insert_query, prep)
return
assert False
@ -338,7 +448,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_hostname_prepare')
prep: typing.Dict[str, DbValue] = {
'val': self.prepare_hostname(hostname),
'val': self.pack_hostname(hostname),
}
self._set_generic(
'hostname',
@ -353,7 +463,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_asn_prepare')
try:
asn_prep = self.prepare_asn(asn)
asn_prep = self.pack_asn(asn)
except ValueError:
self.log.error("Invalid asn: %s", asn)
return
@ -371,10 +481,9 @@ class Database():
def set_ip4address(self, ip4address: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
# TODO Do not add if already in ip4network
self.enter_step('set_ip4add_prepare')
try:
ip4address_prep = self.prepare_ip4address(ip4address)
ip4address_prep = self.pack_ip4address(ip4address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", ip4address)
return
@ -394,7 +503,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_zone_prepare')
prep: typing.Dict[str, DbValue] = {
'val': self.prepare_zone(zone),
'val': self.pack_zone(zone),
}
self._set_generic(
'zone',
@ -409,7 +518,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_ip4net_prepare')
try:
ip4network_prep = self.prepare_ip4network(ip4network)
ip4network_prep = self.pack_ip4network(ip4network)
except (ValueError, IndexError):
self.log.error("Invalid ip4network: %s", ip4network)
return
@ -439,8 +548,12 @@ if __name__ == '__main__':
'-p', '--prune', action='store_true',
help="Remove old entries from database")
parser.add_argument(
'-e', '--expire', action='store_true',
help="Set the whole database as an old source")
'-b', '--prune-base', action='store_true',
help="TODO")
parser.add_argument(
'-s', '--prune-before', type=int,
default=(int(time.time()) - 60*60*24*31*6),
help="TODO")
parser.add_argument(
'-r', '--references', action='store_true',
help="Update the reference count")
@ -451,10 +564,8 @@ if __name__ == '__main__':
if args.initialize:
DB.initialize()
if args.prune:
DB.prune()
if args.expire:
DB.expire()
if args.references and not args.prune:
DB.prune(before=args.prune_before, base_only=args.prune_base)
if args.references:
DB.update_references()
DB.close()

View file

@ -10,30 +10,37 @@ CREATE TABLE rules (
level INTEGER, -- Level of recursion to the root source rule (used for source priority)
FOREIGN KEY (source) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX rules_source ON rules (source); -- for references recounting
CREATE INDEX rules_updated ON rules (updated); -- for pruning
CREATE INDEX rules_level_firstparty ON rules (level, first_party); -- for counting rules
CREATE TABLE asn (
val INTEGER PRIMARY KEY,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX asn_entry ON asn (entry); -- for explainations
CREATE TABLE hostname (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for consistency with zone)
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX hostname_entry ON hostname (entry); -- for explainations
CREATE TABLE zone (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for easier matching)
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX zone_entry ON zone (entry); -- for explainations
CREATE TABLE ip4address (
val INTEGER PRIMARY KEY,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX ip4address_entry ON ip4address (entry); -- for explainations
CREATE TABLE ip4network (
-- val TEXT PRIMARY KEY,
@ -43,6 +50,7 @@ CREATE TABLE ip4network (
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX ip4network_minmax ON ip4network (mini, maxi);
CREATE INDEX ip4network_entry ON ip4network (entry); -- for explainations
-- Store various things
CREATE TABLE meta (

View file

@ -19,12 +19,31 @@ if __name__ == '__main__':
parser.add_argument(
'-e', '--end-chain', action='store_true',
help="TODO")
parser.add_argument(
'-x', '--explain', action='store_true',
help="TODO")
parser.add_argument(
'-r', '--rules', action='store_true',
help="TODO")
parser.add_argument(
'-c', '--count', action='store_true',
help="TODO")
args = parser.parse_args()
DB = database.Database()
for domain in DB.export(first_party_only=args.first_party,
end_chain_only=args.end_chain):
print(domain, file=args.output)
if args.rules:
if not args.count:
raise NotImplementedError
print(DB.count_rules(first_party_only=args.first_party))
else:
if args.count:
raise NotImplementedError
for domain in DB.export(
first_party_only=args.first_party,
end_chain_only=args.end_chain,
explain=args.explain,
):
print(domain, file=args.output)
DB.close()

53
feed_asn.py Executable file
View file

@ -0,0 +1,53 @@
#!/usr/bin/env python3
import database
import argparse
import requests
import typing
import ipaddress
import logging
import time
IPNetwork = typing.Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
def get_ranges(asn: str) -> typing.Iterable[str]:
req = requests.get(
'https://stat.ripe.net/data/as-routing-consistency/data.json',
params={'resource': asn}
)
data = req.json()
for pref in data['data']['prefixes']:
yield pref['prefix']
if __name__ == '__main__':
log = logging.getLogger('feed_asn')
# Parsing arguments
parser = argparse.ArgumentParser(
description="TODO")
args = parser.parse_args()
DB = database.Database()
DBW = database.Database(write=True)
for asn, entry in DB.list_asn():
DB.enter_step('asn_get_ranges')
for prefix in get_ranges(asn):
parsed_prefix: IPNetwork = ipaddress.ip_network(prefix)
if parsed_prefix.version == 4:
DBW.set_ip4network(
prefix,
source=entry,
updated=int(time.time())
)
log.info('Added %s from %s (id=%s)', prefix, asn, entry)
elif parsed_prefix.version == 6:
log.warning('Unimplemented prefix version: %s', prefix)
else:
log.error('Unknown prefix version: %s', prefix)
DB.close()
DBW.close()

184
feed_dns.py Normal file → Executable file
View file

@ -1,23 +1,43 @@
#!/usr/bin/env python3
import database
import argparse
import sys
import database
import json
import logging
import threading
import queue
import sys
import typing
import multiprocessing
NUMBER_THREADS = 8
NUMBER_THREADS = 2
BLOCK_SIZE = 100
# select, confirm, write
FUNCTION_MAP: typing.Any = {
'a': (
database.Database.get_ip4,
database.Database.get_domain_in_zone,
database.Database.set_hostname,
),
'cname': (
database.Database.get_domain,
database.Database.get_domain_in_zone,
database.Database.set_hostname,
),
'ptr': (
database.Database.get_domain,
database.Database.get_ip4_in_network,
database.Database.set_ip4address,
),
}
class Worker(threading.Thread):
class Reader(multiprocessing.Process):
def __init__(self,
lines_queue: queue.Queue,
write_queue: queue.Queue,
lines_queue: multiprocessing.Queue,
write_queue: multiprocessing.Queue,
index: int = 0):
super(Worker, self).__init__()
self.log = logging.getLogger(f'worker{index:03d}')
super(Reader, self).__init__()
self.log = logging.getLogger(f'rd{index:03d}')
self.lines_queue = lines_queue
self.write_queue = write_queue
self.index = index
@ -25,45 +45,51 @@ class Worker(threading.Thread):
def run(self) -> None:
self.db = database.Database(write=False)
self.db.log = logging.getLogger(f'db{self.index:03d}')
self.db.enter_step('wait_line')
line: str
for line in iter(self.lines_queue.get, None):
self.db.enter_step('feed_json_parse')
# split = line.split(b'"')
split = line.split('"')
try:
name = split[7]
dtype = split[11]
value = split[15]
except IndexError:
log.error("Invalid JSON: %s", line)
continue
# DB.enter_step('feed_json_assert')
# data = json.loads(line)
# assert dtype == data['type']
# assert name == data['name']
# assert value == data['value']
self.db.enter_step('feed_switch')
if dtype == 'a':
for rule in self.db.get_ip4(value):
self.db.enter_step('wait_put')
self.write_queue.put(
(database.Database.set_hostname, name, rule))
elif dtype == 'cname':
for rule in self.db.get_domain(value):
self.db.enter_step('wait_put')
self.write_queue.put(
(database.Database.set_hostname, name, rule))
elif dtype == 'ptr':
for rule in self.db.get_domain(value):
self.db.enter_step('wait_put')
self.write_queue.put(
(database.Database.set_ip4address, name, rule))
self.db.enter_step('wait_line')
self.db.enter_step('line_wait')
block: typing.List[str]
try:
for block in iter(self.lines_queue.get, None):
for line in block:
dtype, updated, name, value = line
self.db.enter_step('feed_switch')
select, confirm, write = FUNCTION_MAP[dtype]
for rule in select(self.db, value):
if not any(confirm(self.db, name)):
self.db.enter_step('wait_put')
self.write_queue.put((write, name, updated))
self.db.enter_step('line_wait')
except KeyboardInterrupt:
self.log.error('Interrupted')
self.db.enter_step('end')
self.db.close()
class Writer(multiprocessing.Process):
def __init__(self,
write_queue: multiprocessing.Queue,
):
super(Writer, self).__init__()
self.log = logging.getLogger(f'wr ')
self.write_queue = write_queue
def run(self) -> None:
self.db = database.Database(write=True)
self.db.log = logging.getLogger(f'dbw ')
self.db.enter_step('line_wait')
block: typing.List[str]
try:
fun: typing.Callable
name: str
updated: int
for fun, name, updated in iter(self.write_queue.get, None):
self.db.enter_step('exec')
fun(self.db, name, updated)
self.db.enter_step('line_wait')
except KeyboardInterrupt:
self.log.error('Interrupted')
self.db.enter_step('end')
self.write_queue.put(None)
self.db.close()
@ -80,42 +106,52 @@ if __name__ == '__main__':
args = parser.parse_args()
DB = database.Database(write=False) # Not needed, just for timing
DB.log = logging.getLogger('dbf')
DBW = database.Database(write=True)
DBW.log = logging.getLogger('dbw')
DB.log = logging.getLogger('db ')
lines_queue: queue.Queue = queue.Queue(maxsize=NUMBER_THREADS)
write_queue: queue.Queue = queue.Queue(maxsize=NUMBER_THREADS)
lines_queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=100)
write_queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=100)
def fill_lines_queue() -> None:
DB.enter_step('proc_create')
readers: typing.List[Reader] = list()
for w in range(NUMBER_THREADS):
readers.append(Reader(lines_queue, write_queue, w))
writer = Writer(write_queue)
DB.enter_step('proc_start')
for reader in readers:
reader.start()
writer.start()
try:
block: typing.List[str] = list()
DB.enter_step('iowait')
for line in args.input:
DB.enter_step('wait_put')
lines_queue.put(line)
DB.enter_step('block_append')
DB.enter_step('feed_json_parse')
data = json.loads(line)
line = (data['type'],
int(data['timestamp']),
data['name'],
data['value'])
block.append(line)
if len(block) >= BLOCK_SIZE:
DB.enter_step('wait_put')
lines_queue.put(block)
block = list()
DB.enter_step('iowait')
DB.enter_step('wait_put')
lines_queue.put(block)
DB.enter_step('end_put')
for _ in range(NUMBER_THREADS):
lines_queue.put(None)
write_queue.put(None)
for w in range(NUMBER_THREADS):
Worker(lines_queue, write_queue, w).start()
DB.enter_step('proc_join')
for reader in readers:
reader.join()
writer.join()
except KeyboardInterrupt:
log.error('Interrupted')
threading.Thread(target=fill_lines_queue).start()
for _ in range(NUMBER_THREADS):
fun: typing.Callable
name: str
source: int
DBW.enter_step('wait_fun')
for fun, name, source in iter(write_queue.get, None):
DBW.enter_step('exec_fun')
fun(DBW, name, source=source)
DBW.enter_step('commit')
DBW.conn.commit()
DBW.enter_step('wait_fun')
DBW.enter_step('end')
DBW.close()
DB.close()

View file

@ -3,6 +3,7 @@
import database
import argparse
import sys
import time
FUNCTION_MAP = {
'zone': database.Database.set_zone,
@ -32,6 +33,10 @@ if __name__ == '__main__':
fun = FUNCTION_MAP[args.type]
for rule in args.input:
fun(DB, rule.strip(), is_first_party=args.first_party)
fun(DB,
rule.strip(),
is_first_party=args.first_party,
updated=int(time.time()),
)
DB.close()

View file

@ -18,7 +18,7 @@ log "Retrieving rules…"
rm -f rules*/*.cache.*
dl https://easylist.to/easylist/easyprivacy.txt rules_adblock/easyprivacy.cache.txt
# From firebog.net Tracking & Telemetry Lists
dl https://v.firebog.net/hosts/Prigent-Ads.txt rules/prigent-ads.cache.list
# dl https://v.firebog.net/hosts/Prigent-Ads.txt rules/prigent-ads.cache.list
# dl https://gitlab.com/quidsup/notrack-blocklists/raw/master/notrack-blocklist.txt rules/notrack-blocklist.cache.list
# False positives: https://github.com/WaLLy3K/wally3k.github.io/issues/73 -> 69.media.tumblr.com chicdn.net
dl https://raw.githubusercontent.com/StevenBlack/hosts/master/data/add.2o7Net/hosts rules_hosts/add2o7.cache.txt

View file

@ -4,6 +4,12 @@ function log() {
echo -e "\033[33m$@\033[0m"
}
log "Pruning old data…"
./database.py --prune
log "Recounting references…"
./database.py --references
log "Exporting lists…"
./export.py --first-party --output dist/firstparty-trackers.txt
./export.py --first-party --end-chain --output dist/firstparty-only-trackers.txt
@ -11,6 +17,8 @@ log "Exporting lists…"
./export.py --end-chain --output dist/multiparty-only-trackers.txt
log "Generating hosts lists…"
./export.py --rules --count --first-party > temp/count_rules_firstparty.txt
./export.py --rules --count > temp/count_rules_multiparty.txt
function generate_hosts {
basename="$1"
description="$2"
@ -36,15 +44,16 @@ function generate_hosts {
echo "#"
echo "# Generation date: $(date -Isec)"
echo "# Generation software: eulaurarien $(git describe --tags)"
echo "# Number of source websites: TODO"
echo "# Number of source subdomains: TODO"
echo "# Number of source websites: $(wc -l temp/all_websites.list | cut -d' ' -f1)"
echo "# Number of source subdomains: $(wc -l temp/all_subdomains.list | cut -d' ' -f1)"
echo "# Number of source DNS records: ~2M + $(wc -l temp/all_resolved.json | cut -d' ' -f1)"
echo "#"
echo "# Number of known first-party trackers: TODO"
echo "# Number of first-party subdomains: $(wc -l dist/firstparty-trackers.txt | cut -d' ' -f1)"
echo "# Known first-party trackers: $(cat temp/count_rules_firstparty.txt)"
echo "# Number of first-party hostnames: $(wc -l dist/firstparty-trackers.txt | cut -d' ' -f1)"
echo "# … excluding redirected: $(wc -l dist/firstparty-only-trackers.txt | cut -d' ' -f1)"
echo "#"
echo "# Number of known multi-party trackers: TODO"
echo "# Number of multi-party subdomains: $(wc -l dist/multiparty-trackers.txt | cut -d' ' -f1)"
echo "# Known multi-party trackers: $(cat temp/count_rules_multiparty.txt)"
echo "# Number of multi-party hostnames: $(wc -l dist/multiparty-trackers.txt | cut -d' ' -f1)"
echo "# … excluding redirected: $(wc -l dist/multiparty-only-trackers.txt | cut -d' ' -f1)"
echo
sed 's|^|0.0.0.0 |' "dist/$basename.txt"

View file

@ -5,6 +5,7 @@ function log() {
}
log "Importing rules…"
BEFORE="$(date +%s)"
cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone
cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone
cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone
@ -17,3 +18,5 @@ cat rules_asn/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py as
./feed_asn.py
log "Pruning old rules…"
./database.py --prune --prune-before "$BEFORE" --prune-base

36
json_to_csv.py Executable file
View file

@ -0,0 +1,36 @@
#!/usr/bin/env python3
import argparse
import sys
import logging
import json
import csv
if __name__ == '__main__':
# Parsing arguments
log = logging.getLogger('json_to_csv')
parser = argparse.ArgumentParser(
description="TODO")
parser.add_argument(
# '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer,
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO")
parser.add_argument(
# '-i', '--output', type=argparse.FileType('wb'), default=sys.stdout.buffer,
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
help="TODO")
args = parser.parse_args()
writer = csv.writer(args.output)
for line in args.input:
data = json.loads(line)
try:
writer.writerow([
data['type'][0], # First letter, will need to do something special for AAAA
data['timestamp'],
data['name'],
data['value']])
except (KeyError, json.decoder.JSONDecodeError):
log.error('Could not parse line: %s', line)
pass

View file

@ -4,18 +4,16 @@ function log() {
echo -e "\033[33m$@\033[0m"
}
log "Preparing database…"
./database.py --expire
./fetch_resources.sh
./import_rules.sh
# TODO Fetch 'em
log "Reading PTR records…"
pv ptr.json.gz | gunzip | ./feed_dns.py
pv ptr.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py
log "Reading A records…"
pv a.json.gz | gunzip | ./feed_dns.py
pv a.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py
log "Reading CNAME records…"
pv cname.json.gz | gunzip | ./feed_dns.py
pv cname.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py
log "Pruning old data…"
./database.py --prune

View file

@ -1,21 +0,0 @@
#!/usr/bin/env python3
"""
List of regex matching first-party trackers.
"""
# Syntax: https://docs.python.org/3/library/re.html#regular-expression-syntax
REGEXES = [
r'^.+\.eulerian\.net\.$', # Eulerian
r'^.+\.criteo\.com\.$', # Criteo
r'^.+\.dnsdelegation\.io\.$', # Criteo
r'^.+\.keyade\.com\.$', # Keyade
r'^.+\.omtrdc\.net\.$', # Adobe Experience Cloud
r'^.+\.bp01\.net\.$', # NP6
r'^.+\.ati-host\.net\.$', # Xiti (AT Internet)
r'^.+\.at-o\.net\.$', # Xiti (AT Internet)
r'^.+\.edgkey\.net\.$', # Edgekey (Akamai)
r'^.+\.akaimaiedge\.net\.$', # Edgekey (Akamai)
r'^.+\.storetail\.io\.$', # Storetail (Criteo)
]

View file

@ -12,22 +12,15 @@ import queue
import sys
import threading
import typing
import csv
import time
import coloredlogs
import dns.exception
import dns.resolver
import progressbar
DNS_TIMEOUT = 5.0
NUMBER_THREADS = 512
NUMBER_TRIES = 5
# TODO All the domains don't get treated,
# so it leaves with 4-5 subdomains not resolved
glob = None
class Worker(threading.Thread):
"""
@ -59,9 +52,9 @@ class Worker(threading.Thread):
self.change_nameserver()
def resolve_subdomain(self, subdomain: str) -> typing.Optional[
typing.List[
str
]
typing.List[
dns.rrset.RRset
]
]:
"""
Returns the resolution chain of the subdomain to an A record,
@ -93,18 +86,7 @@ class Worker(threading.Thread):
except dns.name.EmptyLabel:
self.log.warning("Empty label for %s", subdomain)
return None
resolved = list()
last = len(query.response.answer) - 1
for a, answer in enumerate(query.response.answer):
if answer.rdtype == dns.rdatatype.CNAME:
assert a < last
resolved.append(answer.items[0].to_text()[:-1])
elif answer.rdtype == dns.rdatatype.A:
assert a == last
resolved.append(answer.items[0].address)
else:
assert False
return resolved
return query.response.answer
def run(self) -> None:
self.log.info("Started")
@ -124,7 +106,6 @@ class Worker(threading.Thread):
self.log.error("Gave up on %s", subdomain)
resolved = []
resolved.insert(0, subdomain)
assert isinstance(resolved, list)
self.orchestrator.results_queue.put(resolved)
@ -150,15 +131,17 @@ class Orchestrator():
def __init__(self, subdomains: typing.Iterable[str],
nameservers: typing.List[str] = None,
nb_workers: int = 1,
):
self.log = logging.getLogger('orchestrator')
self.subdomains = subdomains
self.nb_workers = nb_workers
# Use interal resolver by default
self.nameservers = nameservers or dns.resolver.Resolver().nameservers
self.subdomains_queue: queue.Queue = queue.Queue(
maxsize=NUMBER_THREADS)
maxsize=self.nb_workers)
self.results_queue: queue.Queue = queue.Queue()
self.nameservers_queue: queue.Queue = queue.Queue()
@ -179,16 +162,31 @@ class Orchestrator():
self.log.info("Finished reading subdomains")
# Send sentinel to each worker
# sentinel = None ~= EOF
for _ in range(NUMBER_THREADS):
for _ in range(self.nb_workers):
self.subdomains_queue.put(None)
def run(self) -> typing.Iterable[typing.List[str]]:
@staticmethod
def format_rrset(rrset: dns.rrset.RRset) -> typing.Iterable[str]:
if rrset.rdtype == dns.rdatatype.CNAME:
dtype = 'c'
elif rrset.rdtype == dns.rdatatype.A:
dtype = 'a'
else:
raise NotImplementedError
name = rrset.name.to_text()[:-1]
for item in rrset.items:
value = item.to_text()
if rrset.rdtype == dns.rdatatype.CNAME:
value = value[:-1]
yield f'{dtype},{int(time.time())},{name},{value}\n'
def run(self) -> typing.Iterable[str]:
"""
Yield the results.
"""
# Create workers
self.log.info("Creating workers")
for i in range(NUMBER_THREADS):
for i in range(self.nb_workers):
Worker(self, i).start()
fill_thread = threading.Thread(target=self.fill_subdomain_queue)
@ -196,10 +194,11 @@ class Orchestrator():
# Wait for one sentinel per worker
# In the meantime output results
for _ in range(NUMBER_THREADS):
result: typing.List[str]
for result in iter(self.results_queue.get, None):
yield result
for _ in range(self.nb_workers):
resolved: typing.List[dns.rrset.RRset]
for resolved in iter(self.results_queue.get, None):
for rrset in resolved:
yield from self.format_rrset(rrset)
self.log.info("Waiting for reader thread")
fill_thread.join()
@ -214,11 +213,9 @@ def main() -> None:
the last CNAME resolved and the IP adress it resolves to.
Takes as an input a filename (or nothing, for stdin),
and as an output a filename (or nothing, for stdout).
The input must be a subdomain per line, the output is a comma-sep
file with the columns source CNAME and A.
The input must be a subdomain per line, the output is a TODO
Use the file `nameservers` as the list of nameservers
to use, or else it will use the system defaults.
Also shows a nice progressbar.
"""
# Initialization
@ -236,28 +233,14 @@ def main() -> None:
parser.add_argument(
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
help="Outptut file with DNS chains")
# parser.add_argument(
# '-n', '--nameserver', type=argparse.FileType('r'),
# default='nameservers', help="File with one nameserver per line")
# parser.add_argument(
# '-j', '--workers', type=int, default=512,
# help="Number of threads to use")
parser.add_argument(
'-n', '--nameservers', default='nameservers',
help="File with one nameserver per line")
parser.add_argument(
'-j', '--workers', type=int, default=512,
help="Number of threads to use")
args = parser.parse_args()
# Progress bar
widgets = [
progressbar.Percentage(),
' ', progressbar.SimpleProgress(),
' ', progressbar.Bar(),
' ', progressbar.Timer(),
' ', progressbar.AdaptiveTransferSpeed(unit='req'),
' ', progressbar.AdaptiveETA(),
]
progress = progressbar.ProgressBar(widgets=widgets)
if args.input.seekable():
progress.max_value = len(args.input.readlines())
args.input.seek(0)
# Cleaning input
iterator = iter(args.input)
iterator = map(str.strip, iterator)
@ -265,19 +248,16 @@ def main() -> None:
# Reading nameservers
servers: typing.List[str] = list()
if os.path.isfile('nameservers'):
servers = open('nameservers').readlines()
if os.path.isfile(args.nameservers):
servers = open(args.nameservers).readlines()
servers = list(filter(None, map(str.strip, servers)))
writer = csv.writer(args.output)
progress.start()
global glob
glob = Orchestrator(iterator, servers)
for resolved in glob.run():
progress.update(progress.value + 1)
writer.writerow(resolved)
progress.finish()
for resolved in Orchestrator(
iterator,
servers,
nb_workers=args.workers
).run():
args.output.write(resolved)
if __name__ == '__main__':

View file

@ -4,11 +4,9 @@ function log() {
echo -e "\033[33m$@\033[0m"
}
# Resolve the CNAME chain of all the known subdomains for later analysis
log "Compiling subdomain lists..."
pv subdomains/*.list | sort -u > temp/all_subdomains.list
log "Compiling locally known subdomain…"
# Sort by last character to utilize the DNS server caching mechanism
pv temp/all_subdomains.list | rev | sort | rev > temp/all_subdomains_reversort.list
./resolve_subdomains.py --input temp/all_subdomains_reversort.list --output temp/all_resolved.csv
sort -u temp/all_resolved.csv > temp/all_resolved_sorted.csv
pv subdomains/*.list | sed 's/\r$//' | rev | sort -u | rev > temp/all_subdomains.list
log "Resolving locally known subdomain…"
pv temp/all_subdomains.list | ./resolve_subdomains.py --output temp/all_resolved.csv