Black pass
This commit is contained in:
parent
a023dc8322
commit
3dcccad39a
|
@ -16,25 +16,36 @@ import abp.filters
|
||||||
def get_domains(rule: abp.filters.parser.Filter) -> typing.Iterable[str]:
|
def get_domains(rule: abp.filters.parser.Filter) -> typing.Iterable[str]:
|
||||||
if rule.options:
|
if rule.options:
|
||||||
return
|
return
|
||||||
selector_type = rule.selector['type']
|
selector_type = rule.selector["type"]
|
||||||
selector_value = rule.selector['value']
|
selector_value = rule.selector["value"]
|
||||||
if selector_type == 'url-pattern' \
|
if (
|
||||||
and selector_value.startswith('||') \
|
selector_type == "url-pattern"
|
||||||
and selector_value.endswith('^'):
|
and selector_value.startswith("||")
|
||||||
|
and selector_value.endswith("^")
|
||||||
|
):
|
||||||
yield selector_value[2:-1]
|
yield selector_value[2:-1]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# Parsing arguments
|
# Parsing arguments
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Extract whole domains from an AdBlock blocking list")
|
description="Extract whole domains from an AdBlock blocking list"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
|
"-i",
|
||||||
help="Input file with AdBlock rules")
|
"--input",
|
||||||
|
type=argparse.FileType("r"),
|
||||||
|
default=sys.stdin,
|
||||||
|
help="Input file with AdBlock rules",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
|
"-o",
|
||||||
help="Outptut file with one rule tracking subdomain per line")
|
"--output",
|
||||||
|
type=argparse.FileType("w"),
|
||||||
|
default=sys.stdout,
|
||||||
|
help="Outptut file with one rule tracking subdomain per line",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Reading rules
|
# Reading rules
|
||||||
|
|
|
@ -16,26 +16,25 @@ import selenium.webdriver.firefox.options
|
||||||
import seleniumwire.webdriver
|
import seleniumwire.webdriver
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
log = logging.getLogger('cs')
|
log = logging.getLogger("cs")
|
||||||
DRIVER = None
|
DRIVER = None
|
||||||
SCROLL_TIME = 10.0
|
SCROLL_TIME = 10.0
|
||||||
SCROLL_STEPS = 100
|
SCROLL_STEPS = 100
|
||||||
SCROLL_CMD = f'window.scrollBy(0,document.body.scrollHeight/{SCROLL_STEPS})'
|
SCROLL_CMD = f"window.scrollBy(0,document.body.scrollHeight/{SCROLL_STEPS})"
|
||||||
|
|
||||||
|
|
||||||
def new_driver() -> seleniumwire.webdriver.browser.Firefox:
|
def new_driver() -> seleniumwire.webdriver.browser.Firefox:
|
||||||
profile = selenium.webdriver.FirefoxProfile()
|
profile = selenium.webdriver.FirefoxProfile()
|
||||||
profile.set_preference('privacy.trackingprotection.enabled', False)
|
profile.set_preference("privacy.trackingprotection.enabled", False)
|
||||||
profile.set_preference('network.cookie.cookieBehavior', 0)
|
profile.set_preference("network.cookie.cookieBehavior", 0)
|
||||||
profile.set_preference('privacy.trackingprotection.pbmode.enabled', False)
|
profile.set_preference("privacy.trackingprotection.pbmode.enabled", False)
|
||||||
profile.set_preference(
|
profile.set_preference("privacy.trackingprotection.cryptomining.enabled", False)
|
||||||
'privacy.trackingprotection.cryptomining.enabled', False)
|
profile.set_preference("privacy.trackingprotection.fingerprinting.enabled", False)
|
||||||
profile.set_preference(
|
|
||||||
'privacy.trackingprotection.fingerprinting.enabled', False)
|
|
||||||
options = selenium.webdriver.firefox.options.Options()
|
options = selenium.webdriver.firefox.options.Options()
|
||||||
# options.add_argument('-headless')
|
# options.add_argument('-headless')
|
||||||
driver = seleniumwire.webdriver.Firefox(profile,
|
driver = seleniumwire.webdriver.Firefox(
|
||||||
executable_path='geckodriver', options=options)
|
profile, executable_path="geckodriver", options=options
|
||||||
|
)
|
||||||
return driver
|
return driver
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,7 +59,7 @@ def collect_subdomains(url: str) -> typing.Iterable[str]:
|
||||||
DRIVER.get(url)
|
DRIVER.get(url)
|
||||||
for s in range(SCROLL_STEPS):
|
for s in range(SCROLL_STEPS):
|
||||||
DRIVER.execute_script(SCROLL_CMD)
|
DRIVER.execute_script(SCROLL_CMD)
|
||||||
time.sleep(SCROLL_TIME/SCROLL_STEPS)
|
time.sleep(SCROLL_TIME / SCROLL_STEPS)
|
||||||
for request in DRIVER.requests:
|
for request in DRIVER.requests:
|
||||||
if request.response:
|
if request.response:
|
||||||
yield subdomain_from_url(request.path)
|
yield subdomain_from_url(request.path)
|
||||||
|
@ -78,10 +77,10 @@ def collect_subdomains_standalone(url: str) -> None:
|
||||||
print(subdomain)
|
print(subdomain)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
assert len(sys.argv) <= 2
|
assert len(sys.argv) <= 2
|
||||||
filename = None
|
filename = None
|
||||||
if len(sys.argv) == 2 and sys.argv[1] != '-':
|
if len(sys.argv) == 2 and sys.argv[1] != "-":
|
||||||
filename = sys.argv[1]
|
filename = sys.argv[1]
|
||||||
num_lines = sum(1 for line in open(filename))
|
num_lines = sum(1 for line in open(filename))
|
||||||
iterator = progressbar.progressbar(open(filename), max_value=num_lines)
|
iterator = progressbar.progressbar(open(filename), max_value=num_lines)
|
||||||
|
|
296
database.py
296
database.py
|
@ -15,33 +15,30 @@ import os
|
||||||
|
|
||||||
TLD_LIST: typing.Set[str] = set()
|
TLD_LIST: typing.Set[str] = set()
|
||||||
|
|
||||||
coloredlogs.install(
|
coloredlogs.install(level="DEBUG", fmt="%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||||
level='DEBUG',
|
|
||||||
fmt='%(asctime)s %(name)s %(levelname)s %(message)s'
|
|
||||||
)
|
|
||||||
|
|
||||||
Asn = int
|
Asn = int
|
||||||
Timestamp = int
|
Timestamp = int
|
||||||
Level = int
|
Level = int
|
||||||
|
|
||||||
|
|
||||||
class Path():
|
class Path:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RulePath(Path):
|
class RulePath(Path):
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return '(rule)'
|
return "(rule)"
|
||||||
|
|
||||||
|
|
||||||
class RuleFirstPath(RulePath):
|
class RuleFirstPath(RulePath):
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return '(first-party rule)'
|
return "(first-party rule)"
|
||||||
|
|
||||||
|
|
||||||
class RuleMultiPath(RulePath):
|
class RuleMultiPath(RulePath):
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return '(multi-party rule)'
|
return "(multi-party rule)"
|
||||||
|
|
||||||
|
|
||||||
class DomainPath(Path):
|
class DomainPath(Path):
|
||||||
|
@ -49,7 +46,7 @@ class DomainPath(Path):
|
||||||
self.parts = parts
|
self.parts = parts
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return '?.' + Database.unpack_domain(self)
|
return "?." + Database.unpack_domain(self)
|
||||||
|
|
||||||
|
|
||||||
class HostnamePath(DomainPath):
|
class HostnamePath(DomainPath):
|
||||||
|
@ -59,7 +56,7 @@ class HostnamePath(DomainPath):
|
||||||
|
|
||||||
class ZonePath(DomainPath):
|
class ZonePath(DomainPath):
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return '*.' + Database.unpack_domain(self)
|
return "*." + Database.unpack_domain(self)
|
||||||
|
|
||||||
|
|
||||||
class AsnPath(Path):
|
class AsnPath(Path):
|
||||||
|
@ -79,7 +76,7 @@ class Ip4Path(Path):
|
||||||
return Database.unpack_ip4network(self)
|
return Database.unpack_ip4network(self)
|
||||||
|
|
||||||
|
|
||||||
class Match():
|
class Match:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.source: typing.Optional[Path] = None
|
self.source: typing.Optional[Path] = None
|
||||||
self.updated: int = 0
|
self.updated: int = 0
|
||||||
|
@ -102,10 +99,10 @@ class Match():
|
||||||
class AsnNode(Match):
|
class AsnNode(Match):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
Match.__init__(self)
|
Match.__init__(self)
|
||||||
self.name = ''
|
self.name = ""
|
||||||
|
|
||||||
|
|
||||||
class DomainTreeNode():
|
class DomainTreeNode:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.children: typing.Dict[str, DomainTreeNode] = dict()
|
self.children: typing.Dict[str, DomainTreeNode] = dict()
|
||||||
self.match_zone = Match()
|
self.match_zone = Match()
|
||||||
|
@ -120,18 +117,16 @@ class IpTreeNode(Match):
|
||||||
|
|
||||||
|
|
||||||
Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode]
|
Node = typing.Union[DomainTreeNode, IpTreeNode, AsnNode]
|
||||||
MatchCallable = typing.Callable[[Path,
|
MatchCallable = typing.Callable[[Path, Match], typing.Any]
|
||||||
Match],
|
|
||||||
typing.Any]
|
|
||||||
|
|
||||||
|
|
||||||
class Profiler():
|
class Profiler:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
do_profile = int(os.environ.get('PROFILE', '0'))
|
do_profile = int(os.environ.get("PROFILE", "0"))
|
||||||
if do_profile:
|
if do_profile:
|
||||||
self.log = logging.getLogger('profiler')
|
self.log = logging.getLogger("profiler")
|
||||||
self.time_last = time.perf_counter()
|
self.time_last = time.perf_counter()
|
||||||
self.time_step = 'init'
|
self.time_step = "init"
|
||||||
self.time_dict: typing.Dict[str, float] = dict()
|
self.time_dict: typing.Dict[str, float] = dict()
|
||||||
self.step_dict: typing.Dict[str, int] = dict()
|
self.step_dict: typing.Dict[str, int] = dict()
|
||||||
self.enter_step = self.enter_step_real
|
self.enter_step = self.enter_step_real
|
||||||
|
@ -158,14 +153,17 @@ class Profiler():
|
||||||
return
|
return
|
||||||
|
|
||||||
def profile_real(self) -> None:
|
def profile_real(self) -> None:
|
||||||
self.enter_step('profile')
|
self.enter_step("profile")
|
||||||
total = sum(self.time_dict.values())
|
total = sum(self.time_dict.values())
|
||||||
for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]):
|
for key, secs in sorted(self.time_dict.items(), key=lambda t: t[1]):
|
||||||
times = self.step_dict[key]
|
times = self.step_dict[key]
|
||||||
self.log.debug(f"{key:<20}: {times:9d} × {secs/times:5.3e} "
|
self.log.debug(
|
||||||
f"= {secs:9.2f} s ({secs/total:7.2%}) ")
|
f"{key:<20}: {times:9d} × {secs/times:5.3e} "
|
||||||
self.log.debug(f"{'total':<20}: "
|
f"= {secs:9.2f} s ({secs/total:7.2%}) "
|
||||||
f"{total:9.2f} s ({1:7.2%})")
|
)
|
||||||
|
self.log.debug(
|
||||||
|
f"{'total':<20}: " f"{total:9.2f} s ({1:7.2%})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Database(Profiler):
|
class Database(Profiler):
|
||||||
|
@ -173,9 +171,7 @@ class Database(Profiler):
|
||||||
PATH = "blocking.p"
|
PATH = "blocking.p"
|
||||||
|
|
||||||
def initialize(self) -> None:
|
def initialize(self) -> None:
|
||||||
self.log.warning(
|
self.log.warning("Creating database version: %d ", Database.VERSION)
|
||||||
"Creating database version: %d ",
|
|
||||||
Database.VERSION)
|
|
||||||
# Dummy match objects that everything refer to
|
# Dummy match objects that everything refer to
|
||||||
self.rules: typing.List[Match] = list()
|
self.rules: typing.List[Match] = list()
|
||||||
for first_party in (False, True):
|
for first_party in (False, True):
|
||||||
|
@ -189,76 +185,77 @@ class Database(Profiler):
|
||||||
self.ip4tree = IpTreeNode()
|
self.ip4tree = IpTreeNode()
|
||||||
|
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
self.enter_step('load')
|
self.enter_step("load")
|
||||||
try:
|
try:
|
||||||
with open(self.PATH, 'rb') as db_fdsec:
|
with open(self.PATH, "rb") as db_fdsec:
|
||||||
version, data = pickle.load(db_fdsec)
|
version, data = pickle.load(db_fdsec)
|
||||||
if version == Database.VERSION:
|
if version == Database.VERSION:
|
||||||
self.rules, self.domtree, self.asns, self.ip4tree = data
|
self.rules, self.domtree, self.asns, self.ip4tree = data
|
||||||
return
|
return
|
||||||
self.log.warning(
|
self.log.warning(
|
||||||
"Outdated database version found: %d, "
|
"Outdated database version found: %d, " "it will be rebuilt.",
|
||||||
"it will be rebuilt.",
|
version,
|
||||||
version)
|
)
|
||||||
except (TypeError, AttributeError, EOFError):
|
except (TypeError, AttributeError, EOFError):
|
||||||
self.log.error(
|
self.log.error(
|
||||||
"Corrupt (or heavily outdated) database found, "
|
"Corrupt (or heavily outdated) database found, " "it will be rebuilt."
|
||||||
"it will be rebuilt.")
|
)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
self.initialize()
|
self.initialize()
|
||||||
|
|
||||||
def save(self) -> None:
|
def save(self) -> None:
|
||||||
self.enter_step('save')
|
self.enter_step("save")
|
||||||
with open(self.PATH, 'wb') as db_fdsec:
|
with open(self.PATH, "wb") as db_fdsec:
|
||||||
data = self.rules, self.domtree, self.asns, self.ip4tree
|
data = self.rules, self.domtree, self.asns, self.ip4tree
|
||||||
pickle.dump((self.VERSION, data), db_fdsec)
|
pickle.dump((self.VERSION, data), db_fdsec)
|
||||||
self.profile()
|
self.profile()
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
Profiler.__init__(self)
|
Profiler.__init__(self)
|
||||||
self.log = logging.getLogger('db')
|
self.log = logging.getLogger("db")
|
||||||
self.load()
|
self.load()
|
||||||
self.ip4cache_shift: int = 32
|
self.ip4cache_shift: int = 32
|
||||||
self.ip4cache = numpy.ones(1)
|
self.ip4cache = numpy.ones(1)
|
||||||
|
|
||||||
def _set_ip4cache(self, path: Path, _: Match) -> None:
|
def _set_ip4cache(self, path: Path, _: Match) -> None:
|
||||||
assert isinstance(path, Ip4Path)
|
assert isinstance(path, Ip4Path)
|
||||||
self.enter_step('set_ip4cache')
|
self.enter_step("set_ip4cache")
|
||||||
mini = path.value >> self.ip4cache_shift
|
mini = path.value >> self.ip4cache_shift
|
||||||
maxi = (path.value + 2**(32-path.prefixlen)) >> self.ip4cache_shift
|
maxi = (path.value + 2 ** (32 - path.prefixlen)) >> self.ip4cache_shift
|
||||||
if mini == maxi:
|
if mini == maxi:
|
||||||
self.ip4cache[mini] = True
|
self.ip4cache[mini] = True
|
||||||
else:
|
else:
|
||||||
self.ip4cache[mini:maxi] = True
|
self.ip4cache[mini:maxi] = True
|
||||||
|
|
||||||
def fill_ip4cache(self, max_size: int = 512*1024**2) -> None:
|
def fill_ip4cache(self, max_size: int = 512 * 1024 ** 2) -> None:
|
||||||
"""
|
"""
|
||||||
Size in bytes
|
Size in bytes
|
||||||
"""
|
"""
|
||||||
if max_size > 2**32/8:
|
if max_size > 2 ** 32 / 8:
|
||||||
self.log.warning("Allocating more than 512 MiB of RAM for "
|
self.log.warning(
|
||||||
"the Ip4 cache is not necessary.")
|
"Allocating more than 512 MiB of RAM for "
|
||||||
max_cache_width = int(math.log2(max(1, max_size*8)))
|
"the Ip4 cache is not necessary."
|
||||||
|
)
|
||||||
|
max_cache_width = int(math.log2(max(1, max_size * 8)))
|
||||||
allocated = False
|
allocated = False
|
||||||
cache_width = min(32, max_cache_width)
|
cache_width = min(32, max_cache_width)
|
||||||
while not allocated:
|
while not allocated:
|
||||||
cache_size = 2**cache_width
|
cache_size = 2 ** cache_width
|
||||||
try:
|
try:
|
||||||
self.ip4cache = numpy.zeros(cache_size, dtype=bool)
|
self.ip4cache = numpy.zeros(cache_size, dtype=bool)
|
||||||
except MemoryError:
|
except MemoryError:
|
||||||
self.log.exception(
|
self.log.exception("Could not allocate cache. Retrying a smaller one.")
|
||||||
"Could not allocate cache. Retrying a smaller one.")
|
|
||||||
cache_width -= 1
|
cache_width -= 1
|
||||||
continue
|
continue
|
||||||
allocated = True
|
allocated = True
|
||||||
self.ip4cache_shift = 32-cache_width
|
self.ip4cache_shift = 32 - cache_width
|
||||||
for _ in self.exec_each_ip4(self._set_ip4cache):
|
for _ in self.exec_each_ip4(self._set_ip4cache):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def populate_tld_list() -> None:
|
def populate_tld_list() -> None:
|
||||||
with open('temp/all_tld.list', 'r') as tld_fdesc:
|
with open("temp/all_tld.list", "r") as tld_fdesc:
|
||||||
for tld in tld_fdesc:
|
for tld in tld_fdesc:
|
||||||
tld = tld.strip()
|
tld = tld.strip()
|
||||||
TLD_LIST.add(tld)
|
TLD_LIST.add(tld)
|
||||||
|
@ -267,7 +264,7 @@ class Database(Profiler):
|
||||||
def validate_domain(path: str) -> bool:
|
def validate_domain(path: str) -> bool:
|
||||||
if len(path) > 255:
|
if len(path) > 255:
|
||||||
return False
|
return False
|
||||||
splits = path.split('.')
|
splits = path.split(".")
|
||||||
if not TLD_LIST:
|
if not TLD_LIST:
|
||||||
Database.populate_tld_list()
|
Database.populate_tld_list()
|
||||||
if splits[-1] not in TLD_LIST:
|
if splits[-1] not in TLD_LIST:
|
||||||
|
@ -279,26 +276,26 @@ class Database(Profiler):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pack_domain(domain: str) -> DomainPath:
|
def pack_domain(domain: str) -> DomainPath:
|
||||||
return DomainPath(domain.split('.')[::-1])
|
return DomainPath(domain.split(".")[::-1])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def unpack_domain(domain: DomainPath) -> str:
|
def unpack_domain(domain: DomainPath) -> str:
|
||||||
return '.'.join(domain.parts[::-1])
|
return ".".join(domain.parts[::-1])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pack_asn(asn: str) -> AsnPath:
|
def pack_asn(asn: str) -> AsnPath:
|
||||||
asn = asn.upper()
|
asn = asn.upper()
|
||||||
if asn.startswith('AS'):
|
if asn.startswith("AS"):
|
||||||
asn = asn[2:]
|
asn = asn[2:]
|
||||||
return AsnPath(int(asn))
|
return AsnPath(int(asn))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def unpack_asn(asn: AsnPath) -> str:
|
def unpack_asn(asn: AsnPath) -> str:
|
||||||
return f'AS{asn.asn}'
|
return f"AS{asn.asn}"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate_ip4address(path: str) -> bool:
|
def validate_ip4address(path: str) -> bool:
|
||||||
splits = path.split('.')
|
splits = path.split(".")
|
||||||
if len(splits) != 4:
|
if len(splits) != 4:
|
||||||
return False
|
return False
|
||||||
for split in splits:
|
for split in splits:
|
||||||
|
@ -312,7 +309,7 @@ class Database(Profiler):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pack_ip4address_low(address: str) -> int:
|
def pack_ip4address_low(address: str) -> int:
|
||||||
addr = 0
|
addr = 0
|
||||||
for split in address.split('.'):
|
for split in address.split("."):
|
||||||
octet = int(split)
|
octet = int(split)
|
||||||
addr = (addr << 8) + octet
|
addr = (addr << 8) + octet
|
||||||
return addr
|
return addr
|
||||||
|
@ -330,12 +327,12 @@ class Database(Profiler):
|
||||||
for o in reversed(range(4)):
|
for o in reversed(range(4)):
|
||||||
octets[o] = addr & 0xFF
|
octets[o] = addr & 0xFF
|
||||||
addr >>= 8
|
addr >>= 8
|
||||||
return '.'.join(map(str, octets))
|
return ".".join(map(str, octets))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate_ip4network(path: str) -> bool:
|
def validate_ip4network(path: str) -> bool:
|
||||||
# A bit generous but ok for our usage
|
# A bit generous but ok for our usage
|
||||||
splits = path.split('/')
|
splits = path.split("/")
|
||||||
if len(splits) != 2:
|
if len(splits) != 2:
|
||||||
return False
|
return False
|
||||||
if not Database.validate_ip4address(splits[0]):
|
if not Database.validate_ip4address(splits[0]):
|
||||||
|
@ -349,7 +346,7 @@ class Database(Profiler):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pack_ip4network(network: str) -> Ip4Path:
|
def pack_ip4network(network: str) -> Ip4Path:
|
||||||
address, prefixlen_str = network.split('/')
|
address, prefixlen_str = network.split("/")
|
||||||
prefixlen = int(prefixlen_str)
|
prefixlen = int(prefixlen_str)
|
||||||
addr = Database.pack_ip4address(address)
|
addr = Database.pack_ip4address(address)
|
||||||
addr.prefixlen = prefixlen
|
addr.prefixlen = prefixlen
|
||||||
|
@ -363,7 +360,7 @@ class Database(Profiler):
|
||||||
for o in reversed(range(4)):
|
for o in reversed(range(4)):
|
||||||
octets[o] = addr & 0xFF
|
octets[o] = addr & 0xFF
|
||||||
addr >>= 8
|
addr >>= 8
|
||||||
return '.'.join(map(str, octets)) + '/' + str(network.prefixlen)
|
return ".".join(map(str, octets)) + "/" + str(network.prefixlen)
|
||||||
|
|
||||||
def get_match(self, path: Path) -> Match:
|
def get_match(self, path: Path) -> Match:
|
||||||
if isinstance(path, RuleMultiPath):
|
if isinstance(path, RuleMultiPath):
|
||||||
|
@ -384,7 +381,7 @@ class Database(Profiler):
|
||||||
raise ValueError
|
raise ValueError
|
||||||
elif isinstance(path, Ip4Path):
|
elif isinstance(path, Ip4Path):
|
||||||
dici = self.ip4tree
|
dici = self.ip4tree
|
||||||
for i in range(31, 31-path.prefixlen, -1):
|
for i in range(31, 31 - path.prefixlen, -1):
|
||||||
bit = (path.value >> i) & 0b1
|
bit = (path.value >> i) & 0b1
|
||||||
dici_next = dici.one if bit else dici.zero
|
dici_next = dici.one if bit else dici.zero
|
||||||
if not dici_next:
|
if not dici_next:
|
||||||
|
@ -394,7 +391,8 @@ class Database(Profiler):
|
||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
def exec_each_asn(self,
|
def exec_each_asn(
|
||||||
|
self,
|
||||||
callback: MatchCallable,
|
callback: MatchCallable,
|
||||||
) -> typing.Any:
|
) -> typing.Any:
|
||||||
for asn in self.asns:
|
for asn in self.asns:
|
||||||
|
@ -409,7 +407,8 @@ class Database(Profiler):
|
||||||
except TypeError: # not iterable
|
except TypeError: # not iterable
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def exec_each_domain(self,
|
def exec_each_domain(
|
||||||
|
self,
|
||||||
callback: MatchCallable,
|
callback: MatchCallable,
|
||||||
_dic: DomainTreeNode = None,
|
_dic: DomainTreeNode = None,
|
||||||
_par: DomainPath = None,
|
_par: DomainPath = None,
|
||||||
|
@ -437,12 +436,11 @@ class Database(Profiler):
|
||||||
for part in _dic.children:
|
for part in _dic.children:
|
||||||
dic = _dic.children[part]
|
dic = _dic.children[part]
|
||||||
yield from self.exec_each_domain(
|
yield from self.exec_each_domain(
|
||||||
callback,
|
callback, _dic=dic, _par=DomainPath(_par.parts + [part])
|
||||||
_dic=dic,
|
|
||||||
_par=DomainPath(_par.parts + [part])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def exec_each_ip4(self,
|
def exec_each_ip4(
|
||||||
|
self,
|
||||||
callback: MatchCallable,
|
callback: MatchCallable,
|
||||||
_dic: IpTreeNode = None,
|
_dic: IpTreeNode = None,
|
||||||
_par: Ip4Path = None,
|
_par: Ip4Path = None,
|
||||||
|
@ -466,23 +464,16 @@ class Database(Profiler):
|
||||||
# addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-pref)))
|
# addr0 = _par.value & (0xFFFFFFFF ^ (1 << (32-pref)))
|
||||||
# assert addr0 == _par.value
|
# assert addr0 == _par.value
|
||||||
addr0 = _par.value
|
addr0 = _par.value
|
||||||
yield from self.exec_each_ip4(
|
yield from self.exec_each_ip4(callback, _dic=dic, _par=Ip4Path(addr0, pref))
|
||||||
callback,
|
|
||||||
_dic=dic,
|
|
||||||
_par=Ip4Path(addr0, pref)
|
|
||||||
)
|
|
||||||
# 1
|
# 1
|
||||||
dic = _dic.one
|
dic = _dic.one
|
||||||
if dic:
|
if dic:
|
||||||
addr1 = _par.value | (1 << (32-pref))
|
addr1 = _par.value | (1 << (32 - pref))
|
||||||
# assert addr1 != _par.value
|
# assert addr1 != _par.value
|
||||||
yield from self.exec_each_ip4(
|
yield from self.exec_each_ip4(callback, _dic=dic, _par=Ip4Path(addr1, pref))
|
||||||
callback,
|
|
||||||
_dic=dic,
|
|
||||||
_par=Ip4Path(addr1, pref)
|
|
||||||
)
|
|
||||||
|
|
||||||
def exec_each(self,
|
def exec_each(
|
||||||
|
self,
|
||||||
callback: MatchCallable,
|
callback: MatchCallable,
|
||||||
) -> typing.Any:
|
) -> typing.Any:
|
||||||
yield from self.exec_each_domain(callback)
|
yield from self.exec_each_domain(callback)
|
||||||
|
@ -492,19 +483,17 @@ class Database(Profiler):
|
||||||
def update_references(self) -> None:
|
def update_references(self) -> None:
|
||||||
# Should be correctly calculated normally,
|
# Should be correctly calculated normally,
|
||||||
# keeping this just in case
|
# keeping this just in case
|
||||||
def reset_references_cb(path: Path,
|
def reset_references_cb(path: Path, match: Match) -> None:
|
||||||
match: Match
|
|
||||||
) -> None:
|
|
||||||
match.references = 0
|
match.references = 0
|
||||||
|
|
||||||
for _ in self.exec_each(reset_references_cb):
|
for _ in self.exec_each(reset_references_cb):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def increment_references_cb(path: Path,
|
def increment_references_cb(path: Path, match: Match) -> None:
|
||||||
match: Match
|
|
||||||
) -> None:
|
|
||||||
if match.source:
|
if match.source:
|
||||||
source = self.get_match(match.source)
|
source = self.get_match(match.source)
|
||||||
source.references += 1
|
source.references += 1
|
||||||
|
|
||||||
for _ in self.exec_each(increment_references_cb):
|
for _ in self.exec_each(increment_references_cb):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -513,9 +502,7 @@ class Database(Profiler):
|
||||||
# matches until all disabled matches reference count = 0
|
# matches until all disabled matches reference count = 0
|
||||||
did_something = True
|
did_something = True
|
||||||
|
|
||||||
def clean_deps_cb(path: Path,
|
def clean_deps_cb(path: Path, match: Match) -> None:
|
||||||
match: Match
|
|
||||||
) -> None:
|
|
||||||
nonlocal did_something
|
nonlocal did_something
|
||||||
if not match.source:
|
if not match.source:
|
||||||
return
|
return
|
||||||
|
@ -530,15 +517,13 @@ class Database(Profiler):
|
||||||
|
|
||||||
while did_something:
|
while did_something:
|
||||||
did_something = False
|
did_something = False
|
||||||
self.enter_step('pass_clean_deps')
|
self.enter_step("pass_clean_deps")
|
||||||
for _ in self.exec_each(clean_deps_cb):
|
for _ in self.exec_each(clean_deps_cb):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def prune(self, before: int, base_only: bool = False) -> None:
|
def prune(self, before: int, base_only: bool = False) -> None:
|
||||||
# Disable the matches targeted
|
# Disable the matches targeted
|
||||||
def prune_cb(path: Path,
|
def prune_cb(path: Path, match: Match) -> None:
|
||||||
match: Match
|
|
||||||
) -> None:
|
|
||||||
if base_only and match.level > 1:
|
if base_only and match.level > 1:
|
||||||
return
|
return
|
||||||
if match.updated > before:
|
if match.updated > before:
|
||||||
|
@ -546,7 +531,7 @@ class Database(Profiler):
|
||||||
self._unset_match(match)
|
self._unset_match(match)
|
||||||
self.log.debug("Print: disabled %s", path)
|
self.log.debug("Print: disabled %s", path)
|
||||||
|
|
||||||
self.enter_step('pass_prune')
|
self.enter_step("pass_prune")
|
||||||
for _ in self.exec_each(prune_cb):
|
for _ in self.exec_each(prune_cb):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -559,15 +544,16 @@ class Database(Profiler):
|
||||||
match = self.get_match(path)
|
match = self.get_match(path)
|
||||||
string = str(path)
|
string = str(path)
|
||||||
if isinstance(match, AsnNode):
|
if isinstance(match, AsnNode):
|
||||||
string += f' ({match.name})'
|
string += f" ({match.name})"
|
||||||
party_char = 'F' if match.first_party else 'M'
|
party_char = "F" if match.first_party else "M"
|
||||||
dup_char = 'D' if match.dupplicate else '_'
|
dup_char = "D" if match.dupplicate else "_"
|
||||||
string += f' {match.level}{party_char}{dup_char}{match.references}'
|
string += f" {match.level}{party_char}{dup_char}{match.references}"
|
||||||
if match.source:
|
if match.source:
|
||||||
string += f' ← {self.explain(match.source)}'
|
string += f" ← {self.explain(match.source)}"
|
||||||
return string
|
return string
|
||||||
|
|
||||||
def list_records(self,
|
def list_records(
|
||||||
|
self,
|
||||||
first_party_only: bool = False,
|
first_party_only: bool = False,
|
||||||
end_chain_only: bool = False,
|
end_chain_only: bool = False,
|
||||||
no_dupplicates: bool = False,
|
no_dupplicates: bool = False,
|
||||||
|
@ -575,9 +561,7 @@ class Database(Profiler):
|
||||||
hostnames_only: bool = False,
|
hostnames_only: bool = False,
|
||||||
explain: bool = False,
|
explain: bool = False,
|
||||||
) -> typing.Iterable[str]:
|
) -> typing.Iterable[str]:
|
||||||
|
def export_cb(path: Path, match: Match) -> typing.Iterable[str]:
|
||||||
def export_cb(path: Path, match: Match
|
|
||||||
) -> typing.Iterable[str]:
|
|
||||||
if first_party_only and not match.first_party:
|
if first_party_only and not match.first_party:
|
||||||
return
|
return
|
||||||
if end_chain_only and match.references > 0:
|
if end_chain_only and match.references > 0:
|
||||||
|
@ -596,7 +580,8 @@ class Database(Profiler):
|
||||||
|
|
||||||
yield from self.exec_each(export_cb)
|
yield from self.exec_each(export_cb)
|
||||||
|
|
||||||
def count_records(self,
|
def count_records(
|
||||||
|
self,
|
||||||
first_party_only: bool = False,
|
first_party_only: bool = False,
|
||||||
end_chain_only: bool = False,
|
end_chain_only: bool = False,
|
||||||
no_dupplicates: bool = False,
|
no_dupplicates: bool = False,
|
||||||
|
@ -627,54 +612,55 @@ class Database(Profiler):
|
||||||
|
|
||||||
split: typing.List[str] = list()
|
split: typing.List[str] = list()
|
||||||
for key, value in sorted(memo.items(), key=lambda s: s[0]):
|
for key, value in sorted(memo.items(), key=lambda s: s[0]):
|
||||||
split.append(f'{key[:-4].lower()}s: {value}')
|
split.append(f"{key[:-4].lower()}s: {value}")
|
||||||
return ', '.join(split)
|
return ", ".join(split)
|
||||||
|
|
||||||
def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]:
|
def get_domain(self, domain_str: str) -> typing.Iterable[DomainPath]:
|
||||||
self.enter_step('get_domain_pack')
|
self.enter_step("get_domain_pack")
|
||||||
domain = self.pack_domain(domain_str)
|
domain = self.pack_domain(domain_str)
|
||||||
self.enter_step('get_domain_brws')
|
self.enter_step("get_domain_brws")
|
||||||
dic = self.domtree
|
dic = self.domtree
|
||||||
depth = 0
|
depth = 0
|
||||||
for part in domain.parts:
|
for part in domain.parts:
|
||||||
if dic.match_zone.active():
|
if dic.match_zone.active():
|
||||||
self.enter_step('get_domain_yield')
|
self.enter_step("get_domain_yield")
|
||||||
yield ZonePath(domain.parts[:depth])
|
yield ZonePath(domain.parts[:depth])
|
||||||
self.enter_step('get_domain_brws')
|
self.enter_step("get_domain_brws")
|
||||||
if part not in dic.children:
|
if part not in dic.children:
|
||||||
return
|
return
|
||||||
dic = dic.children[part]
|
dic = dic.children[part]
|
||||||
depth += 1
|
depth += 1
|
||||||
if dic.match_zone.active():
|
if dic.match_zone.active():
|
||||||
self.enter_step('get_domain_yield')
|
self.enter_step("get_domain_yield")
|
||||||
yield ZonePath(domain.parts)
|
yield ZonePath(domain.parts)
|
||||||
if dic.match_hostname.active():
|
if dic.match_hostname.active():
|
||||||
self.enter_step('get_domain_yield')
|
self.enter_step("get_domain_yield")
|
||||||
yield HostnamePath(domain.parts)
|
yield HostnamePath(domain.parts)
|
||||||
|
|
||||||
def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]:
|
def get_ip4(self, ip4_str: str) -> typing.Iterable[Path]:
|
||||||
self.enter_step('get_ip4_pack')
|
self.enter_step("get_ip4_pack")
|
||||||
ip4val = self.pack_ip4address_low(ip4_str)
|
ip4val = self.pack_ip4address_low(ip4_str)
|
||||||
self.enter_step('get_ip4_cache')
|
self.enter_step("get_ip4_cache")
|
||||||
if not self.ip4cache[ip4val >> self.ip4cache_shift]:
|
if not self.ip4cache[ip4val >> self.ip4cache_shift]:
|
||||||
return
|
return
|
||||||
self.enter_step('get_ip4_brws')
|
self.enter_step("get_ip4_brws")
|
||||||
dic = self.ip4tree
|
dic = self.ip4tree
|
||||||
for i in range(31, -1, -1):
|
for i in range(31, -1, -1):
|
||||||
bit = (ip4val >> i) & 0b1
|
bit = (ip4val >> i) & 0b1
|
||||||
if dic.active():
|
if dic.active():
|
||||||
self.enter_step('get_ip4_yield')
|
self.enter_step("get_ip4_yield")
|
||||||
yield Ip4Path(ip4val >> (i+1) << (i+1), 31-i)
|
yield Ip4Path(ip4val >> (i + 1) << (i + 1), 31 - i)
|
||||||
self.enter_step('get_ip4_brws')
|
self.enter_step("get_ip4_brws")
|
||||||
next_dic = dic.one if bit else dic.zero
|
next_dic = dic.one if bit else dic.zero
|
||||||
if next_dic is None:
|
if next_dic is None:
|
||||||
return
|
return
|
||||||
dic = next_dic
|
dic = next_dic
|
||||||
if dic.active():
|
if dic.active():
|
||||||
self.enter_step('get_ip4_yield')
|
self.enter_step("get_ip4_yield")
|
||||||
yield Ip4Path(ip4val, 32)
|
yield Ip4Path(ip4val, 32)
|
||||||
|
|
||||||
def _unset_match(self,
|
def _unset_match(
|
||||||
|
self,
|
||||||
match: Match,
|
match: Match,
|
||||||
) -> None:
|
) -> None:
|
||||||
match.disable()
|
match.disable()
|
||||||
|
@ -682,7 +668,8 @@ class Database(Profiler):
|
||||||
source_match = self.get_match(match.source)
|
source_match = self.get_match(match.source)
|
||||||
source_match.references -= 1
|
source_match.references -= 1
|
||||||
|
|
||||||
def _set_match(self,
|
def _set_match(
|
||||||
|
self,
|
||||||
match: Match,
|
match: Match,
|
||||||
updated: int,
|
updated: int,
|
||||||
source: Path,
|
source: Path,
|
||||||
|
@ -694,8 +681,11 @@ class Database(Profiler):
|
||||||
# so it can pass it to save a traversal
|
# so it can pass it to save a traversal
|
||||||
source_match = source_match or self.get_match(source)
|
source_match = source_match or self.get_match(source)
|
||||||
new_level = source_match.level + 1
|
new_level = source_match.level + 1
|
||||||
if updated > match.updated or new_level < match.level \
|
if (
|
||||||
or source_match.first_party > match.first_party:
|
updated > match.updated
|
||||||
|
or new_level < match.level
|
||||||
|
or source_match.first_party > match.first_party
|
||||||
|
):
|
||||||
# NOTE FP and level of matches referencing this one
|
# NOTE FP and level of matches referencing this one
|
||||||
# won't be updated until run or prune
|
# won't be updated until run or prune
|
||||||
if match.source:
|
if match.source:
|
||||||
|
@ -708,20 +698,18 @@ class Database(Profiler):
|
||||||
source_match.references += 1
|
source_match.references += 1
|
||||||
match.dupplicate = dupplicate
|
match.dupplicate = dupplicate
|
||||||
|
|
||||||
def _set_domain(self,
|
def _set_domain(
|
||||||
hostname: bool,
|
self, hostname: bool, domain_str: str, updated: int, source: Path
|
||||||
domain_str: str,
|
) -> None:
|
||||||
updated: int,
|
self.enter_step("set_domain_val")
|
||||||
source: Path) -> None:
|
|
||||||
self.enter_step('set_domain_val')
|
|
||||||
if not Database.validate_domain(domain_str):
|
if not Database.validate_domain(domain_str):
|
||||||
raise ValueError(f"Invalid domain: {domain_str}")
|
raise ValueError(f"Invalid domain: {domain_str}")
|
||||||
self.enter_step('set_domain_pack')
|
self.enter_step("set_domain_pack")
|
||||||
domain = self.pack_domain(domain_str)
|
domain = self.pack_domain(domain_str)
|
||||||
self.enter_step('set_domain_fp')
|
self.enter_step("set_domain_fp")
|
||||||
source_match = self.get_match(source)
|
source_match = self.get_match(source)
|
||||||
is_first_party = source_match.first_party
|
is_first_party = source_match.first_party
|
||||||
self.enter_step('set_domain_brws')
|
self.enter_step("set_domain_brws")
|
||||||
dic = self.domtree
|
dic = self.domtree
|
||||||
dupplicate = False
|
dupplicate = False
|
||||||
for part in domain.parts:
|
for part in domain.parts:
|
||||||
|
@ -742,21 +730,14 @@ class Database(Profiler):
|
||||||
dupplicate=dupplicate,
|
dupplicate=dupplicate,
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_hostname(self,
|
def set_hostname(self, *args: typing.Any, **kwargs: typing.Any) -> None:
|
||||||
*args: typing.Any, **kwargs: typing.Any
|
|
||||||
) -> None:
|
|
||||||
self._set_domain(True, *args, **kwargs)
|
self._set_domain(True, *args, **kwargs)
|
||||||
|
|
||||||
def set_zone(self,
|
def set_zone(self, *args: typing.Any, **kwargs: typing.Any) -> None:
|
||||||
*args: typing.Any, **kwargs: typing.Any
|
|
||||||
) -> None:
|
|
||||||
self._set_domain(False, *args, **kwargs)
|
self._set_domain(False, *args, **kwargs)
|
||||||
|
|
||||||
def set_asn(self,
|
def set_asn(self, asn_str: str, updated: int, source: Path) -> None:
|
||||||
asn_str: str,
|
self.enter_step("set_asn")
|
||||||
updated: int,
|
|
||||||
source: Path) -> None:
|
|
||||||
self.enter_step('set_asn')
|
|
||||||
path = self.pack_asn(asn_str)
|
path = self.pack_asn(asn_str)
|
||||||
if path.asn in self.asns:
|
if path.asn in self.asns:
|
||||||
match = self.asns[path.asn]
|
match = self.asns[path.asn]
|
||||||
|
@ -769,17 +750,14 @@ class Database(Profiler):
|
||||||
source,
|
source,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_ip4(self,
|
def _set_ip4(self, ip4: Ip4Path, updated: int, source: Path) -> None:
|
||||||
ip4: Ip4Path,
|
self.enter_step("set_ip4_fp")
|
||||||
updated: int,
|
|
||||||
source: Path) -> None:
|
|
||||||
self.enter_step('set_ip4_fp')
|
|
||||||
source_match = self.get_match(source)
|
source_match = self.get_match(source)
|
||||||
is_first_party = source_match.first_party
|
is_first_party = source_match.first_party
|
||||||
self.enter_step('set_ip4_brws')
|
self.enter_step("set_ip4_brws")
|
||||||
dic = self.ip4tree
|
dic = self.ip4tree
|
||||||
dupplicate = False
|
dupplicate = False
|
||||||
for i in range(31, 31-ip4.prefixlen, -1):
|
for i in range(31, 31 - ip4.prefixlen, -1):
|
||||||
bit = (ip4.value >> i) & 0b1
|
bit = (ip4.value >> i) & 0b1
|
||||||
next_dic = dic.one if bit else dic.zero
|
next_dic = dic.one if bit else dic.zero
|
||||||
if next_dic is None:
|
if next_dic is None:
|
||||||
|
@ -800,24 +778,22 @@ class Database(Profiler):
|
||||||
)
|
)
|
||||||
self._set_ip4cache(ip4, dic)
|
self._set_ip4cache(ip4, dic)
|
||||||
|
|
||||||
def set_ip4address(self,
|
def set_ip4address(
|
||||||
ip4address_str: str,
|
self, ip4address_str: str, *args: typing.Any, **kwargs: typing.Any
|
||||||
*args: typing.Any, **kwargs: typing.Any
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.enter_step('set_ip4add_val')
|
self.enter_step("set_ip4add_val")
|
||||||
if not Database.validate_ip4address(ip4address_str):
|
if not Database.validate_ip4address(ip4address_str):
|
||||||
raise ValueError(f"Invalid ip4address: {ip4address_str}")
|
raise ValueError(f"Invalid ip4address: {ip4address_str}")
|
||||||
self.enter_step('set_ip4add_pack')
|
self.enter_step("set_ip4add_pack")
|
||||||
ip4 = self.pack_ip4address(ip4address_str)
|
ip4 = self.pack_ip4address(ip4address_str)
|
||||||
self._set_ip4(ip4, *args, **kwargs)
|
self._set_ip4(ip4, *args, **kwargs)
|
||||||
|
|
||||||
def set_ip4network(self,
|
def set_ip4network(
|
||||||
ip4network_str: str,
|
self, ip4network_str: str, *args: typing.Any, **kwargs: typing.Any
|
||||||
*args: typing.Any, **kwargs: typing.Any
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.enter_step('set_ip4net_val')
|
self.enter_step("set_ip4net_val")
|
||||||
if not Database.validate_ip4network(ip4network_str):
|
if not Database.validate_ip4network(ip4network_str):
|
||||||
raise ValueError(f"Invalid ip4network: {ip4network_str}")
|
raise ValueError(f"Invalid ip4network: {ip4network_str}")
|
||||||
self.enter_step('set_ip4net_pack')
|
self.enter_step("set_ip4net_pack")
|
||||||
ip4 = self.pack_ip4network(ip4network_str)
|
ip4 = self.pack_ip4network(ip4network_str)
|
||||||
self._set_ip4(ip4, *args, **kwargs)
|
self._set_ip4(ip4, *args, **kwargs)
|
||||||
|
|
38
db.py
38
db.py
|
@ -5,29 +5,37 @@ import database
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# Parsing arguments
|
# Parsing arguments
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Database operations")
|
||||||
description="Database operations")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-i', '--initialize', action='store_true',
|
"-i", "--initialize", action="store_true", help="Reconstruct the whole database"
|
||||||
help="Reconstruct the whole database")
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-p', '--prune', action='store_true',
|
"-p", "--prune", action="store_true", help="Remove old entries from database"
|
||||||
help="Remove old entries from database")
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-b', '--prune-base', action='store_true',
|
"-b",
|
||||||
|
"--prune-base",
|
||||||
|
action="store_true",
|
||||||
help="With --prune, only prune base rules "
|
help="With --prune, only prune base rules "
|
||||||
"(the ones added by ./feed_rules.py)")
|
"(the ones added by ./feed_rules.py)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-s', '--prune-before', type=int,
|
"-s",
|
||||||
default=(int(time.time()) - 60*60*24*31*6),
|
"--prune-before",
|
||||||
|
type=int,
|
||||||
|
default=(int(time.time()) - 60 * 60 * 24 * 31 * 6),
|
||||||
help="With --prune, only rules updated before "
|
help="With --prune, only rules updated before "
|
||||||
"this UNIX timestamp will be deleted")
|
"this UNIX timestamp will be deleted",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-r', '--references', action='store_true',
|
"-r",
|
||||||
help="DEBUG: Update the reference count")
|
"--references",
|
||||||
|
action="store_true",
|
||||||
|
help="DEBUG: Update the reference count",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not args.initialize:
|
if not args.initialize:
|
||||||
|
@ -37,7 +45,7 @@ if __name__ == '__main__':
|
||||||
os.unlink(database.Database.PATH)
|
os.unlink(database.Database.PATH)
|
||||||
DB = database.Database()
|
DB = database.Database()
|
||||||
|
|
||||||
DB.enter_step('main')
|
DB.enter_step("main")
|
||||||
if args.prune:
|
if args.prune:
|
||||||
DB.prune(before=args.prune_before, base_only=args.prune_base)
|
DB.prune(before=args.prune_before, base_only=args.prune_base)
|
||||||
if args.references:
|
if args.references:
|
||||||
|
|
69
export.py
69
export.py
|
@ -5,53 +5,80 @@ import argparse
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# Parsing arguments
|
# Parsing arguments
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Export the hostnames rules stored "
|
description="Export the hostnames rules stored " "in the Database as plain text"
|
||||||
"in the Database as plain text")
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-o', '--output', type=argparse.FileType('w'), default=sys.stdout,
|
"-o",
|
||||||
help="Output file, one rule per line")
|
"--output",
|
||||||
|
type=argparse.FileType("w"),
|
||||||
|
default=sys.stdout,
|
||||||
|
help="Output file, one rule per line",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-f', '--first-party', action='store_true',
|
"-f",
|
||||||
help="Only output rules issued from first-party sources")
|
"--first-party",
|
||||||
|
action="store_true",
|
||||||
|
help="Only output rules issued from first-party sources",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-e', '--end-chain', action='store_true',
|
"-e",
|
||||||
help="Only output rules that are not referenced by any other")
|
"--end-chain",
|
||||||
|
action="store_true",
|
||||||
|
help="Only output rules that are not referenced by any other",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-r', '--rules', action='store_true',
|
"-r",
|
||||||
help="Output all kinds of rules, not just hostnames")
|
"--rules",
|
||||||
|
action="store_true",
|
||||||
|
help="Output all kinds of rules, not just hostnames",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-b', '--base-rules', action='store_true',
|
"-b",
|
||||||
|
"--base-rules",
|
||||||
|
action="store_true",
|
||||||
help="Output base rules "
|
help="Output base rules "
|
||||||
"(the ones added by ./feed_rules.py) "
|
"(the ones added by ./feed_rules.py) "
|
||||||
"(implies --rules)")
|
"(implies --rules)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-d', '--no-dupplicates', action='store_true',
|
"-d",
|
||||||
|
"--no-dupplicates",
|
||||||
|
action="store_true",
|
||||||
help="Do not output rules that already match a zone/network rule "
|
help="Do not output rules that already match a zone/network rule "
|
||||||
"(e.g. dummy.example.com when there's a zone example.com rule)")
|
"(e.g. dummy.example.com when there's a zone example.com rule)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-x', '--explain', action='store_true',
|
"-x",
|
||||||
|
"--explain",
|
||||||
|
action="store_true",
|
||||||
help="Show the chain of rules leading to one "
|
help="Show the chain of rules leading to one "
|
||||||
"(and the number of references they have)")
|
"(and the number of references they have)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-c', '--count', action='store_true',
|
"-c",
|
||||||
help="Show the number of rules per type instead of listing them")
|
"--count",
|
||||||
|
action="store_true",
|
||||||
|
help="Show the number of rules per type instead of listing them",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
DB = database.Database()
|
DB = database.Database()
|
||||||
|
|
||||||
if args.count:
|
if args.count:
|
||||||
assert not args.explain
|
assert not args.explain
|
||||||
print(DB.count_records(
|
print(
|
||||||
|
DB.count_records(
|
||||||
first_party_only=args.first_party,
|
first_party_only=args.first_party,
|
||||||
end_chain_only=args.end_chain,
|
end_chain_only=args.end_chain,
|
||||||
no_dupplicates=args.no_dupplicates,
|
no_dupplicates=args.no_dupplicates,
|
||||||
rules_only=args.base_rules,
|
rules_only=args.base_rules,
|
||||||
hostnames_only=not (args.rules or args.base_rules),
|
hostnames_only=not (args.rules or args.base_rules),
|
||||||
))
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
for domain in DB.list_records(
|
for domain in DB.list_records(
|
||||||
first_party_only=args.first_party,
|
first_party_only=args.first_party,
|
||||||
|
|
39
feed_asn.py
39
feed_asn.py
|
@ -13,57 +13,54 @@ IPNetwork = typing.Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
|
||||||
|
|
||||||
def get_ranges(asn: str) -> typing.Iterable[str]:
|
def get_ranges(asn: str) -> typing.Iterable[str]:
|
||||||
req = requests.get(
|
req = requests.get(
|
||||||
'https://stat.ripe.net/data/as-routing-consistency/data.json',
|
"https://stat.ripe.net/data/as-routing-consistency/data.json",
|
||||||
params={'resource': asn}
|
params={"resource": asn},
|
||||||
)
|
)
|
||||||
data = req.json()
|
data = req.json()
|
||||||
for pref in data['data']['prefixes']:
|
for pref in data["data"]["prefixes"]:
|
||||||
yield pref['prefix']
|
yield pref["prefix"]
|
||||||
|
|
||||||
|
|
||||||
def get_name(asn: str) -> str:
|
def get_name(asn: str) -> str:
|
||||||
req = requests.get(
|
req = requests.get(
|
||||||
'https://stat.ripe.net/data/as-overview/data.json',
|
"https://stat.ripe.net/data/as-overview/data.json", params={"resource": asn}
|
||||||
params={'resource': asn}
|
|
||||||
)
|
)
|
||||||
data = req.json()
|
data = req.json()
|
||||||
return data['data']['holder']
|
return data["data"]["holder"]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
|
|
||||||
log = logging.getLogger('feed_asn')
|
log = logging.getLogger("feed_asn")
|
||||||
|
|
||||||
# Parsing arguments
|
# Parsing arguments
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Add the IP ranges associated to the AS in the database")
|
description="Add the IP ranges associated to the AS in the database"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
DB = database.Database()
|
DB = database.Database()
|
||||||
|
|
||||||
def add_ranges(path: database.Path,
|
def add_ranges(
|
||||||
|
path: database.Path,
|
||||||
match: database.Match,
|
match: database.Match,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert isinstance(path, database.AsnPath)
|
assert isinstance(path, database.AsnPath)
|
||||||
assert isinstance(match, database.AsnNode)
|
assert isinstance(match, database.AsnNode)
|
||||||
asn_str = database.Database.unpack_asn(path)
|
asn_str = database.Database.unpack_asn(path)
|
||||||
DB.enter_step('asn_get_name')
|
DB.enter_step("asn_get_name")
|
||||||
name = get_name(asn_str)
|
name = get_name(asn_str)
|
||||||
match.name = name
|
match.name = name
|
||||||
DB.enter_step('asn_get_ranges')
|
DB.enter_step("asn_get_ranges")
|
||||||
for prefix in get_ranges(asn_str):
|
for prefix in get_ranges(asn_str):
|
||||||
parsed_prefix: IPNetwork = ipaddress.ip_network(prefix)
|
parsed_prefix: IPNetwork = ipaddress.ip_network(prefix)
|
||||||
if parsed_prefix.version == 4:
|
if parsed_prefix.version == 4:
|
||||||
DB.set_ip4network(
|
DB.set_ip4network(prefix, source=path, updated=int(time.time()))
|
||||||
prefix,
|
log.info("Added %s from %s (%s)", prefix, path, name)
|
||||||
source=path,
|
|
||||||
updated=int(time.time())
|
|
||||||
)
|
|
||||||
log.info('Added %s from %s (%s)', prefix, path, name)
|
|
||||||
elif parsed_prefix.version == 6:
|
elif parsed_prefix.version == 6:
|
||||||
log.warning('Unimplemented prefix version: %s', prefix)
|
log.warning("Unimplemented prefix version: %s", prefix)
|
||||||
else:
|
else:
|
||||||
log.error('Unknown prefix version: %s', prefix)
|
log.error("Unknown prefix version: %s", prefix)
|
||||||
|
|
||||||
for _ in DB.exec_each_asn(add_ranges):
|
for _ in DB.exec_each_asn(add_ranges):
|
||||||
pass
|
pass
|
||||||
|
|
134
feed_dns.py
134
feed_dns.py
|
@ -12,15 +12,15 @@ Record = typing.Tuple[typing.Callable, typing.Callable, int, str, str]
|
||||||
|
|
||||||
# select, write
|
# select, write
|
||||||
FUNCTION_MAP: typing.Any = {
|
FUNCTION_MAP: typing.Any = {
|
||||||
'a': (
|
"a": (
|
||||||
database.Database.get_ip4,
|
database.Database.get_ip4,
|
||||||
database.Database.set_hostname,
|
database.Database.set_hostname,
|
||||||
),
|
),
|
||||||
'cname': (
|
"cname": (
|
||||||
database.Database.get_domain,
|
database.Database.get_domain,
|
||||||
database.Database.set_hostname,
|
database.Database.set_hostname,
|
||||||
),
|
),
|
||||||
'ptr': (
|
"ptr": (
|
||||||
database.Database.get_domain,
|
database.Database.get_domain,
|
||||||
database.Database.set_ip4address,
|
database.Database.set_ip4address,
|
||||||
),
|
),
|
||||||
|
@ -28,7 +28,8 @@ FUNCTION_MAP: typing.Any = {
|
||||||
|
|
||||||
|
|
||||||
class Writer(multiprocessing.Process):
|
class Writer(multiprocessing.Process):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
recs_queue: multiprocessing.Queue = None,
|
recs_queue: multiprocessing.Queue = None,
|
||||||
autosave_interval: int = 0,
|
autosave_interval: int = 0,
|
||||||
ip4_cache: int = 0,
|
ip4_cache: int = 0,
|
||||||
|
@ -36,7 +37,7 @@ class Writer(multiprocessing.Process):
|
||||||
if recs_queue: # MP
|
if recs_queue: # MP
|
||||||
super(Writer, self).__init__()
|
super(Writer, self).__init__()
|
||||||
self.recs_queue = recs_queue
|
self.recs_queue = recs_queue
|
||||||
self.log = logging.getLogger(f'wr')
|
self.log = logging.getLogger(f"wr")
|
||||||
self.autosave_interval = autosave_interval
|
self.autosave_interval = autosave_interval
|
||||||
self.ip4_cache = ip4_cache
|
self.ip4_cache = ip4_cache
|
||||||
if not recs_queue: # No MP
|
if not recs_queue: # No MP
|
||||||
|
@ -44,11 +45,11 @@ class Writer(multiprocessing.Process):
|
||||||
|
|
||||||
def open_db(self) -> None:
|
def open_db(self) -> None:
|
||||||
self.db = database.Database()
|
self.db = database.Database()
|
||||||
self.db.log = logging.getLogger(f'wr')
|
self.db.log = logging.getLogger(f"wr")
|
||||||
self.db.fill_ip4cache(max_size=self.ip4_cache)
|
self.db.fill_ip4cache(max_size=self.ip4_cache)
|
||||||
|
|
||||||
def exec_record(self, record: Record) -> None:
|
def exec_record(self, record: Record) -> None:
|
||||||
self.db.enter_step('exec_record')
|
self.db.enter_step("exec_record")
|
||||||
select, write, updated, name, value = record
|
select, write, updated, name, value = record
|
||||||
try:
|
try:
|
||||||
for source in select(self.db, value):
|
for source in select(self.db, value):
|
||||||
|
@ -59,7 +60,7 @@ class Writer(multiprocessing.Process):
|
||||||
self.log.exception("Cannot execute: %s", record)
|
self.log.exception("Cannot execute: %s", record)
|
||||||
|
|
||||||
def end(self) -> None:
|
def end(self) -> None:
|
||||||
self.db.enter_step('end')
|
self.db.enter_step("end")
|
||||||
self.db.save()
|
self.db.save()
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
|
@ -69,7 +70,7 @@ class Writer(multiprocessing.Process):
|
||||||
else:
|
else:
|
||||||
next_save = 0
|
next_save = 0
|
||||||
|
|
||||||
self.db.enter_step('block_wait')
|
self.db.enter_step("block_wait")
|
||||||
block: typing.List[Record]
|
block: typing.List[Record]
|
||||||
for block in iter(self.recs_queue.get, None):
|
for block in iter(self.recs_queue.get, None):
|
||||||
|
|
||||||
|
@ -83,12 +84,13 @@ class Writer(multiprocessing.Process):
|
||||||
self.log.info("Done!")
|
self.log.info("Done!")
|
||||||
next_save = time.time() + self.autosave_interval
|
next_save = time.time() + self.autosave_interval
|
||||||
|
|
||||||
self.db.enter_step('block_wait')
|
self.db.enter_step("block_wait")
|
||||||
self.end()
|
self.end()
|
||||||
|
|
||||||
|
|
||||||
class Parser():
|
class Parser:
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
buf: typing.Any,
|
buf: typing.Any,
|
||||||
recs_queue: multiprocessing.Queue = None,
|
recs_queue: multiprocessing.Queue = None,
|
||||||
block_size: int = 0,
|
block_size: int = 0,
|
||||||
|
@ -96,7 +98,7 @@ class Parser():
|
||||||
):
|
):
|
||||||
assert bool(writer) ^ bool(block_size and recs_queue)
|
assert bool(writer) ^ bool(block_size and recs_queue)
|
||||||
self.buf = buf
|
self.buf = buf
|
||||||
self.log = logging.getLogger('pr')
|
self.log = logging.getLogger("pr")
|
||||||
self.recs_queue = recs_queue
|
self.recs_queue = recs_queue
|
||||||
if writer: # No MP
|
if writer: # No MP
|
||||||
self.prof: database.Profiler = writer.db
|
self.prof: database.Profiler = writer.db
|
||||||
|
@ -105,14 +107,14 @@ class Parser():
|
||||||
self.block: typing.List[Record] = list()
|
self.block: typing.List[Record] = list()
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.prof = database.Profiler()
|
self.prof = database.Profiler()
|
||||||
self.prof.log = logging.getLogger('pr')
|
self.prof.log = logging.getLogger("pr")
|
||||||
self.register = self.add_to_queue
|
self.register = self.add_to_queue
|
||||||
|
|
||||||
def add_to_queue(self, record: Record) -> None:
|
def add_to_queue(self, record: Record) -> None:
|
||||||
self.prof.enter_step('register')
|
self.prof.enter_step("register")
|
||||||
self.block.append(record)
|
self.block.append(record)
|
||||||
if len(self.block) >= self.block_size:
|
if len(self.block) >= self.block_size:
|
||||||
self.prof.enter_step('put_block')
|
self.prof.enter_step("put_block")
|
||||||
assert self.recs_queue
|
assert self.recs_queue
|
||||||
self.recs_queue.put(self.block)
|
self.recs_queue.put(self.block)
|
||||||
self.block = list()
|
self.block = list()
|
||||||
|
@ -131,22 +133,22 @@ class Rapid7Parser(Parser):
|
||||||
def consume(self) -> None:
|
def consume(self) -> None:
|
||||||
data = dict()
|
data = dict()
|
||||||
for line in self.buf:
|
for line in self.buf:
|
||||||
self.prof.enter_step('parse_rapid7')
|
self.prof.enter_step("parse_rapid7")
|
||||||
split = line.split('"')
|
split = line.split('"')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for k in range(1, 14, 4):
|
for k in range(1, 14, 4):
|
||||||
key = split[k]
|
key = split[k]
|
||||||
val = split[k+2]
|
val = split[k + 2]
|
||||||
data[key] = val
|
data[key] = val
|
||||||
|
|
||||||
select, writer = FUNCTION_MAP[data['type']]
|
select, writer = FUNCTION_MAP[data["type"]]
|
||||||
record = (
|
record = (
|
||||||
select,
|
select,
|
||||||
writer,
|
writer,
|
||||||
int(data['timestamp']),
|
int(data["timestamp"]),
|
||||||
data['name'],
|
data["name"],
|
||||||
data['value']
|
data["value"],
|
||||||
)
|
)
|
||||||
except (IndexError, KeyError):
|
except (IndexError, KeyError):
|
||||||
# IndexError: missing field
|
# IndexError: missing field
|
||||||
|
@ -159,13 +161,13 @@ class MassDnsParser(Parser):
|
||||||
# massdns --output Snrql
|
# massdns --output Snrql
|
||||||
# --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4
|
# --retry REFUSED,SERVFAIL --resolvers nameservers-ipv4
|
||||||
TYPES = {
|
TYPES = {
|
||||||
'A': (FUNCTION_MAP['a'][0], FUNCTION_MAP['a'][1], -1, None),
|
"A": (FUNCTION_MAP["a"][0], FUNCTION_MAP["a"][1], -1, None),
|
||||||
# 'AAAA': (FUNCTION_MAP['aaaa'][0], FUNCTION_MAP['aaaa'][1], -1, None),
|
# 'AAAA': (FUNCTION_MAP['aaaa'][0], FUNCTION_MAP['aaaa'][1], -1, None),
|
||||||
'CNAME': (FUNCTION_MAP['cname'][0], FUNCTION_MAP['cname'][1], -1, -1),
|
"CNAME": (FUNCTION_MAP["cname"][0], FUNCTION_MAP["cname"][1], -1, -1),
|
||||||
}
|
}
|
||||||
|
|
||||||
def consume(self) -> None:
|
def consume(self) -> None:
|
||||||
self.prof.enter_step('parse_massdns')
|
self.prof.enter_step("parse_massdns")
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
header = True
|
header = True
|
||||||
for line in self.buf:
|
for line in self.buf:
|
||||||
|
@ -174,14 +176,15 @@ class MassDnsParser(Parser):
|
||||||
header = True
|
header = True
|
||||||
continue
|
continue
|
||||||
|
|
||||||
split = line.split(' ')
|
split = line.split(" ")
|
||||||
try:
|
try:
|
||||||
if header:
|
if header:
|
||||||
timestamp = int(split[1])
|
timestamp = int(split[1])
|
||||||
header = False
|
header = False
|
||||||
else:
|
else:
|
||||||
select, write, name_offset, value_offset = \
|
select, write, name_offset, value_offset = MassDnsParser.TYPES[
|
||||||
MassDnsParser.TYPES[split[1]]
|
split[1]
|
||||||
|
]
|
||||||
record = (
|
record = (
|
||||||
select,
|
select,
|
||||||
write,
|
write,
|
||||||
|
@ -190,74 +193,85 @@ class MassDnsParser(Parser):
|
||||||
split[2][:value_offset].lower(),
|
split[2][:value_offset].lower(),
|
||||||
)
|
)
|
||||||
self.register(record)
|
self.register(record)
|
||||||
self.prof.enter_step('parse_massdns')
|
self.prof.enter_step("parse_massdns")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
PARSERS = {
|
PARSERS = {
|
||||||
'rapid7': Rapid7Parser,
|
"rapid7": Rapid7Parser,
|
||||||
'massdns': MassDnsParser,
|
"massdns": MassDnsParser,
|
||||||
}
|
}
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# Parsing arguments
|
# Parsing arguments
|
||||||
log = logging.getLogger('feed_dns')
|
log = logging.getLogger("feed_dns")
|
||||||
args_parser = argparse.ArgumentParser(
|
args_parser = argparse.ArgumentParser(
|
||||||
description="Read DNS records and import "
|
description="Read DNS records and import "
|
||||||
"tracking-relevant data into the database")
|
"tracking-relevant data into the database"
|
||||||
|
)
|
||||||
|
args_parser.add_argument("parser", choices=PARSERS.keys(), help="Input format")
|
||||||
args_parser.add_argument(
|
args_parser.add_argument(
|
||||||
'parser',
|
"-i",
|
||||||
choices=PARSERS.keys(),
|
"--input",
|
||||||
help="Input format")
|
type=argparse.FileType("r"),
|
||||||
|
default=sys.stdin,
|
||||||
|
help="Input file",
|
||||||
|
)
|
||||||
args_parser.add_argument(
|
args_parser.add_argument(
|
||||||
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
|
"-b", "--block-size", type=int, default=1024, help="Performance tuning value"
|
||||||
help="Input file")
|
)
|
||||||
args_parser.add_argument(
|
args_parser.add_argument(
|
||||||
'-b', '--block-size', type=int, default=1024,
|
"-q", "--queue-size", type=int, default=128, help="Performance tuning value"
|
||||||
help="Performance tuning value")
|
)
|
||||||
args_parser.add_argument(
|
args_parser.add_argument(
|
||||||
'-q', '--queue-size', type=int, default=128,
|
"-a",
|
||||||
help="Performance tuning value")
|
"--autosave-interval",
|
||||||
|
type=int,
|
||||||
|
default=900,
|
||||||
|
help="Interval to which the database will save in seconds. " "0 to disable.",
|
||||||
|
)
|
||||||
args_parser.add_argument(
|
args_parser.add_argument(
|
||||||
'-a', '--autosave-interval', type=int, default=900,
|
"-s",
|
||||||
help="Interval to which the database will save in seconds. "
|
"--single-process",
|
||||||
"0 to disable.")
|
action="store_true",
|
||||||
|
help="Only use one process. " "Might be useful for single core computers.",
|
||||||
|
)
|
||||||
args_parser.add_argument(
|
args_parser.add_argument(
|
||||||
'-s', '--single-process', action='store_true',
|
"-4",
|
||||||
help="Only use one process. "
|
"--ip4-cache",
|
||||||
"Might be useful for single core computers.")
|
type=int,
|
||||||
args_parser.add_argument(
|
default=0,
|
||||||
'-4', '--ip4-cache', type=int, default=0,
|
|
||||||
help="RAM cache for faster IPv4 lookup. "
|
help="RAM cache for faster IPv4 lookup. "
|
||||||
"Maximum useful value: 512 MiB (536870912). "
|
"Maximum useful value: 512 MiB (536870912). "
|
||||||
"Warning: Depending on the rules, this might already "
|
"Warning: Depending on the rules, this might already "
|
||||||
"be a memory-heavy process, even without the cache.")
|
"be a memory-heavy process, even without the cache.",
|
||||||
|
)
|
||||||
args = args_parser.parse_args()
|
args = args_parser.parse_args()
|
||||||
|
|
||||||
parser_cls = PARSERS[args.parser]
|
parser_cls = PARSERS[args.parser]
|
||||||
if args.single_process:
|
if args.single_process:
|
||||||
writer = Writer(
|
writer = Writer(
|
||||||
autosave_interval=args.autosave_interval,
|
autosave_interval=args.autosave_interval, ip4_cache=args.ip4_cache
|
||||||
ip4_cache=args.ip4_cache
|
|
||||||
)
|
)
|
||||||
parser = parser_cls(args.input, writer=writer)
|
parser = parser_cls(args.input, writer=writer)
|
||||||
parser.run()
|
parser.run()
|
||||||
writer.end()
|
writer.end()
|
||||||
else:
|
else:
|
||||||
recs_queue: multiprocessing.Queue = multiprocessing.Queue(
|
recs_queue: multiprocessing.Queue = multiprocessing.Queue(
|
||||||
maxsize=args.queue_size)
|
maxsize=args.queue_size
|
||||||
|
)
|
||||||
|
|
||||||
writer = Writer(recs_queue,
|
writer = Writer(
|
||||||
|
recs_queue,
|
||||||
autosave_interval=args.autosave_interval,
|
autosave_interval=args.autosave_interval,
|
||||||
ip4_cache=args.ip4_cache
|
ip4_cache=args.ip4_cache,
|
||||||
)
|
)
|
||||||
writer.start()
|
writer.start()
|
||||||
|
|
||||||
parser = parser_cls(args.input,
|
parser = parser_cls(
|
||||||
recs_queue=recs_queue,
|
args.input, recs_queue=recs_queue, block_size=args.block_size
|
||||||
block_size=args.block_size
|
|
||||||
)
|
)
|
||||||
parser.run()
|
parser.run()
|
||||||
|
|
||||||
|
|
|
@ -6,28 +6,33 @@ import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
FUNCTION_MAP = {
|
FUNCTION_MAP = {
|
||||||
'zone': database.Database.set_zone,
|
"zone": database.Database.set_zone,
|
||||||
'hostname': database.Database.set_hostname,
|
"hostname": database.Database.set_hostname,
|
||||||
'asn': database.Database.set_asn,
|
"asn": database.Database.set_asn,
|
||||||
'ip4network': database.Database.set_ip4network,
|
"ip4network": database.Database.set_ip4network,
|
||||||
'ip4address': database.Database.set_ip4address,
|
"ip4address": database.Database.set_ip4address,
|
||||||
}
|
}
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# Parsing arguments
|
# Parsing arguments
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Import base rules to the database")
|
||||||
description="Import base rules to the database")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'type',
|
"type", choices=FUNCTION_MAP.keys(), help="Type of rule inputed"
|
||||||
choices=FUNCTION_MAP.keys(),
|
)
|
||||||
help="Type of rule inputed")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
|
"-i",
|
||||||
help="File with one rule per line")
|
"--input",
|
||||||
|
type=argparse.FileType("r"),
|
||||||
|
default=sys.stdin,
|
||||||
|
help="File with one rule per line",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-f', '--first-party', action='store_true',
|
"-f",
|
||||||
help="The input only comes from verified first-party sources")
|
"--first-party",
|
||||||
|
action="store_true",
|
||||||
|
help="The input only comes from verified first-party sources",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
DB = database.Database()
|
DB = database.Database()
|
||||||
|
@ -43,7 +48,8 @@ if __name__ == '__main__':
|
||||||
for rule in args.input:
|
for rule in args.input:
|
||||||
rule = rule.strip()
|
rule = rule.strip()
|
||||||
try:
|
try:
|
||||||
fun(DB,
|
fun(
|
||||||
|
DB,
|
||||||
rule,
|
rule,
|
||||||
source=source,
|
source=source,
|
||||||
updated=int(time.time()),
|
updated=int(time.time()),
|
||||||
|
|
|
@ -2,11 +2,9 @@
|
||||||
|
|
||||||
import markdown2
|
import markdown2
|
||||||
|
|
||||||
extras = [
|
extras = ["header-ids"]
|
||||||
"header-ids"
|
|
||||||
]
|
|
||||||
|
|
||||||
with open('dist/README.md', 'r') as fdesc:
|
with open("dist/README.md", "r") as fdesc:
|
||||||
body = markdown2.markdown(fdesc.read(), extras=extras)
|
body = markdown2.markdown(fdesc.read(), extras=extras)
|
||||||
|
|
||||||
output = f"""<!DOCTYPE html>
|
output = f"""<!DOCTYPE html>
|
||||||
|
@ -23,5 +21,5 @@ output = f"""<!DOCTYPE html>
|
||||||
</html>
|
</html>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with open('dist/index.html', 'w') as fdesc:
|
with open("dist/index.html", "w") as fdesc:
|
||||||
fdesc.write(output)
|
fdesc.write(output)
|
||||||
|
|
Loading…
Reference in a new issue