@ -1,256 +1,385 @@
#!/usr/bin/env python3
"""
Utility functions to interact with the database .
"""
import sqlite3
import typing
import time
import os
import logging
import argparse
import typing
import coloredlogs
import ipaddress
import enum
import time
import ctypes
"""
Utility functions to interact with the database .
"""
coloredlogs . install (
level = ' DEBUG ' ,
fmt = ' %(asctime)s %(name)s %(levelname)s %(message)s '
)
DbValue = typing . Union [ None , int , float , str , bytes ]
class Database ( ) :
VERSION = 3
PATH = " blocking.db "
def open ( self ) - > None :
self . conn = sqlite3 . connect ( self . PATH )
self . cursor = self . conn . cursor ( )
self . execute ( " PRAGMA foreign_keys = ON " )
# self.conn.create_function("prepare_ip4address", 1,
# Database.prepare_ip4address,
# deterministic=True)
def execute ( self , cmd : str , args : typing . Union [
typing . Tuple [ DbValue , . . . ] ,
typing . Dict [ str , DbValue ] ] = None ) - > None :
self . cursor . execute ( cmd , args or tuple ( ) )
def get_meta ( self , key : str ) - > typing . Optional [ int ] :
try :
self . execute ( " SELECT value FROM meta WHERE key=? " , ( key , ) )
except sqlite3 . OperationalError :
return None
for ver , in self . cursor :
return ver
return None
# TODO Rule level and source priority
VERSION = 2
PATH = f " blocking.db "
CONN = None
C = None # Cursor
TIME_DICT : typing . Dict [ str , float ] = dict ( )
TIME_LAST = time . perf_counter ( )
TIME_STEP = ' start '
ACCEL = ctypes . cdll . LoadLibrary ( ' ./libaccel.so ' )
ACCEL_IP4_BUF = ctypes . create_unicode_buffer ( ' Z ' * 32 , 32 )
def time_step ( step : str ) - > None :
global TIME_LAST
global TIME_STEP
now = time . perf_counter ( )
TIME_DICT . setdefault ( TIME_STEP , 0.0 )
TIME_DICT [ TIME_STEP ] + = now - TIME_LAST
TIME_STEP = step
TIME_LAST = time . perf_counter ( )
def time_print ( ) - > None :
time_step ( ' postprint ' )
total = sum ( TIME_DICT . values ( ) )
for key , secs in sorted ( TIME_DICT . items ( ) , key = lambda t : t [ 1 ] ) :
print ( f " {key:<20}: {secs/total:7.2 % } = {secs:.6f} s " )
print ( f " { ' total ' :<20}: {1:7.2 % } = {total:.6f} s " )
class RowType ( enum . Enum ) :
AS = 1
DomainTree = 2
Domain = 3
IPv4Network = 4
IPv6Network = 6
def open_db ( ) - > None :
time_step ( ' open_db ' )
global CONN
global C
CONN = sqlite3 . connect ( PATH )
C = CONN . cursor ( )
# C.execute("PRAGMA foreign_keys = ON");
initialized = False
try :
C . execute ( " SELECT value FROM meta WHERE key= ' version ' " )
version_ex = C . fetchone ( )
if version_ex :
if version_ex [ 0 ] == VERSION :
initialized = True
else :
print ( f " Database version {version_ex[0]} found, "
" it will be deleted. " )
except sqlite3 . OperationalError :
pass
if not initialized :
time_step ( ' init_db ' )
print ( f " Creating database version {VERSION}. " )
CONN . close ( )
os . unlink ( PATH )
CONN = sqlite3 . connect ( PATH )
C = CONN . cursor ( )
def set_meta ( self , key : str , val : int ) - > None :
self . execute ( " INSERT INTO meta VALUES (?, ?) "
" ON CONFLICT (key) DO "
" UPDATE set value=? " ,
( key , val , val ) )
def close ( self ) - > None :
self . enter_step ( ' close_commit ' )
self . conn . commit ( )
self . enter_step ( ' close ' )
self . conn . close ( )
self . profile ( )
def initialize ( self ) - > None :
self . enter_step ( ' initialize ' )
self . close ( )
os . unlink ( self . PATH )
self . open ( )
self . log . info ( " Creating database version %d . " , self . VERSION )
with open ( " database_schema.sql " , ' r ' ) as db_schema :
C . executescript ( db_schema . read ( ) )
C . execute ( " INSERT INTO meta VALUES ( ' version ' , ?) " , ( VERSION , ) )
CONN . commit ( )
time_step ( ' other ' )
def close_db ( ) - > None :
assert CONN
time_step ( ' close_db_commit ' )
CONN . commit ( )
time_step ( ' close_db ' )
CONN . close ( )
time_step ( ' other ' )
time_print ( )
def refresh ( ) - > None :
assert C
C . execute ( ' UPDATE blocking SET updated = 0 ' )
# TODO PERF Use a meta value instead
RULE_SUBDOMAIN_COMMAND = \
' INSERT INTO blocking (key, type, updated, firstpart, level) ' \
f ' VALUES (?, {RowType.DomainTree.value}, 1, ?, 0) ' \
' ON CONFLICT(key) ' \
f ' DO UPDATE SET source=null, type={RowType.DomainTree.value}, ' \
' updated=1, firstparty=?, level=0 '
def feed_rule_subdomains ( subdomain : str , first_party : bool = False ) - > None :
assert C
subdomain = subdomain [ : : - 1 ]
C . execute ( RULE_SUBDOMAIN_COMMAND ,
( subdomain , int ( first_party ) , int ( first_party ) ) )
# Since regex type takes precedence over domain type,
# and firstparty takes precedence over multiparty,
# we can afford to replace the whole row without checking
# the row without checking previous values and making sure
# firstparty subdomains are updated last
def ip_get_bits ( address : ipaddress . IPv4Address ) - > typing . Iterator [ int ] :
for char in address . packed :
for i in range ( 7 , - 1 , - 1 ) :
yield ( char >> i ) & 0b1
self . cursor . executescript ( db_schema . read ( ) )
self . set_meta ( ' version ' , self . VERSION )
self . conn . commit ( )
def __init__ ( self ) - > None :
self . log = logging . getLogger ( ' db ' )
self . time_last = time . perf_counter ( )
self . time_step = ' init '
self . time_dict : typing . Dict [ str , float ] = dict ( )
self . step_dict : typing . Dict [ str , int ] = dict ( )
self . accel_ip4_buf = ctypes . create_unicode_buffer ( ' Z ' * 32 , 32 )
self . open ( )
version = self . get_meta ( ' version ' )
if version != self . VERSION :
if version is not None :
self . log . warning (
" Outdated database version: %d found, will be rebuilt. " ,
version )
self . initialize ( )
updated = self . get_meta ( ' updated ' )
if updated is None :
self . execute ( ' SELECT max(updated) FROM rules ' )
data = self . cursor . fetchone ( )
updated , = data
self . updated = updated or 1
def enter_step ( self , name : str ) - > None :
now = time . perf_counter ( )
try :
self . time_dict [ self . time_step ] + = now - self . time_last
self . step_dict [ self . time_step ] + = 1
except KeyError :
self . time_dict [ self . time_step ] = now - self . time_last
self . step_dict [ self . time_step ] = 1
self . time_step = name
self . time_last = time . perf_counter ( )
def profile ( self ) - > None :
self . enter_step ( ' profile ' )
total = sum ( self . time_dict . values ( ) )
for key , secs in sorted ( self . time_dict . items ( ) , key = lambda t : t [ 1 ] ) :
times = self . step_dict [ key ]
self . log . debug ( f " {key:<20}: {times:9d} × {secs/times:5.3e} "
f " = {secs:9.2f} s ({secs/total:7.2 % }) " )
self . log . debug ( f " { ' total ' :<20}: "
f " {total:9.2f} s ({1:7.2 % }) " )
def prepare_hostname ( self , hostname : str ) - > str :
return hostname [ : : - 1 ] + ' . '
def prepare_zone ( self , zone : str ) - > str :
return self . prepare_hostname ( zone )
@staticmethod
def prepare_ip4address ( address : str ) - > int :
total = 0
for i , octet in enumerate ( address . split ( ' . ' ) ) :
total + = int ( octet ) << ( 3 - i ) * 8
return total
# return '{:02x}{:02x}{:02x}{:02x}'.format(
# *[int(c) for c in address.split('.')])
# return base64.b16encode(packed).decode()
# return '{:08b}{:08b}{:08b}{:08b}'.format(
# *[int(c) for c in address.split('.')])
# carg = ctypes.c_wchar_p(address)
# ret = ACCEL.ip4_flat(carg, self.accel_ip4_buf)
# if ret != 0:
# raise ValueError
# return self.accel_ip4_buf.value
# packed = ipaddress.ip_address(address).packed
# return packed
def prepare_ip4network ( self , network : str ) - > typing . Tuple [ int , int ] :
# def prepare_ip4network(network: str) -> str:
net = ipaddress . ip_network ( network )
mini = self . prepare_ip4address ( net . network_address . exploded )
maxi = self . prepare_ip4address ( net . broadcast_address . exploded )
# mini = net.network_address.packed
# maxi = net.broadcast_address.packed
return mini , maxi
# return Database.prepare_ip4address(net.network_address.exploded)[:net.prefixlen]
def expire ( self ) - > None :
self . enter_step ( ' expire ' )
self . updated + = 1
self . set_meta ( ' updated ' , self . updated )
def update_references ( self ) - > None :
self . enter_step ( ' update_refs ' )
self . execute ( ' UPDATE rules AS r SET refs= '
' (SELECT count(*) FROM rules '
' WHERE source=r.id) ' )
def prune ( self ) - > None :
self . enter_step ( ' prune ' )
self . execute ( ' DELETE FROM rules WHERE updated<? ' , ( self . updated , ) )
def export ( self , first_party_only : bool = False ,
end_chain_only : bool = False ) - > typing . Iterable [ str ] :
command = ' SELECT val FROM rules ' \
' INNER JOIN hostname ON rules.id = hostname.entry '
restrictions : typing . List [ str ] = list ( )
if first_party_only :
restrictions . append ( ' rules.first_party = 1 ' )
if end_chain_only :
restrictions . append ( ' rules.refs = 0 ' )
if restrictions :
command + = ' WHERE ' + ' AND ' . join ( restrictions )
self . execute ( command )
for val , in self . cursor :
yield val [ : - 1 ] [ : : - 1 ]
def get_domain ( self , domain : str ) - > typing . Iterable [ int ] :
self . enter_step ( ' get_domain_prepare ' )
domain_prep = self . prepare_hostname ( domain )
self . enter_step ( ' get_domain_select ' )
self . execute (
' SELECT null, entry FROM hostname '
' WHERE val=:d '
' UNION '
' SELECT * FROM ( '
' SELECT val, entry FROM zone '
' WHERE val<=:d '
' ORDER BY val DESC LIMIT 1 '
' ) ' ,
{ ' d ' : domain_prep }
)
for val , entry in self . cursor :
self . enter_step ( ' get_domain_confirm ' )
if not ( val is None or domain_prep . startswith ( val ) ) :
continue
self . enter_step ( ' get_domain_yield ' )
yield entry
def get_ip4 ( self , address : str ) - > typing . Iterable [ int ] :
self . enter_step ( ' get_ip4_prepare ' )
try :
address_prep = self . prepare_ip4address ( address )
except ( ValueError , IndexError ) :
self . log . error ( " Invalid ip4address: %s " , address )
return
self . enter_step ( ' get_ip4_select ' )
self . execute (
' SELECT entry FROM ip4address '
# 'SELECT null, entry FROM ip4address '
' WHERE val=:a '
' UNION '
# 'SELECT * FROM ('
# 'SELECT val, entry FROM ip4network '
# 'WHERE val<=:a '
# 'AND instr(:a, val) > 0 '
# 'ORDER BY val DESC'
# ')'
' SELECT entry FROM ip4network '
' WHERE :a BETWEEN mini AND maxi ' ,
{ ' a ' : address_prep }
)
for val , entry in self . cursor :
# self.enter_step('get_ip4_confirm')
# if not (val is None or val.startswith(address_prep)):
# # PERF startswith but from the end
# continue
self . enter_step ( ' get_ip4_yield ' )
yield entry
def _set_generic ( self ,
table : str ,
select_query : str ,
insert_query : str ,
prep : typing . Dict [ str , DbValue ] ,
is_first_party : bool = False ,
source : int = None ,
) - > None :
# Since this isn't the bulk of the processing,
# here abstraction > performaces
# Fields based on the source
if source is None :
first_party = int ( is_first_party )
level = 0
else :
self . enter_step ( f ' set_{table}_source ' )
self . execute (
' SELECT first_party, level FROM rules '
' WHERE id=? ' ,
( source , )
)
first_party , level = self . cursor . fetchone ( )
level + = 1
self . enter_step ( f ' set_{table}_select ' )
self . execute ( select_query , prep )
rules_prep = {
" source " : source ,
" updated " : self . updated ,
" first_party " : first_party ,
" level " : level ,
}
# If the entry already exists
for entry , in self . cursor : # only one
self . enter_step ( f ' set_{table}_update ' )
rules_prep [ ' entry ' ] = entry
self . execute (
' UPDATE rules SET '
' source=:source, updated=:updated, '
' first_party=:first_party, level=:level '
' WHERE id=:entry AND (updated<:updated OR '
' first_party<:first_party OR level<:level) ' ,
rules_prep
)
# Only update if any of the following:
# - the entry is outdataed
# - the entry was not a first_party but this is
# - this is closer to the original rule
return
# If it does not exist
if source is not None :
self . enter_step ( f ' set_{table}_incsrc ' )
self . execute ( ' UPDATE rules SET refs = refs + 1 WHERE id=? ' ,
( source , ) )
self . enter_step ( f ' set_{table}_insert ' )
self . execute (
' INSERT INTO rules '
' (source, updated, first_party, refs, level) '
' VALUES (:source, :updated, :first_party, 0, :level) ' ,
rules_prep
)
self . execute ( ' SELECT id FROM rules WHERE rowid=? ' ,
( self . cursor . lastrowid , ) )
for entry , in self . cursor : # only one
prep [ ' entry ' ] = entry
self . execute ( insert_query , prep )
return
assert False
def set_hostname ( self , hostname : str ,
* args : typing . Any , * * kwargs : typing . Any ) - > None :
self . enter_step ( ' set_hostname_prepare ' )
prep : typing . Dict [ str , DbValue ] = {
' val ' : self . prepare_hostname ( hostname ) ,
}
self . _set_generic (
' hostname ' ,
' SELECT entry FROM hostname WHERE val=:val ' ,
' INSERT INTO hostname (val, entry) '
' VALUES (:val, :entry) ' ,
prep ,
* args , * * kwargs
)
def ip_flat ( address : ipaddress . IPv4Address ) - > str :
return ' ' . join ( map ( str , ip_get_bits ( address ) ) )
def set_ip4address ( self , ip4address : str ,
* args : typing . Any , * * kwargs : typing . Any ) - > None :
self . enter_step ( ' set_ip4add_prepare ' )
try :
ip4address_prep = self . prepare_ip4address ( ip4address )
except ( ValueError , IndexError ) :
self . log . error ( " Invalid ip4address: %s " , ip4address )
return
prep : typing . Dict [ str , DbValue ] = {
' val ' : ip4address_prep ,
}
self . _set_generic (
' ip4add ' ,
' SELECT entry FROM ip4address WHERE val=:val ' ,
' INSERT INTO ip4address (val, entry) '
' VALUES (:val, :entry) ' ,
prep ,
* args , * * kwargs
)
def set_zone ( self , zone : str ,
* args : typing . Any , * * kwargs : typing . Any ) - > None :
self . enter_step ( ' set_zone_prepare ' )
prep : typing . Dict [ str , DbValue ] = {
' val ' : self . prepare_zone ( zone ) ,
}
self . _set_generic (
' zone ' ,
' SELECT entry FROM zone WHERE val=:val ' ,
' INSERT INTO zone (val, entry) '
' VALUES (:val, :entry) ' ,
prep ,
* args , * * kwargs
)
def ip4_flat ( address : bytes ) - > typing . Optional [ str ] :
carg = ctypes . c_char_p ( address )
ret = ACCEL . ip4_flat ( carg , ACCEL_IP4_BUF )
if ret != 0 :
return None
return ACCEL_IP4_BUF . value
RULE_IP4NETWORK_COMMAND = \
' INSERT INTO blocking (key, type, updated, firstparty, level) ' \
f ' VALUES (?, {RowType.IPv4Network.value}, 1, ?, 0) ' \
' ON CONFLICT(key) ' \
f ' DO UPDATE SET source=null, type={RowType.IPv4Network.value}, ' \
' updated=1, firstparty=?, level=0 '
def feed_rule_ip4network ( network : ipaddress . IPv4Network ,
first_party : bool = False ) - > None :
assert C
flat = ip_flat ( network . network_address ) [ : network . prefixlen ]
C . execute ( RULE_IP4NETWORK_COMMAND ,
( flat , int ( first_party ) , int ( first_party ) ) )
FEED_A_COMMAND_FETCH = \
' SELECT key, firstparty FROM blocking ' \
' WHERE key<=? ' \
' AND instr(?, key) > 0 ' \
f ' AND type={RowType.IPv4Network.value} ' \
' ORDER BY key DESC '
# UPSERT are not issued often relative to FETCH,
# merging the both might be counterproductive
FEED_A_COMMAND_UPSERT = \
' INSERT INTO blocking (key, source, type, updated, firstparty) ' \
f ' VALUES (?, ?, {RowType.Domain.value}, 1, ?) ' \
' ON CONFLICT(key) ' \
f ' DO UPDATE SET source=?, type={RowType.Domain.value}, ' \
' updated=1, firstparty=? ' \
' WHERE updated=0 OR firstparty<? '
def feed_a ( name : bytes , value_ip : bytes ) - > None :
assert C
assert CONN
time_step ( ' a_flat ' )
value_dec = ip4_flat ( value_ip )
if value_dec is None :
# Malformed IPs
time_step ( ' a_malformed ' )
return
time_step ( ' a_fetch ' )
C . execute ( FEED_A_COMMAND_FETCH , ( value_dec , value_dec ) )
base = C . fetchone ( )
time_step ( ' a_fetch_confirm ' )
name = name [ : : - 1 ]
for b_key , b_firstparty in C :
time_step ( ' a_upsert ' )
C . execute ( FEED_A_COMMAND_UPSERT ,
( name , b_key , b_firstparty , # Insert
b_key , b_firstparty , b_firstparty ) # Update
)
time_step ( ' a_fetch_confirm ' )
time_step ( ' a_end ' )
FEED_CNAME_COMMAND_FETCH = \
' SELECT key, type, firstparty FROM blocking ' \
' WHERE key<=? ' \
f ' AND (type={RowType.DomainTree.value} OR type={RowType.Domain.value}) ' \
' ORDER BY key DESC ' \
' LIMIT 1 '
# Optimisations that renders the index unused
# (and thus counterproductive until fixed):
# 'AND instr(?, key) > 0 ' \
# f'WHERE ((type={RowType.DomainTree.value} AND key<=?) OR ' \
# f'(type={RowType.Domain.value} AND key=?)) ' \
# Might be fixable by using multiple SELECT and a JOIN
# In the meantime the confirm is very light so it's ok
FEED_CNAME_COMMAND_UPSERT = \
' INSERT INTO blocking (key, source, type, updated, firstparty) ' \
f ' VALUES (?, ?, {RowType.Domain.value}, 1, ?) ' \
' ON CONFLICT(key) ' \
f ' DO UPDATE SET source=?, type={RowType.Domain.value}, ' \
' updated=1, firstparty=? ' \
' WHERE updated=0 OR firstparty<? '
def feed_cname ( name : bytes , value : bytes ) - > None :
assert C
assert CONN
time_step ( ' cname_decode ' )
value = value [ : : - 1 ]
value_dec = value . decode ( )
time_step ( ' cname_fetch ' )
C . execute ( FEED_CNAME_COMMAND_FETCH , ( value_dec , ) )
time_step ( ' cname_fetch_confirm ' )
for b_key , b_type , b_firstparty in C :
matching = b_key == value_dec [ : len ( b_key ) ] and (
len ( value_dec ) == len ( b_key )
or (
b_type == RowType . DomainTree . value
and value_dec [ len ( b_key ) ] == ' . '
)
def set_ip4network ( self , ip4network : str ,
* args : typing . Any , * * kwargs : typing . Any ) - > None :
self . enter_step ( ' set_ip4net_prepare ' )
try :
ip4network_prep = self . prepare_ip4network ( ip4network )
except ( ValueError , IndexError ) :
self . log . error ( " Invalid ip4network: %s " , ip4network )
return
prep : typing . Dict [ str , DbValue ] = {
' mini ' : ip4network_prep [ 0 ] ,
' maxi ' : ip4network_prep [ 1 ] ,
}
self . _set_generic (
' ip4net ' ,
' SELECT entry FROM ip4network WHERE mini=:mini AND maxi=:maxi ' ,
' INSERT INTO ip4network (mini, maxi, entry) '
' VALUES (:mini, :maxi, :entry) ' ,
prep ,
* args , * * kwargs
)
if not matching :
continue
name = name [ : : - 1 ]
time_step ( ' cname_upsert ' )
C . execute ( FEED_CNAME_COMMAND_UPSERT ,
( name , b_key , b_firstparty , # Insert
b_key , b_firstparty , b_firstparty ) # Update
)
time_step ( ' cname_fetch_confirm ' )
time_step ( ' cname_end ' )
if __name__ == ' __main__ ' :
@ -259,13 +388,28 @@ if __name__ == '__main__':
parser = argparse . ArgumentParser (
description = " Database operations " )
parser . add_argument (
' -r ' , ' --refresh ' , action = ' store_true ' ,
' -i ' , ' --initialize ' , action = ' store_true ' ,
help = " Reconstruct the whole database " )
parser . add_argument (
' -p ' , ' --prune ' , action = ' store_true ' ,
help = " Remove old entries from database " )
parser . add_argument (
' -e ' , ' --expire ' , action = ' store_true ' ,
help = " Set the whole database as an old source " )
parser . add_argument (
' -r ' , ' --references ' , action = ' store_true ' ,
help = " Update the reference count " )
args = parser . parse_args ( )
open_db ( )
DB = Database ( )
if args . refresh :
refresh ( )
if args . initialize :
DB . initialize ( )
if args . prune :
DB . prune ( )
if args . expire :
DB . expire ( )
if args . references and not args . prune :
DB . update_references ( )
close_db ( )
DB . close ( )