Workflow: Multiprocess

Still trying.
It's better than multithread though.

Merge branch 'newworkflow' into newworkflow_threaded
newworkflow_threaded
Geoffrey Frogeye 2019-12-14 17:27:46 +01:00
commit 189deeb559
Signed by: geoffrey
GPG Key ID: D8A7ECA00A8CD3DD
15 changed files with 512 additions and 279 deletions

2
.gitignore vendored
View File

@ -3,5 +3,3 @@
*.db-journal
nameservers
nameservers.head
*.o
*.so

View File

@ -12,6 +12,7 @@ import logging
import argparse
import coloredlogs
import ipaddress
import math
coloredlogs.install(
level='DEBUG',
@ -22,40 +23,44 @@ DbValue = typing.Union[None, int, float, str, bytes]
class Database():
VERSION = 3
VERSION = 5
PATH = "blocking.db"
def open(self) -> None:
mode = 'rwc' if self.write else 'ro'
uri = f'file:{self.PATH}?mode={mode}'
self.conn = sqlite3.connect(uri, uri=True)
self.cursor = self.conn.cursor()
self.execute("PRAGMA foreign_keys = ON")
# self.conn.create_function("prepare_ip4address", 1,
# Database.prepare_ip4address,
# deterministic=True)
cursor = self.conn.cursor()
cursor.execute("PRAGMA foreign_keys = ON")
self.conn.create_function("unpack_asn", 1,
self.unpack_asn,
deterministic=True)
self.conn.create_function("unpack_ip4address", 1,
self.unpack_ip4address,
deterministic=True)
self.conn.create_function("unpack_ip4network", 2,
self.unpack_ip4network,
deterministic=True)
self.conn.create_function("unpack_domain", 1,
lambda s: s[:-1][::-1],
deterministic=True)
def execute(self, cmd: str, args: typing.Union[
typing.Tuple[DbValue, ...],
typing.Dict[str, DbValue]] = None) -> None:
# self.log.debug(cmd)
# self.log.debug(args)
self.cursor.execute(cmd, args or tuple())
self.conn.create_function("format_zone", 1,
lambda s: '*' + s[::-1],
deterministic=True)
def get_meta(self, key: str) -> typing.Optional[int]:
cursor = self.conn.cursor()
try:
self.execute("SELECT value FROM meta WHERE key=?", (key,))
cursor.execute("SELECT value FROM meta WHERE key=?", (key,))
except sqlite3.OperationalError:
return None
for ver, in self.cursor:
for ver, in cursor:
return ver
return None
def set_meta(self, key: str, val: int) -> None:
self.execute("INSERT INTO meta VALUES (?, ?) "
cursor = self.conn.cursor()
cursor.execute("INSERT INTO meta VALUES (?, ?) "
"ON CONFLICT (key) DO "
"UPDATE set value=?",
(key, val, val))
@ -76,8 +81,9 @@ class Database():
os.unlink(self.PATH)
self.open()
self.log.info("Creating database version %d.", self.VERSION)
cursor = self.conn.cursor()
with open("database_schema.sql", 'r') as db_schema:
self.cursor.executescript(db_schema.read())
cursor.executescript(db_schema.read())
self.set_meta('version', self.VERSION)
self.conn.commit()
@ -98,13 +104,6 @@ class Database():
version)
self.initialize()
updated = self.get_meta('updated')
if updated is None:
self.execute('SELECT max(updated) FROM rules')
data = self.cursor.fetchone()
updated, = data
self.updated = updated or 1
def enter_step(self, name: str) -> None:
now = time.perf_counter()
try:
@ -126,24 +125,32 @@ class Database():
self.log.debug(f"{'total':<20}: "
f"{total:9.2f} s ({1:7.2%})")
def prepare_hostname(self, hostname: str) -> str:
@staticmethod
def pack_hostname(hostname: str) -> str:
return hostname[::-1] + '.'
def prepare_zone(self, zone: str) -> str:
return self.prepare_hostname(zone)
@staticmethod
def pack_zone(zone: str) -> str:
return Database.pack_hostname(zone)
@staticmethod
def prepare_asn(asn: str) -> int:
def pack_asn(asn: str) -> int:
asn = asn.upper()
if asn.startswith('AS'):
asn = asn[2:]
return int(asn)
@staticmethod
def prepare_ip4address(address: str) -> int:
def unpack_asn(asn: int) -> str:
return f'AS{asn}'
@staticmethod
def pack_ip4address(address: str) -> int:
total = 0
for i, octet in enumerate(address.split('.')):
total += int(octet) << (3-i)*8
if total > 0xFFFFFFFF:
raise ValueError
return total
# return '{:02x}{:02x}{:02x}{:02x}'.format(
# *[int(c) for c in address.split('.')])
@ -158,34 +165,78 @@ class Database():
# packed = ipaddress.ip_address(address).packed
# return packed
def prepare_ip4network(self, network: str) -> typing.Tuple[int, int]:
# def prepare_ip4network(network: str) -> str:
@staticmethod
def unpack_ip4address(address: int) -> str:
return '.'.join(str((address >> (i * 8)) & 0xFF)
for i in reversed(range(4)))
@staticmethod
def pack_ip4network(network: str) -> typing.Tuple[int, int]:
# def pack_ip4network(network: str) -> str:
net = ipaddress.ip_network(network)
mini = self.prepare_ip4address(net.network_address.exploded)
maxi = self.prepare_ip4address(net.broadcast_address.exploded)
mini = Database.pack_ip4address(net.network_address.exploded)
maxi = Database.pack_ip4address(net.broadcast_address.exploded)
# mini = net.network_address.packed
# maxi = net.broadcast_address.packed
return mini, maxi
# return Database.prepare_ip4address(net.network_address.exploded)[:net.prefixlen]
# return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen]
def expire(self) -> None:
self.enter_step('expire')
self.updated += 1
self.set_meta('updated', self.updated)
@staticmethod
def unpack_ip4network(mini: int, maxi: int) -> str:
addr = Database.unpack_ip4address(mini)
prefixlen = 32-int(math.log2(maxi-mini+1))
return f'{addr}/{prefixlen}'
def update_references(self) -> None:
self.enter_step('update_refs')
self.execute('UPDATE rules AS r SET refs='
cursor = self.conn.cursor()
cursor.execute('UPDATE rules AS r SET refs='
'(SELECT count(*) FROM rules '
'WHERE source=r.id)')
def prune(self) -> None:
def prune(self, before: int, base_only: bool = False) -> None:
self.enter_step('prune')
self.execute('DELETE FROM rules WHERE updated<?', (self.updated,))
cursor = self.conn.cursor()
cmd = 'DELETE FROM rules WHERE updated<?'
if base_only:
cmd += ' AND level=0'
cursor.execute(cmd, (before,))
def export(self, first_party_only: bool = False,
end_chain_only: bool = False) -> typing.Iterable[str]:
command = 'SELECT unpack_domain(val) FROM rules ' \
def explain(self, entry: int) -> str:
# Format current
string = '???'
cursor = self.conn.cursor()
cursor.execute(
'SELECT unpack_asn(val) FROM asn WHERE entry=:entry '
'UNION '
'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry '
'UNION '
'SELECT format_zone(val) FROM zone WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4network(mini, maxi) '
'FROM ip4network WHERE entry=:entry ',
{"entry": entry}
)
for val, in cursor: # only one
string = str(val)
string += f' #{entry}'
# Add source if any
cursor.execute('SELECT source FROM rules WHERE id=?', (entry,))
for source, in cursor:
if source:
string += f'{self.explain(source)}'
return string
def export(self,
first_party_only: bool = False,
end_chain_only: bool = False,
explain: bool = False,
) -> typing.Iterable[str]:
selection = 'entry' if explain else 'unpack_domain(val)'
command = f'SELECT {selection} FROM rules ' \
'INNER JOIN hostname ON rules.id = hostname.entry'
restrictions: typing.List[str] = list()
if first_party_only:
@ -194,16 +245,40 @@ class Database():
restrictions.append('rules.refs = 0')
if restrictions:
command += ' WHERE ' + ' AND '.join(restrictions)
if not explain:
command += ' ORDER BY unpack_domain(val) ASC'
self.execute(command)
for val, in self.cursor:
cursor = self.conn.cursor()
cursor.execute(command)
for val, in cursor:
if explain:
yield self.explain(val)
else:
yield val
def count_rules(self,
first_party_only: bool = False,
) -> str:
counts: typing.List[str] = list()
cursor = self.conn.cursor()
for table in ['asn', 'ip4network', 'ip4address', 'zone', 'hostname']:
command = f'SELECT count(*) FROM rules ' \
f'INNER JOIN {table} ON rules.id = {table}.entry ' \
'WHERE rules.level = 0'
if first_party_only:
command += ' AND first_party=1'
cursor.execute(command)
count, = cursor.fetchone()
if count > 0:
counts.append(f'{table}: {count}')
return ', '.join(counts)
def get_domain(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domain_prepare')
domain_prep = self.prepare_hostname(domain)
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domain_select')
self.execute(
cursor.execute(
'SELECT null, entry FROM hostname '
'WHERE val=:d '
'UNION '
@ -214,22 +289,41 @@ class Database():
')',
{'d': domain_prep}
)
for val, entry in self.cursor:
for val, entry in cursor:
self.enter_step('get_domain_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
self.enter_step('get_domain_yield')
yield entry
def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domainiz_prepare')
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domainiz_select')
cursor.execute(
'SELECT val, entry FROM zone '
'WHERE val<=:d '
'ORDER BY val DESC LIMIT 1',
{'d': domain_prep}
)
for val, entry in cursor:
self.enter_step('get_domainiz_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
self.enter_step('get_domainiz_yield')
yield entry
def get_ip4(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4_prepare')
try:
address_prep = self.prepare_ip4address(address)
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return
cursor = self.conn.cursor()
self.enter_step('get_ip4_select')
self.execute(
cursor.execute(
'SELECT entry FROM ip4address '
# 'SELECT null, entry FROM ip4address '
'WHERE val=:a '
@ -244,7 +338,7 @@ class Database():
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for val, entry in self.cursor:
for entry, in cursor:
# self.enter_step('get_ip4_confirm')
# if not (val is None or val.startswith(address_prep)):
# # PERF startswith but from the end
@ -252,11 +346,29 @@ class Database():
self.enter_step('get_ip4_yield')
yield entry
def get_ip4_in_network(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4in_prepare')
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return
cursor = self.conn.cursor()
self.enter_step('get_ip4in_select')
cursor.execute(
'SELECT entry FROM ip4network '
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for entry, in cursor:
self.enter_step('get_ip4in_yield')
yield entry
def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]:
cursor = self.conn.cursor()
self.enter_step('list_asn_select')
self.enter_step('get_domain_select')
self.execute('SELECT val, entry FROM asn')
for val, entry in self.cursor:
cursor.execute('SELECT val, entry FROM asn')
for val, entry in cursor:
yield f'AS{val}', entry
def _set_generic(self,
@ -264,6 +376,7 @@ class Database():
select_query: str,
insert_query: str,
prep: typing.Dict[str, DbValue],
updated: int,
is_first_party: bool = False,
source: int = None,
) -> None:
@ -271,34 +384,36 @@ class Database():
# here abstraction > performaces
# Fields based on the source
self.enter_step(f'set_{table}_prepare')
cursor = self.conn.cursor()
if source is None:
first_party = int(is_first_party)
level = 0
else:
self.enter_step(f'set_{table}_source')
self.execute(
cursor.execute(
'SELECT first_party, level FROM rules '
'WHERE id=?',
(source,)
)
first_party, level = self.cursor.fetchone()
first_party, level = cursor.fetchone()
level += 1
self.enter_step(f'set_{table}_select')
self.execute(select_query, prep)
cursor.execute(select_query, prep)
rules_prep = {
rules_prep: typing.Dict[str, DbValue] = {
"source": source,
"updated": self.updated,
"updated": updated,
"first_party": first_party,
"level": level,
}
# If the entry already exists
for entry, in self.cursor: # only one
for entry, in cursor: # only one
self.enter_step(f'set_{table}_update')
rules_prep['entry'] = entry
self.execute(
cursor.execute(
'UPDATE rules SET '
'source=:source, updated=:updated, '
'first_party=:first_party, level=:level '
@ -314,23 +429,18 @@ class Database():
# If it does not exist
if source is not None:
self.enter_step(f'set_{table}_incsrc')
self.execute('UPDATE rules SET refs = refs + 1 WHERE id=?',
(source,))
self.enter_step(f'set_{table}_insert')
self.execute(
cursor.execute(
'INSERT INTO rules '
'(source, updated, first_party, refs, level) '
'VALUES (:source, :updated, :first_party, 0, :level) ',
'(source, updated, first_party, level) '
'VALUES (:source, :updated, :first_party, :level) ',
rules_prep
)
self.execute('SELECT id FROM rules WHERE rowid=?',
(self.cursor.lastrowid,))
for entry, in self.cursor: # only one
cursor.execute('SELECT id FROM rules WHERE rowid=?',
(cursor.lastrowid,))
for entry, in cursor: # only one
prep['entry'] = entry
self.execute(insert_query, prep)
cursor.execute(insert_query, prep)
return
assert False
@ -338,7 +448,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_hostname_prepare')
prep: typing.Dict[str, DbValue] = {
'val': self.prepare_hostname(hostname),
'val': self.pack_hostname(hostname),
}
self._set_generic(
'hostname',
@ -353,7 +463,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_asn_prepare')
try:
asn_prep = self.prepare_asn(asn)
asn_prep = self.pack_asn(asn)
except ValueError:
self.log.error("Invalid asn: %s", asn)
return
@ -371,10 +481,9 @@ class Database():
def set_ip4address(self, ip4address: str,
*args: typing.Any, **kwargs: typing.Any) -> None:
# TODO Do not add if already in ip4network
self.enter_step('set_ip4add_prepare')
try:
ip4address_prep = self.prepare_ip4address(ip4address)
ip4address_prep = self.pack_ip4address(ip4address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", ip4address)
return
@ -394,7 +503,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_zone_prepare')
prep: typing.Dict[str, DbValue] = {
'val': self.prepare_zone(zone),
'val': self.pack_zone(zone),
}
self._set_generic(
'zone',
@ -409,7 +518,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_ip4net_prepare')
try:
ip4network_prep = self.prepare_ip4network(ip4network)
ip4network_prep = self.pack_ip4network(ip4network)
except (ValueError, IndexError):
self.log.error("Invalid ip4network: %s", ip4network)
return
@ -439,8 +548,12 @@ if __name__ == '__main__':
'-p', '--prune', action='store_true',
help="Remove old entries from database")
parser.add_argument(
'-e', '--expire', action='store_true',
help="Set the whole database as an old source")
'-b', '--prune-base', action='store_true',
help="TODO")
parser.add_argument(
'-s', '--prune-before', type=int,
default=(int(time.time()) - 60*60*24*31*6),
help="TODO")
parser.add_argument(
'-r', '--references', action='store_true',
help="Update the reference count")
@ -451,10 +564,8 @@ if __name__ == '__main__':
if args.initialize:
DB.initialize()
if args.prune:
DB.prune()
if args.expire:
DB.expire()
if args.references and not args.prune:
DB.prune(before=args.prune_before, base_only=args.prune_base)
if args.references:
DB.update_references()
DB.close()

View File

@ -10,30 +10,37 @@ CREATE TABLE rules (
level INTEGER, -- Level of recursion to the root source rule (used for source priority)
FOREIGN KEY (source) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX rules_source ON rules (source); -- for references recounting
CREATE INDEX rules_updated ON rules (updated); -- for pruning
CREATE INDEX rules_level_firstparty ON rules (level, first_party); -- for counting rules
CREATE TABLE asn (
val INTEGER PRIMARY KEY,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX asn_entry ON asn (entry); -- for explainations
CREATE TABLE hostname (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for consistency with zone)
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX hostname_entry ON hostname (entry); -- for explainations
CREATE TABLE zone (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for easier matching)
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX zone_entry ON zone (entry); -- for explainations
CREATE TABLE ip4address (
val INTEGER PRIMARY KEY,
entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX ip4address_entry ON ip4address (entry); -- for explainations
CREATE TABLE ip4network (
-- val TEXT PRIMARY KEY,
@ -43,6 +50,7 @@ CREATE TABLE ip4network (
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
);
CREATE INDEX ip4network_minmax ON ip4network (mini, maxi);
CREATE INDEX ip4network_entry ON ip4network (entry); -- for explainations
-- Store various things
CREATE TABLE meta (

View File

@ -19,12 +19,31 @@ if __name__ == '__main__':
parser.add_argument(
'-e', '--end-chain', action='store_true',
help="TODO")
parser.add_argument(
'-x', '--explain', action='store_true',
help="TODO")
parser.add_argument(
'-r', '--rules', action='store_true',
help="TODO")
parser.add_argument(
'-c', '--count', action='store_true',
help="TODO")
args = parser.parse_args()
DB = database.Database()
for domain in DB.export(first_party_only=args.first_party,
end_chain_only=args.end_chain):
if args.rules:
if not args.count:
raise NotImplementedError
print(DB.count_rules(first_party_only=args.first_party))
else:
if args.count:
raise NotImplementedError
for domain in DB.export(
first_party_only=args.first_party,
end_chain_only=args.end_chain,
explain=args.explain,
):
print(domain, file=args.output)
DB.close()

53
feed_asn.py Executable file
View File

@ -0,0 +1,53 @@
#!/usr/bin/env python3
import database
import argparse
import requests
import typing
import ipaddress
import logging
import time
IPNetwork = typing.Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
def get_ranges(asn: str) -> typing.Iterable[str]:
req = requests.get(
'https://stat.ripe.net/data/as-routing-consistency/data.json',
params={'resource': asn}
)
data = req.json()
for pref in data['data']['prefixes']:
yield pref['prefix']
if __name__ == '__main__':
log = logging.getLogger('feed_asn')
# Parsing arguments
parser = argparse.ArgumentParser(
description="TODO")
args = parser.parse_args()
DB = database.Database()
DBW = database.Database(write=True)
for asn, entry in DB.list_asn():
DB.enter_step('asn_get_ranges')
for prefix in get_ranges(asn):
parsed_prefix: IPNetwork = ipaddress.ip_network(prefix)
if parsed_prefix.version == 4:
DBW.set_ip4network(
prefix,
source=entry,
updated=int(time.time())
)
log.info('Added %s from %s (id=%s)', prefix, asn, entry)
elif parsed_prefix.version == 6:
log.warning('Unimplemented prefix version: %s', prefix)
else:
log.error('Unknown prefix version: %s', prefix)
DB.close()
DBW.close()

176
feed_dns.py Normal file → Executable file
View File

@ -1,23 +1,43 @@
#!/usr/bin/env python3
import database
import argparse
import sys
import database
import json
import logging
import threading
import queue
import sys
import typing
import multiprocessing
NUMBER_THREADS = 8
NUMBER_THREADS = 2
BLOCK_SIZE = 100
# select, confirm, write
FUNCTION_MAP: typing.Any = {
'a': (
database.Database.get_ip4,
database.Database.get_domain_in_zone,
database.Database.set_hostname,
),
'cname': (
database.Database.get_domain,
database.Database.get_domain_in_zone,
database.Database.set_hostname,
),
'ptr': (
database.Database.get_domain,
database.Database.get_ip4_in_network,
database.Database.set_ip4address,
),
}
class Worker(threading.Thread):
class Reader(multiprocessing.Process):
def __init__(self,
lines_queue: queue.Queue,
write_queue: queue.Queue,
lines_queue: multiprocessing.Queue,
write_queue: multiprocessing.Queue,
index: int = 0):
super(Worker, self).__init__()
self.log = logging.getLogger(f'worker{index:03d}')
super(Reader, self).__init__()
self.log = logging.getLogger(f'rd{index:03d}')
self.lines_queue = lines_queue
self.write_queue = write_queue
self.index = index
@ -25,45 +45,51 @@ class Worker(threading.Thread):
def run(self) -> None:
self.db = database.Database(write=False)
self.db.log = logging.getLogger(f'db{self.index:03d}')
self.db.enter_step('wait_line')
line: str
for line in iter(self.lines_queue.get, None):
self.db.enter_step('feed_json_parse')
# split = line.split(b'"')
split = line.split('"')
self.db.enter_step('line_wait')
block: typing.List[str]
try:
name = split[7]
dtype = split[11]
value = split[15]
except IndexError:
log.error("Invalid JSON: %s", line)
continue
# DB.enter_step('feed_json_assert')
# data = json.loads(line)
# assert dtype == data['type']
# assert name == data['name']
# assert value == data['value']
for block in iter(self.lines_queue.get, None):
for line in block:
dtype, updated, name, value = line
self.db.enter_step('feed_switch')
if dtype == 'a':
for rule in self.db.get_ip4(value):
select, confirm, write = FUNCTION_MAP[dtype]
for rule in select(self.db, value):
if not any(confirm(self.db, name)):
self.db.enter_step('wait_put')
self.write_queue.put(
(database.Database.set_hostname, name, rule))
elif dtype == 'cname':
for rule in self.db.get_domain(value):
self.db.enter_step('wait_put')
self.write_queue.put(
(database.Database.set_hostname, name, rule))
elif dtype == 'ptr':
for rule in self.db.get_domain(value):
self.db.enter_step('wait_put')
self.write_queue.put(
(database.Database.set_ip4address, name, rule))
self.db.enter_step('wait_line')
self.write_queue.put((write, name, updated))
self.db.enter_step('line_wait')
except KeyboardInterrupt:
self.log.error('Interrupted')
self.db.enter_step('end')
self.db.close()
class Writer(multiprocessing.Process):
def __init__(self,
write_queue: multiprocessing.Queue,
):
super(Writer, self).__init__()
self.log = logging.getLogger(f'wr ')
self.write_queue = write_queue
def run(self) -> None:
self.db = database.Database(write=True)
self.db.log = logging.getLogger(f'dbw ')
self.db.enter_step('line_wait')
block: typing.List[str]
try:
fun: typing.Callable
name: str
updated: int
for fun, name, updated in iter(self.write_queue.get, None):
self.db.enter_step('exec')
fun(self.db, name, updated)
self.db.enter_step('line_wait')
except KeyboardInterrupt:
self.log.error('Interrupted')
self.db.enter_step('end')
self.write_queue.put(None)
self.db.close()
@ -80,42 +106,52 @@ if __name__ == '__main__':
args = parser.parse_args()
DB = database.Database(write=False) # Not needed, just for timing
DB.log = logging.getLogger('dbf')
DBW = database.Database(write=True)
DBW.log = logging.getLogger('dbw')
DB.log = logging.getLogger('db ')
lines_queue: queue.Queue = queue.Queue(maxsize=NUMBER_THREADS)
write_queue: queue.Queue = queue.Queue(maxsize=NUMBER_THREADS)
lines_queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=100)
write_queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=100)
def fill_lines_queue() -> None:
DB.enter_step('proc_create')
readers: typing.List[Reader] = list()
for w in range(NUMBER_THREADS):
readers.append(Reader(lines_queue, write_queue, w))
writer = Writer(write_queue)
DB.enter_step('proc_start')
for reader in readers:
reader.start()
writer.start()
try:
block: typing.List[str] = list()
DB.enter_step('iowait')
for line in args.input:
DB.enter_step('block_append')
DB.enter_step('feed_json_parse')
data = json.loads(line)
line = (data['type'],
int(data['timestamp']),
data['name'],
data['value'])
block.append(line)
if len(block) >= BLOCK_SIZE:
DB.enter_step('wait_put')
lines_queue.put(line)
lines_queue.put(block)
block = list()
DB.enter_step('iowait')
DB.enter_step('wait_put')
lines_queue.put(block)
DB.enter_step('end_put')
for _ in range(NUMBER_THREADS):
lines_queue.put(None)
write_queue.put(None)
for w in range(NUMBER_THREADS):
Worker(lines_queue, write_queue, w).start()
DB.enter_step('proc_join')
for reader in readers:
reader.join()
writer.join()
except KeyboardInterrupt:
log.error('Interrupted')
threading.Thread(target=fill_lines_queue).start()
for _ in range(NUMBER_THREADS):
fun: typing.Callable
name: str
source: int
DBW.enter_step('wait_fun')
for fun, name, source in iter(write_queue.get, None):
DBW.enter_step('exec_fun')
fun(DBW, name, source=source)
DBW.enter_step('commit')
DBW.conn.commit()
DBW.enter_step('wait_fun')
DBW.enter_step('end')
DBW.close()
DB.close()

View File

@ -3,6 +3,7 @@
import database
import argparse
import sys
import time
FUNCTION_MAP = {
'zone': database.Database.set_zone,
@ -32,6 +33,10 @@ if __name__ == '__main__':
fun = FUNCTION_MAP[args.type]
for rule in args.input:
fun(DB, rule.strip(), is_first_party=args.first_party)
fun(DB,
rule.strip(),
is_first_party=args.first_party,
updated=int(time.time()),
)
DB.close()

View File

@ -18,7 +18,7 @@ log "Retrieving rules…"
rm -f rules*/*.cache.*
dl https://easylist.to/easylist/easyprivacy.txt rules_adblock/easyprivacy.cache.txt
# From firebog.net Tracking & Telemetry Lists
dl https://v.firebog.net/hosts/Prigent-Ads.txt rules/prigent-ads.cache.list
# dl https://v.firebog.net/hosts/Prigent-Ads.txt rules/prigent-ads.cache.list
# dl https://gitlab.com/quidsup/notrack-blocklists/raw/master/notrack-blocklist.txt rules/notrack-blocklist.cache.list
# False positives: https://github.com/WaLLy3K/wally3k.github.io/issues/73 -> 69.media.tumblr.com chicdn.net
dl https://raw.githubusercontent.com/StevenBlack/hosts/master/data/add.2o7Net/hosts rules_hosts/add2o7.cache.txt

View File

@ -4,6 +4,12 @@ function log() {
echo -e "\033[33m$@\033[0m"
}
log "Pruning old data…"
./database.py --prune
log "Recounting references…"
./database.py --references
log "Exporting lists…"
./export.py --first-party --output dist/firstparty-trackers.txt
./export.py --first-party --end-chain --output dist/firstparty-only-trackers.txt
@ -11,6 +17,8 @@ log "Exporting lists…"
./export.py --end-chain --output dist/multiparty-only-trackers.txt
log "Generating hosts lists…"
./export.py --rules --count --first-party > temp/count_rules_firstparty.txt
./export.py --rules --count > temp/count_rules_multiparty.txt
function generate_hosts {
basename="$1"
description="$2"
@ -36,15 +44,16 @@ function generate_hosts {
echo "#"
echo "# Generation date: $(date -Isec)"
echo "# Generation software: eulaurarien $(git describe --tags)"
echo "# Number of source websites: TODO"
echo "# Number of source subdomains: TODO"
echo "# Number of source websites: $(wc -l temp/all_websites.list | cut -d' ' -f1)"
echo "# Number of source subdomains: $(wc -l temp/all_subdomains.list | cut -d' ' -f1)"
echo "# Number of source DNS records: ~2M + $(wc -l temp/all_resolved.json | cut -d' ' -f1)"
echo "#"
echo "# Number of known first-party trackers: TODO"
echo "# Number of first-party subdomains: $(wc -l dist/firstparty-trackers.txt | cut -d' ' -f1)"
echo "# Known first-party trackers: $(cat temp/count_rules_firstparty.txt)"
echo "# Number of first-party hostnames: $(wc -l dist/firstparty-trackers.txt | cut -d' ' -f1)"
echo "# … excluding redirected: $(wc -l dist/firstparty-only-trackers.txt | cut -d' ' -f1)"
echo "#"
echo "# Number of known multi-party trackers: TODO"
echo "# Number of multi-party subdomains: $(wc -l dist/multiparty-trackers.txt | cut -d' ' -f1)"
echo "# Known multi-party trackers: $(cat temp/count_rules_multiparty.txt)"
echo "# Number of multi-party hostnames: $(wc -l dist/multiparty-trackers.txt | cut -d' ' -f1)"
echo "# … excluding redirected: $(wc -l dist/multiparty-only-trackers.txt | cut -d' ' -f1)"
echo
sed 's|^|0.0.0.0 |' "dist/$basename.txt"

View File

@ -5,6 +5,7 @@ function log() {
}
log "Importing rules…"
BEFORE="$(date +%s)"
cat rules_adblock/*.txt | grep -v '^!' | grep -v '^\[Adblock' | ./adblock_to_domain_list.py | ./feed_rules.py zone
cat rules_hosts/*.txt | grep -v '^#' | grep -v '^$' | cut -d ' ' -f2 | ./feed_rules.py zone
cat rules/*.list | grep -v '^#' | grep -v '^$' | ./feed_rules.py zone
@ -17,3 +18,5 @@ cat rules_asn/first-party.txt | grep -v '^#' | grep -v '^$' | ./feed_rules.py as
./feed_asn.py
log "Pruning old rules…"
./database.py --prune --prune-before "$BEFORE" --prune-base

36
json_to_csv.py Executable file
View File

@ -0,0 +1,36 @@
#!/usr/bin/env python3
import argparse
import sys
import logging
import json
import csv
if __name__ == '__main__':
# Parsing arguments
log = logging.getLogger('json_to_csv')
parser = argparse.ArgumentParser(
description="TODO")
parser.add_argument(
# '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer,
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
help="TODO")
parser.add_argument(
# '-i', '--output', type=argparse.FileType('wb'), default=sys.stdout.buffer,
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
help="TODO")
args = parser.parse_args()
writer = csv.writer(args.output)
for line in args.input:
data = json.loads(line)
try:
writer.writerow([
data['type'][0], # First letter, will need to do something special for AAAA
data['timestamp'],
data['name'],
data['value']])
except (KeyError, json.decoder.JSONDecodeError):
log.error('Could not parse line: %s', line)
pass

View File

@ -4,18 +4,16 @@ function log() {
echo -e "\033[33m$@\033[0m"
}
log "Preparing database…"
./database.py --expire
./fetch_resources.sh
./import_rules.sh
# TODO Fetch 'em
log "Reading PTR records…"
pv ptr.json.gz | gunzip | ./feed_dns.py
pv ptr.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py
log "Reading A records…"
pv a.json.gz | gunzip | ./feed_dns.py
pv a.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py
log "Reading CNAME records…"
pv cname.json.gz | gunzip | ./feed_dns.py
pv cname.json.gz | gunzip | ./json_to_csv.py | ./feed_dns.py
log "Pruning old data…"
./database.py --prune

View File

@ -1,21 +0,0 @@
#!/usr/bin/env python3
"""
List of regex matching first-party trackers.
"""
# Syntax: https://docs.python.org/3/library/re.html#regular-expression-syntax
REGEXES = [
r'^.+\.eulerian\.net\.$', # Eulerian
r'^.+\.criteo\.com\.$', # Criteo
r'^.+\.dnsdelegation\.io\.$', # Criteo
r'^.+\.keyade\.com\.$', # Keyade
r'^.+\.omtrdc\.net\.$', # Adobe Experience Cloud
r'^.+\.bp01\.net\.$', # NP6
r'^.+\.ati-host\.net\.$', # Xiti (AT Internet)
r'^.+\.at-o\.net\.$', # Xiti (AT Internet)
r'^.+\.edgkey\.net\.$', # Edgekey (Akamai)
r'^.+\.akaimaiedge\.net\.$', # Edgekey (Akamai)
r'^.+\.storetail\.io\.$', # Storetail (Criteo)
]

View File

@ -12,22 +12,15 @@ import queue
import sys
import threading
import typing
import csv
import time
import coloredlogs
import dns.exception
import dns.resolver
import progressbar
DNS_TIMEOUT = 5.0
NUMBER_THREADS = 512
NUMBER_TRIES = 5
# TODO All the domains don't get treated,
# so it leaves with 4-5 subdomains not resolved
glob = None
class Worker(threading.Thread):
"""
@ -60,7 +53,7 @@ class Worker(threading.Thread):
def resolve_subdomain(self, subdomain: str) -> typing.Optional[
typing.List[
str
dns.rrset.RRset
]
]:
"""
@ -93,18 +86,7 @@ class Worker(threading.Thread):
except dns.name.EmptyLabel:
self.log.warning("Empty label for %s", subdomain)
return None
resolved = list()
last = len(query.response.answer) - 1
for a, answer in enumerate(query.response.answer):
if answer.rdtype == dns.rdatatype.CNAME:
assert a < last
resolved.append(answer.items[0].to_text()[:-1])
elif answer.rdtype == dns.rdatatype.A:
assert a == last
resolved.append(answer.items[0].address)
else:
assert False
return resolved
return query.response.answer
def run(self) -> None:
self.log.info("Started")
@ -124,7 +106,6 @@ class Worker(threading.Thread):
self.log.error("Gave up on %s", subdomain)
resolved = []
resolved.insert(0, subdomain)
assert isinstance(resolved, list)
self.orchestrator.results_queue.put(resolved)
@ -150,15 +131,17 @@ class Orchestrator():
def __init__(self, subdomains: typing.Iterable[str],
nameservers: typing.List[str] = None,
nb_workers: int = 1,
):
self.log = logging.getLogger('orchestrator')
self.subdomains = subdomains
self.nb_workers = nb_workers
# Use interal resolver by default
self.nameservers = nameservers or dns.resolver.Resolver().nameservers
self.subdomains_queue: queue.Queue = queue.Queue(
maxsize=NUMBER_THREADS)
maxsize=self.nb_workers)
self.results_queue: queue.Queue = queue.Queue()
self.nameservers_queue: queue.Queue = queue.Queue()
@ -179,16 +162,31 @@ class Orchestrator():
self.log.info("Finished reading subdomains")
# Send sentinel to each worker
# sentinel = None ~= EOF
for _ in range(NUMBER_THREADS):
for _ in range(self.nb_workers):
self.subdomains_queue.put(None)
def run(self) -> typing.Iterable[typing.List[str]]:
@staticmethod
def format_rrset(rrset: dns.rrset.RRset) -> typing.Iterable[str]:
if rrset.rdtype == dns.rdatatype.CNAME:
dtype = 'c'
elif rrset.rdtype == dns.rdatatype.A:
dtype = 'a'
else:
raise NotImplementedError
name = rrset.name.to_text()[:-1]
for item in rrset.items:
value = item.to_text()
if rrset.rdtype == dns.rdatatype.CNAME:
value = value[:-1]
yield f'{dtype},{int(time.time())},{name},{value}\n'
def run(self) -> typing.Iterable[str]:
"""
Yield the results.
"""
# Create workers
self.log.info("Creating workers")
for i in range(NUMBER_THREADS):
for i in range(self.nb_workers):
Worker(self, i).start()
fill_thread = threading.Thread(target=self.fill_subdomain_queue)
@ -196,10 +194,11 @@ class Orchestrator():
# Wait for one sentinel per worker
# In the meantime output results
for _ in range(NUMBER_THREADS):
result: typing.List[str]
for result in iter(self.results_queue.get, None):
yield result
for _ in range(self.nb_workers):
resolved: typing.List[dns.rrset.RRset]
for resolved in iter(self.results_queue.get, None):
for rrset in resolved:
yield from self.format_rrset(rrset)
self.log.info("Waiting for reader thread")
fill_thread.join()
@ -214,11 +213,9 @@ def main() -> None:
the last CNAME resolved and the IP adress it resolves to.
Takes as an input a filename (or nothing, for stdin),
and as an output a filename (or nothing, for stdout).
The input must be a subdomain per line, the output is a comma-sep
file with the columns source CNAME and A.
The input must be a subdomain per line, the output is a TODO
Use the file `nameservers` as the list of nameservers
to use, or else it will use the system defaults.
Also shows a nice progressbar.
"""
# Initialization
@ -236,28 +233,14 @@ def main() -> None:
parser.add_argument(
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
help="Outptut file with DNS chains")
# parser.add_argument(
# '-n', '--nameserver', type=argparse.FileType('r'),
# default='nameservers', help="File with one nameserver per line")
# parser.add_argument(
# '-j', '--workers', type=int, default=512,
# help="Number of threads to use")
parser.add_argument(
'-n', '--nameservers', default='nameservers',
help="File with one nameserver per line")
parser.add_argument(
'-j', '--workers', type=int, default=512,
help="Number of threads to use")
args = parser.parse_args()
# Progress bar
widgets = [
progressbar.Percentage(),
' ', progressbar.SimpleProgress(),
' ', progressbar.Bar(),
' ', progressbar.Timer(),
' ', progressbar.AdaptiveTransferSpeed(unit='req'),
' ', progressbar.AdaptiveETA(),
]
progress = progressbar.ProgressBar(widgets=widgets)
if args.input.seekable():
progress.max_value = len(args.input.readlines())
args.input.seek(0)
# Cleaning input
iterator = iter(args.input)
iterator = map(str.strip, iterator)
@ -265,19 +248,16 @@ def main() -> None:
# Reading nameservers
servers: typing.List[str] = list()
if os.path.isfile('nameservers'):
servers = open('nameservers').readlines()
if os.path.isfile(args.nameservers):
servers = open(args.nameservers).readlines()
servers = list(filter(None, map(str.strip, servers)))
writer = csv.writer(args.output)
progress.start()
global glob
glob = Orchestrator(iterator, servers)
for resolved in glob.run():
progress.update(progress.value + 1)
writer.writerow(resolved)
progress.finish()
for resolved in Orchestrator(
iterator,
servers,
nb_workers=args.workers
).run():
args.output.write(resolved)
if __name__ == '__main__':

View File

@ -4,11 +4,9 @@ function log() {
echo -e "\033[33m$@\033[0m"
}
# Resolve the CNAME chain of all the known subdomains for later analysis
log "Compiling subdomain lists..."
pv subdomains/*.list | sort -u > temp/all_subdomains.list
log "Compiling locally known subdomain…"
# Sort by last character to utilize the DNS server caching mechanism
pv temp/all_subdomains.list | rev | sort | rev > temp/all_subdomains_reversort.list
./resolve_subdomains.py --input temp/all_subdomains_reversort.list --output temp/all_resolved.csv
sort -u temp/all_resolved.csv > temp/all_resolved_sorted.csv
pv subdomains/*.list | sed 's/\r$//' | rev | sort -u | rev > temp/all_subdomains.list
log "Resolving locally known subdomain…"
pv temp/all_subdomains.list | ./resolve_subdomains.py --output temp/all_resolved.csv