Compare commits

...

5 commits

Author SHA1 Message Date
Geoffrey Frogeye 269b8278b5
Worflow: Fixed rules counts 2019-12-13 18:36:08 +01:00
Geoffrey Frogeye ab7ef609dd
Workflow: Various optimisations and fixes
I forgot to close this one earlier, so:
Closes #7
2019-12-13 18:08:22 +01:00
Geoffrey Frogeye f3eedcba22
Updated now based on timestamp
Did I forget to add feed_asn.py a few commits ago?
Oh well...
2019-12-13 13:54:00 +01:00
Geoffrey Frogeye 8d94b80fd0
Integrated DNS resolving to workflow
Since the bigger datasets are only updated once a month,
this might help for quick updates.
2019-12-13 13:38:23 +01:00
Geoffrey Frogeye 9050a84670
Read-only mode 2019-12-13 12:35:05 +01:00
11 changed files with 368 additions and 209 deletions

View file

@ -12,7 +12,7 @@ import logging
import argparse import argparse
import coloredlogs import coloredlogs
import ipaddress import ipaddress
import ctypes import math
coloredlogs.install( coloredlogs.install(
level='DEBUG', level='DEBUG',
@ -23,36 +23,44 @@ DbValue = typing.Union[None, int, float, str, bytes]
class Database(): class Database():
VERSION = 3 VERSION = 5
PATH = "blocking.db" PATH = "blocking.db"
def open(self) -> None: def open(self) -> None:
self.conn = sqlite3.connect(self.PATH) mode = 'rwc' if self.write else 'ro'
self.cursor = self.conn.cursor() uri = f'file:{self.PATH}?mode={mode}'
self.execute("PRAGMA foreign_keys = ON") self.conn = sqlite3.connect(uri, uri=True)
# self.conn.create_function("prepare_ip4address", 1, cursor = self.conn.cursor()
# Database.prepare_ip4address, cursor.execute("PRAGMA foreign_keys = ON")
# deterministic=True) self.conn.create_function("unpack_asn", 1,
self.unpack_asn,
deterministic=True)
self.conn.create_function("unpack_ip4address", 1,
self.unpack_ip4address,
deterministic=True)
self.conn.create_function("unpack_ip4network", 2,
self.unpack_ip4network,
deterministic=True)
self.conn.create_function("unpack_domain", 1, self.conn.create_function("unpack_domain", 1,
lambda s: s[:-1][::-1], lambda s: s[:-1][::-1],
deterministic=True) deterministic=True)
self.conn.create_function("format_zone", 1,
def execute(self, cmd: str, args: typing.Union[ lambda s: '*' + s[::-1],
typing.Tuple[DbValue, ...], deterministic=True)
typing.Dict[str, DbValue]] = None) -> None:
self.cursor.execute(cmd, args or tuple())
def get_meta(self, key: str) -> typing.Optional[int]: def get_meta(self, key: str) -> typing.Optional[int]:
cursor = self.conn.cursor()
try: try:
self.execute("SELECT value FROM meta WHERE key=?", (key,)) cursor.execute("SELECT value FROM meta WHERE key=?", (key,))
except sqlite3.OperationalError: except sqlite3.OperationalError:
return None return None
for ver, in self.cursor: for ver, in cursor:
return ver return ver
return None return None
def set_meta(self, key: str, val: int) -> None: def set_meta(self, key: str, val: int) -> None:
self.execute("INSERT INTO meta VALUES (?, ?) " cursor = self.conn.cursor()
cursor.execute("INSERT INTO meta VALUES (?, ?) "
"ON CONFLICT (key) DO " "ON CONFLICT (key) DO "
"UPDATE set value=?", "UPDATE set value=?",
(key, val, val)) (key, val, val))
@ -65,23 +73,27 @@ class Database():
self.profile() self.profile()
def initialize(self) -> None: def initialize(self) -> None:
self.enter_step('initialize')
self.close() self.close()
self.enter_step('initialize')
if not self.write:
self.log.error("Cannot initialize in read-only mode.")
raise
os.unlink(self.PATH) os.unlink(self.PATH)
self.open() self.open()
self.log.info("Creating database version %d.", self.VERSION) self.log.info("Creating database version %d.", self.VERSION)
cursor = self.conn.cursor()
with open("database_schema.sql", 'r') as db_schema: with open("database_schema.sql", 'r') as db_schema:
self.cursor.executescript(db_schema.read()) cursor.executescript(db_schema.read())
self.set_meta('version', self.VERSION) self.set_meta('version', self.VERSION)
self.conn.commit() self.conn.commit()
def __init__(self) -> None: def __init__(self, write: bool = False) -> None:
self.log = logging.getLogger('db') self.log = logging.getLogger('db')
self.time_last = time.perf_counter() self.time_last = time.perf_counter()
self.time_step = 'init' self.time_step = 'init'
self.time_dict: typing.Dict[str, float] = dict() self.time_dict: typing.Dict[str, float] = dict()
self.step_dict: typing.Dict[str, int] = dict() self.step_dict: typing.Dict[str, int] = dict()
self.accel_ip4_buf = ctypes.create_unicode_buffer('Z'*32, 32) self.write = write
self.open() self.open()
version = self.get_meta('version') version = self.get_meta('version')
@ -92,13 +104,6 @@ class Database():
version) version)
self.initialize() self.initialize()
updated = self.get_meta('updated')
if updated is None:
self.execute('SELECT max(updated) FROM rules')
data = self.cursor.fetchone()
updated, = data
self.updated = updated or 1
def enter_step(self, name: str) -> None: def enter_step(self, name: str) -> None:
now = time.perf_counter() now = time.perf_counter()
try: try:
@ -120,21 +125,27 @@ class Database():
self.log.debug(f"{'total':<20}: " self.log.debug(f"{'total':<20}: "
f"{total:9.2f} s ({1:7.2%})") f"{total:9.2f} s ({1:7.2%})")
def prepare_hostname(self, hostname: str) -> str: @staticmethod
def pack_hostname(hostname: str) -> str:
return hostname[::-1] + '.' return hostname[::-1] + '.'
def prepare_zone(self, zone: str) -> str: @staticmethod
return self.prepare_hostname(zone) def pack_zone(zone: str) -> str:
return Database.pack_hostname(zone)
@staticmethod @staticmethod
def prepare_asn(asn: str) -> int: def pack_asn(asn: str) -> int:
asn = asn.upper() asn = asn.upper()
if asn.startswith('AS'): if asn.startswith('AS'):
asn = asn[2:] asn = asn[2:]
return int(asn) return int(asn)
@staticmethod @staticmethod
def prepare_ip4address(address: str) -> int: def unpack_asn(asn: int) -> str:
return f'AS{asn}'
@staticmethod
def pack_ip4address(address: str) -> int:
total = 0 total = 0
for i, octet in enumerate(address.split('.')): for i, octet in enumerate(address.split('.')):
total += int(octet) << (3-i)*8 total += int(octet) << (3-i)*8
@ -152,34 +163,75 @@ class Database():
# packed = ipaddress.ip_address(address).packed # packed = ipaddress.ip_address(address).packed
# return packed # return packed
def prepare_ip4network(self, network: str) -> typing.Tuple[int, int]: @staticmethod
# def prepare_ip4network(network: str) -> str: def unpack_ip4address(address: int) -> str:
return '.'.join(str((address >> (i * 8)) & 0xFF)
for i in reversed(range(4)))
@staticmethod
def pack_ip4network(network: str) -> typing.Tuple[int, int]:
# def pack_ip4network(network: str) -> str:
net = ipaddress.ip_network(network) net = ipaddress.ip_network(network)
mini = self.prepare_ip4address(net.network_address.exploded) mini = Database.pack_ip4address(net.network_address.exploded)
maxi = self.prepare_ip4address(net.broadcast_address.exploded) maxi = Database.pack_ip4address(net.broadcast_address.exploded)
# mini = net.network_address.packed # mini = net.network_address.packed
# maxi = net.broadcast_address.packed # maxi = net.broadcast_address.packed
return mini, maxi return mini, maxi
# return Database.prepare_ip4address(net.network_address.exploded)[:net.prefixlen] # return Database.pack_ip4address(net.network_address.exploded)[:net.prefixlen]
def expire(self) -> None: @staticmethod
self.enter_step('expire') def unpack_ip4network(mini: int, maxi: int) -> str:
self.updated += 1 addr = Database.unpack_ip4address(mini)
self.set_meta('updated', self.updated) prefixlen = 32-int(math.log2(maxi-mini+1))
return f'{addr}/{prefixlen}'
def update_references(self) -> None: def update_references(self) -> None:
self.enter_step('update_refs') self.enter_step('update_refs')
self.execute('UPDATE rules AS r SET refs=' cursor = self.conn.cursor()
cursor.execute('UPDATE rules AS r SET refs='
'(SELECT count(*) FROM rules ' '(SELECT count(*) FROM rules '
'WHERE source=r.id)') 'WHERE source=r.id)')
def prune(self) -> None: def prune(self, before: int) -> None:
self.enter_step('prune') self.enter_step('prune')
self.execute('DELETE FROM rules WHERE updated<?', (self.updated,)) cursor = self.conn.cursor()
cursor.execute('DELETE FROM rules WHERE updated<?', (before,))
def export(self, first_party_only: bool = False, def explain(self, entry: int) -> str:
end_chain_only: bool = False) -> typing.Iterable[str]: # Format current
command = 'SELECT unpack_domain(val) FROM rules ' \ string = '???'
cursor = self.conn.cursor()
cursor.execute(
'SELECT unpack_asn(val) FROM asn WHERE entry=:entry '
'UNION '
'SELECT unpack_domain(val) FROM hostname WHERE entry=:entry '
'UNION '
'SELECT format_zone(val) FROM zone WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4address(val) FROM ip4address WHERE entry=:entry '
'UNION '
'SELECT unpack_ip4network(mini, maxi) '
'FROM ip4network WHERE entry=:entry ',
{"entry": entry}
)
for val, in cursor: # only one
string = str(val)
string += f' #{entry}'
# Add source if any
cursor.execute('SELECT source FROM rules WHERE id=?', (entry,))
for source, in cursor:
if source:
string += f'{self.explain(source)}'
return string
def export(self,
first_party_only: bool = False,
end_chain_only: bool = False,
explain: bool = False,
) -> typing.Iterable[str]:
selection = 'entry' if explain else 'unpack_domain(val)'
command = f'SELECT {selection} FROM rules ' \
'INNER JOIN hostname ON rules.id = hostname.entry' 'INNER JOIN hostname ON rules.id = hostname.entry'
restrictions: typing.List[str] = list() restrictions: typing.List[str] = list()
if first_party_only: if first_party_only:
@ -188,16 +240,40 @@ class Database():
restrictions.append('rules.refs = 0') restrictions.append('rules.refs = 0')
if restrictions: if restrictions:
command += ' WHERE ' + ' AND '.join(restrictions) command += ' WHERE ' + ' AND '.join(restrictions)
if not explain:
command += ' ORDER BY unpack_domain(val) ASC' command += ' ORDER BY unpack_domain(val) ASC'
self.execute(command) cursor = self.conn.cursor()
for val, in self.cursor: cursor.execute(command)
for val, in cursor:
if explain:
yield self.explain(val)
else:
yield val yield val
def count_rules(self,
first_party_only: bool = False,
) -> str:
counts: typing.List[str] = list()
cursor = self.conn.cursor()
for table in ['asn', 'ip4network', 'ip4address', 'zone', 'hostname']:
command = f'SELECT count(*) FROM rules ' \
f'INNER JOIN {table} ON rules.id = {table}.entry ' \
'WHERE rules.level = 0'
if first_party_only:
command += ' AND first_party=1'
cursor.execute(command)
count, = cursor.fetchone()
if count > 0:
counts.append(f'{table}: {count}')
return ', '.join(counts)
def get_domain(self, domain: str) -> typing.Iterable[int]: def get_domain(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domain_prepare') self.enter_step('get_domain_prepare')
domain_prep = self.prepare_hostname(domain) domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domain_select') self.enter_step('get_domain_select')
self.execute( cursor.execute(
'SELECT null, entry FROM hostname ' 'SELECT null, entry FROM hostname '
'WHERE val=:d ' 'WHERE val=:d '
'UNION ' 'UNION '
@ -208,22 +284,41 @@ class Database():
')', ')',
{'d': domain_prep} {'d': domain_prep}
) )
for val, entry in self.cursor: for val, entry in cursor:
self.enter_step('get_domain_confirm') self.enter_step('get_domain_confirm')
if not (val is None or domain_prep.startswith(val)): if not (val is None or domain_prep.startswith(val)):
continue continue
self.enter_step('get_domain_yield') self.enter_step('get_domain_yield')
yield entry yield entry
def get_domain_in_zone(self, domain: str) -> typing.Iterable[int]:
self.enter_step('get_domainiz_prepare')
domain_prep = self.pack_hostname(domain)
cursor = self.conn.cursor()
self.enter_step('get_domainiz_select')
cursor.execute(
'SELECT val, entry FROM zone '
'WHERE val<=:d '
'ORDER BY val DESC LIMIT 1',
{'d': domain_prep}
)
for val, entry in cursor:
self.enter_step('get_domainiz_confirm')
if not (val is None or domain_prep.startswith(val)):
continue
self.enter_step('get_domainiz_yield')
yield entry
def get_ip4(self, address: str) -> typing.Iterable[int]: def get_ip4(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4_prepare') self.enter_step('get_ip4_prepare')
try: try:
address_prep = self.prepare_ip4address(address) address_prep = self.pack_ip4address(address)
except (ValueError, IndexError): except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address) self.log.error("Invalid ip4address: %s", address)
return return
cursor = self.conn.cursor()
self.enter_step('get_ip4_select') self.enter_step('get_ip4_select')
self.execute( cursor.execute(
'SELECT entry FROM ip4address ' 'SELECT entry FROM ip4address '
# 'SELECT null, entry FROM ip4address ' # 'SELECT null, entry FROM ip4address '
'WHERE val=:a ' 'WHERE val=:a '
@ -238,7 +333,7 @@ class Database():
'WHERE :a BETWEEN mini AND maxi ', 'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep} {'a': address_prep}
) )
for val, entry in self.cursor: for entry, in cursor:
# self.enter_step('get_ip4_confirm') # self.enter_step('get_ip4_confirm')
# if not (val is None or val.startswith(address_prep)): # if not (val is None or val.startswith(address_prep)):
# # PERF startswith but from the end # # PERF startswith but from the end
@ -246,11 +341,29 @@ class Database():
self.enter_step('get_ip4_yield') self.enter_step('get_ip4_yield')
yield entry yield entry
def get_ip4_in_network(self, address: str) -> typing.Iterable[int]:
self.enter_step('get_ip4in_prepare')
try:
address_prep = self.pack_ip4address(address)
except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", address)
return
cursor = self.conn.cursor()
self.enter_step('get_ip4in_select')
cursor.execute(
'SELECT entry FROM ip4network '
'WHERE :a BETWEEN mini AND maxi ',
{'a': address_prep}
)
for entry, in cursor:
self.enter_step('get_ip4in_yield')
yield entry
def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]: def list_asn(self) -> typing.Iterable[typing.Tuple[str, int]]:
cursor = self.conn.cursor()
self.enter_step('list_asn_select') self.enter_step('list_asn_select')
self.enter_step('get_domain_select') cursor.execute('SELECT val, entry FROM asn')
self.execute('SELECT val, entry FROM asn') for val, entry in cursor:
for val, entry in self.cursor:
yield f'AS{val}', entry yield f'AS{val}', entry
def _set_generic(self, def _set_generic(self,
@ -258,6 +371,7 @@ class Database():
select_query: str, select_query: str,
insert_query: str, insert_query: str,
prep: typing.Dict[str, DbValue], prep: typing.Dict[str, DbValue],
updated: int,
is_first_party: bool = False, is_first_party: bool = False,
source: int = None, source: int = None,
) -> None: ) -> None:
@ -265,34 +379,36 @@ class Database():
# here abstraction > performaces # here abstraction > performaces
# Fields based on the source # Fields based on the source
self.enter_step(f'set_{table}_prepare')
cursor = self.conn.cursor()
if source is None: if source is None:
first_party = int(is_first_party) first_party = int(is_first_party)
level = 0 level = 0
else: else:
self.enter_step(f'set_{table}_source') self.enter_step(f'set_{table}_source')
self.execute( cursor.execute(
'SELECT first_party, level FROM rules ' 'SELECT first_party, level FROM rules '
'WHERE id=?', 'WHERE id=?',
(source,) (source,)
) )
first_party, level = self.cursor.fetchone() first_party, level = cursor.fetchone()
level += 1 level += 1
self.enter_step(f'set_{table}_select') self.enter_step(f'set_{table}_select')
self.execute(select_query, prep) cursor.execute(select_query, prep)
rules_prep = { rules_prep: typing.Dict[str, DbValue] = {
"source": source, "source": source,
"updated": self.updated, "updated": updated,
"first_party": first_party, "first_party": first_party,
"level": level, "level": level,
} }
# If the entry already exists # If the entry already exists
for entry, in self.cursor: # only one for entry, in cursor: # only one
self.enter_step(f'set_{table}_update') self.enter_step(f'set_{table}_update')
rules_prep['entry'] = entry rules_prep['entry'] = entry
self.execute( cursor.execute(
'UPDATE rules SET ' 'UPDATE rules SET '
'source=:source, updated=:updated, ' 'source=:source, updated=:updated, '
'first_party=:first_party, level=:level ' 'first_party=:first_party, level=:level '
@ -308,23 +424,18 @@ class Database():
# If it does not exist # If it does not exist
if source is not None:
self.enter_step(f'set_{table}_incsrc')
self.execute('UPDATE rules SET refs = refs + 1 WHERE id=?',
(source,))
self.enter_step(f'set_{table}_insert') self.enter_step(f'set_{table}_insert')
self.execute( cursor.execute(
'INSERT INTO rules ' 'INSERT INTO rules '
'(source, updated, first_party, refs, level) ' '(source, updated, first_party, level) '
'VALUES (:source, :updated, :first_party, 0, :level) ', 'VALUES (:source, :updated, :first_party, :level) ',
rules_prep rules_prep
) )
self.execute('SELECT id FROM rules WHERE rowid=?', cursor.execute('SELECT id FROM rules WHERE rowid=?',
(self.cursor.lastrowid,)) (cursor.lastrowid,))
for entry, in self.cursor: # only one for entry, in cursor: # only one
prep['entry'] = entry prep['entry'] = entry
self.execute(insert_query, prep) cursor.execute(insert_query, prep)
return return
assert False assert False
@ -332,7 +443,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None: *args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_hostname_prepare') self.enter_step('set_hostname_prepare')
prep: typing.Dict[str, DbValue] = { prep: typing.Dict[str, DbValue] = {
'val': self.prepare_hostname(hostname), 'val': self.pack_hostname(hostname),
} }
self._set_generic( self._set_generic(
'hostname', 'hostname',
@ -347,7 +458,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None: *args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_asn_prepare') self.enter_step('set_asn_prepare')
try: try:
asn_prep = self.prepare_asn(asn) asn_prep = self.pack_asn(asn)
except ValueError: except ValueError:
self.log.error("Invalid asn: %s", asn) self.log.error("Invalid asn: %s", asn)
return return
@ -365,10 +476,9 @@ class Database():
def set_ip4address(self, ip4address: str, def set_ip4address(self, ip4address: str,
*args: typing.Any, **kwargs: typing.Any) -> None: *args: typing.Any, **kwargs: typing.Any) -> None:
# TODO Do not add if already in ip4network
self.enter_step('set_ip4add_prepare') self.enter_step('set_ip4add_prepare')
try: try:
ip4address_prep = self.prepare_ip4address(ip4address) ip4address_prep = self.pack_ip4address(ip4address)
except (ValueError, IndexError): except (ValueError, IndexError):
self.log.error("Invalid ip4address: %s", ip4address) self.log.error("Invalid ip4address: %s", ip4address)
return return
@ -388,7 +498,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None: *args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_zone_prepare') self.enter_step('set_zone_prepare')
prep: typing.Dict[str, DbValue] = { prep: typing.Dict[str, DbValue] = {
'val': self.prepare_zone(zone), 'val': self.pack_zone(zone),
} }
self._set_generic( self._set_generic(
'zone', 'zone',
@ -403,7 +513,7 @@ class Database():
*args: typing.Any, **kwargs: typing.Any) -> None: *args: typing.Any, **kwargs: typing.Any) -> None:
self.enter_step('set_ip4net_prepare') self.enter_step('set_ip4net_prepare')
try: try:
ip4network_prep = self.prepare_ip4network(ip4network) ip4network_prep = self.pack_ip4network(ip4network)
except (ValueError, IndexError): except (ValueError, IndexError):
self.log.error("Invalid ip4network: %s", ip4network) self.log.error("Invalid ip4network: %s", ip4network)
return return
@ -431,23 +541,18 @@ if __name__ == '__main__':
help="Reconstruct the whole database") help="Reconstruct the whole database")
parser.add_argument( parser.add_argument(
'-p', '--prune', action='store_true', '-p', '--prune', action='store_true',
help="Remove old entries from database") help="Remove old (+6 months) entries from database")
parser.add_argument(
'-e', '--expire', action='store_true',
help="Set the whole database as an old source")
parser.add_argument( parser.add_argument(
'-r', '--references', action='store_true', '-r', '--references', action='store_true',
help="Update the reference count") help="Update the reference count")
args = parser.parse_args() args = parser.parse_args()
DB = Database() DB = Database(write=True)
if args.initialize: if args.initialize:
DB.initialize() DB.initialize()
if args.prune: if args.prune:
DB.prune() DB.prune(before=int(time.time()) - 60*60*24*31*6)
if args.expire:
DB.expire()
if args.references and not args.prune: if args.references and not args.prune:
DB.update_references() DB.update_references()

View file

@ -10,30 +10,37 @@ CREATE TABLE rules (
level INTEGER, -- Level of recursion to the root source rule (used for source priority) level INTEGER, -- Level of recursion to the root source rule (used for source priority)
FOREIGN KEY (source) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (source) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX rules_source ON rules (source); -- for references recounting
CREATE INDEX rules_updated ON rules (updated); -- for pruning
CREATE INDEX rules_level_firstparty ON rules (level, first_party); -- for counting rules
CREATE TABLE asn ( CREATE TABLE asn (
val INTEGER PRIMARY KEY, val INTEGER PRIMARY KEY,
entry INTEGER, entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX asn_entry ON asn (entry); -- for explainations
CREATE TABLE hostname ( CREATE TABLE hostname (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for consistency with zone) val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for consistency with zone)
entry INTEGER, entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX hostname_entry ON hostname (entry); -- for explainations
CREATE TABLE zone ( CREATE TABLE zone (
val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for easier matching) val TEXT PRIMARY KEY, -- rev'd, ends with a dot (for easier matching)
entry INTEGER, entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX zone_entry ON zone (entry); -- for explainations
CREATE TABLE ip4address ( CREATE TABLE ip4address (
val INTEGER PRIMARY KEY, val INTEGER PRIMARY KEY,
entry INTEGER, entry INTEGER,
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX ip4address_entry ON ip4address (entry); -- for explainations
CREATE TABLE ip4network ( CREATE TABLE ip4network (
-- val TEXT PRIMARY KEY, -- val TEXT PRIMARY KEY,
@ -43,6 +50,7 @@ CREATE TABLE ip4network (
FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE FOREIGN KEY (entry) REFERENCES rules(id) ON DELETE CASCADE
); );
CREATE INDEX ip4network_minmax ON ip4network (mini, maxi); CREATE INDEX ip4network_minmax ON ip4network (mini, maxi);
CREATE INDEX ip4network_entry ON ip4network (entry); -- for explainations
-- Store various things -- Store various things
CREATE TABLE meta ( CREATE TABLE meta (

View file

@ -19,12 +19,31 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
'-e', '--end-chain', action='store_true', '-e', '--end-chain', action='store_true',
help="TODO") help="TODO")
parser.add_argument(
'-x', '--explain', action='store_true',
help="TODO")
parser.add_argument(
'-r', '--rules', action='store_true',
help="TODO")
parser.add_argument(
'-c', '--count', action='store_true',
help="TODO")
args = parser.parse_args() args = parser.parse_args()
DB = database.Database() DB = database.Database()
for domain in DB.export(first_party_only=args.first_party, if args.rules:
end_chain_only=args.end_chain): if not args.count:
raise NotImplementedError
print(DB.count_rules(first_party_only=args.first_party))
else:
if args.count:
raise NotImplementedError
for domain in DB.export(
first_party_only=args.first_party,
end_chain_only=args.end_chain,
explain=args.explain,
):
print(domain, file=args.output) print(domain, file=args.output)
DB.close() DB.close()

53
feed_asn.py Executable file
View file

@ -0,0 +1,53 @@
#!/usr/bin/env python3
import database
import argparse
import requests
import typing
import ipaddress
import logging
import time
IPNetwork = typing.Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
def get_ranges(asn: str) -> typing.Iterable[str]:
req = requests.get(
'https://stat.ripe.net/data/as-routing-consistency/data.json',
params={'resource': asn}
)
data = req.json()
for pref in data['data']['prefixes']:
yield pref['prefix']
if __name__ == '__main__':
log = logging.getLogger('feed_asn')
# Parsing arguments
parser = argparse.ArgumentParser(
description="TODO")
args = parser.parse_args()
DB = database.Database()
DBW = database.Database(write=True)
for asn, entry in DB.list_asn():
DB.enter_step('asn_get_ranges')
for prefix in get_ranges(asn):
parsed_prefix: IPNetwork = ipaddress.ip_network(prefix)
if parsed_prefix.version == 4:
DBW.set_ip4network(
prefix,
source=entry,
updated=int(time.time())
)
log.info('Added %s from %s (id=%s)', prefix, asn, entry)
elif parsed_prefix.version == 6:
log.warning('Unimplemented prefix version: %s', prefix)
else:
log.error('Unknown prefix version: %s', prefix)
DB.close()
DBW.close()

View file

@ -17,7 +17,7 @@ if __name__ == '__main__':
help="TODO") help="TODO")
args = parser.parse_args() args = parser.parse_args()
DB = database.Database() DB = database.Database(write=True)
try: try:
DB.enter_step('iowait') DB.enter_step('iowait')
@ -28,6 +28,7 @@ if __name__ == '__main__':
# split = line.split(b'"') # split = line.split(b'"')
split = line.split('"') split = line.split('"')
try: try:
updated = int(split[3])
name = split[7] name = split[7]
dtype = split[11] dtype = split[11]
value = split[15] value = split[15]
@ -43,13 +44,16 @@ if __name__ == '__main__':
DB.enter_step('feed_switch') DB.enter_step('feed_switch')
if dtype == 'a': if dtype == 'a':
for rule in DB.get_ip4(value): for rule in DB.get_ip4(value):
DB.set_hostname(name, source=rule) if not list(DB.get_domain_in_zone(name)):
DB.set_hostname(name, source=rule, updated=updated)
elif dtype == 'cname': elif dtype == 'cname':
for rule in DB.get_domain(value): for rule in DB.get_domain(value):
DB.set_hostname(name, source=rule) if not list(DB.get_domain_in_zone(name)):
DB.set_hostname(name, source=rule, updated=updated)
elif dtype == 'ptr': elif dtype == 'ptr':
for rule in DB.get_domain(value): for rule in DB.get_domain(value):
DB.set_ip4address(name, source=rule) if not list(DB.get_ip4_in_network(name)):
DB.set_ip4address(name, source=rule, updated=updated)
DB.enter_step('iowait') DB.enter_step('iowait')
except KeyboardInterrupt: except KeyboardInterrupt:
log.warning("Interupted.") log.warning("Interupted.")

View file

@ -3,6 +3,7 @@
import database import database
import argparse import argparse
import sys import sys
import time
FUNCTION_MAP = { FUNCTION_MAP = {
'zone': database.Database.set_zone, 'zone': database.Database.set_zone,
@ -27,11 +28,15 @@ if __name__ == '__main__':
help="The input only comes from verified first-party sources") help="The input only comes from verified first-party sources")
args = parser.parse_args() args = parser.parse_args()
DB = database.Database() DB = database.Database(write=True)
fun = FUNCTION_MAP[args.type] fun = FUNCTION_MAP[args.type]
for rule in args.input: for rule in args.input:
fun(DB, rule.strip(), is_first_party=args.first_party) fun(DB,
rule.strip(),
is_first_party=args.first_party,
updated=int(time.time()),
)
DB.close() DB.close()

View file

@ -4,6 +4,12 @@ function log() {
echo -e "\033[33m$@\033[0m" echo -e "\033[33m$@\033[0m"
} }
log "Pruning old data…"
./database.py --prune
log "Recounting references…"
./database.py --references
log "Exporting lists…" log "Exporting lists…"
./export.py --first-party --output dist/firstparty-trackers.txt ./export.py --first-party --output dist/firstparty-trackers.txt
./export.py --first-party --end-chain --output dist/firstparty-only-trackers.txt ./export.py --first-party --end-chain --output dist/firstparty-only-trackers.txt
@ -11,6 +17,8 @@ log "Exporting lists…"
./export.py --end-chain --output dist/multiparty-only-trackers.txt ./export.py --end-chain --output dist/multiparty-only-trackers.txt
log "Generating hosts lists…" log "Generating hosts lists…"
./export.py --rules --count --first-party > temp/count_rules_firstparty.txt
./export.py --rules --count > temp/count_rules_multiparty.txt
function generate_hosts { function generate_hosts {
basename="$1" basename="$1"
description="$2" description="$2"
@ -36,15 +44,16 @@ function generate_hosts {
echo "#" echo "#"
echo "# Generation date: $(date -Isec)" echo "# Generation date: $(date -Isec)"
echo "# Generation software: eulaurarien $(git describe --tags)" echo "# Generation software: eulaurarien $(git describe --tags)"
echo "# Number of source websites: TODO" echo "# Number of source websites: $(wc -l temp/all_websites.list | cut -d' ' -f1)"
echo "# Number of source subdomains: TODO" echo "# Number of source subdomains: $(wc -l temp/all_subdomains.list | cut -d' ' -f1)"
echo "# Number of source DNS records: ~2M + $(wc -l temp/all_resolved.json | cut -d' ' -f1)"
echo "#" echo "#"
echo "# Number of known first-party trackers: TODO" echo "# Known first-party trackers: $(cat temp/count_rules_firstparty.txt)"
echo "# Number of first-party subdomains: $(wc -l dist/firstparty-trackers.txt | cut -d' ' -f1)" echo "# Number of first-party hostnames: $(wc -l dist/firstparty-trackers.txt | cut -d' ' -f1)"
echo "# … excluding redirected: $(wc -l dist/firstparty-only-trackers.txt | cut -d' ' -f1)" echo "# … excluding redirected: $(wc -l dist/firstparty-only-trackers.txt | cut -d' ' -f1)"
echo "#" echo "#"
echo "# Number of known multi-party trackers: TODO" echo "# Known multi-party trackers: $(cat temp/count_rules_multiparty.txt)"
echo "# Number of multi-party subdomains: $(wc -l dist/multiparty-trackers.txt | cut -d' ' -f1)" echo "# Number of multi-party hostnames: $(wc -l dist/multiparty-trackers.txt | cut -d' ' -f1)"
echo "# … excluding redirected: $(wc -l dist/multiparty-only-trackers.txt | cut -d' ' -f1)" echo "# … excluding redirected: $(wc -l dist/multiparty-only-trackers.txt | cut -d' ' -f1)"
echo echo
sed 's|^|0.0.0.0 |' "dist/$basename.txt" sed 's|^|0.0.0.0 |' "dist/$basename.txt"

View file

@ -4,9 +4,7 @@ function log() {
echo -e "\033[33m$@\033[0m" echo -e "\033[33m$@\033[0m"
} }
log "Preparing database…" ./fetch_resources.sh
./database.py --expire
./import_rules.sh ./import_rules.sh
# TODO Fetch 'em # TODO Fetch 'em

View file

@ -1,21 +0,0 @@
#!/usr/bin/env python3
"""
List of regex matching first-party trackers.
"""
# Syntax: https://docs.python.org/3/library/re.html#regular-expression-syntax
REGEXES = [
r'^.+\.eulerian\.net\.$', # Eulerian
r'^.+\.criteo\.com\.$', # Criteo
r'^.+\.dnsdelegation\.io\.$', # Criteo
r'^.+\.keyade\.com\.$', # Keyade
r'^.+\.omtrdc\.net\.$', # Adobe Experience Cloud
r'^.+\.bp01\.net\.$', # NP6
r'^.+\.ati-host\.net\.$', # Xiti (AT Internet)
r'^.+\.at-o\.net\.$', # Xiti (AT Internet)
r'^.+\.edgkey\.net\.$', # Edgekey (Akamai)
r'^.+\.akaimaiedge\.net\.$', # Edgekey (Akamai)
r'^.+\.storetail\.io\.$', # Storetail (Criteo)
]

View file

@ -12,22 +12,15 @@ import queue
import sys import sys
import threading import threading
import typing import typing
import csv import time
import coloredlogs import coloredlogs
import dns.exception import dns.exception
import dns.resolver import dns.resolver
import progressbar
DNS_TIMEOUT = 5.0 DNS_TIMEOUT = 5.0
NUMBER_THREADS = 512
NUMBER_TRIES = 5 NUMBER_TRIES = 5
# TODO All the domains don't get treated,
# so it leaves with 4-5 subdomains not resolved
glob = None
class Worker(threading.Thread): class Worker(threading.Thread):
""" """
@ -60,7 +53,7 @@ class Worker(threading.Thread):
def resolve_subdomain(self, subdomain: str) -> typing.Optional[ def resolve_subdomain(self, subdomain: str) -> typing.Optional[
typing.List[ typing.List[
str dns.rrset.RRset
] ]
]: ]:
""" """
@ -93,18 +86,7 @@ class Worker(threading.Thread):
except dns.name.EmptyLabel: except dns.name.EmptyLabel:
self.log.warning("Empty label for %s", subdomain) self.log.warning("Empty label for %s", subdomain)
return None return None
resolved = list() return query.response.answer
last = len(query.response.answer) - 1
for a, answer in enumerate(query.response.answer):
if answer.rdtype == dns.rdatatype.CNAME:
assert a < last
resolved.append(answer.items[0].to_text()[:-1])
elif answer.rdtype == dns.rdatatype.A:
assert a == last
resolved.append(answer.items[0].address)
else:
assert False
return resolved
def run(self) -> None: def run(self) -> None:
self.log.info("Started") self.log.info("Started")
@ -124,7 +106,6 @@ class Worker(threading.Thread):
self.log.error("Gave up on %s", subdomain) self.log.error("Gave up on %s", subdomain)
resolved = [] resolved = []
resolved.insert(0, subdomain)
assert isinstance(resolved, list) assert isinstance(resolved, list)
self.orchestrator.results_queue.put(resolved) self.orchestrator.results_queue.put(resolved)
@ -150,15 +131,17 @@ class Orchestrator():
def __init__(self, subdomains: typing.Iterable[str], def __init__(self, subdomains: typing.Iterable[str],
nameservers: typing.List[str] = None, nameservers: typing.List[str] = None,
nb_workers: int = 1,
): ):
self.log = logging.getLogger('orchestrator') self.log = logging.getLogger('orchestrator')
self.subdomains = subdomains self.subdomains = subdomains
self.nb_workers = nb_workers
# Use interal resolver by default # Use interal resolver by default
self.nameservers = nameservers or dns.resolver.Resolver().nameservers self.nameservers = nameservers or dns.resolver.Resolver().nameservers
self.subdomains_queue: queue.Queue = queue.Queue( self.subdomains_queue: queue.Queue = queue.Queue(
maxsize=NUMBER_THREADS) maxsize=self.nb_workers)
self.results_queue: queue.Queue = queue.Queue() self.results_queue: queue.Queue = queue.Queue()
self.nameservers_queue: queue.Queue = queue.Queue() self.nameservers_queue: queue.Queue = queue.Queue()
@ -179,16 +162,32 @@ class Orchestrator():
self.log.info("Finished reading subdomains") self.log.info("Finished reading subdomains")
# Send sentinel to each worker # Send sentinel to each worker
# sentinel = None ~= EOF # sentinel = None ~= EOF
for _ in range(NUMBER_THREADS): for _ in range(self.nb_workers):
self.subdomains_queue.put(None) self.subdomains_queue.put(None)
def run(self) -> typing.Iterable[typing.List[str]]: @staticmethod
def format_rrset(rrset: dns.rrset.RRset) -> typing.Iterable[str]:
if rrset.rdtype == dns.rdatatype.CNAME:
dtype = 'cname'
elif rrset.rdtype == dns.rdatatype.A:
dtype = 'a'
else:
raise NotImplementedError
name = rrset.name.to_text()[:-1]
for item in rrset.items:
value = item.to_text()
if rrset.rdtype == dns.rdatatype.CNAME:
value = value[:-1]
yield '{"timestamp":"' + str(int(time.time())) + '","name":"' + \
name + '","type":"' + dtype + '","value":"' + value + '"}\n'
def run(self) -> typing.Iterable[str]:
""" """
Yield the results. Yield the results.
""" """
# Create workers # Create workers
self.log.info("Creating workers") self.log.info("Creating workers")
for i in range(NUMBER_THREADS): for i in range(self.nb_workers):
Worker(self, i).start() Worker(self, i).start()
fill_thread = threading.Thread(target=self.fill_subdomain_queue) fill_thread = threading.Thread(target=self.fill_subdomain_queue)
@ -196,10 +195,11 @@ class Orchestrator():
# Wait for one sentinel per worker # Wait for one sentinel per worker
# In the meantime output results # In the meantime output results
for _ in range(NUMBER_THREADS): for _ in range(self.nb_workers):
result: typing.List[str] resolved: typing.List[dns.rrset.RRset]
for result in iter(self.results_queue.get, None): for resolved in iter(self.results_queue.get, None):
yield result for rrset in resolved:
yield from self.format_rrset(rrset)
self.log.info("Waiting for reader thread") self.log.info("Waiting for reader thread")
fill_thread.join() fill_thread.join()
@ -214,11 +214,9 @@ def main() -> None:
the last CNAME resolved and the IP adress it resolves to. the last CNAME resolved and the IP adress it resolves to.
Takes as an input a filename (or nothing, for stdin), Takes as an input a filename (or nothing, for stdin),
and as an output a filename (or nothing, for stdout). and as an output a filename (or nothing, for stdout).
The input must be a subdomain per line, the output is a comma-sep The input must be a subdomain per line, the output is a TODO
file with the columns source CNAME and A.
Use the file `nameservers` as the list of nameservers Use the file `nameservers` as the list of nameservers
to use, or else it will use the system defaults. to use, or else it will use the system defaults.
Also shows a nice progressbar.
""" """
# Initialization # Initialization
@ -236,28 +234,14 @@ def main() -> None:
parser.add_argument( parser.add_argument(
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout, '-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
help="Outptut file with DNS chains") help="Outptut file with DNS chains")
# parser.add_argument( parser.add_argument(
# '-n', '--nameserver', type=argparse.FileType('r'), '-n', '--nameservers', default='nameservers',
# default='nameservers', help="File with one nameserver per line") help="File with one nameserver per line")
# parser.add_argument( parser.add_argument(
# '-j', '--workers', type=int, default=512, '-j', '--workers', type=int, default=512,
# help="Number of threads to use") help="Number of threads to use")
args = parser.parse_args() args = parser.parse_args()
# Progress bar
widgets = [
progressbar.Percentage(),
' ', progressbar.SimpleProgress(),
' ', progressbar.Bar(),
' ', progressbar.Timer(),
' ', progressbar.AdaptiveTransferSpeed(unit='req'),
' ', progressbar.AdaptiveETA(),
]
progress = progressbar.ProgressBar(widgets=widgets)
if args.input.seekable():
progress.max_value = len(args.input.readlines())
args.input.seek(0)
# Cleaning input # Cleaning input
iterator = iter(args.input) iterator = iter(args.input)
iterator = map(str.strip, iterator) iterator = map(str.strip, iterator)
@ -265,19 +249,16 @@ def main() -> None:
# Reading nameservers # Reading nameservers
servers: typing.List[str] = list() servers: typing.List[str] = list()
if os.path.isfile('nameservers'): if os.path.isfile(args.nameservers):
servers = open('nameservers').readlines() servers = open(args.nameservers).readlines()
servers = list(filter(None, map(str.strip, servers))) servers = list(filter(None, map(str.strip, servers)))
writer = csv.writer(args.output) for resolved in Orchestrator(
iterator,
progress.start() servers,
global glob nb_workers=args.workers
glob = Orchestrator(iterator, servers) ).run():
for resolved in glob.run(): args.output.write(resolved)
progress.update(progress.value + 1)
writer.writerow(resolved)
progress.finish()
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -4,11 +4,9 @@ function log() {
echo -e "\033[33m$@\033[0m" echo -e "\033[33m$@\033[0m"
} }
# Resolve the CNAME chain of all the known subdomains for later analysis log "Compiling locally known subdomain…"
log "Compiling subdomain lists..."
pv subdomains/*.list | sort -u > temp/all_subdomains.list
# Sort by last character to utilize the DNS server caching mechanism # Sort by last character to utilize the DNS server caching mechanism
pv temp/all_subdomains.list | rev | sort | rev > temp/all_subdomains_reversort.list pv subdomains/*.list | rev | sort -u | rev > temp/all_subdomains.list
./resolve_subdomains.py --input temp/all_subdomains_reversort.list --output temp/all_resolved.csv log "Resolving locally known subdomain…"
sort -u temp/all_resolved.csv > temp/all_resolved_sorted.csv pv temp/all_subdomains.list | ./resolve_subdomains.py --output temp/all_resolved.json