Compare commits

..

No commits in common. "f165e5a0946538f14984a1e522cdb8e09f8b80c7" and "edf444cc28c45fbef765fa54c1d786c5fe4780f7" have entirely different histories.

10 changed files with 384 additions and 426 deletions

View file

@ -16,36 +16,25 @@ import abp.filters
def get_domains(rule: abp.filters.parser.Filter) -> typing.Iterable[str]: def get_domains(rule: abp.filters.parser.Filter) -> typing.Iterable[str]:
if rule.options: if rule.options:
return return
selector_type = rule.selector["type"] selector_type = rule.selector['type']
selector_value = rule.selector["value"] selector_value = rule.selector['value']
if ( if selector_type == 'url-pattern' \
selector_type == "url-pattern" and selector_value.startswith('||') \
and selector_value.startswith("||") and selector_value.endswith('^'):
and selector_value.endswith("^")
):
yield selector_value[2:-1] yield selector_value[2:-1]
if __name__ == "__main__": if __name__ == '__main__':
# Parsing arguments # Parsing arguments
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Extract whole domains from an AdBlock blocking list" description="Extract whole domains from an AdBlock blocking list")
)
parser.add_argument( parser.add_argument(
"-i", '-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
"--input", help="Input file with AdBlock rules")
type=argparse.FileType("r"),
default=sys.stdin,
help="Input file with AdBlock rules",
)
parser.add_argument( parser.add_argument(
"-o", '-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
"--output", help="Outptut file with one rule tracking subdomain per line")
type=argparse.FileType("w"),
default=sys.stdout,
help="Outptut file with one rule tracking subdomain per line",
)
args = parser.parse_args() args = parser.parse_args()
# Reading rules # Reading rules

View file

@ -16,25 +16,26 @@ import selenium.webdriver.firefox.options
import seleniumwire.webdriver import seleniumwire.webdriver
import logging import logging
log = logging.getLogger("cs") log = logging.getLogger('cs')
DRIVER = None DRIVER = None
SCROLL_TIME = 10.0 SCROLL_TIME = 10.0
SCROLL_STEPS = 100 SCROLL_STEPS = 100
SCROLL_CMD = f"window.scrollBy(0,document.body.scrollHeight/{SCROLL_STEPS})" SCROLL_CMD = f'window.scrollBy(0,document.body.scrollHeight/{SCROLL_STEPS})'
def new_driver() -> seleniumwire.webdriver.browser.Firefox: def new_driver() -> seleniumwire.webdriver.browser.Firefox:
profile = selenium.webdriver.FirefoxProfile() profile = selenium.webdriver.FirefoxProfile()
profile.set_preference("privacy.trackingprotection.enabled", False) profile.set_preference('privacy.trackingprotection.enabled', False)
profile.set_preference("network.cookie.cookieBehavior", 0) profile.set_preference('network.cookie.cookieBehavior', 0)
profile.set_preference("privacy.trackingprotection.pbmode.enabled", False) profile.set_preference('privacy.trackingprotection.pbmode.enabled', False)
profile.set_preference("privacy.trackingprotection.cryptomining.enabled", False) profile.set_preference(
profile.set_preference("privacy.trackingprotection.fingerprinting.enabled", False) 'privacy.trackingprotection.cryptomining.enabled', False)
profile.set_preference(
'privacy.trackingprotection.fingerprinting.enabled', False)
options = selenium.webdriver.firefox.options.Options() options = selenium.webdriver.firefox.options.Options()
# options.add_argument('-headless') # options.add_argument('-headless')
driver = seleniumwire.webdriver.Firefox( driver = seleniumwire.webdriver.Firefox(profile,
profile, executable_path="geckodriver", options=options executable_path='geckodriver', options=options)
)
return driver return driver
@ -59,11 +60,11 @@ def collect_subdomains(url: str) -> typing.Iterable[str]:
DRIVER.get(url) DRIVER.get(url)
for s in range(SCROLL_STEPS): for s in range(SCROLL_STEPS):
DRIVER.execute_script(SCROLL_CMD) DRIVER.execute_script(SCROLL_CMD)
time.sleep(SCROLL_TIME / SCROLL_STEPS) time.sleep(SCROLL_TIME/SCROLL_STEPS)
for request in DRIVER.requests: for request in DRIVER.requests:
if request.response: if request.response:
yield subdomain_from_url(request.path) yield subdomain_from_url(request.path)
except Exception: except:
log.exception("Error") log.exception("Error")
DRIVER.quit() DRIVER.quit()
DRIVER = None DRIVER = None
@ -77,10 +78,10 @@ def collect_subdomains_standalone(url: str) -> None:
print(subdomain) print(subdomain)
if __name__ == "__main__": if __name__ == '__main__':
assert len(sys.argv) <= 2 assert len(sys.argv) <= 2
filename = None filename = None
if len(sys.argv) == 2 and sys.argv[1] != "-": if len(sys.argv) == 2 and sys.argv[1] != '-':
filename = sys.argv[1] filename = sys.argv[1]
num_lines = sum(1 for line in open(filename)) num_lines = sum(1 for line in open(filename))
iterator = progressbar.progressbar(open(filename), max_value=num_lines) iterator = progressbar.progressbar(open(filename), max_value=num_lines)

View file

@ -15,30 +15,33 @@ import os
TLD_LIST: typing.Set[str] = set() TLD_LIST: typing.Set[str] = set()
coloredlogs.install(level="DEBUG", fmt="%(asctime)s %(name)s %(levelname)s %(message)s") coloredlogs.install(
level='DEBUG',
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
)
Asn = int Asn = int
Timestamp = int Timestamp = int
Level = int Level = int
class Path: class Path():
pass pass
class RulePath(Path): class RulePath(Path):
def __str__(self) -> str: def __str__(self) -> str:
return "(rule)" return '(rule)'
class RuleFirstPath(RulePath): class RuleFirstPath(RulePath):
def __str__(self) -> str: def __str__(self) -> str:
return "(first-party rule)" return '(first-party rule)'
class RuleMultiPath(RulePath): class RuleMultiPath(RulePath):
def __str__(self) -> str: def __str__(self) -> str:
return "(multi-party rule)" return '(multi-party rule)'
class DomainPath(Path): class DomainPath(Path):
@ -46,7 +49,7 @@ class DomainPath(Path):
self.parts = parts self.parts = parts
def __str__(self) -> str: def __str__(self) -> str:
return "?." + Database.unpack_domain(self) return '?.' + Database.unpack_domain(self)
class HostnamePath(DomainPath): class HostnamePath(DomainPath):
@ -56,7 +59,7 @@ class HostnamePath(DomainPath):
class ZonePath(DomainPath): class ZonePath(DomainPath):
def __str__(self) -> str: def __str__(self) -> str:
return "*." + Database.unpack_domain(self) return '*.' + Database.unpack_domain(self)
class AsnPath(Path): class AsnPath(Path):
@ -76,7 +79,7 @@ class Ip4Path(Path):
return Database.unpack_ip4network(self) return Database.unpack_ip4network(self)
class Match: class Match():
def __init__(self) -> None: def __init__(self) -> None:
self.source: typing.Optional[Path] = None self.source: typing.Optional[Path] = None
self.updated: int = 0 self.updated: int = 0
@ -99,10 +102,10 @@ class Match:
class AsnNode(Match): class AsnNode(Match):
def __init__(self) -> None: def __init__(self) -> None:
Match.__init__(self) Match.__init__(self)
self.name = "" self.name = ''
class DomainTreeNode: class DomainTreeNode():
def __init__(self) -> None: def __init__(self) -> None:
self.children: typing.Dict[str, DomainTreeNode] = dict() self.children: typing.Dict[str, DomainTreeNode] = dict()
self.match_zone = Match() self.match_zone = Match()
@ -117,16 +120,18 @@ class IpTreeNode(Match):
Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode] Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode]
MatchCallable = typing.Callable[[Path, Match], typing.Any] MatchCallable = typing.Callable[[Path,
Match],
typing.Any]
class Profiler: class Profiler():
def __init__(self) -> None: def __init__(self) -> None:
do_profile = int(os.environ.get("PROFILE", "0")) do_profile = int(os.environ.get('PROFILE', '0'))
if do_profile: if do_profile:
self.log = logging.getLogger("profiler") self.log = logging.getLogger('profiler')
self.time_last = time.perf_counter() self.time_last = time.perf_counter()
self.time_step = "init" self.time_step = 'init'
self.time_dict: typing.Dict[str, float] = dict() self.time_dict: typing.Dict[str, float] = dict()
self.step_dict: typing.Dict[str, int] = dict() self.step_dict: typing.Dict[str, int] = dict()
self.enter_step = self.enter_step_real self.enter_step = self.enter_step_real
@ -153,17 +158,14 @@ class Profiler:
return return
def profile_real(self) -> None: def profile_real(self) -> None:
self.enter_step("profile") self.enter_step('profile')
total = sum(self.time_dict.values()) total = sum(self.time_dict.values())
for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]): for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]):
times = self.step_dict[key] times = self.step_dict[key]
self.log.debug( self.log.debug(f"{key:<20}: {times:9d} × {secs/times:5.3e} "
f"{key:<20}: {times:9d} × {secs/times:5.3e} " f"= {secs:9.2f} s ({secs/total:7.2%}) ")
f"= {secs:9.2f} s ({secs/total:7.2%}) " self.log.debug(f"{'total':<20}: "
) f"{total:9.2f} s ({1:7.2%})")
self.log.debug(
f"{'total':<20}: " f"{total:9.2f} s ({1:7.2%})"
)
class Database(Profiler): class Database(Profiler):
@ -171,7 +173,9 @@ class Database(Profiler):
PATH = "blocking.p" PATH = "blocking.p"
def initialize(self) -> None: def initialize(self) -> None:
self.log.warning("Creating database version: %d ", Database.VERSION) self.log.warning(
"Creating database version: %d ",
Database.VERSION)
# Dummy match objects that everything refer to # Dummy match objects that everything refer to
self.rules: typing.List[Match] = list() self.rules: typing.List[Match] = list()
for first_party in (False, True): for first_party in (False, True):
@ -185,77 +189,76 @@ class Database(Profiler):
self.ip4tree = IpTreeNode() self.ip4tree = IpTreeNode()
def load(self) -> None: def load(self) -> None:
self.enter_step("load") self.enter_step('load')
try: try:
with open(self.PATH, "rb") as db_fdsec: with open(self.PATH, 'rb') as db_fdsec:
version, data = pickle.load(db_fdsec) version, data = pickle.load(db_fdsec)
if version == Database.VERSION: if version == Database.VERSION:
self.rules, self.domtree, self.asns, self.ip4tree = data self.rules, self.domtree, self.asns, self.ip4tree = data
return return
self.log.warning( self.log.warning(
"Outdated database version found: %d, " "it will be rebuilt.", "Outdated database version found: %d, "
version, "it will be rebuilt.",
) version)
except (TypeError, AttributeError, EOFError): except (TypeError, AttributeError, EOFError):
self.log.error( self.log.error(
"Corrupt (or heavily outdated) database found, " "it will be rebuilt." "Corrupt (or heavily outdated) database found, "
) "it will be rebuilt.")
except FileNotFoundError: except FileNotFoundError:
pass pass
self.initialize() self.initialize()
def save(self) -> None: def save(self) -> None:
self.enter_step("save") self.enter_step('save')
with open(self.PATH, "wb") as db_fdsec: with open(self.PATH, 'wb') as db_fdsec:
data = self.rules, self.domtree, self.asns, self.ip4tree data = self.rules, self.domtree, self.asns, self.ip4tree
pickle.dump((self.VERSION, data), db_fdsec) pickle.dump((self.VERSION, data), db_fdsec)
self.profile() self.profile()
def __init__(self) -> None: def __init__(self) -> None:
Profiler.__init__(self) Profiler.__init__(self)
self.log = logging.getLogger("db") self.log = logging.getLogger('db')
self.load() self.load()
self.ip4cache_shift: int = 32 self.ip4cache_shift: int = 32
self.ip4cache = numpy.ones(1) self.ip4cache = numpy.ones(1)
def _set_ip4cache(self, path: Path, _: Match) -> None: def _set_ip4cache(self, path: Path, _: Match) -> None:
assert isinstance(path, Ip4Path) assert isinstance(path, Ip4Path)
self.enter_step("set_ip4cache") self.enter_step('set_ip4cache')
mini = path.value >> self.ip4cache_shift mini = path.value >> self.ip4cache_shift
maxi = (path.value + 2 ** (32 - path.prefixlen)) >> self.ip4cache_shift maxi = (path.value + 2**(32-path.prefixlen)) >> self.ip4cache_shift
if mini == maxi: if mini == maxi:
self.ip4cache[mini] = True self.ip4cache[mini] = True
else: else:
self.ip4cache[mini:maxi] = True self.ip4cache[mini:maxi] = True
def fill_ip4cache(self, max_size: int = 512 * 1024 ** 2) -> None: def fill_ip4cache(self, max_size: int = 512*1024**2) -> None:
""" """
Size in bytes Size in bytes
""" """
if max_size > 2 ** 32 / 8: if max_size > 2**32/8:
self.log.warning( self.log.warning("Allocating more than 512 MiB of RAM for "
"Allocating more than 512 MiB of RAM for " "the Ip4 cache is not necessary.")
"the Ip4 cache is not necessary." max_cache_width = int(math.log2(max(1, max_size*8)))
)
max_cache_width = int(math.log2(max(1, max_size * 8)))
allocated = False allocated = False
cache_width = min(32, max_cache_width) cache_width = min(2**32, max_cache_width)
while not allocated: while not allocated:
cache_size = 2 ** cache_width cache_size = 2**cache_width
try: try:
self.ip4cache = numpy.zeros(cache_size, dtype=bool) self.ip4cache = numpy.zeros(cache_size, dtype=numpy.bool)
except MemoryError: except MemoryError:
self.log.exception("Could not allocate cache. Retrying a smaller one.") self.log.exception(
"Could not allocate cache. Retrying a smaller one.")
cache_width -= 1 cache_width -= 1
continue continue
allocated = True allocated = True
self.ip4cache_shift = 32 - cache_width self.ip4cache_shift = 32-cache_width
for _ in self.exec_each_ip4(self._set_ip4cache): for _ in self.exec_each_ip4(self._set_ip4cache):
pass pass
@staticmethod @staticmethod
def populate_tld_list() -> None: def populate_tld_list() -> None:
with open("temp/all_tld.list", "r") as tld_fdesc: with open('temp/all_tld.list', 'r') as tld_fdesc:
for tld in tld_fdesc: for tld in tld_fdesc:
tld = tld.strip() tld = tld.strip()
TLD_LIST.add(tld) TLD_LIST.add(tld)
@ -264,7 +267,7 @@ class Database(Profiler):
def validate_domain(path: str) -> bool: def validate_domain(path: str) -> bool:
if len(path) > 255: if len(path) > 255:
return False return False
splits = path.split(".") splits = path.split('.')
if not TLD_LIST: if not TLD_LIST:
Database.populate_tld_list() Database.populate_tld_list()
if splits[-1] not in TLD_LIST: if splits[-1] not in TLD_LIST:
@ -276,26 +279,26 @@ class Database(Profiler):
@staticmethod @staticmethod
def pack_domain(domain: str) -> DomainPath: def pack_domain(domain: str) -> DomainPath:
return DomainPath(domain.split(".")[::-1]) return DomainPath(domain.split('.')[::-1])
@staticmethod @staticmethod
def unpack_domain(domain: DomainPath) -> str: def unpack_domain(domain: DomainPath) -> str:
return ".".join(domain.parts[::-1]) return '.'.join(domain.parts[::-1])
@staticmethod @staticmethod
def pack_asn(asn: str) -> AsnPath: def pack_asn(asn: str) -> AsnPath:
asn = asn.upper() asn = asn.upper()
if asn.startswith("AS"): if asn.startswith('AS'):
asn = asn[2:] asn = asn[2:]
return AsnPath(int(asn)) return AsnPath(int(asn))
@staticmethod @staticmethod
def unpack_asn(asn: AsnPath) -> str: def unpack_asn(asn: AsnPath) -> str:
return f"AS{asn.asn}" return f'AS{asn.asn}'
@staticmethod @staticmethod
def validate_ip4address(path: str) -> bool: def validate_ip4address(path: str) -> bool:
splits = path.split(".") splits = path.split('.')
if len(splits) != 4: if len(splits) != 4:
return False return False
for split in splits: for split in splits:
@ -309,7 +312,7 @@ class Database(Profiler):
@staticmethod @staticmethod
def pack_ip4address_low(address: str) -> int: def pack_ip4address_low(address: str) -> int:
addr = 0 addr = 0
for split in address.split("."): for split in address.split('.'):
octet = int(split) octet = int(split)
addr = (addr << 8) + octet addr = (addr << 8) + octet
return addr return addr
@ -327,12 +330,12 @@ class Database(Profiler):
for o in reversed(range(4)): for o in reversed(range(4)):
octets[o] = addr & 0xFF octets[o] = addr & 0xFF
addr >>= 8 addr >>= 8
return ".".join(map(str, octets)) return '.'.join(map(str, octets))
@staticmethod @staticmethod
def validate_ip4network(path: str) -> bool: def validate_ip4network(path: str) -> bool:
# A bit generous but ok for our usage # A bit generous but ok for our usage
splits = path.split("/") splits = path.split('/')
if len(splits) != 2: if len(splits) != 2:
return False return False
if not Database.validate_ip4address(splits[0]): if not Database.validate_ip4address(splits[0]):
@ -346,7 +349,7 @@ class Database(Profiler):
@staticmethod @staticmethod
def pack_ip4network(network: str) -> Ip4Path: def pack_ip4network(network: str) -> Ip4Path:
address, prefixlen_str = network.split("/") address, prefixlen_str = network.split('/')
prefixlen = int(prefixlen_str) prefixlen = int(prefixlen_str)
addr = Database.pack_ip4address(address) addr = Database.pack_ip4address(address)
addr.prefixlen = prefixlen addr.prefixlen = prefixlen
@ -360,7 +363,7 @@ class Database(Profiler):
for o in reversed(range(4)): for o in reversed(range(4)):
octets[o] = addr & 0xFF octets[o] = addr & 0xFF
addr >>= 8 addr >>= 8
return ".".join(map(str, octets)) + "/" + str(network.prefixlen) return '.'.join(map(str, octets)) + '/' + str(network.prefixlen)
def get_match(self, path: Path) -> Match: def get_match(self, path: Path) -> Match:
if isinstance(path, RuleMultiPath): if isinstance(path, RuleMultiPath):
@ -381,7 +384,7 @@ class Database(Profiler):
raise ValueError raise ValueError
elif isinstance(path, Ip4Path): elif isinstance(path, Ip4Path):
dici = self.ip4tree dici = self.ip4tree
for i in range(31, 31 - path.prefixlen, -1): for i in range(31, 31-path.prefixlen, -1):
bit = (path.value >> i) & 0b1 bit = (path.value >> i) & 0b1
dici_next = dici.one if bit else dici.zero dici_next = dici.one if bit else dici.zero
if not dici_next: if not dici_next:
@ -391,10 +394,9 @@ class Database(Profiler):
else: else:
raise ValueError raise ValueError
def exec_each_asn( def exec_each_asn(self,
self, callback: MatchCallable,
callback: MatchCallable, ) -> typing.Any:
) -> typing.Any:
for asn in self.asns: for asn in self.asns:
match = self.asns[asn] match = self.asns[asn]
if match.active(): if match.active():
@ -407,12 +409,11 @@ class Database(Profiler):
except TypeError: # not iterable except TypeError: # not iterable
pass pass
def exec_each_domain( def exec_each_domain(self,
self, callback: MatchCallable,
callback: MatchCallable, _dic: DomainTreeNode = None,
_dic: DomainTreeNode = None, _par: DomainPath = None,
_par: DomainPath = None, ) -> typing.Any:
) -> typing.Any:
_dic = _dic or self.domtree _dic = _dic or self.domtree
_par = _par or DomainPath([]) _par = _par or DomainPath([])
if _dic.match_hostname.active(): if _dic.match_hostname.active():
@ -436,15 +437,16 @@ class Database(Profiler):
for part in _dic.children: for part in _dic.children:
dic = _dic.children[part] dic = _dic.children[part]
yield from self.exec_each_domain( yield from self.exec_each_domain(
callback, _dic=dic, _par=DomainPath(_par.parts + [part]) callback,
_dic=dic,
_par=DomainPath(_par.parts + [part])
) )
def exec_each_ip4( def exec_each_ip4(self,
self, callback: MatchCallable,
callback: MatchCallable, _dic: IpTreeNode = None,
_dic: IpTreeNode = None, _par: Ip4Path = None,
_par: Ip4Path = None, ) -> typing.Any:
) -> typing.Any:
_dic = _dic or self.ip4tree _dic = _dic or self.ip4tree
_par = _par or Ip4Path(0, 0) _par = _par or Ip4Path(0, 0)
if _dic.active(): if _dic.active():
@ -464,18 +466,25 @@ class Database(Profiler):
# addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-pref))) # addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-pref)))
# assert addr0 == _par.value # assert addr0 == _par.value
addr0 = _par.value addr0 = _par.value
yield from self.exec_each_ip4(callback, _dic=dic, _par=Ip4Path(addr0, pref)) yield from self.exec_each_ip4(
callback,
_dic=dic,
_par=Ip4Path(addr0, pref)
)
# 1 # 1
dic = _dic.one dic = _dic.one
if dic: if dic:
addr1 = _par.value | (1 << (32 - pref)) addr1 = _par.value | (1 << (32-pref))
# assert addr1 != _par.value # assert addr1 != _par.value
yield from self.exec_each_ip4(callback, _dic=dic, _par=Ip4Path(addr1, pref)) yield from self.exec_each_ip4(
callback,
_dic=dic,
_par=Ip4Path(addr1, pref)
)
def exec_each( def exec_each(self,
self, callback: MatchCallable,
callback: MatchCallable, ) -> typing.Any:
) -> typing.Any:
yield from self.exec_each_domain(callback) yield from self.exec_each_domain(callback)
yield from self.exec_each_ip4(callback) yield from self.exec_each_ip4(callback)
yield from self.exec_each_asn(callback) yield from self.exec_each_asn(callback)
@ -483,17 +492,19 @@ class Database(Profiler):
def update_references(self) -> None: def update_references(self) -> None:
# Should be correctly calculated normally, # Should be correctly calculated normally,
# keeping this just in case # keeping this just in case
def reset_references_cb(path: Path, match: Match) -> None: def reset_references_cb(path: Path,
match: Match
) -> None:
match.references = 0 match.references = 0
for _ in self.exec_each(reset_references_cb): for _ in self.exec_each(reset_references_cb):
pass pass
def increment_references_cb(path: Path, match: Match) -> None: def increment_references_cb(path: Path,
match: Match
) -> None:
if match.source: if match.source:
source = self.get_match(match.source) source = self.get_match(match.source)
source.references += 1 source.references += 1
for _ in self.exec_each(increment_references_cb): for _ in self.exec_each(increment_references_cb):
pass pass
@ -502,7 +513,9 @@ class Database(Profiler):
# matches until all disabled matches reference count = 0 # matches until all disabled matches reference count = 0
did_something = True did_something = True
def clean_deps_cb(path: Path, match: Match) -> None: def clean_deps_cb(path: Path,
match: Match
) -> None:
nonlocal did_something nonlocal did_something
if not match.source: if not match.source:
return return
@ -517,13 +530,15 @@ class Database(Profiler):
while did_something: while did_something:
did_something = False did_something = False
self.enter_step("pass_clean_deps") self.enter_step('pass_clean_deps')
for _ in self.exec_each(clean_deps_cb): for _ in self.exec_each(clean_deps_cb):
pass pass
def prune(self, before: int, base_only: bool = False) -> None: def prune(self, before: int, base_only: bool = False) -> None:
# Disable the matches targeted # Disable the matches targeted
def prune_cb(path: Path, match: Match) -> None: def prune_cb(path: Path,
match: Match
) -> None:
if base_only and match.level > 1: if base_only and match.level > 1:
return return
if match.updated > before: if match.updated > before:
@ -531,7 +546,7 @@ class Database(Profiler):
self._unset_match(match) self._unset_match(match)
self.log.debug("Print: disabled %s", path) self.log.debug("Print: disabled %s", path)
self.enter_step("pass_prune") self.enter_step('pass_prune')
for _ in self.exec_each(prune_cb): for _ in self.exec_each(prune_cb):
pass pass
@ -544,24 +559,25 @@ class Database(Profiler):
match = self.get_match(path) match = self.get_match(path)
string = str(path) string = str(path)
if isinstance(match, AsnNode): if isinstance(match, AsnNode):
string += f" ({match.name})" string += f' ({match.name})'
party_char = "F" if match.first_party else "M" party_char = 'F' if match.first_party else 'M'
dup_char = "D" if match.dupplicate else "_" dup_char = 'D' if match.dupplicate else '_'
string += f" {match.level}{party_char}{dup_char}{match.references}" string += f' {match.level}{party_char}{dup_char}{match.references}'
if match.source: if match.source:
string += f"{self.explain(match.source)}" string += f'{self.explain(match.source)}'
return string return string
def list_records( def list_records(self,
self, first_party_only: bool = False,
first_party_only: bool = False, end_chain_only: bool = False,
end_chain_only: bool = False, no_dupplicates: bool = False,
no_dupplicates: bool = False, rules_only: bool = False,
rules_only: bool = False, hostnames_only: bool = False,
hostnames_only: bool = False, explain: bool = False,
explain: bool = False, ) -> typing.Iterable[str]:
) -> typing.Iterable[str]:
def export_cb(path: Path, match: Match) -> typing.Iterable[str]: def export_cb(path: Path, match: Match
) -> typing.Iterable[str]:
if first_party_only and not match.first_party: if first_party_only and not match.first_party:
return return
if end_chain_only and match.references > 0: if end_chain_only and match.references > 0:
@ -580,14 +596,13 @@ class Database(Profiler):
yield from self.exec_each(export_cb) yield from self.exec_each(export_cb)
def count_records( def count_records(self,
self, first_party_only: bool = False,
first_party_only: bool = False, end_chain_only: bool = False,
end_chain_only: bool = False, no_dupplicates: bool = False,
no_dupplicates: bool = False, rules_only: bool = False,
rules_only: bool = False, hostnames_only: bool = False,
hostnames_only: bool = False, ) -> str:
) -> str:
memo: typing.Dict[str, int] = dict() memo: typing.Dict[str, int] = dict()
def count_records_cb(path: Path, match: Match) -> None: def count_records_cb(path: Path, match: Match) -> None:
@ -612,80 +627,75 @@ class Database(Profiler):
split: typing.List[str] = list() split: typing.List[str] = list()
for key, value in sorted(memo.items(), key=lambda s: s[0]): for key, value in sorted(memo.items(), key=lambda s: s[0]):
split.append(f"{key[:-4].lower()}s: {value}") split.append(f'{key[:-4].lower()}s: {value}')
return ", ".join(split) return ', '.join(split)
def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]: def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]:
self.enter_step("get_domain_pack") self.enter_step('get_domain_pack')
domain = self.pack_domain(domain_str) domain = self.pack_domain(domain_str)
self.enter_step("get_domain_brws") self.enter_step('get_domain_brws')
dic = self.domtree dic = self.domtree
depth = 0 depth = 0
for part in domain.parts: for part in domain.parts:
if dic.match_zone.active(): if dic.match_zone.active():
self.enter_step("get_domain_yield") self.enter_step('get_domain_yield')
yield ZonePath(domain.parts[:depth]) yield ZonePath(domain.parts[:depth])
self.enter_step("get_domain_brws") self.enter_step('get_domain_brws')
if part not in dic.children: if part not in dic.children:
return return
dic = dic.children[part] dic = dic.children[part]
depth += 1 depth += 1
if dic.match_zone.active(): if dic.match_zone.active():
self.enter_step("get_domain_yield") self.enter_step('get_domain_yield')
yield ZonePath(domain.parts) yield ZonePath(domain.parts)
if dic.match_hostname.active(): if dic.match_hostname.active():
self.enter_step("get_domain_yield") self.enter_step('get_domain_yield')
yield HostnamePath(domain.parts) yield HostnamePath(domain.parts)
def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]: def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]:
self.enter_step("get_ip4_pack") self.enter_step('get_ip4_pack')
ip4val = self.pack_ip4address_low(ip4_str) ip4val = self.pack_ip4address_low(ip4_str)
self.enter_step("get_ip4_cache") self.enter_step('get_ip4_cache')
if not self.ip4cache[ip4val >> self.ip4cache_shift]: if not self.ip4cache[ip4val >> self.ip4cache_shift]:
return return
self.enter_step("get_ip4_brws") self.enter_step('get_ip4_brws')
dic = self.ip4tree dic = self.ip4tree
for i in range(31, -1, -1): for i in range(31, -1, -1):
bit = (ip4val >> i) & 0b1 bit = (ip4val >> i) & 0b1
if dic.active(): if dic.active():
self.enter_step("get_ip4_yield") self.enter_step('get_ip4_yield')
yield Ip4Path(ip4val >> (i + 1) << (i + 1), 31 - i) yield Ip4Path(ip4val >> (i+1) << (i+1), 31-i)
self.enter_step("get_ip4_brws") self.enter_step('get_ip4_brws')
next_dic = dic.one if bit else dic.zero next_dic = dic.one if bit else dic.zero
if next_dic is None: if next_dic is None:
return return
dic = next_dic dic = next_dic
if dic.active(): if dic.active():
self.enter_step("get_ip4_yield") self.enter_step('get_ip4_yield')
yield Ip4Path(ip4val, 32) yield Ip4Path(ip4val, 32)
def _unset_match( def _unset_match(self,
self, match: Match,
match: Match, ) -> None:
) -> None:
match.disable() match.disable()
if match.source: if match.source:
source_match = self.get_match(match.source) source_match = self.get_match(match.source)
source_match.references -= 1 source_match.references -= 1
def _set_match( def _set_match(self,
self, match: Match,
match: Match, updated: int,
updated: int, source: Path,
source: Path, source_match: Match = None,
source_match: Match = None, dupplicate: bool = False,
dupplicate: bool = False, ) -> None:
) -> None:
# source_match is in parameters because most of the time # source_match is in parameters because most of the time
# its parent function needs it too, # its parent function needs it too,
# so it can pass it to save a traversal # so it can pass it to save a traversal
source_match = source_match or self.get_match(source) source_match = source_match or self.get_match(source)
new_level = source_match.level + 1 new_level = source_match.level + 1
if ( if updated > match.updated or new_level < match.level \
updated > match.updated or source_match.first_party > match.first_party:
or new_level < match.level
or source_match.first_party > match.first_party
):
# NOTE FP and level of matches referencing this one # NOTE FP and level of matches referencing this one
# won't be updated until run or prune # won't be updated until run or prune
if match.source: if match.source:
@ -698,18 +708,20 @@ class Database(Profiler):
source_match.references += 1 source_match.references += 1
match.dupplicate = dupplicate match.dupplicate = dupplicate
def _set_domain( def _set_domain(self,
self, hostname: bool, domain_str: str, updated: int, source: Path hostname: bool,
) -> None: domain_str: str,
self.enter_step("set_domain_val") updated: int,
source: Path) -> None:
self.enter_step('set_domain_val')
if not Database.validate_domain(domain_str): if not Database.validate_domain(domain_str):
raise ValueError(f"Invalid domain: {domain_str}") raise ValueError(f"Invalid domain: {domain_str}")
self.enter_step("set_domain_pack") self.enter_step('set_domain_pack')
domain = self.pack_domain(domain_str) domain = self.pack_domain(domain_str)
self.enter_step("set_domain_fp") self.enter_step('set_domain_fp')
source_match = self.get_match(source) source_match = self.get_match(source)
is_first_party = source_match.first_party is_first_party = source_match.first_party
self.enter_step("set_domain_brws") self.enter_step('set_domain_brws')
dic = self.domtree dic = self.domtree
dupplicate = False dupplicate = False
for part in domain.parts: for part in domain.parts:
@ -730,14 +742,21 @@ class Database(Profiler):
dupplicate=dupplicate, dupplicate=dupplicate,
) )
def set_hostname(self, *args: typing.Any, **kwargs: typing.Any) -> None: def set_hostname(self,
*args: typing.Any, **kwargs: typing.Any
) -> None:
self._set_domain(True, *args, **kwargs) self._set_domain(True, *args, **kwargs)
def set_zone(self, *args: typing.Any, **kwargs: typing.Any) -> None: def set_zone(self,
*args: typing.Any, **kwargs: typing.Any
) -> None:
self._set_domain(False, *args, **kwargs) self._set_domain(False, *args, **kwargs)
def set_asn(self, asn_str: str, updated: int, source: Path) -> None: def set_asn(self,
self.enter_step("set_asn") asn_str: str,
updated: int,
source: Path) -> None:
self.enter_step('set_asn')
path = self.pack_asn(asn_str) path = self.pack_asn(asn_str)
if path.asn in self.asns: if path.asn in self.asns:
match = self.asns[path.asn] match = self.asns[path.asn]
@ -750,14 +769,17 @@ class Database(Profiler):
source, source,
) )
def _set_ip4(self, ip4: Ip4Path, updated: int, source: Path) -> None: def _set_ip4(self,
self.enter_step("set_ip4_fp") ip4: Ip4Path,
updated: int,
source: Path) -> None:
self.enter_step('set_ip4_fp')
source_match = self.get_match(source) source_match = self.get_match(source)
is_first_party = source_match.first_party is_first_party = source_match.first_party
self.enter_step("set_ip4_brws") self.enter_step('set_ip4_brws')
dic = self.ip4tree dic = self.ip4tree
dupplicate = False dupplicate = False
for i in range(31, 31 - ip4.prefixlen, -1): for i in range(31, 31-ip4.prefixlen, -1):
bit = (ip4.value >> i) & 0b1 bit = (ip4.value >> i) & 0b1
next_dic = dic.one if bit else dic.zero next_dic = dic.one if bit else dic.zero
if next_dic is None: if next_dic is None:
@ -778,22 +800,24 @@ class Database(Profiler):
) )
self._set_ip4cache(ip4, dic) self._set_ip4cache(ip4, dic)
def set_ip4address( def set_ip4address(self,
self, ip4address_str: str, *args: typing.Any, **kwargs: typing.Any ip4address_str: str,
) -> None: *args: typing.Any, **kwargs: typing.Any
self.enter_step("set_ip4add_val") ) -> None:
self.enter_step('set_ip4add_val')
if not Database.validate_ip4address(ip4address_str): if not Database.validate_ip4address(ip4address_str):
raise ValueError(f"Invalid ip4address: {ip4address_str}") raise ValueError(f"Invalid ip4address: {ip4address_str}")
self.enter_step("set_ip4add_pack") self.enter_step('set_ip4add_pack')
ip4 = self.pack_ip4address(ip4address_str) ip4 = self.pack_ip4address(ip4address_str)
self._set_ip4(ip4, *args, **kwargs) self._set_ip4(ip4, *args, **kwargs)
def set_ip4network( def set_ip4network(self,
self, ip4network_str: str, *args: typing.Any, **kwargs: typing.Any ip4network_str: str,
) -> None: *args: typing.Any, **kwargs: typing.Any
self.enter_step("set_ip4net_val") ) -> None:
self.enter_step('set_ip4net_val')
if not Database.validate_ip4network(ip4network_str): if not Database.validate_ip4network(ip4network_str):
raise ValueError(f"Invalid ip4network: {ip4network_str}") raise ValueError(f"Invalid ip4network: {ip4network_str}")
self.enter_step("set_ip4net_pack") self.enter_step('set_ip4net_pack')
ip4 = self.pack_ip4network(ip4network_str) ip4 = self.pack_ip4network(ip4network_str)
self._set_ip4(ip4, *args, **kwargs) self._set_ip4(ip4, *args, **kwargs)

38
db.py
View file

@ -5,37 +5,29 @@ import database
import time import time
import os import os
if __name__ == "__main__": if __name__ == '__main__':
# Parsing arguments # Parsing arguments
parser = argparse.ArgumentParser(description="Database operations") parser = argparse.ArgumentParser(
description="Database operations")
parser.add_argument( parser.add_argument(
"-i", "--initialize", action="store_true", help="Reconstruct the whole database" '-i', '--initialize', action='store_true',
) help="Reconstruct the whole database")
parser.add_argument( parser.add_argument(
"-p", "--prune", action="store_true", help="Remove old entries from database" '-p', '--prune', action='store_true',
) help="Remove old entries from database")
parser.add_argument( parser.add_argument(
"-b", '-b', '--prune-base', action='store_true',
"--prune-base",
action="store_true",
help="With --prune, only prune base rules " help="With --prune, only prune base rules "
"(the ones added by ./feed_rules.py)", "(the ones added by ./feed_rules.py)")
)
parser.add_argument( parser.add_argument(
"-s", '-s', '--prune-before', type=int,
"--prune-before", default=(int(time.time()) - 60*60*24*31*6),
type=int,
default=(int(time.time()) - 60 * 60 * 24 * 31 * 6),
help="With --prune, only rules updated before " help="With --prune, only rules updated before "
"this UNIX timestamp will be deleted", "this UNIX timestamp will be deleted")
)
parser.add_argument( parser.add_argument(
"-r", '-r', '--references', action='store_true',
"--references", help="DEBUG: Update the reference count")
action="store_true",
help="DEBUG: Update the reference count",
)
args = parser.parse_args() args = parser.parse_args()
if not args.initialize: if not args.initialize:
@ -45,7 +37,7 @@ if __name__ == "__main__":
os.unlink(database.Database.PATH) os.unlink(database.Database.PATH)
DB = database.Database() DB = database.Database()
DB.enter_step("main") DB.enter_step('main')
if args.prune: if args.prune:
DB.prune(before=args.prune_before, base_only=args.prune_base) DB.prune(before=args.prune_before, base_only=args.prune_base)
if args.references: if args.references:

View file

@ -5,80 +5,53 @@ import argparse
import sys import sys
if __name__ == "__main__": if __name__ == '__main__':
# Parsing arguments # Parsing arguments
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Export the hostnames rules stored " "in the Database as plain text" description="Export the hostnames rules stored "
) "in the Database as plain text")
parser.add_argument( parser.add_argument(
"-o", '-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
"--output", help="Output file, one rule per line")
type=argparse.FileType("w"),
default=sys.stdout,
help="Output file, one rule per line",
)
parser.add_argument( parser.add_argument(
"-f", '-f', '--first-party', action='store_true',
"--first-party", help="Only output rules issued from first-party sources")
action="store_true",
help="Only output rules issued from first-party sources",
)
parser.add_argument( parser.add_argument(
"-e", '-e', '--end-chain', action='store_true',
"--end-chain", help="Only output rules that are not referenced by any other")
action="store_true",
help="Only output rules that are not referenced by any other",
)
parser.add_argument( parser.add_argument(
"-r", '-r', '--rules', action='store_true',
"--rules", help="Output all kinds of rules, not just hostnames")
action="store_true",
help="Output all kinds of rules, not just hostnames",
)
parser.add_argument( parser.add_argument(
"-b", '-b', '--base-rules', action='store_true',
"--base-rules",
action="store_true",
help="Output base rules " help="Output base rules "
"(the ones added by ./feed_rules.py) " "(the ones added by ./feed_rules.py) "
"(implies --rules)", "(implies --rules)")
)
parser.add_argument( parser.add_argument(
"-d", '-d', '--no-dupplicates', action='store_true',
"--no-dupplicates",
action="store_true",
help="Do not output rules that already match a zone/network rule " help="Do not output rules that already match a zone/network rule "
"(e.g. dummy.example.com when there's a zone example.com rule)", "(e.g. dummy.example.com when there's a zone example.com rule)")
)
parser.add_argument( parser.add_argument(
"-x", '-x', '--explain', action='store_true',
"--explain",
action="store_true",
help="Show the chain of rules leading to one " help="Show the chain of rules leading to one "
"(and the number of references they have)", "(and the number of references they have)")
)
parser.add_argument( parser.add_argument(
"-c", '-c', '--count', action='store_true',
"--count", help="Show the number of rules per type instead of listing them")
action="store_true",
help="Show the number of rules per type instead of listing them",
)
args = parser.parse_args() args = parser.parse_args()
DB = database.Database() DB = database.Database()
if args.count: if args.count:
assert not args.explain assert not args.explain
print( print(DB.count_records(
DB.count_records( first_party_only=args.first_party,
first_party_only=args.first_party, end_chain_only=args.end_chain,
end_chain_only=args.end_chain, no_dupplicates=args.no_dupplicates,
no_dupplicates=args.no_dupplicates, rules_only=args.base_rules,
rules_only=args.base_rules, hostnames_only=not (args.rules or args.base_rules),
hostnames_only=not (args.rules or args.base_rules), ))
)
)
else: else:
for domain in DB.list_records( for domain in DB.list_records(
first_party_only=args.first_party, first_party_only=args.first_party,

View file

@ -13,54 +13,57 @@ IPNetwork = typing.Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
def get_ranges(asn: str) -> typing.Iterable[str]: def get_ranges(asn: str) -> typing.Iterable[str]:
req = requests.get( req = requests.get(
"https://stat.ripe.net/data/as-routing-consistency/data.json", 'https://stat.ripe.net/data/as-routing-consistency/data.json',
params={"resource": asn}, params={'resource': asn}
) )
data = req.json() data = req.json()
for pref in data["data"]["prefixes"]: for pref in data['data']['prefixes']:
yield pref["prefix"] yield pref['prefix']
def get_name(asn: str) -> str: def get_name(asn: str) -> str:
req = requests.get( req = requests.get(
"https://stat.ripe.net/data/as-overview/data.json", params={"resource": asn} 'https://stat.ripe.net/data/as-overview/data.json',
params={'resource': asn}
) )
data = req.json() data = req.json()
return data["data"]["holder"] return data['data']['holder']
if __name__ == "__main__": if __name__ == '__main__':
log = logging.getLogger("feed_asn") log = logging.getLogger('feed_asn')
# Parsing arguments # Parsing arguments
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Add the IP ranges associated to the AS in the database" description="Add the IP ranges associated to the AS in the database")
)
args = parser.parse_args() args = parser.parse_args()
DB = database.Database() DB = database.Database()
def add_ranges( def add_ranges(path: database.Path,
path: database.Path, match: database.Match,
match: database.Match, ) -> None:
) -> None:
assert isinstance(path, database.AsnPath) assert isinstance(path, database.AsnPath)
assert isinstance(match, database.AsnNode) assert isinstance(match, database.AsnNode)
asn_str = database.Database.unpack_asn(path) asn_str = database.Database.unpack_asn(path)
DB.enter_step("asn_get_name") DB.enter_step('asn_get_name')
name = get_name(asn_str) name = get_name(asn_str)
match.name = name match.name = name
DB.enter_step("asn_get_ranges") DB.enter_step('asn_get_ranges')
for prefix in get_ranges(asn_str): for prefix in get_ranges(asn_str):
parsed_prefix: IPNetwork = ipaddress.ip_network(prefix) parsed_prefix: IPNetwork = ipaddress.ip_network(prefix)
if parsed_prefix.version == 4: if parsed_prefix.version == 4:
DB.set_ip4network(prefix, source=path, updated=int(time.time())) DB.set_ip4network(
log.info("Added %s from %s (%s)", prefix, path, name) prefix,
source=path,
updated=int(time.time())
)
log.info('Added %s from %s (%s)', prefix, path, name)
elif parsed_prefix.version == 6: elif parsed_prefix.version == 6:
log.warning("Unimplemented prefix version: %s", prefix) log.warning('Unimplemented prefix version: %s', prefix)
else: else:
log.error("Unknown prefix version: %s", prefix) log.error('Unknown prefix version: %s', prefix)
for _ in DB.exec_each_asn(add_ranges): for _ in DB.exec_each_asn(add_ranges):
pass pass

View file

@ -12,15 +12,15 @@ Record = typing.Tuple[typing.Callable, typing.Callable, int, str, str]
# select, write # select, write
FUNCTION_MAP: typing.Any = { FUNCTION_MAP: typing.Any = {
"a": ( 'a': (
database.Database.get_ip4, database.Database.get_ip4,
database.Database.set_hostname, database.Database.set_hostname,
), ),
"cname": ( 'cname': (
database.Database.get_domain, database.Database.get_domain,
database.Database.set_hostname, database.Database.set_hostname,
), ),
"ptr": ( 'ptr': (
database.Database.get_domain, database.Database.get_domain,
database.Database.set_ip4address, database.Database.set_ip4address,
), ),
@ -28,16 +28,15 @@ FUNCTION_MAP: typing.Any = {
class Writer(multiprocessing.Process): class Writer(multiprocessing.Process):
def __init__( def __init__(self,
self, recs_queue: multiprocessing.Queue = None,
recs_queue: multiprocessing.Queue = None, autosave_interval: int = 0,
autosave_interval: int = 0, ip4_cache: int = 0,
ip4_cache: int = 0, ):
):
if recs_queue: # MP if recs_queue: # MP
super(Writer, self).__init__() super(Writer, self).__init__()
self.recs_queue = recs_queue self.recs_queue = recs_queue
self.log = logging.getLogger("wr") self.log = logging.getLogger(f'wr')
self.autosave_interval = autosave_interval self.autosave_interval = autosave_interval
self.ip4_cache = ip4_cache self.ip4_cache = ip4_cache
if not recs_queue: # No MP if not recs_queue: # No MP
@ -45,11 +44,11 @@ class Writer(multiprocessing.Process):
def open_db(self) -> None: def open_db(self) -> None:
self.db = database.Database() self.db = database.Database()
self.db.log = logging.getLogger("wr") self.db.log = logging.getLogger(f'wr')
self.db.fill_ip4cache(max_size=self.ip4_cache) self.db.fill_ip4cache(max_size=self.ip4_cache)
def exec_record(self, record: Record) -> None: def exec_record(self, record: Record) -> None:
self.db.enter_step("exec_record") self.db.enter_step('exec_record')
select, write, updated, name, value = record select, write, updated, name, value = record
try: try:
for source in select(self.db, value): for source in select(self.db, value):
@ -60,7 +59,7 @@ class Writer(multiprocessing.Process):
self.log.exception("Cannot execute: %s", record) self.log.exception("Cannot execute: %s", record)
def end(self) -> None: def end(self) -> None:
self.db.enter_step("end") self.db.enter_step('end')
self.db.save() self.db.save()
def run(self) -> None: def run(self) -> None:
@ -70,11 +69,10 @@ class Writer(multiprocessing.Process):
else: else:
next_save = 0 next_save = 0
self.db.enter_step("block_wait") self.db.enter_step('block_wait')
block: typing.List[Record] block: typing.List[Record]
for block in iter(self.recs_queue.get, None): for block in iter(self.recs_queue.get, None):
assert block
record: Record record: Record
for record in block: for record in block:
self.exec_record(record) self.exec_record(record)
@ -85,21 +83,20 @@ class Writer(multiprocessing.Process):
self.log.info("Done!") self.log.info("Done!")
next_save = time.time() + self.autosave_interval next_save = time.time() + self.autosave_interval
self.db.enter_step("block_wait") self.db.enter_step('block_wait')
self.end() self.end()
class Parser: class Parser():
def __init__( def __init__(self,
self, buf: typing.Any,
buf: typing.Any, recs_queue: multiprocessing.Queue = None,
recs_queue: multiprocessing.Queue = None, block_size: int = 0,
block_size: int = 0, writer: Writer = None,
writer: Writer = None, ):
):
assert bool(writer) ^ bool(block_size and recs_queue) assert bool(writer) ^ bool(block_size and recs_queue)
self.buf = buf self.buf = buf
self.log = logging.getLogger("pr") self.log = logging.getLogger('pr')
self.recs_queue = recs_queue self.recs_queue = recs_queue
if writer: # No MP if writer: # No MP
self.prof: database.Profiler = writer.db self.prof: database.Profiler = writer.db
@ -108,14 +105,14 @@ class Parser:
self.block: typing.List[Record] = list() self.block: typing.List[Record] = list()
self.block_size = block_size self.block_size = block_size
self.prof = database.Profiler() self.prof = database.Profiler()
self.prof.log = logging.getLogger("pr") self.prof.log = logging.getLogger('pr')
self.register = self.add_to_queue self.register = self.add_to_queue
def add_to_queue(self, record: Record) -> None: def add_to_queue(self, record: Record) -> None:
self.prof.enter_step("register") self.prof.enter_step('register')
self.block.append(record) self.block.append(record)
if len(self.block) >= self.block_size: if len(self.block) >= self.block_size:
self.prof.enter_step("put_block") self.prof.enter_step('put_block')
assert self.recs_queue assert self.recs_queue
self.recs_queue.put(self.block) self.recs_queue.put(self.block)
self.block = list() self.block = list()
@ -134,26 +131,26 @@ class Rapid7Parser(Parser):
def consume(self) -> None: def consume(self) -> None:
data = dict() data = dict()
for line in self.buf: for line in self.buf:
self.prof.enter_step("parse_rapid7") self.prof.enter_step('parse_rapid7')
split = line.split('"') split = line.split('"')
try: try:
for k in range(1, 14, 4): for k in range(1, 14, 4):
key = split[k] key = split[k]
val = split[k + 2] val = split[k+2]
data[key] = val data[key] = val
select, writer = FUNCTION_MAP[data["type"]] select, writer = FUNCTION_MAP[data['type']]
record = ( record = (
select, select,
writer, writer,
int(data["timestamp"]), int(data['timestamp']),
data["name"], data['name'],
data["value"], data['value']
) )
except (IndexError, KeyError): except (IndexError, KeyError):
# IndexError: missing field # IndexError: missing field
# KeyError: Unknown type field # KeyError: Unknown type field
self.log.exception("Cannot parse: %s", line) self.log.exception("Cannot parse: %s", line)
self.register(record) self.register(record)
@ -162,13 +159,13 @@ class MassDnsParser(Parser):
# massdns --output Snrql # massdns --output Snrql
# --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4 # --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4
TYPES = { TYPES = {
"A": (FUNCTION_MAP["a"][0], FUNCTION_MAP["a"][1], -1, None), 'A': (FUNCTION_MAP['a'][0], FUNCTION_MAP['a'][1], -1, None),
# 'AAAA': (FUNCTION_MAP['aaaa'][0], FUNCTION_MAP['aaaa'][1], -1, None), # 'AAAA': (FUNCTION_MAP['aaaa'][0], FUNCTION_MAP['aaaa'][1], -1, None),
"CNAME": (FUNCTION_MAP["cname"][0], FUNCTION_MAP["cname"][1], -1, -1), 'CNAME': (FUNCTION_MAP['cname'][0], FUNCTION_MAP['cname'][1], -1, -1),
} }
def consume(self) -> None: def consume(self) -> None:
self.prof.enter_step("parse_massdns") self.prof.enter_step('parse_massdns')
timestamp = 0 timestamp = 0
header = True header = True
for line in self.buf: for line in self.buf:
@ -177,15 +174,14 @@ class MassDnsParser(Parser):
header = True header = True
continue continue
split = line.split(" ") split = line.split(' ')
try: try:
if header: if header:
timestamp = int(split[1]) timestamp = int(split[1])
header = False header = False
else: else:
select, write, name_offset, value_offset = MassDnsParser.TYPES[ select, write, name_offset, value_offset = \
split[1] MassDnsParser.TYPES[split[1]]
]
record = ( record = (
select, select,
write, write,
@ -194,86 +190,75 @@ class MassDnsParser(Parser):
split[2][:value_offset].lower(), split[2][:value_offset].lower(),
) )
self.register(record) self.register(record)
self.prof.enter_step("parse_massdns") self.prof.enter_step('parse_massdns')
except KeyError: except KeyError:
continue continue
PARSERS = { PARSERS = {
"rapid7": Rapid7Parser, 'rapid7': Rapid7Parser,
"massdns": MassDnsParser, 'massdns': MassDnsParser,
} }
if __name__ == "__main__": if __name__ == '__main__':
# Parsing arguments # Parsing arguments
log = logging.getLogger("feed_dns") log = logging.getLogger('feed_dns')
args_parser = argparse.ArgumentParser( args_parser = argparse.ArgumentParser(
description="Read DNS records and import " description="Read DNS records and import "
"tracking-relevant data into the database" "tracking-relevant data into the database")
)
args_parser.add_argument("parser", choices=PARSERS.keys(), help="Input format")
args_parser.add_argument( args_parser.add_argument(
"-i", 'parser',
"--input", choices=PARSERS.keys(),
type=argparse.FileType("r"), help="Input format")
default=sys.stdin,
help="Input file",
)
args_parser.add_argument( args_parser.add_argument(
"-b", "--block-size", type=int, default=1024, help="Performance tuning value" '-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
) help="Input file")
args_parser.add_argument( args_parser.add_argument(
"-q", "--queue-size", type=int, default=128, help="Performance tuning value" '-b', '--block-size', type=int, default=1024,
) help="Performance tuning value")
args_parser.add_argument( args_parser.add_argument(
"-a", '-q', '--queue-size', type=int, default=128,
"--autosave-interval", help="Performance tuning value")
type=int,
default=900,
help="Interval to which the database will save in seconds. " "0 to disable.",
)
args_parser.add_argument( args_parser.add_argument(
"-s", '-a', '--autosave-interval', type=int, default=900,
"--single-process", help="Interval to which the database will save in seconds. "
action="store_true", "0 to disable.")
help="Only use one process. " "Might be useful for single core computers.",
)
args_parser.add_argument( args_parser.add_argument(
"-4", '-s', '--single-process', action='store_true',
"--ip4-cache", help="Only use one process. "
type=int, "Might be useful for single core computers.")
default=0, args_parser.add_argument(
'-4', '--ip4-cache', type=int, default=0,
help="RAM cache for faster IPv4 lookup. " help="RAM cache for faster IPv4 lookup. "
"Maximum useful value: 512 MiB (536870912). " "Maximum useful value: 512 MiB (536870912). "
"Warning: Depending on the rules, this might already " "Warning: Depending on the rules, this might already "
"be a memory-heavy process, even without the cache.", "be a memory-heavy process, even without the cache.")
)
args = args_parser.parse_args() args = args_parser.parse_args()
parser_cls = PARSERS[args.parser] parser_cls = PARSERS[args.parser]
if args.single_process: if args.single_process:
writer = Writer( writer = Writer(
autosave_interval=args.autosave_interval, ip4_cache=args.ip4_cache autosave_interval=args.autosave_interval,
ip4_cache=args.ip4_cache
) )
parser = parser_cls(args.input, writer=writer) parser = parser_cls(args.input, writer=writer)
parser.run() parser.run()
writer.end() writer.end()
else: else:
recs_queue: multiprocessing.Queue = multiprocessing.Queue( recs_queue: multiprocessing.Queue = multiprocessing.Queue(
maxsize=args.queue_size maxsize=args.queue_size)
)
writer = Writer( writer = Writer(recs_queue,
recs_queue, autosave_interval=args.autosave_interval,
autosave_interval=args.autosave_interval, ip4_cache=args.ip4_cache
ip4_cache=args.ip4_cache, )
)
writer.start() writer.start()
parser = parser_cls( parser = parser_cls(args.input,
args.input, recs_queue=recs_queue, block_size=args.block_size recs_queue=recs_queue,
) block_size=args.block_size
)
parser.run() parser.run()
recs_queue.put(None) recs_queue.put(None)

View file

@ -4,36 +4,30 @@ import database
import argparse import argparse
import sys import sys
import time import time
import typing
FUNCTION_MAP = { FUNCTION_MAP = {
"zone": database.Database.set_zone, 'zone': database.Database.set_zone,
"hostname": database.Database.set_hostname, 'hostname': database.Database.set_hostname,
"asn": database.Database.set_asn, 'asn': database.Database.set_asn,
"ip4network": database.Database.set_ip4network, 'ip4network': database.Database.set_ip4network,
"ip4address": database.Database.set_ip4address, 'ip4address': database.Database.set_ip4address,
} }
if __name__ == "__main__": if __name__ == '__main__':
# Parsing arguments # Parsing arguments
parser = argparse.ArgumentParser(description="Import base rules to the database") parser = argparse.ArgumentParser(
description="Import base rules to the database")
parser.add_argument( parser.add_argument(
"type", choices=FUNCTION_MAP.keys(), help="Type of rule inputed" 'type',
) choices=FUNCTION_MAP.keys(),
help="Type of rule inputed")
parser.add_argument( parser.add_argument(
"-i", '-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
"--input", help="File with one rule per line")
type=argparse.FileType("r"),
default=sys.stdin,
help="File with one rule per line",
)
parser.add_argument( parser.add_argument(
"-f", '-f', '--first-party', action='store_true',
"--first-party", help="The input only comes from verified first-party sources")
action="store_true",
help="The input only comes from verified first-party sources",
)
args = parser.parse_args() args = parser.parse_args()
DB = database.Database() DB = database.Database()
@ -49,12 +43,11 @@ if __name__ == "__main__":
for rule in args.input: for rule in args.input:
rule = rule.strip() rule = rule.strip()
try: try:
fun( fun(DB,
DB,
rule, rule,
source=source, source=source,
updated=int(time.time()), updated=int(time.time()),
) )
except ValueError: except ValueError:
DB.log.error(f"Could not add rule: {rule}") DB.log.error(f"Could not add rule: {rule}")

View file

@ -2,9 +2,11 @@
import markdown2 import markdown2
extras = ["header-ids"] extras = [
"header-ids"
]
with open("dist/README.md", "r") as fdesc: with open('dist/README.md', 'r') as fdesc:
body = markdown2.markdown(fdesc.read(), extras=extras) body = markdown2.markdown(fdesc.read(), extras=extras)
output = f"""<!DOCTYPE html> output = f"""<!DOCTYPE html>
@ -21,5 +23,5 @@ output = f"""<!DOCTYPE html>
</html> </html>
""" """
with open("dist/index.html", "w") as fdesc: with open('dist/index.html', 'w') as fdesc:
fdesc.write(output) fdesc.write(output)

View file

@ -57,11 +57,7 @@ if __name__ == "__main__":
perc_all = (100 * pass_all / count_all) if count_all else 100 perc_all = (100 * pass_all / count_all) if count_all else 100
perc_den = (100 * pass_den / count_den) if count_den else 100 perc_den = (100 * pass_den / count_den) if count_den else 100
log.info( log.info(
( "%s: Entries %d/%d (%.2f%%) | Allow %d/%d (%.2f%%) | Deny %d/%d (%.2f%%)",
"%s: Entries %d/%d (%.2f%%)"
" | Allow %d/%d (%.2f%%)"
"| Deny %d/%d (%.2f%%)"
),
filename, filename,
pass_ent, pass_ent,
count_ent, count_ent,