IP parsing C accelerated, use bytes everywhere
This commit is contained in:
parent
7937496882
commit
55877be891
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -3,3 +3,5 @@
|
|||
*.db-journal
|
||||
nameservers
|
||||
nameservers.head
|
||||
*.o
|
||||
*.so
|
||||
|
|
5
Makefile
Normal file
5
Makefile
Normal file
|
@ -0,0 +1,5 @@
|
|||
libaccel.so: accel.o
|
||||
clang -shared -Wl,-soname,libaccel.so -o libaccel.so accel.o
|
||||
|
||||
accel.o: accel.c
|
||||
clang -c -fPIC -O3 accel.c -o accel.o
|
37
accel.c
Normal file
37
accel.c
Normal file
|
@ -0,0 +1,37 @@
|
|||
#include <stdlib.h>
|
||||
|
||||
int ip4_flat(char* value, wchar_t* flat)
|
||||
{
|
||||
unsigned char value_index = 0;
|
||||
unsigned char octet_index = 0;
|
||||
unsigned char octet_value = 0;
|
||||
char flat_index;
|
||||
unsigned char value_chara;
|
||||
do {
|
||||
value_chara = value[value_index];
|
||||
if (value_chara >= '0' && value_chara <= '9') {
|
||||
octet_value *= 10;
|
||||
octet_value += value_chara - '0';
|
||||
} else if (value_chara == '.') {
|
||||
for (flat_index = (octet_index+1)*8-1; flat_index >= octet_index*8; flat_index--) {
|
||||
flat[flat_index] = '0' + (octet_value & 1);
|
||||
octet_value >>= 1;
|
||||
}
|
||||
octet_index++;
|
||||
octet_value = 0;
|
||||
} else if (value_chara == '\0') {
|
||||
if (octet_index != 3) {
|
||||
return 1;
|
||||
}
|
||||
for (flat_index = 31; flat_index >= 24; flat_index--) {
|
||||
flat[flat_index] = '0' + (octet_value & 1);
|
||||
octet_value >>= 1;
|
||||
}
|
||||
return 0;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
value_index++;
|
||||
} while (1); // This ugly thing save one comparison
|
||||
return 1;
|
||||
}
|
36
database.py
36
database.py
|
@ -7,7 +7,7 @@ import typing
|
|||
import ipaddress
|
||||
import enum
|
||||
import time
|
||||
import pprint
|
||||
import ctypes
|
||||
|
||||
"""
|
||||
Utility functions to interact with the database.
|
||||
|
@ -20,6 +20,8 @@ C = None # Cursor
|
|||
TIME_DICT: typing.Dict[str, float] = dict()
|
||||
TIME_LAST = time.perf_counter()
|
||||
TIME_STEP = 'start'
|
||||
ACCEL = ctypes.cdll.LoadLibrary('./libaccel.so')
|
||||
ACCEL_IP4_BUF = ctypes.create_unicode_buffer('Z'*32, 32)
|
||||
|
||||
|
||||
def time_step(step: str) -> None:
|
||||
|
@ -127,9 +129,12 @@ def ip_flat(address: ipaddress.IPv4Address) -> str:
|
|||
return ''.join(map(str, ip_get_bits(address)))
|
||||
|
||||
|
||||
def ip4_flat(address: str) -> str:
|
||||
return '{:08b}{:08b}{:08b}{:08b}'.format(
|
||||
*[int(c) for c in address.split('.')])
|
||||
def ip4_flat(address: bytes) -> typing.Optional[str]:
|
||||
carg = ctypes.c_char_p(address)
|
||||
ret = ACCEL.ip4_flat(carg, ACCEL_IP4_BUF)
|
||||
if ret != 0:
|
||||
return None
|
||||
return ACCEL_IP4_BUF.value
|
||||
|
||||
|
||||
RULE_IP4NETWORK_COMMAND = \
|
||||
|
@ -165,23 +170,22 @@ FEED_A_COMMAND_UPSERT = \
|
|||
'WHERE updated=0 OR firstparty<?'
|
||||
|
||||
|
||||
def feed_a(name: str, value_ip: str) -> None:
|
||||
def feed_a(name: bytes, value_ip: bytes) -> None:
|
||||
assert C
|
||||
assert CONN
|
||||
time_step('a_flat')
|
||||
try:
|
||||
value = ip4_flat(value_ip)
|
||||
except (ValueError, IndexError):
|
||||
value_dec = ip4_flat(value_ip)
|
||||
if value_dec is None:
|
||||
# Malformed IPs
|
||||
return
|
||||
time_step('a_fetch')
|
||||
C.execute(FEED_A_COMMAND_FETCH, (value,))
|
||||
C.execute(FEED_A_COMMAND_FETCH, (value_dec,))
|
||||
base = C.fetchone()
|
||||
time_step('a_fetch_confirm')
|
||||
if not base:
|
||||
return
|
||||
b_key, b_firstparty = base
|
||||
if not value.startswith(b_key):
|
||||
if not value_dec.startswith(b_key):
|
||||
return
|
||||
name = name[::-1]
|
||||
time_step('a_upsert')
|
||||
|
@ -212,23 +216,25 @@ FEED_CNAME_COMMAND_UPSERT = \
|
|||
'WHERE updated=0 OR firstparty<?'
|
||||
|
||||
|
||||
def feed_cname(name: str, value: str) -> None:
|
||||
def feed_cname(name: bytes, value: bytes) -> None:
|
||||
assert C
|
||||
assert CONN
|
||||
time_step('cname_decode')
|
||||
value = value[::-1]
|
||||
value_dec = value.decode()
|
||||
time_step('cname_fetch')
|
||||
C.execute(FEED_CNAME_COMMAND_FETCH, (value,))
|
||||
C.execute(FEED_CNAME_COMMAND_FETCH, (value_dec,))
|
||||
base = C.fetchone()
|
||||
time_step('cname_fetch_confirm')
|
||||
if not base:
|
||||
# Should only happen at an extremum of the database
|
||||
return
|
||||
b_key, b_type, b_firstparty = base
|
||||
matching = b_key == value[:len(b_key)] and (
|
||||
len(value) == len(b_key)
|
||||
matching = b_key == value_dec[:len(b_key)] and (
|
||||
len(value_dec) == len(b_key)
|
||||
or (
|
||||
b_type == RowType.DomainTree.value
|
||||
and value[len(b_key)] == '.'
|
||||
and value_dec[len(b_key)] == '.'
|
||||
)
|
||||
)
|
||||
if not matching:
|
||||
|
|
|
@ -2,11 +2,12 @@
|
|||
-- in database.py on changes to this file
|
||||
|
||||
CREATE TABLE blocking (
|
||||
key text PRIMARY KEY, -- Contains the reversed domain name or IP in binary form
|
||||
key TEXT PRIMARY KEY, -- Contains the reversed domain name or IP in binary form
|
||||
source TEXT, -- The rule this one is based on
|
||||
type INTEGER, -- Type of the field: 1: AS, 2: domain tree, 3: domain, 4: IPv4 network, 6: IPv6 network
|
||||
updated INTEGER, -- If the row was updated during last data import (0: No, 1: Yes)
|
||||
firstparty INTEGER, -- Which blocking list this row is issued from (0: first-party, 1: multi-party)
|
||||
-- refs INTEGER, -- Which blocking list this row is issued from (0: first-party, 1: multi-party)
|
||||
FOREIGN KEY (source) REFERENCES blocking(key) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX "blocking_type_updated_key" ON "blocking" (
|
||||
|
@ -17,6 +18,6 @@ CREATE INDEX "blocking_type_updated_key" ON "blocking" (
|
|||
|
||||
-- Store various things
|
||||
CREATE TABLE meta (
|
||||
key text PRIMARY KEY,
|
||||
key TEXT PRIMARY KEY,
|
||||
value integer
|
||||
);
|
||||
|
|
|
@ -5,8 +5,8 @@ import argparse
|
|||
import sys
|
||||
|
||||
FUNCTION_MAP = {
|
||||
'a': database.feed_a,
|
||||
'cname': database.feed_cname,
|
||||
b'a': database.feed_a,
|
||||
b'cname': database.feed_cname,
|
||||
}
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -15,7 +15,7 @@ if __name__ == '__main__':
|
|||
parser = argparse.ArgumentParser(
|
||||
description="TODO")
|
||||
parser.add_argument(
|
||||
'-i', '--input', type=argparse.FileType('r'), default=sys.stdin,
|
||||
'-i', '--input', type=argparse.FileType('rb'), default=sys.stdin.buffer,
|
||||
help="TODO")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -23,9 +23,10 @@ if __name__ == '__main__':
|
|||
|
||||
try:
|
||||
database.time_step('iowait')
|
||||
line: bytes
|
||||
for line in args.input:
|
||||
database.time_step('feed_json_parse')
|
||||
split = line.split('"')
|
||||
split = line.split(b'"')
|
||||
name = split[7]
|
||||
dtype = split[11]
|
||||
value = split[15]
|
||||
|
|
Loading…
Reference in a new issue