From 55877be8912dc04db0ad35bb24f6c2a9579188ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Geoffrey=20=E2=80=9CFrogeye=E2=80=9D=20Preud=27homme?= Date: Mon, 9 Dec 2019 08:55:34 +0100 Subject: [PATCH] IP parsing C accelerated, use bytes everywhere --- .gitignore | 2 ++ Makefile | 5 +++++ accel.c | 37 +++++++++++++++++++++++++++++++++++++ database.py | 36 +++++++++++++++++++++--------------- database_schema.sql | 5 +++-- feed_dns.py | 9 +++++---- 6 files changed, 73 insertions(+), 21 deletions(-) create mode 100644 Makefile create mode 100644 accel.c diff --git a/.gitignore b/.gitignore index 188051c..aa3f3eb 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ *.db-journal nameservers nameservers.head +*.o +*.so diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..fb06f61 --- /dev/null +++ b/Makefile @@ -0,0 +1,5 @@ +libaccel.so: accel.o + clang -shared -Wl,-soname,libaccel.so -o libaccel.so accel.o + +accel.o: accel.c + clang -c -fPIC -O3 accel.c -o accel.o diff --git a/accel.c b/accel.c new file mode 100644 index 0000000..bda0072 --- /dev/null +++ b/accel.c @@ -0,0 +1,37 @@ +#include + +int ip4_flat(char* value, wchar_t* flat) +{ + unsigned char value_index = 0; + unsigned char octet_index = 0; + unsigned char octet_value = 0; + char flat_index; + unsigned char value_chara; + do { + value_chara = value[value_index]; + if (value_chara >= '0' && value_chara <= '9') { + octet_value *= 10; + octet_value += value_chara - '0'; + } else if (value_chara == '.') { + for (flat_index = (octet_index+1)*8-1; flat_index >= octet_index*8; flat_index--) { + flat[flat_index] = '0' + (octet_value & 1); + octet_value >>= 1; + } + octet_index++; + octet_value = 0; + } else if (value_chara == '\0') { + if (octet_index != 3) { + return 1; + } + for (flat_index = 31; flat_index >= 24; flat_index--) { + flat[flat_index] = '0' + (octet_value & 1); + octet_value >>= 1; + } + return 0; + } else { + return 1; + } + value_index++; + } while (1); // This ugly thing save one comparison + return 1; +} diff --git a/database.py b/database.py index 370d25b..bdb92b0 100755 --- a/database.py +++ b/database.py @@ -7,7 +7,7 @@ import typing import ipaddress import enum import time -import pprint +import ctypes """ Utility functions to interact with the database. @@ -20,6 +20,8 @@ C = None # Cursor TIME_DICT: typing.Dict[str, float] = dict() TIME_LAST = time.perf_counter() TIME_STEP = 'start' +ACCEL = ctypes.cdll.LoadLibrary('./libaccel.so') +ACCEL_IP4_BUF = ctypes.create_unicode_buffer('Z'*32, 32) def time_step(step: str) -> None: @@ -127,9 +129,12 @@ def ip_flat(address: ipaddress.IPv4Address) -> str: return ''.join(map(str, ip_get_bits(address))) -def ip4_flat(address: str) -> str: - return '{:08b}{:08b}{:08b}{:08b}'.format( - *[int(c) for c in address.split('.')]) +def ip4_flat(address: bytes) -> typing.Optional[str]: + carg = ctypes.c_char_p(address) + ret = ACCEL.ip4_flat(carg, ACCEL_IP4_BUF) + if ret != 0: + return None + return ACCEL_IP4_BUF.value RULE_IP4NETWORK_COMMAND = \ @@ -165,23 +170,22 @@ FEED_A_COMMAND_UPSERT = \ 'WHERE updated=0 OR firstparty None: +def feed_a(name: bytes, value_ip: bytes) -> None: assert C assert CONN time_step('a_flat') - try: - value = ip4_flat(value_ip) - except (ValueError, IndexError): + value_dec = ip4_flat(value_ip) + if value_dec is None: # Malformed IPs return time_step('a_fetch') - C.execute(FEED_A_COMMAND_FETCH, (value,)) + C.execute(FEED_A_COMMAND_FETCH, (value_dec,)) base = C.fetchone() time_step('a_fetch_confirm') if not base: return b_key, b_firstparty = base - if not value.startswith(b_key): + if not value_dec.startswith(b_key): return name = name[::-1] time_step('a_upsert') @@ -212,23 +216,25 @@ FEED_CNAME_COMMAND_UPSERT = \ 'WHERE updated=0 OR firstparty None: +def feed_cname(name: bytes, value: bytes) -> None: assert C assert CONN + time_step('cname_decode') value = value[::-1] + value_dec = value.decode() time_step('cname_fetch') - C.execute(FEED_CNAME_COMMAND_FETCH, (value,)) + C.execute(FEED_CNAME_COMMAND_FETCH, (value_dec,)) base = C.fetchone() time_step('cname_fetch_confirm') if not base: # Should only happen at an extremum of the database return b_key, b_type, b_firstparty = base - matching = b_key == value[:len(b_key)] and ( - len(value) == len(b_key) + matching = b_key == value_dec[:len(b_key)] and ( + len(value_dec) == len(b_key) or ( b_type == RowType.DomainTree.value - and value[len(b_key)] == '.' + and value_dec[len(b_key)] == '.' ) ) if not matching: diff --git a/database_schema.sql b/database_schema.sql index 5e9618b..1985281 100644 --- a/database_schema.sql +++ b/database_schema.sql @@ -2,11 +2,12 @@ -- in database.py on changes to this file CREATE TABLE blocking ( - key text PRIMARY KEY, -- Contains the reversed domain name or IP in binary form + key TEXT PRIMARY KEY, -- Contains the reversed domain name or IP in binary form source TEXT, -- The rule this one is based on type INTEGER, -- Type of the field: 1: AS, 2: domain tree, 3: domain, 4: IPv4 network, 6: IPv6 network updated INTEGER, -- If the row was updated during last data import (0: No, 1: Yes) firstparty INTEGER, -- Which blocking list this row is issued from (0: first-party, 1: multi-party) + -- refs INTEGER, -- Which blocking list this row is issued from (0: first-party, 1: multi-party) FOREIGN KEY (source) REFERENCES blocking(key) ON DELETE CASCADE ); CREATE INDEX "blocking_type_updated_key" ON "blocking" ( @@ -17,6 +18,6 @@ CREATE INDEX "blocking_type_updated_key" ON "blocking" ( -- Store various things CREATE TABLE meta ( - key text PRIMARY KEY, + key TEXT PRIMARY KEY, value integer ); diff --git a/feed_dns.py b/feed_dns.py index 47ea5d8..1cc3247 100755 --- a/feed_dns.py +++ b/feed_dns.py @@ -5,8 +5,8 @@ import argparse import sys FUNCTION_MAP = { - 'a': database.feed_a, - 'cname': database.feed_cname, + b'a': database.feed_a, + b'cname': database.feed_cname, } if __name__ == '__main__': @@ -15,7 +15,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser( description="TODO") parser.add_argument( - '-i', '--input', type=argparse.FileType('r'), default=sys.stdin, + '-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer, help="TODO") args = parser.parse_args() @@ -23,9 +23,10 @@ if __name__ == '__main__': try: database.time_step('iowait') + line: bytes for line in args.input: database.time_step('feed_json_parse') - split = line.split('"') + split = line.split(b'"') name = split[7] dtype = split[11] value = split[15]