Repo re-org

This commit is contained in:
jaseg 2021-04-09 18:38:02 +02:00
parent 312fee491c
commit 50998fcfb9
270 changed files with 9 additions and 9 deletions

View file

@ -0,0 +1,93 @@
#!/usr/bin/env python3
import math
import sys
import contextlib
import scipy.signal as sig
import numpy as np
@contextlib.contextmanager
def wrap(left='{', right='}', file=None, end=''):
print(left, file=file, end=end)
yield
print(right, file=file, end=end)
@contextlib.contextmanager
def print_include_guards(macro_name):
print(f'#ifndef {macro_name}')
print(f'#define {macro_name}')
print()
yield
print()
print(f'#endif /* {macro_name} */')
macro_float = lambda f: f'{f}'.replace('.', 'F').replace('-', 'N').replace('+', 'P')
ordinal = lambda n: "%d%s" % (n,"tsnrhtdd"[(n//10%10!=1)*(n%10<4)*n%10::4])
SI_TABLE = {-18: 'a', -15: 'f', -12: 'p', -9: 'n', -6: 'µ', -3: 'm', 0: '', 3: 'k', 6: 'M', 9: 'G', 12: 'T', 15: 'P', 18: 'E'}
def siprefix(x, space=' ', unit=''):
l = math.log10(x)//3*3
if l in SI_TABLE:
return f'{x/10**l}{space}{SI_TABLE[l]}{unit}'
return f'{x}{space}{unit}'
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--macro-name', default='butter_filter', help='Prefix for output macro names')
parser.add_argument('fc', type=float, help='Corner frequency [Hz]')
parser.add_argument('fs', type=float, help='Sampling rate [Hz]')
parser.add_argument('n', type=int, nargs='?', default=6, help='Filter order')
args = parser.parse_args()
sos = sig.butter(args.n, args.fc, fs=args.fs, output='sos')
print('/* THIS IS A GENERATED FILE. DO NOT EDIT! */')
print()
with print_include_guards(f'__BUTTER_FILTER_GENERATED_{args.n}_{macro_float(args.fc)}_{macro_float(args.fs)}__'):
print(f'/* {ordinal(args.n)} order Butterworth IIR filter coefficients')
print(f' *')
print(f' * corner frequency f_c = {siprefix(args.fc)}Hz')
print(f' * sampling rate f_s = {siprefix(args.fs)}Hz')
print(f' */')
print()
print(f'#define {args.macro_name.upper()}_ORDER {args.n}')
print(f'#define {args.macro_name.upper()}_CLEN {(args.n+1)//2}')
# scipy.signal.butter by default returns extremely small bs for the first biquad and large ones for subsequent
# sections. Balance magnitudes to reduce possible rounding errors.
first_biquad_bs = sos[0][:3]
approx_mag = round(math.log10(np.mean(first_biquad_bs)))
mags = [approx_mag // len(sos)] * len(sos)
mags[0] += approx_mag - sum(mags)
sos[0][:3] /= 10**approx_mag
sos = np.array([ sec * np.array([10**mag, 10**mag, 10**mag, 1, 1, 1]) for mag, sec in zip(mags, sos) ])
ones = np.ones([100000])
_, steady_state = sig.sosfilt(sos, ones, zi=np.zeros([(args.n+1)//2, 2]))
print(f'#define {args.macro_name.upper()}_COEFF ', end='')
for sec in sos:
bs, ases = sec[:3], sec[4:6]
with wrap():
print('.b=', end='')
with wrap():
print(', '.join(f'{v}' for v in bs), end='')
print(', .a=', end='')
with wrap():
print(', '.join(f'{v}' for v in ases), end='')
print(', ', end='')
print()
print(f'#define {args.macro_name.upper()}_STEADY_STATE ', end='')
for sec in steady_state:
with wrap():
print(', '.join(f'{v}' for v in sec), end='')
print(', ', end='')
print()

View file

@ -0,0 +1,46 @@
#include <stdint.h>
#include <math.h>
#include <unistd.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>
#include <stdlib.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/fcntl.h>
#include "crypto.h"
void oob_trigger_activated(enum trigger_domain domain, int serial) {
printf("oob_trigger_activated(%d, %d)\n", domain, serial);
fflush(stdout);
}
void print_usage() {
fprintf(stderr, "Usage: crypto_test [auth_key_hex]\n");
}
int main(int argc, char **argv) {
if (argc != 2) {
fprintf(stderr, "Error: Invalid arguments.\n");
print_usage();
return 1;
}
uint8_t auth_key[16];
for (size_t i=0; argv[1][i+0] != '\0' && argv[1][i+1] != '\0' && i/2<sizeof(auth_key); i+= 2) {
char buf[3] = { argv[1][i+0], argv[1][i+1], 0};
char *endptr;
auth_key[i/2] = strtoul(buf, &endptr, 16);
if (!endptr || *endptr != '\0') {
fprintf(stderr, "Invalid authkey\n");
return 1;
}
}
printf("rc=%d\n", oob_message_received(auth_key));
return 0;
}

View file

@ -0,0 +1,46 @@
#!/usr/bin/env python3
import subprocess
from os import path
import binascii
import re
import presig_gen
def do_test(domain, value, height, root_key, binary, expect_fail=False):
auth = presig_gen.gen_at_height(domain, value, height, root_key)
auth = binascii.hexlify(auth).decode()
output = subprocess.check_output([binary, auth])
*lines, rc_line = output.decode().splitlines()
rc = int(re.match('^rc=(\d+)$', rc_line).group(1))
assert expect_fail == (rc == 0)
def run_tests(root_key, max_height, binary):
for domain, value in {
'all': 'all',
'vendor': presig_gen.TEST_VENDOR,
'series': presig_gen.TEST_SERIES,
'country': presig_gen.TEST_COUNTRY,
'region': presig_gen.TEST_REGION,
}.items():
for height in range(max_height):
do_test(domain, value, height, root_key, binary)
do_test(domain, 'fail', height, root_key, binary, expect_fail=True)
do_test('fail', 'fail', height, root_key, binary, expect_fail=True)
do_test('', '', height, root_key, binary, expect_fail=True)
do_test(domain, value, max_height, root_key, binary, expect_fail=True)
do_test(domain, value, max_height+1, root_key, binary, expect_fail=True)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('keyfile', help='Root key file')
parser.add_argument('max_height', type=int, default=8, nargs='?', help='Height of generated prekeys')
default_binary = path.abspath(path.join(path.dirname(__file__), '../build/tools/crypto_test'))
parser.add_argument('binary', default=default_binary, nargs='?', help='crypto_test binary to use')
args = parser.parse_args()
with open(args.keyfile, 'r') as f:
root_key = binascii.unhexlify(f.read().strip())
run_tests(root_key, args.max_height, args.binary)

View file

@ -0,0 +1,29 @@
#!/usr/bin/env python3
import textwrap
import scipy.signal as sig
import numpy as np
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('n', type=int, help='Window size')
parser.add_argument('w', type=float, help='Wavelet width')
parser.add_argument('-v', '--variable', default='cwt_ricker_table', help='Name for alias variable pointing to generated wavelet LUT')
args = parser.parse_args()
print(f'/* CWT Ricker wavelet LUT for {args.n} sample window of width {args.w}. */')
varname = f'cwt_ricker_{args.n}_window_{str(args.w).replace(".", "F")}'
print(f'const float {varname}[{args.n}] = {{')
win = sig.ricker(args.n, args.w)
par = ' '.join(f'{f:>015.12e}f,' for f in win)
print(textwrap.fill(par,
initial_indent=' '*4, subsequent_indent=' '*4,
width=120,
replace_whitespace=False, drop_whitespace=False))
print('};')
print()
print(f'const float * const {args.variable} __attribute__((weak)) = {varname};')

View file

@ -0,0 +1,109 @@
#include <stdint.h>
#include <math.h>
#include <unistd.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>
#include <stdlib.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/fcntl.h>
#include "dsss_demod.h"
void handle_dsss_received(symbol_t data[static TRANSMISSION_SYMBOLS]) {
printf("data sequence received: [ ");
for (size_t i=0; i<TRANSMISSION_SYMBOLS; i++) {
//printf("%+3d", ((data[i]&1) ? 1 : -1) * (data[i]>>1));
printf("%2d", data[i]);
if (i+1 < TRANSMISSION_SYMBOLS)
printf(", ");
}
printf(" ]\n");
}
void print_usage() {
fprintf(stderr, "Usage: dsss_demod_test [test_data.bin] [optional recording channel number]\n");
}
int main(int argc, char **argv) {
if (argc != 2 && argc != 3) {
fprintf(stderr, "Error: Invalid arguments.\n");
print_usage();
return 1;
}
int fd = open(argv[1], O_RDONLY);
struct stat st;
if (fstat(fd, &st)) {
fprintf(stderr, "Error querying test data file size: %s\n", strerror(errno));
return 2;
}
if (st.st_size < 0 || st.st_size > 10000000) {
fprintf(stderr, "Error reading test data: too much test data (size=%zd)\n", st.st_size);
return 2;
}
if (st.st_size % sizeof(float) != 0) {
fprintf(stderr, "Error reading test data: file size is not divisible by %zd (size=%zd)\n", sizeof(float), st.st_size);
return 2;
}
char *buf = malloc(st.st_size);
if (!buf) {
fprintf(stderr, "Error allocating memory");
return 2;
}
int record_channel = -1;
if (argc == 3) {
char *endptr;
record_channel = strtoul(argv[2], &endptr, 10);
if (!endptr || *endptr != '\0') {
fprintf(stderr, "Invalid channel number \"%s\"\n", argv[2]);
return 1;
}
}
if (record_channel != -1)
fprintf(stderr, "Reading %zd samples test data...", st.st_size/sizeof(float));
ssize_t nread = 0;
while (nread < st.st_size) {
ssize_t rc = read(fd, buf + nread, st.st_size - nread);
if (rc == -EINTR || rc == -EAGAIN)
continue;
if (rc < 0) {
fprintf(stderr, "\nError reading test data: %s\n", strerror(errno));
return 2;
}
if (rc == 0) {
fprintf(stderr, "\nError reading test data: Unexpected end of file\n");
return 2;
}
nread += rc;
}
if (record_channel != -1)
fprintf(stderr, " done.\n");
const size_t n_samples = st.st_size / sizeof(float);
float *buf_f = (float *)buf;
if (record_channel != -1)
fprintf(stderr, "Starting simulation.\n");
struct dsss_demod_state demod;
dsss_demod_init(&demod);
for (size_t i=0; i<n_samples; i++) {
//fprintf(stderr, "Iteration %zd/%zd\n", i, n_samples);
dsss_demod_step(&demod, buf_f[i], i);
}
free(buf);
return 0;
}

View file

@ -0,0 +1,241 @@
#!/usr/bin/env python3
import os
import sys
from os import path
import subprocess
import json
from collections import namedtuple, defaultdict
from tqdm import tqdm
import uuid
import multiprocessing
import sqlite3
import time
from urllib.parse import urlparse
import tempfile
import itertools
import numpy as np
np.set_printoptions(linewidth=240)
from dsss_demod_test_waveform_gen import load_noise_gen, modulate as dsss_modulate
def build_test_binary(nbits, thf, decimation, symbols, cachedir):
build_id = str(uuid.uuid4())
builddir = path.join(cachedir, build_id)
os.mkdir(builddir)
cwd = path.join(path.dirname(__file__), '..')
env = os.environ.copy()
env['BUILDDIR'] = path.abspath(builddir)
env['DSSS_GOLD_CODE_NBITS'] = str(nbits)
env['DSSS_DECIMATION'] = str(decimation)
env['DSSS_THRESHOLD_FACTOR'] = str(thf)
env['DSSS_WAVELET_WIDTH'] = str(0.73 * decimation)
env['DSSS_WAVELET_LUT_SIZE'] = str(10 * decimation)
env['TRANSMISSION_SYMBOLS'] = str(symbols)
with open(path.join(builddir, 'make_stdout.txt'), 'w') as stdout,\
open(path.join(builddir, 'make_stderr.txt'), 'w') as stderr:
subprocess.run(['make', 'clean', os.path.abspath(path.join(builddir, 'tools/dsss_demod_test'))],
env=env, cwd=cwd, check=True, stdout=stdout, stderr=stderr)
return build_id
def sequence_matcher(test_data, decoded, max_shift=3):
match_result = []
for shift in range(-max_shift, max_shift):
failures = -shift if shift < 0 else 0 # we're skipping the first $shift symbols
a = test_data if shift > 0 else test_data[-shift:]
b = decoded if shift < 0 else decoded[shift:]
for i, (ref, found) in enumerate(itertools.zip_longest(a, b)):
if ref is None: # end of signal
break
if ref != found:
failures += 1
match_result.append(failures)
failures = min(match_result)
return failures/len(test_data)
ResultParams = namedtuple('ResultParams', ['nbits', 'thf', 'decimation', 'symbols', 'seed', 'amplitude', 'background'])
def run_test(seed, amplitude_spec, background, nbits, decimation, symbols, thfs, lookup_binary, cachedir):
noise_gen, noise_params = load_noise_gen(background)
test_data = np.random.RandomState(seed=seed).randint(0, 2 * (2**nbits), symbols)
signal = np.repeat(dsss_modulate(test_data, nbits) * 2.0 - 1, decimation)
# We're re-using the seed here. This is not a problem.
noise = noise_gen(seed, len(signal), *noise_params)
amplitudes = amplitude_spec[0] * 10 ** np.linspace(0, amplitude_spec[1], amplitude_spec[2])
# DEBUG
my_pid = multiprocessing.current_process().pid
wql = len(amplitudes) * len(thfs)
print(f'[{my_pid}] starting, got workqueue of length {wql}')
i = 0
# Map lsb to sign to match test program
# test_data = (test_data>>1) * (2*(test_data&1) - 1)
# END DEBUG
output = []
for amp in amplitudes:
with tempfile.NamedTemporaryFile(dir=cachedir) as f:
waveform = signal*amp + noise
f.write(waveform.astype('float32').tobytes())
f.flush()
# DEBUG
fcopy = f'/tmp/test-{path.basename(f.name)}'
import shutil
shutil.copy(f.name, fcopy)
# END DEBUG
for thf in thfs:
rpars = ResultParams(nbits, thf, decimation, symbols, seed, amp, background)
cmdline = [lookup_binary(nbits, thf, decimation, symbols), f.name]
# DEBUG
starttime = time.time()
# END DEBUG
try:
proc = subprocess.run(cmdline, stdout=subprocess.PIPE, encoding='utf-8', check=True, timeout=300)
lines = proc.stdout.splitlines()
matched = [ l.partition('[')[2].partition(']')[0]
for l in lines if l.strip().startswith('data sequence received:') ]
matched = [ [ int(elem) for elem in l.split(',') ] for l in matched ]
ser = min(sequence_matcher(test_data, match) for match in matched) if matched else None
output.append((rpars, ser))
# DEBUG
#print(f'[{my_pid}] ran {i}/{wql}: time={time.time() - starttime}\n {ser=}\n {rpars}\n {" ".join(cmdline)}\n {fcopy}', flush=True)
i += 1
# END DEBUG
except subprocess.TimeoutExpired:
output.append((rpars, None))
# DEBUG
print(f'[{my_pid}] ran {i}/{wql}: Timeout!\n {rpars}\n {" ".join(cmdline)}\n {fcopy}', flush=True)
i += 1
# END DEBUG
print(f'[{my_pid}] finished.')
return output
def parallel_generator(db, table, columns, builder, param_list, desc, context={}, params_mapper=lambda *args: args,
disable_cache=False):
with multiprocessing.Pool(multiprocessing.cpu_count()) as pool:
with db as conn:
jobs = []
for params in param_list:
found_res = conn.execute(
f'SELECT result FROM {table} WHERE ({",".join(columns)}) = ({",".join("?"*len(columns))})',
params_mapper(*params)).fetchone()
if found_res and not disable_cache:
yield params, json.loads(*found_res)
else:
jobs.append((params, pool.apply_async(builder, params, context)))
pool.close()
print('Using', len(param_list) - len(jobs), 'cached jobs', flush=True)
with tqdm(total=len(jobs), desc=desc) as tq:
for i, (params, res) in enumerate(jobs):
# DEBUG
print('Got result', i, params, res)
# END DEBUG
tq.update(1)
result = res.get()
with db as conn:
conn.execute(f'INSERT INTO {table} VALUES ({"?,"*len(params)}?,?)',
(*params_mapper(*params), json.dumps(result), timestamp()))
yield params, result
pool.join()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--dump', help='Write results to JSON file')
parser.add_argument('-c', '--cachedir', default='dsss_test_cache', help='Directory to store build output and data in')
parser.add_argument('-n', '--no-cache', action='store_true', help='Disable result cache')
parser.add_argument('-b', '--batches', type=int, default=1, help='Number of batches to split the computation into')
parser.add_argument('-i', '--index', type=int, default=0, help='Batch index to compute')
parser.add_argument('-p', '--prepare', action='store_true', help='Prepare mode: compile runners, then exit.')
args = parser.parse_args()
DecoderParams = namedtuple('DecoderParams', ['nbits', 'thf', 'decimation', 'symbols'])
# dec_paramses = [ DecoderParams(nbits=nbits, thf=thf, decimation=decimation, symbols=20)
# for nbits in [5, 6]
# for thf in [4.5, 4.0, 5.0]
# for decimation in [10, 5, 22] ]
dec_paramses = [ DecoderParams(nbits=nbits, thf=thf, decimation=decimation, symbols=100)
for nbits in [5, 6]
for thf in [3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0]
for decimation in [1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 16, 22, 30, 40, 50] ]
# dec_paramses = [ DecoderParams(nbits=nbits, thf=thf, decimation=decimation, symbols=100)
# for nbits in [5, 6, 7, 8]
# for thf in [1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0]
# for decimation in [1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 16, 22, 30, 40, 50] ]
build_cache_dir = path.join(args.cachedir, 'builds')
data_cache_dir = path.join(args.cachedir, 'data')
os.makedirs(build_cache_dir, exist_ok=True)
os.makedirs(data_cache_dir, exist_ok=True)
build_db = sqlite3.connect(path.join(args.cachedir, 'build_db.sqlite3'))
build_db.execute('CREATE TABLE IF NOT EXISTS builds (nbits, thf, decimation, symbols, result, timestamp)')
timestamp = lambda: int(time.time()*1000)
builds = dict(parallel_generator(build_db, table='builds', columns=['nbits', 'thf', 'decimation', 'symbols'],
builder=build_test_binary, param_list=dec_paramses, desc='Building decoders',
context=dict(cachedir=build_cache_dir)))
print('Done building decoders.')
if args.prepare:
sys.exit(0)
GeneratorParams = namedtuple('GeneratorParams', ['seed', 'amplitude_spec', 'background'])
gen_params = [ GeneratorParams(rep, (5e-3, 1, 5), background)
#GeneratorParams(rep, (0.05e-3, 3.5, 50), background)
for rep in range(50)
for background in ['meas://fmeas_export_ocxo_2day.bin', 'synth://grid_freq_psd_spl_108pt.json'] ]
# gen_params = [ GeneratorParams(rep, (5e-3, 1, 5), background)
# for rep in range(1)
# for background in ['meas://fmeas_export_ocxo_2day.bin'] ]
data_db = sqlite3.connect(path.join(args.cachedir, 'data_db.sqlite3'))
data_db.execute('CREATE TABLE IF NOT EXISTS waveforms'
'(seed, amplitude_spec, background, nbits, decimation, symbols, thresholds, result, timestamp)')
'SELECT FROM waveforms GROUP BY (amplitude_spec, background, nbits, decimation, symbols, thresholds, result)'
dec_param_groups = defaultdict(lambda: [])
for nbits, thf, decimation, symbols in dec_paramses:
dec_param_groups[(nbits, decimation, symbols)].append(thf)
waveform_params = [ (*gp, *dp, thfs) for gp in gen_params for dp, thfs in dec_param_groups.items() ]
print(f'Generated {len(waveform_params)} parameter sets')
# Separate out our batch
waveform_params = waveform_params[args.index::args.batches]
def lookup_binary(*params):
return path.join(build_cache_dir, builds[tuple(params)], 'tools/dsss_demod_test')
def params_mapper(seed, amplitude_spec, background, nbits, decimation, symbols, thresholds):
amplitude_spec = ','.join(str(x) for x in amplitude_spec)
thresholds = ','.join(str(x) for x in thresholds)
return seed, amplitude_spec, background, nbits, decimation, symbols, thresholds
results = []
for _params, chunk in parallel_generator(data_db, 'waveforms',
['seed', 'amplitude_spec', 'background', 'nbits', 'decimation', 'symbols', 'thresholds'],
params_mapper=params_mapper,
builder=run_test,
param_list=waveform_params, desc='Simulating demodulation',
context=dict(cachedir=data_cache_dir, lookup_binary=lookup_binary),
disable_cache=args.no_cache):
results += chunk
if args.dump:
with open(args.dump, 'w') as f:
json.dump(results, f)

View file

@ -0,0 +1,86 @@
from os import path
import json
import functools
import numpy as np
import numbers
import math
from scipy import signal as sig
import scipy.fftpack
sampling_rate = 10 # sp/s
# From https://github.com/mubeta06/python/blob/master/signal_processing/sp/gold.py
preferred_pairs = {5:[[2],[1,2,3]], 6:[[5],[1,4,5]], 7:[[4],[4,5,6]],
8:[[1,2,3,6,7],[1,2,7]], 9:[[5],[3,5,6]],
10:[[2,5,9],[3,4,6,8,9]], 11:[[9],[3,6,9]]}
def gen_gold(seq1, seq2):
gold = [seq1, seq2]
for shift in range(len(seq1)):
gold.append(seq1 ^ np.roll(seq2, -shift))
return gold
def gold(n):
n = int(n)
if not n in preferred_pairs:
raise KeyError('preferred pairs for %s bits unknown' % str(n))
t0, t1 = preferred_pairs[n]
(seq0, _st0), (seq1, _st1) = sig.max_len_seq(n, taps=t0), sig.max_len_seq(n, taps=t1)
return gen_gold(seq0, seq1)
def modulate(data, nbits=5):
# 0, 1 -> -1, 1
mask = np.array(gold(nbits))*2 - 1
sel = mask[data>>1]
data_lsb_centered = ((data&1)*2 - 1)
signal = (np.multiply(sel, np.tile(data_lsb_centered, (2**nbits-1, 1)).T).flatten() + 1) // 2
return np.hstack([ np.zeros(len(mask)), signal, np.zeros(len(mask)) ])
def load_noise_meas_params(capture_file):
with open(capture_file, 'rb') as f:
meas_data = np.copy(np.frombuffer(f.read(), dtype='float32'))
meas_data -= np.mean(meas_data)
return (meas_data,)
def mains_noise_measured(seed, n, meas_data):
last_valid = len(meas_data) - n
st = np.random.RandomState(seed)
start = st.randint(last_valid)
return meas_data[start:start+n] + 50.00
def load_noise_synth_params(specfile):
with open(specfile) as f:
d = json.load(f)
return {'spl_x': np.linspace(*d['x_spec']),
'spl_N': d['x_spec'][2],
'psd_spl': (d['t'], d['c'], d['k']) }
def mains_noise_synthetic(seed, n, psd_spl, spl_N, spl_x):
st = np.random.RandomState(seed)
noise = st.normal(size=spl_N) * 2
spec = scipy.fftpack.fft(noise) **2
spec *= np.exp(scipy.interpolate.splev(spl_x, psd_spl))
spec **= 1/2
renoise = scipy.fftpack.ifft(spec)
return renoise[10000:][:n] + 50.00
@functools.lru_cache()
def load_noise_gen(url):
schema, refpath = url.split('://')
if not path.isabs(refpath):
refpath = path.abspath(path.join(path.dirname(__file__), refpath))
if schema == 'meas':
return mains_noise_measured, load_noise_meas_params(refpath)
elif schema == 'synth':
return mains_noise_synthetic, load_noise_synth_params(refpath)
else:
raise ValueError('Invalid schema', schema)

View file

@ -0,0 +1,111 @@
#include <stdint.h>
#include <math.h>
#include <unistd.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>
#include <stdlib.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/fcntl.h>
#include "freq_meas.h"
#include "dsss_demod.h"
typedef uint16_t adc_data_t;
void handle_dsss_received(uint8_t data[static TRANSMISSION_SYMBOLS]) {
printf("data sequence received: [ ");
for (size_t i=0; i<TRANSMISSION_SYMBOLS; i++) {
printf("%+3d", ((data[i]&1) ? 1 : -1) * (data[i]>>1));
if (i+1 < TRANSMISSION_SYMBOLS)
printf(", ");
}
printf(" ]\n");
}
void print_usage(void);
void print_usage() {
fprintf(stderr, "Usage: e2e_test [emulated_adc_data.bin]\n");
}
int main(int argc, char **argv) {
if (argc != 2) {
fprintf(stderr, "Error: Invalid arguments.\n");
print_usage();
return 1;
}
int fd = open(argv[1], O_RDONLY);
struct stat st;
if (fstat(fd, &st)) {
fprintf(stderr, "Error querying test data file size: %s\n", strerror(errno));
return 2;
}
if (st.st_size < 0 || st.st_size > 100000000) {
fprintf(stderr, "Error reading test data: too much test data (size=%zd)\n", st.st_size);
return 2;
}
if (st.st_size % sizeof(adc_data_t) != 0) {
fprintf(stderr, "Error reading test data: file size is not divisible by %zd (size=%zd)\n", sizeof(adc_data_t), st.st_size);
return 2;
}
char *buf = malloc(st.st_size);
if (!buf) {
fprintf(stderr, "Error allocating memory");
return 2;
}
const size_t n_samples = st.st_size / sizeof(adc_data_t);
fprintf(stderr, "Reading %zd samples test data...", n_samples);
ssize_t nread = 0;
while (nread < st.st_size) {
ssize_t rc = read(fd, buf + nread, st.st_size - nread);
if (rc == -EINTR || rc == -EAGAIN)
continue;
if (rc < 0) {
fprintf(stderr, "\nError reading test data: %s\n", strerror(errno));
return 2;
}
if (rc == 0) {
fprintf(stderr, "\nError reading test data: Unexpected end of file\n");
return 2;
}
nread += rc;
}
fprintf(stderr, " done. Read %zd bytes.\n", nread);
adc_data_t *buf_d = (adc_data_t *)buf;
struct dsss_demod_state demod;
dsss_demod_init(&demod);
fprintf(stderr, "Starting simulation.\n");
size_t iterations = (n_samples-FMEAS_FFT_LEN)/(FMEAS_FFT_LEN/2);
for (size_t i=0; i<iterations; i++) {
/*
fprintf(stderr, "Iteration %zd/%zd\n", i, iterations);
*/
float res = NAN;
int rc = adc_buf_measure_freq(buf_d + i*(FMEAS_FFT_LEN/2), &res);
if (rc)
printf("ERROR: Simulation error in iteration %zd at position %zd: %d\n", i, i*(FMEAS_FFT_LEN/2), rc);
dsss_demod_step(&demod, res, i);
/*
printf("%09zd %12f\n", i, res);
*/
}
free(buf);
return 0;
}

View file

@ -0,0 +1,59 @@
#!/usr/bin/env python3
import textwrap
import scipy.signal as sig
import numpy as np
WINDOW_TYPES = [
'boxcar',
'triang',
'blackman',
'hamming',
'hann',
'bartlett',
'flattop',
'parzen',
'bohman',
'blackmanharris',
'nuttall',
'barthann',
'kaiser',
'gaussian',
'general_gaussian',
'slepian',
'dpss',
'chebwin',
'exponential',
'tukey',
]
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('window', choices=WINDOW_TYPES, help='Type of window function to use')
parser.add_argument('n', type=int, help='Width of window in samples')
parser.add_argument('window_args', nargs='*', type=float,
help='''Window argument(s) if required. See https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html#scipy.signal.get_window for details.''')
parser.add_argument('-v', '--variable', default='fft_window_table', help='Name for alias variable pointing to generated window')
args = parser.parse_args()
print(f'/* FTT window table for {args.n} sample {args.window} window.')
if args.window_args:
print(f' * Window arguments were: ({" ,".join(str(arg) for arg in args.window_args)})')
print(f' */')
winargs = ''.join(f'_{arg:.4g}'.replace('.', 'F') for arg in args.window_args)
varname = f'fft_{args.n}_window_{args.window}{winargs}'
print(f'const float {varname}[{args.n}] = {{')
win = sig.get_window(args.window if not args.window_args else (args.window, *args.window_args),
Nx=args.n, fftbins=True)
par = ' '.join(f'{f:>013.8g},' for f in win)
print(textwrap.fill(par,
initial_indent=' '*4, subsequent_indent=' '*4,
width=120,
replace_whitespace=False, drop_whitespace=False))
print('};')
print()
print(f'const float * const {args.variable} __attribute__((weak)) = {varname};')

Binary file not shown.

View file

@ -0,0 +1,106 @@
#include <stdint.h>
#include <math.h>
#include <unistd.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>
#include <stdlib.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/fcntl.h>
#include "freq_meas.h"
void print_usage(void);
void print_usage() {
fprintf(stderr, "Usage: freq_meas_test [test_data.bin]\n");
}
int main(int argc, char **argv) {
if (argc != 2) {
fprintf(stderr, "Error: Invalid arguments.\n");
print_usage();
return 1;
}
int fd = open(argv[1], O_RDONLY);
struct stat st;
if (fstat(fd, &st)) {
fprintf(stderr, "Error querying test data file size: %s\n", strerror(errno));
return 2;
}
if (st.st_size < 0 || st.st_size > 1000000) {
fprintf(stderr, "Error reading test data: too much test data (size=%zd)\n", st.st_size);
return 2;
}
if (st.st_size % sizeof(float) != 0) {
fprintf(stderr, "Error reading test data: file size is not divisible by %zd (size=%zd)\n", sizeof(float), st.st_size);
return 2;
}
char *buf = malloc(st.st_size);
if (!buf) {
fprintf(stderr, "Error allocating memory");
return 2;
}
fprintf(stderr, "Reading %zd samples test data...", st.st_size/sizeof(float));
ssize_t nread = 0;
while (nread < st.st_size) {
ssize_t rc = read(fd, buf + nread, st.st_size - nread);
if (rc == -EINTR || rc == -EAGAIN)
continue;
if (rc < 0) {
fprintf(stderr, "\nError reading test data: %s\n", strerror(errno));
return 2;
}
if (rc == 0) {
fprintf(stderr, "\nError reading test data: Unexpected end of file\n");
return 2;
}
nread += rc;
}
fprintf(stderr, " done.\n");
const size_t n_samples = st.st_size / sizeof(float);
float *buf_f = (float *)buf;
int16_t *sim_adc_buf = calloc(sizeof(int16_t), n_samples);
if (!sim_adc_buf) {
fprintf(stderr, "Error allocating memory\n");
return 2;
}
fprintf(stderr, "Converting and truncating test data...");
for (size_t i=0; i<n_samples; i++)
/* Note on scaling: We can't simply scale by 0x8000 (1/2 full range) here. Our test data is nominally 1Vp-p but
* certain tests such as the interharmonics one can have some samples exceeding that range. */
sim_adc_buf[i] = buf_f[i] * (0x4000-1);
fprintf(stderr, " done.\n");
fprintf(stderr, "Starting simulation.\n");
size_t iterations = (n_samples-FMEAS_FFT_LEN)/(FMEAS_FFT_LEN/2);
for (size_t i=0; i<iterations; i++) {
fprintf(stderr, "Iteration %zd/%zd\n", i, iterations);
float res = NAN;
int rc = adc_buf_measure_freq(sim_adc_buf + i*(FMEAS_FFT_LEN/2), &res);
if (rc)
printf("ERROR: Simulation error in iteration %zd at position %zd: %d\n", i, i*(FMEAS_FFT_LEN/2), rc);
printf("%09zd %12f\n", i, res);
}
free(buf);
free(sim_adc_buf);
return 0;
}

View file

@ -0,0 +1,39 @@
#!/usr/bin/env python3
import os
from os import path
import subprocess
import json
import numpy as np
np.set_printoptions(linewidth=240)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(metavar='test_data_directory', dest='dir', help='Directory with test data .bin files')
default_binary = path.abspath(path.join(path.dirname(__file__), '../build/tools/freq_meas_test'))
parser.add_argument(metavar='test_binary', dest='binary', nargs='?', default=default_binary)
parser.add_argument('-d', '--dump', help='Write raw measurements to JSON file')
args = parser.parse_args()
bin_files = [ path.join(args.dir, d) for d in os.listdir(args.dir) if d.lower().endswith('.bin') ]
savedata = {}
for p in bin_files:
output = subprocess.check_output([args.binary, p], stderr=subprocess.DEVNULL)
measurements = np.array([ float(value) for _offset, value in [ line.split() for line in output.splitlines() ] ])
savedata[p] = list(measurements)
# Cut off first and last sample for mean and RMS calculations as these show boundary effects.
measurements = measurements[1:-1]
mean = np.mean(measurements)
rms = np.sqrt(np.mean(np.square(measurements - mean)))
print(f'{path.basename(p):<60}: mean={mean:<8.4f}Hz rms={rms*1000:.3f}mHz')
if args.dump:
with open(args.dump, 'w') as f:
json.dump(savedata, f)

View file

@ -0,0 +1,70 @@
#!/usr/bin/env python3
import sys
import math
import textwrap
import contextlib
import numpy as np
import scipy.signal as sig
# From https://github.com/mubeta06/python/blob/master/signal_processing/sp/gold.py
preferred_pairs = {5:[[2],[1,2,3]], 6:[[5],[1,4,5]], 7:[[4],[4,5,6]],
8:[[1,2,3,6,7],[1,2,7]], 9:[[5],[3,5,6]],
10:[[2,5,9],[3,4,6,8,9]], 11:[[9],[3,6,9]]}
def gen_gold(seq1, seq2):
gold = [seq1, seq2]
for shift in range(len(seq1)):
gold.append(seq1 ^ np.roll(seq2, -shift))
return gold
def gold(n):
n = int(n)
if not n in preferred_pairs:
raise KeyError('preferred pairs for %s bits unknown' % str(n))
t0, t1 = preferred_pairs[n]
(seq0, _st0), (seq1, _st1) = sig.max_len_seq(n, taps=t0), sig.max_len_seq(n, taps=t1)
return gen_gold(seq0, seq1)
@contextlib.contextmanager
def print_include_guards(macro_name):
print(f'#ifndef {macro_name}')
print(f'#define {macro_name}')
yield
print(f'#endif /* {macro_name} */')
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('n', type=int, choices=preferred_pairs, help='bit width of shift register. Generate 2**n + 1 sequences of length 2**n - 1.')
parser.add_argument('-v', '--variable', default='gold_code_table', help='Name for weak alias of generated table')
parser.add_argument('-h', '--header', action='store_true', help='Generate header file')
parser.add_argument('-c', '--source', action='store_true', help='Generate table source file')
args = parser.parse_args()
if not args.header != args.source:
print('Exactly one of --header and --source must be given.', file=sys.stderr)
sys.exit(1)
nbytes = math.ceil((2**args.n-1)/8)
if args.source:
print('/* THIS IS A GENERATED FILE. DO NOT EDIT! */')
print('#include <unistd.h>')
print('#include <stdint.h>')
print()
print(f'/* {args.n} bit gold sequences: {2**args.n+1} sequences of length {2**args.n-1} bit.')
print(f' *')
print(f' * Each code is packed left-aligned into {nbytes} bytes in big-endian byte order.')
print(f' */')
print(f'const uint8_t {args.variable}[{2**args.n+1}][{nbytes}] = {{')
for i, code in enumerate(gold(args.n)):
par = '{' + ' '.join(f'0x{d:02x},' for d in np.packbits(code)) + f'}}, /* {i: 3d} "{"".join(str(x) for x in code)}" */'
print(textwrap.fill(par, initial_indent=' '*4, subsequent_indent=' '*4, width=120))
print('};')
print()
else:
print('/* THIS IS A GENERATED FILE. DO NOT EDIT! */')
with print_include_guards(f'__GOLD_CODE_GENERATED_HEADER_{args.n}__'):
print(f'extern const uint8_t {args.variable}[{2**args.n+1}][{nbytes}];')

View file

@ -0,0 +1 @@
{"x_spec": [3.2595692805152726e-05, 5.0, 613575], "t": [3.2595692805152726e-05, 3.2595692805152726e-05, 3.2595692805152726e-05, 3.2595692805152726e-05, 0.0001423024947075771, 0.00015800362803968106, 0.00017543716661470822, 0.00019479425764873777, 0.0002162871388378975, 0.00024015146540428407, 0.00026664889389955537, 0.00029606995109590574, 0.00032873721941990017, 0.0003650088738553592, 0.0004052826090950758, 0.00045000000000000004, 0.000499651343175437, 0.0005547810327489297, 0.0006159935292916862, 0.0006839599873288199, 0.0007594256141046668, 0.0008432178402871724, 0.0009362553921977272, 0.0010395583650374223, 0.0011542594075560205, 0.001281616140796111, 0.0014230249470757708, 0.001580036280396809, 0.0017543716661470824, 0.0019479425764873776, 0.002162871388378975, 0.0024015146540428403, 0.002666488938995554, 0.002960699510959057, 0.0032873721941990056, 0.0036500887385535925, 0.004052826090950754, 0.0045000000000000005, 0.00499651343175437, 0.005547810327489296, 0.006159935292916869, 0.0068395998732882, 0.007594256141046669, 0.008432178402871724, 0.009362553921977271, 0.010395583650374221, 0.011542594075560205, 0.012816161407961109, 0.014230249470757707, 0.01580036280396809, 0.017543716661470823, 0.01947942576487376, 0.02162871388378975, 0.024015146540428405, 0.026664889389955565, 0.02960699510959057, 0.03287372194199005, 0.036500887385535925, 0.04052826090950754, 0.045, 0.0499651343175437, 0.05547810327489296, 0.06159935292916863, 0.06839599873288206, 0.07594256141046668, 0.08432178402871732, 0.09362553921977272, 0.10395583650374222, 0.11542594075560206, 0.12816161407961107, 0.14230249470757705, 0.15800362803968088, 0.1754371666147082, 0.1947942576487376, 0.21628713883789774, 0.24015146540428406, 0.26664889389955565, 0.2960699510959057, 0.32873721941990053, 0.36500887385535924, 0.40528260909507535, 0.45, 0.499651343175437, 0.5547810327489296, 0.6159935292916868, 0.6839599873288206, 0.7594256141046669, 0.8432178402871732, 0.9362553921977271, 1.0395583650374223, 1.1542594075560206, 1.2816161407961109, 1.4230249470757708, 1.5800362803968104, 1.7543716661470823, 1.9479425764873777, 2.162871388378975, 2.4015146540428405, 2.6664889389955535, 2.960699510959057, 3.287372194199002, 3.6500887385535927, 4.052826090950758, 4.5, 5.0, 5.0, 5.0, 5.0], "c": [0.7720161468716866, -0.5547528253056444, 0.30706059086000753, 0.19422577014134906, -1.1954636661840032, 0.9215976941641111, -0.6668136393976918, -1.341269161156733, -0.16311330594842666, -1.7639636752234251, -1.238385544822954, -0.32649555618555554, -0.03086589610280171, -2.358195657381619, -0.5759152419849985, 0.1892225800004134, -1.8122889670546236, -0.8109120798216202, -0.5500991736738969, -4.680192969256771, -2.8007700704649876, 0.16866469558571784, -1.1040811840849307, -3.0243574268705546, -4.018139927365795, -4.100581028618109, -0.556354762846191, -7.414377514669229, 1.36396325920194, -6.002559557058508, -2.2113451390305365, -4.578944771104116, -4.372644849632638, -3.945339124673235, -4.778747958903158, -2.370174137632325, -5.7372466088109295, -4.707506574819875, -4.834404729330929, -5.005244244061701, -5.82644896783577, -4.717966026411524, -6.146374820241562, -4.972788381244952, -5.854957092953355, -5.702174935205885, -6.222035857079607, -6.2128389666872, -6.212821706753751, -6.253599689326325, -6.681685577659057, -6.372364384360678, -6.771223202540934, -6.856809137231159, -6.986412256164045, -7.190466178818742, -7.577896455149433, -7.515731696006047, -7.598155006351761, -7.824526916149126, -8.141496591776512, -8.36794927682997, -8.80307396767114, -8.828816533544659, -9.357524260470413, -9.658130054343863, -10.005768472049466, -10.499801262514108, -11.028689820560558, -11.413688641742898, -11.906162042727946, -12.232342460719975, -12.438432746733596, -13.088338100203112, -12.308710772618745, -11.685074853925329, -11.397838681243094, -12.265219694936695, -13.600359694898529, -14.031425961884718, -12.236885080485473, -13.527508426900974, -13.698402018452601, -13.397911198962568, -14.144410560196603, -13.905769594095293, -14.410874830544122, -14.531727635304264, -14.59275291853806, -14.35404826562502, -14.58670053318149, -14.432515268864977, -14.363428024828353, -14.429222027493264, -14.73947634127499, -14.717315405960353, -14.678539669792505, -14.825278423641382, -14.80936417940876, -14.943375264882789, -14.680885181815674, -14.54841244844906, -14.634365225950589, -14.609444790868906, 0.0, 0.0, 0.0, 0.0], "k": 3}

View file

@ -0,0 +1,111 @@
#!/usr/bin/env python
# coding: utf-8
import binascii
import struct
import numpy as np
import pydub
from dsss_demod_test_waveform_gen import load_noise_gen, modulate as dsss_modulate
np.set_printoptions(linewidth=240)
def generate_noisy_signal(
test_data=32,
test_nbits=5,
test_decimation=10,
test_signal_amplitude=20e-3,
noise_level=10e-3,
noise_spec='synth://grid_freq_psd_spl_108pt.json',
seed=0):
#test_data = np.random.RandomState(seed=0).randint(0, 2 * (2**test_nbits), test_duration)
#test_data = np.array([0, 1, 2, 3] * 50)
if isinstance(test_data, int):
test_data = np.array(range(test_data))
signal = np.repeat(dsss_modulate(test_data, test_nbits) * 2.0 - 1, test_decimation)
noise_gen, noise_params = load_noise_gen(noise_spec)
noise = noise_gen(seed, len(signal), **noise_params)
return np.absolute(noise + signal*test_signal_amplitude)
def write_raw_frequencies_bin(outfile, **kwargs):
with open(outfile, 'wb') as f:
for x in generate_noisy_signal(**kwargs):
f.write(struct.pack('f', x))
def synthesize_sine(freqs, freqs_sampling_rate=10.0, output_sampling_rate=44100):
duration = len(freqs) / freqs_sampling_rate # seconds
afreq_out = np.interp(np.linspace(0, duration, int(duration*output_sampling_rate)), np.linspace(0, duration, len(freqs)), freqs)
return np.sin(np.cumsum(2*np.pi * afreq_out / output_sampling_rate))
def write_flac(filename, signal, sampling_rate=44100):
signal -= np.min(signal)
signal /= np.max(signal)
signal -= 0.5
signal *= 2**16 - 1
le_bytes = signal.astype(np.int16).tobytes()
seg = pydub.AudioSegment(data=le_bytes, sample_width=2, frame_rate=sampling_rate, channels=1)
seg.export(filename, format='flac')
def write_synthetic_hum_flac(filename, output_sampling_rate=44100, freqs_sampling_rate=10.0, **kwargs):
signal = generate_noisy_signal(**kwargs)
print(signal)
write_flac(filename, synthesize_sine(signal, freqs_sampling_rate, output_sampling_rate),
sampling_rate=output_sampling_rate)
def emulate_adc_signal(adc_bits=12, adc_offset=0.4, adc_amplitude=0.25, freq_sampling_rate=10.0, output_sampling_rate=1000, **kwargs):
signal = synthesize_sine(generate_noisy_signal(), freq_sampling_rate, output_sampling_rate)
signal = signal*adc_amplitude + adc_offset
smin, smax = np.min(signal), np.max(signal)
if smin < 0.0 or smax > 1.0:
raise UserWarning('Amplitude or offset too large: Signal out of bounds with min/max [{smin}, {smax}] of ADC range')
signal *= 2**adc_bits -1
return signal
def save_adc_signal(fn, signal, dtype=np.uint16):
with open(fn, 'wb') as f:
f.write(signal.astype(dtype).tobytes())
def write_emulated_adc_signal_bin(filename, **kwargs):
save_adc_signal(filename, emulate_adc_signal(**kwargs))
def hum_cmd(args):
write_synthetic_hum_flac(args.out_flac,
output_sampling_rate=args.audio_sampling_rate,
freqs_sampling_rate=args.frequency_sampling_rate,
test_data = np.array(list(binascii.unhexlify(args.data))),
test_nbits = args.symbol_bits,
test_decimation = args.decimation,
test_signal_amplitude = args.signal_level/1e3,
noise_level = args.noise_level/1e3,
noise_spec=args.noise_spec,
seed = args.random_seed)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
cmd_parser = parser.add_subparsers(required=True)
hum_parser = cmd_parser.add_parser('hum', help='Generated artificial modulated mains hum')
# output parameters
hum_parser.add_argument('-a', '--audio-sampling-rate', type=int, default=44100)
# modulation parameters
hum_parser.add_argument('-f', '--frequency-sampling-rate', type=float, default=10.0*100/128)
hum_parser.add_argument('-b', '--symbol-bits', type=int, default=5, help='bits per symbol (excluding sign bit)')
hum_parser.add_argument('-n', '--noise-level', type=float, default=1.0, help='Scale synthetic noise level')
hum_parser.add_argument('-s', '--signal-level', type=float, default=20.0, help='Synthetic noise level in mHz')
hum_parser.add_argument('-d', '--decimation', type=int, default=10, help='DSSS modulation decimation in frequency measurement cycles')
hum_parser.add_argument('-r', '--random-seed', type=int, default=0)
hum_parser.add_argument('--noise-spec', type=str, default='synth://grid_freq_psd_spl_108pt.json')
hum_parser.add_argument('out_flac', metavar='out.flac', help='FLAC output file')
hum_parser.add_argument('data', help='modulation data hex string')
hum_parser.set_defaults(func=hum_cmd)
args = parser.parse_args()
args.func(args)

View file

@ -0,0 +1,126 @@
import sys
import pyparsing as pp
from pyparsing import pyparsing_common as ppc
LPAREN, RPAREN, LBRACE, RBRACE, LBROK, RBROK, COLON, SEMICOLON, EQUALS, COMMA = map(pp.Suppress, '(){}<>:;=,')
parse_suffix_int = lambda lit: int(lit[:-1]) * (10**(3*(1 + 'kmgtpe'.find(lit[-1].lower()))))
si_suffix = pp.oneOf('k m g t p e', caseless=True)
numeric_literal = pp.Regex('0x[0-9a-fA-F]+').setName('hex int').setParseAction(pp.tokenMap(int, 16)) \
| (pp.Regex('[0-9]+[kKmMgGtTpPeE]')).setName('size int').setParseAction(pp.tokenMap(parse_suffix_int)) \
| pp.Word(pp.nums).setName('int').setParseAction(pp.tokenMap(int))
access_def = pp.Regex('[rR]?[wW]?[xX]?').setName('access literal').setParseAction(pp.tokenMap(str.lower))
origin_expr = pp.Suppress(pp.CaselessKeyword('ORIGIN')) + EQUALS + numeric_literal
length_expr = pp.Suppress(pp.CaselessKeyword('LENGTH')) + EQUALS + numeric_literal
mem_expr = pp.Group(ppc.identifier + LPAREN + access_def + RPAREN + COLON + origin_expr + COMMA + length_expr)
mem_contents = pp.ZeroOrMore(mem_expr)
mem_toplevel = pp.CaselessKeyword("MEMORY") + pp.Group(LBRACE + pp.Optional(mem_contents, []) + RBRACE)
glob = pp.Word(pp.alphanums + '._*')
match_expr = pp.Forward()
assignment = pp.Forward()
funccall = pp.Group(pp.Word(pp.alphas + '_') + LPAREN + (assignment | numeric_literal | match_expr | glob | ppc.identifier) + RPAREN + pp.Optional(SEMICOLON))
value = numeric_literal | funccall | ppc.identifier | '.'
formula = (value + pp.oneOf('+ = * / %') + value) | value
# suppress stray semicolons
assignment << (SEMICOLON | pp.Group((ppc.identifier | '.') + EQUALS + (formula | value) + pp.Optional(SEMICOLON)))
match_expr << (glob + LPAREN + pp.OneOrMore(funccall | glob) + RPAREN)
section_contents = pp.ZeroOrMore(assignment | funccall | match_expr);
section_name = pp.Regex('\.[a-zA-Z0-9_.]+')
section_def = pp.Group(section_name + pp.Optional(numeric_literal) + COLON + LBRACE + pp.Group(section_contents) +
RBRACE + pp.Optional(RBROK + ppc.identifier + pp.Optional('AT' + RBROK + ppc.identifier)))
sec_contents = pp.ZeroOrMore(section_def | assignment)
sections_toplevel = pp.Group(pp.CaselessKeyword("SECTIONS").suppress() + LBRACE + sec_contents + RBRACE)
toplevel_elements = mem_toplevel | funccall | sections_toplevel | assignment
ldscript = pp.Group(pp.ZeroOrMore(toplevel_elements))
ldscript.ignore(pp.cppStyleComment)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('linker_script', type=argparse.FileType('r'))
args = parser.parse_args()
#print(mem_expr.parseString('FLASH (rx) : ORIGIN = 0x0800000, LENGTH = 512K', parseAll=True))
# print(ldscript.parseString('''
# /* Entry Point */
# ENTRY(Reset_Handler)
#
# /* Highest address of the user mode stack */
# _estack = 0x20020000; /* end of RAM */
# /* Generate a link error if heap and stack don't fit into RAM */
# _Min_Heap_Size = 0x200;; /* required amount of heap */
# _Min_Stack_Size = 0x400;; /* required amount of stack */
# ''', parseAll=True))
print(ldscript.parseFile(args.linker_script, parseAll=True))
#print(funccall.parseString('KEEP(*(.isr_vector))'))
#print(section_contents.parseString('''
# . = ALIGN(4);
# KEEP(*(.isr_vector)) /* Startup code */
# . = ALIGN(4);
# ''', parseAll=True))
#print(section_def.parseString('''
# .text :
# {
# . = ALIGN(4);
# *(.text) /* .text sections (code) */
# *(.text*) /* .text* sections (code) */
# *(.glue_7) /* glue arm to thumb code */
# *(.glue_7t) /* glue thumb to arm code */
# *(.eh_frame)
#
# KEEP (*(.init))
# KEEP (*(.fini))
#
# . = ALIGN(4);
# _etext = .; /* define a global symbols at end of code */
# } >FLASH
# ''', parseAll=True))
#print(section_def.parseString('.ARM.extab : { *(.ARM.extab* .gnu.linkonce.armextab.*) } >FLASH', parseAll=True))
#print(assignment.parseString('__preinit_array_start = .', parseAll=True))
#print(assignment.parseString('a = 23', parseAll=True))
#print(funccall.parseString('foo (a=23)', parseAll=True))
#print(funccall.parseString('PROVIDE_HIDDEN (__preinit_array_start = .);', parseAll=True))
#print(section_def.parseString('''
# .preinit_array :
# {
# PROVIDE_HIDDEN (__preinit_array_start = .);
# KEEP (*(.preinit_array*))
# PROVIDE_HIDDEN (__preinit_array_end = .);
# } >FLASH''', parseAll=True))
#print(match_expr.parseString('*(SORT(.init_array.*))', parseAll=True))
#print(funccall.parseString('KEEP (*(SORT(.init_array.*)))', parseAll=True))
#print(section_def.parseString('''
# .init_array :
# {
# PROVIDE_HIDDEN (__init_array_start = .);
# KEEP (*(SORT(.init_array.*)))
# KEEP (*(.init_array*))
# PROVIDE_HIDDEN (__init_array_end = .);
# } >FLASH
# ''', parseAll=True))
#print(match_expr.parseString('*(.ARM.extab* .gnu.linkonce.armextab.*)', parseAll=True))
#print(formula.parseString('. + _Min_Heap_Size', parseAll=True))
#print(assignment.parseString('. = . + _Min_Heap_Size;', parseAll=True))
#print(sections_toplevel.parseString('''
# SECTIONS
# {
# .ARMattributes : { }
# }
# ''', parseAll=True))
#sys.exit(0)

View file

@ -0,0 +1,276 @@
import tempfile
import os
from os import path
import sys
import re
import subprocess
from contextlib import contextmanager
from collections import defaultdict
import colorsys
import cxxfilt
from elftools.elf.elffile import ELFFile
from elftools.elf.enums import ENUM_ST_SHNDX
from elftools.elf.descriptions import describe_symbol_type, describe_sh_type
import libarchive
import matplotlib.cm
@contextmanager
def chdir(newdir):
old_cwd = os.getcwd()
try:
os.chdir(newdir)
yield
finally:
os.chdir(old_cwd)
def keep_last(it, first=None):
last = first
for elem in it:
yield last, elem
last = elem
def delim(start, end, it, first_only=True):
found = False
for elem in it:
if end(elem):
if first_only:
return
found = False
elif start(elem):
found = True
elif found:
yield elem
def delim_prefix(start, end, it):
yield from delim(lambda l: l.startswith(start), lambda l: end is not None and l.startswith(end), it)
def trace_source_files(linker, cmdline, trace_sections=[], total_sections=['.text', '.data', '.rodata']):
with tempfile.TemporaryDirectory() as tempdir:
out_path = path.join(tempdir, 'output.elf')
output = subprocess.check_output([linker, '-o', out_path, f'-Wl,--print-map', *cmdline])
lines = [ line.strip() for line in output.decode().splitlines() ]
# FIXME also find isr vector table references
defs = {}
objs = defaultdict(lambda: 0)
aliases = {}
sec_name = None
last_loc = None
last_sym = None
line_cont = None
for last_line, line in keep_last(delim_prefix('Linker script and memory map', 'OUTPUT', lines), first=''):
if not line or line.startswith('LOAD '):
sec_name = None
continue
# first part of continuation line
if m := re.match('^(\.[0-9a-zA-Z-_.]+)$', line):
line_cont = line
sec_name = None
continue
if line_cont:
line = line_cont + ' ' + line
line_cont = None
# -ffunction-sections/-fdata-sections section
if m := re.match('^(\.[0-9a-zA-Z-_.]+)\.([0-9a-zA-Z-_.]+)\s+(0x[0-9a-f]+)\s+(0x[0-9a-f]+)\s+(\S+)$', line):
sec, sym, loc, size, obj = m.groups()
*_, sym = sym.rpartition('.')
sym = cxxfilt.demangle(sym)
size = int(size, 16)
obj = path.abspath(obj)
if sec not in total_sections:
size = 0
objs[obj] += size
defs[sym] = (sec, size, obj)
sec_name, last_loc, last_sym = sec, loc, sym
continue
# regular (no -ffunction-sections/-fdata-sections) section
if m := re.match('^(\.[0-9a-zA-Z-_]+)\s+(0x[0-9a-f]+)\s+(0x[0-9a-f]+)\s+(\S+)$', line):
sec, _loc, size, obj = m.groups()
size = int(size, 16)
obj = path.abspath(obj)
if sec in total_sections:
objs[obj] += size
sec_name = sec
last_loc, last_sym = None, None
continue
# symbol def
if m := re.match('^(0x[0-9a-f]+)\s+(\S+)$', line):
loc, sym = m.groups()
sym = cxxfilt.demangle(sym)
loc = int(loc, 16)
if sym in defs:
continue
if loc == last_loc:
assert last_sym is not None
aliases[sym] = last_sym
else:
assert sec_name
defs[sym] = (sec_name, None, obj)
last_loc, last_sym = loc, sym
continue
refs = defaultdict(lambda: set())
for sym, (sec, size, obj) in defs.items():
fn, _, member = re.match('^([^()]+)(\((.+)\))?$', obj).groups()
fn = path.abspath(fn)
if member:
subprocess.check_call(['ar', 'x', '--output', tempdir, fn, member])
fn = path.join(tempdir, member)
with open(fn, 'rb') as f:
elf = ELFFile(f)
symtab = elf.get_section_by_name('.symtab')
symtab_demangled = { cxxfilt.demangle(nsym.name).replace(' ', ''): i
for i, nsym in enumerate(symtab.iter_symbols()) }
s = set()
sec_map = { sec.name: i for i, sec in enumerate(elf.iter_sections()) }
matches = [ i for name, i in sec_map.items() if re.match(f'\.rel\..*\.{sym}', name) ]
if matches:
sec = elf.get_section(matches[0])
for reloc in sec.iter_relocations():
refsym = symtab.get_symbol(reloc['r_info_sym'])
name = refsym.name if refsym.name else elf.get_section(refsym['st_shndx']).name.split('.')[-1]
s.add(name)
refs[sym] = s
for tsec in trace_sections:
matches = [ i for name, i in sec_map.items() if name == f'.rel{tsec}' ]
s = set()
if matches:
sec = elf.get_section(matches[0])
for reloc in sec.iter_relocations():
refsym = symtab.get_symbol(reloc['r_info_sym'])
s.add(refsym.name)
refs[tsec.replace('.', '_')] |= s
return objs, aliases, defs, refs
@contextmanager
def wrap(leader='', print=print, left='{', right='}'):
print(leader, left)
yield lambda *args, **kwargs: print(' ', *args, **kwargs)
print(right)
def mangle(name):
return re.sub('[^a-zA-Z0-9_]', '_', name)
hexcolor = lambda r, g, b, *_a: f'#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}'
def vhex(val):
r,g,b,_a = matplotlib.cm.viridis(1.0-val)
fc = hexcolor(r, g, b)
h,s,v = colorsys.rgb_to_hsv(r,g,b)
cc = '#000000' if v > 0.8 else '#ffffff'
return fc, cc
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--trace-sections', type=str, action='append', default=[])
parser.add_argument('--trim-stubs', type=str, action='append', default=[])
parser.add_argument('--highlight-subdirs', type=str, default=None)
parser.add_argument('linker_binary')
parser.add_argument('linker_args', nargs=argparse.REMAINDER)
args = parser.parse_args()
trace_sections = args.trace_sections
trace_sections_mangled = { sec.replace('.', '_') for sec in trace_sections }
objs, aliases, syms, refs = trace_source_files(args.linker_binary, args.linker_args, trace_sections)
clusters = defaultdict(lambda: [])
for sym, (sec, size, obj) in syms.items():
clusters[obj].append((sym, sec, size))
max_ssize = max(size or 0 for _sec, size, _obj in syms.values())
max_osize = max(objs.values())
subdir_prefix = path.abspath(args.highlight_subdirs) + '/' if args.highlight_subdirs else '### NO HIGHLIGHT ###'
first_comp = lambda le_path: path.dirname(le_path).partition(os.sep)[0]
subdir_colors = sorted({ first_comp(obj[len(subdir_prefix):]) for obj in objs if obj.startswith(subdir_prefix) })
subdir_colors = { path: hexcolor(*matplotlib.cm.Pastel1(i/len(subdir_colors))) for i, path in enumerate(subdir_colors) }
subdir_sizes = defaultdict(lambda: 0)
for obj, size in objs.items():
if not isinstance(size, int):
continue
if obj.startswith(subdir_prefix):
subdir_sizes[first_comp(obj[len(subdir_prefix):])] += size
else:
subdir_sizes['<others>'] += size
print('Subdir sizes:', file=sys.stderr)
for subdir, size in sorted(subdir_sizes.items(), key=lambda x: x[1]):
print(f'{subdir:>20}: {size:>6,d} B', file=sys.stderr)
def lookup_highlight(path):
if args.highlight_subdirs:
if obj.startswith(subdir_prefix):
highlight_head = first_comp(path[len(subdir_prefix):])
return subdir_colors[highlight_head], highlight_head
else:
return '#e0e0e0', None
else:
return '#ddf7f4', None
with wrap('digraph G', print) as lvl1print:
print('size="23.4,16.5!";')
print('graph [fontsize=40];')
print('node [fontsize=40];')
#print('ratio="fill";')
print('rankdir=LR;')
print('ranksep=5;')
print('nodesep=0.2;')
print()
for i, (obj, obj_syms) in enumerate(clusters.items()):
with wrap(f'subgraph cluster_{i}', lvl1print) as lvl2print:
print('style = "filled";')
highlight_color, highlight_head = lookup_highlight(obj)
print(f'bgcolor = "{highlight_color}";')
print('pencolor = none;')
fc, cc = vhex(objs[obj]/max_osize)
highlight_subdir_part = f'<font face="carlito" color="{cc}" point-size="40">{highlight_head} / </font>' if highlight_head else ''
lvl2print(f'label = <<table border="0"><tr><td border="0" cellpadding="5" bgcolor="{fc}">'
f'{highlight_subdir_part}'
f'<font face="carlito" color="{cc}"><b>{path.basename(obj)} ({objs[obj]}B)</b></font>'
f'</td></tr></table>>;')
lvl2print()
for sym, sec, size in obj_syms:
has_size = isinstance(size, int) and size > 0
size_s = f' ({size}B)' if has_size else ''
fc, cc = vhex(size/max_ssize) if has_size else ('#ffffff', '#000000')
shape = 'box' if sec == '.text' else 'oval'
lvl2print(f'{mangle(sym)}[label = "{sym}{size_s}", style="rounded,filled", shape="{shape}", fillcolor="{fc}", fontname="carlito", fontcolor="{cc}" color=none];')
lvl1print()
edges = set()
for start, ends in refs.items():
for end in ends:
end = aliases.get(end, end)
if (start in syms or start in trace_sections_mangled) and end in syms:
edges.add((start, end))
for start, end in edges:
lvl1print(f'{mangle(start)} -> {mangle(end)} [style="bold", color="#333333"];')
for sec in trace_sections:
lvl1print(f'{sec.replace(".", "_")} [label = "section {sec}", shape="box", style="filled,bold"];')

View file

@ -0,0 +1,62 @@
#!/usr/bin/env python3
def parse_linker_script(data):
pass
def link(groups):
defined_symbols = {}
undefined_symbols = set()
for group, files in groups:
while True:
found_something = False
for fn in files:
symbols = load_symbols(fn)
for symbol in symbols:
if symbol in defined_symbols:
if not group or not found_something:
break
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-T', '--script', type=str, help='Linker script to use')
parser.add_argument('-o', '--output', type=str, help='Output file to produce')
args, rest = parser.parse_known_intermixed_args()
print(rest)
addprefix = lambda *xs: [ prefix + opt for opt in xs for prefix in ('', '-Wl,') ]
START_GROUP = addprefix('-(', '--start-group')
END_GROUP = addprefix('-)', '--end-group')
GROUP_OPTS = [*START_GROUP, *END_GROUP]
input_files = [ arg for arg in rest if not arg.startswith('-') or arg in GROUP_OPTS ]
def input_file_iter(input_files):
group = False
files = []
for arg in input_files:
if arg in START_GROUP:
assert not group
if files:
yield False, files # nested -Wl,--start-group
group, files = True, []
elif arg in END_GROUP:
assert group # missing -Wl,--start-group
if files:
yield True, files
group, files = False, []
else:
files.append(arg)
assert not group # missing -Wl,--end-group
if files:
yield False, files

View file

@ -0,0 +1,118 @@
#!/usr/bin/env python3
import re
import subprocess
import tempfile
import pprint
ARCHIVE_RE = r'([^(]*)(\([^)]*\))?'
def trace_source_files(linker, cmdline):
with tempfile.NamedTemporaryFile() as mapfile:
output = subprocess.check_output([linker, f'-Wl,--Map={mapfile.name}', *cmdline])
# intentionally use generator here
idx = 0
lines = [ line.rstrip() for line in mapfile.read().decode().splitlines() if line.strip() ]
for idx, line in enumerate(lines[idx:], start=idx):
#print('Dropping', line)
if line == 'Linker script and memory map':
break
idx += 1
objects = []
symbols = {}
sections = {}
current_object = None
last_offset = None
last_symbol = None
cont_sec = None
cont_ind = None
current_section = None
for idx, line in enumerate(lines[idx:], start=idx):
print(f'Processing >{line}')
if line.startswith('LOAD'):
_load, obj = line.split()
objects.append(obj)
continue
if line.startswith('OUTPUT'):
break
m = re.match(r'^( ?)([^ ]+)? +(0x[0-9a-z]+) +(0x[0-9a-z]+)?(.*)?$', line)
if m is None:
m = re.match(r'^( ?)([^ ]+)?$', line)
if m:
cont_ind, cont_sec = m.groups()
else:
cont_ind, cont_sec = None, None
last_offset, last_symbol = None, None
continue
indent, sec, offx, size, sym_or_src = m.groups()
if sec is None:
sec = cont_sec
ind = cont_ind
cont_sec = None
cont_ind = None
print(f'vals: indent={indent} sec={sec} offx={offx} size={size} sym_or_src={sym_or_src}')
if not re.match('^[a-zA-Z_0-9<>():*]+$', sym_or_src):
continue
if indent == '':
print(f'Section: {sec} 0x{size:x}')
current_section = sec
sections[sec] = size
last_offset = None
last_symbol = None
continue
if offx is not None:
offx = int(offx, 16)
if size is not None:
size = int(size, 16)
if size is not None and sym_or_src is not None:
# archive/object line
archive, _member = re.match(ARCHIVE_RE, sym_or_src).groups()
current_object = archive
last_offset = offx
else:
if sym_or_src is not None:
assert size is None
if last_offset is not None:
last_size = offx - last_offset
symbols[last_symbol] = (last_size, current_section)
print(f'Symbol: {last_symbol} 0x{last_size:x} @{current_section}')
last_offset = offx
last_symbol = sym_or_src
idx += 1
for idx, line in enumerate(lines[idx:], start=idx):
if line == 'Cross Reference Table':
break
idx += 1
# map which symbol was pulled from which object in the end
used_defs = {}
for line in lines:
*left, right = line.split()
archive, _member = re.match(ARCHIVE_RE, right).groups()
if left:
used_defs[''.join(left)] = archive
#pprint.pprint(symbols)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('linker_binary')
parser.add_argument('linker_args', nargs=argparse.REMAINDER)
args = parser.parse_args()
source_files = trace_source_files(args.linker_binary, args.linker_args)

View file

@ -0,0 +1,129 @@
import re
from collections import defaultdict, namedtuple
Section = namedtuple('Section', ['name', 'offset', 'objects'])
ObjectEntry = namedtuple('ObjectEntry', ['filename', 'object', 'offset', 'size'])
FileEntry = namedtuple('FileEntry', ['section', 'object', 'offset', 'length'])
class Memory:
def __init__(self, name, origin, length, attrs=''):
self.name, self.origin, self.length, self.attrs = name, origin, length, attrs
self.sections = {}
self.files = defaultdict(lambda: [])
self.totals = defaultdict(lambda: 0)
def add_toplevel(self, name, offx, length):
self.sections[name] = Section(offx, length, [])
def add_obj(self, name, offx, length, fn, obj):
base_section, sep, subsec = name[1:].partition('.')
base_section = '.'+base_section
if base_section in self.sections:
sec = secname, secoffx, secobjs = self.sections[base_section]
secobjs.append(ObjectEntry(fn, obj, offx, length))
else:
sec = None
self.files[fn].append(FileEntry(sec, obj, offx, length))
self.totals[fn] += length
class MapFile:
def __init__(self, s):
self._lines = s.splitlines()
self.memcfg = {}
self.defaultmem = Memory('default', 0, 0xffffffffffffffff)
self._parse()
def __getitem__(self, offx_or_name):
''' Lookup a memory area by name or address '''
if offx_or_name in self.memcfg:
return self.memcfg[offx_or_name]
elif isinstance(offx_or_name, int):
for mem in self.memcfg.values():
if mem.origin <= offx_or_name < mem.origin+mem.length:
return mem
else:
return self.defaultmem
raise ValueError('Invalid argument type for indexing')
def _skip(self, regex):
matcher = re.compile(regex)
for l in self:
if matcher.match(l):
break
def __iter__(self):
while self._lines:
yield self._lines.pop(0)
def _parse(self):
self._skip('^Memory Configuration')
# Parse memory segmentation info
self._skip('^Name')
for l in self:
if not l:
break
name, origin, length, *attrs = l.split()
if not name.startswith('*'):
self.memcfg[name] = Memory(name, int(origin, 16), int(length, 16), attrs[0] if attrs else '')
# Parse section information
toplevel_m = re.compile('^(\.[a-zA-Z0-9_.]+)\s+(0x[0-9a-fA-F]+)\s+(0x[0-9a-fA-F]+)')
secondlevel_m = re.compile('^ (\.[a-zA-Z0-9_.]+)\s+(0x[0-9a-fA-F]+)\s+(0x[0-9a-fA-F]+)\s+(.*)$')
secondlevel_linebreak_m = re.compile('^ (\.[a-zA-Z0-9_.]+)\n')
filelike = re.compile('^(/?[^()]*\.[a-zA-Z0-9-_]+)(\(.*\))?')
linebreak_section = None
for l in self:
# Toplevel section
match = toplevel_m.match(l)
if match:
name, offx, length = match.groups()
offx, length = int(offx, 16), int(length, 16)
self[offx].add_toplevel(name, offx, length)
match = secondlevel_linebreak_m.match(l)
if match:
linebreak_section, = match.groups()
continue
if linebreak_section:
l = ' {} {}'.format(linebreak_section, l)
linebreak_section = None
# Second-level section
match = secondlevel_m.match(l)
if match:
name, offx, length, misc = match.groups()
match = filelike.match(misc)
if match:
fn, obj = match.groups()
obj = obj.strip('()') if obj else None
offx, length = int(offx, 16), int(length, 16)
self[offx].add_obj(name, offx, length, fn, obj)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Parser GCC map file')
parser.add_argument('mapfile', type=argparse.FileType('r'), help='The GCC .map file to parse')
parser.add_argument('-m', '--memory', type=str, help='The memory segments to print, comma-separated')
args = parser.parse_args()
mf = MapFile(args.mapfile.read())
args.mapfile.close()
mems = args.memory.split(',') if args.memory else mf.memcfg.keys()
for name in mems:
mem = mf.memcfg[name]
print('Symbols by file for memory', name)
for tot, fn in reversed(sorted( (tot, fn) for fn, tot in mem.totals.items() )):
print(' {:>8} {}'.format(tot, fn))
for length, offx, sec, obj in reversed(sorted(( (length, offx, sec, obj) for sec, obj, offx, length in
mem.files[fn] ), key=lambda e: e[0] )):
name = sec.name if sec else None
print(' {:>8} {:>#08x} {}'.format(length, offx, obj))
#print('{:>16} 0x{:016x} 0x{:016x} ({:>24}) {}'.format(name, origin, length, length, attrs))

View file

@ -0,0 +1,141 @@
#!/usr/bin/env python3
import os
import sys
import textwrap
import uuid
import hmac
import binascii
import time
from datetime import datetime
LINKING_KEY_SIZE = 15
PRESIG_VERSION = '000.001'
DOMAINS = ['all', 'country', 'region', 'vendor', 'series']
def format_hex(data, indent=4, wrap=True):
indent = ' '*indent
par = ', '.join(f'0x{b:02x}' for b in data)
par = textwrap.fill(par, width=120,
initial_indent=indent, subsequent_indent=indent,
replace_whitespace=False, drop_whitespace=False)
if wrap:
return f'{{\n{par}\n}}'
return par
def domain_string(domain, value):
return f'smart reset domain string v{PRESIG_VERSION}: domain:{domain}={value}'
def keygen_cmd(args):
if os.path.exists(args.keyfile) and not args.force:
print("Error: keyfile already exists. We won't overwrite it. Instead please remove it manually.",
file=sys.stderr)
return 1
root_key = os.urandom(LINKING_KEY_SIZE)
with open(args.keyfile, 'wb') as f:
f.write(binascii.hexlify(root_key))
f.write(b'\n')
return 0
def gen_at_height(domain, value, height, key):
# nanananananana BLOCKCHAIN!
ds = domain_string(domain, value).encode('utf-8')
for height in range(height+1):
key = hmac.digest(key, ds, 'sha512')[:LINKING_KEY_SIZE]
return key
def auth_cmd(args):
with open(args.keyfile, 'r') as f:
root_key = binascii.unhexlify(f.read().strip())
vals = [ (domain, getattr(args, domain)) for domain in DOMAINS if getattr(args, domain) is not None ]
if not vals:
vals = [('all', 'all')]
for domain, value in vals:
auth = gen_at_height(domain, value, args.height, root_key)
print(f'{domain}="{value}" @{args.height}: {binascii.hexlify(auth).decode()}')
def prekey_cmd(args):
with open(args.keyfile, 'r') as f:
root_key = binascii.unhexlify(f.read().strip())
print('#include <stdint.h>')
print('#include <assert.h>')
print()
print('#include "crypto.h"')
print()
bundle_id = uuid.uuid4().bytes
print(f'/* bundle id {binascii.hexlify(bundle_id).decode()} */')
print(f'uint8_t presig_bundle_id[16] = {format_hex(bundle_id)};')
print()
print(f'/* generated on {datetime.now()} */')
print(f'uint64_t bundle_timestamp = {int(time.time())};')
print()
print(f'int presig_height = {args.max_height};')
print()
print('const char *presig_domain_strings[_TRIGGER_DOMAIN_COUNT] = {')
for domain in DOMAINS:
ds = domain_string(domain, getattr(args, domain))
assert '"' not in ds
print(f' [TRIGGER_DOMAIN_{domain.upper()}] = "{ds}",')
print('};')
print()
print('uint8_t presig_keys[_TRIGGER_DOMAIN_COUNT][PRESIG_MSG_LEN] = {')
for domain in DOMAINS:
key = gen_at_height(domain, getattr(args, domain), args.max_height, root_key)
print(f' [TRIGGER_DOMAIN_{domain.upper()}] = {{{format_hex(key, indent=0, wrap=False)}}},')
print('};')
print()
print('static inline void __hack_asserts_only(void) {')
print(f' static_assert(_TRIGGER_DOMAIN_COUNT == {len(DOMAINS)});')
print(f' static_assert(PRESIG_MSG_LEN == {LINKING_KEY_SIZE});')
print('}')
print()
TEST_VENDOR = 'Darthenschmidt Cyberei und Verschleierungstechnik GmbH'
TEST_SERIES = 'Frobnicator v0.23.7'
TEST_REGION = 'Neuland'
TEST_COUNTRY = 'Germany'
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('keyfile', help='Key file to use')
subparsers = parser.add_subparsers(title='subcommands')
keygen_parser = subparsers.add_parser('keygen', help='Generate a new key')
keygen_parser.add_argument('-f', '--force', action='store_true', help='Force overwriting existing keyfile')
keygen_parser.set_defaults(func=keygen_cmd)
auth_parser = subparsers.add_parser('auth', help='Generate one-time authentication string')
auth_parser.add_argument('height', type=int, help='Authentication string height, counting from 0 (root key)')
auth_parser.set_defaults(func=auth_cmd)
auth_parser.add_argument('-a', '--all', action='store_const', const='all', help='Vendor name for vendor domain')
auth_parser.add_argument('-v', '--vendor', type=str, nargs='?', const=TEST_VENDOR, help='Vendor name for vendor domain')
auth_parser.add_argument('-s', '--series', type=str, nargs='?', const=TEST_SERIES, help='Series identifier for series domain')
auth_parser.add_argument('-r', '--region', type=str, nargs='?', const=TEST_REGION, help='Region name for region domain')
auth_parser.add_argument('-c', '--country', type=str, nargs='?', const=TEST_COUNTRY, help='Country name for country domain')
prekey_parser = subparsers.add_parser('prekey', help='Generate prekey data .C source code file')
prekey_parser.add_argument('-m', '--max-height', type=int, default=8, help='Height of generated prekey')
prekey_parser.add_argument('-v', '--vendor', type=str, default=TEST_VENDOR, help='Vendor name for vendor domain')
prekey_parser.add_argument('-s', '--series', type=str, default=TEST_SERIES, help='Series identifier for series domain')
prekey_parser.add_argument('-r', '--region', type=str, default=TEST_REGION, help='Region name for region domain')
prekey_parser.add_argument('-c', '--country', type=str, default=TEST_COUNTRY, help='Country name for country domain')
prekey_parser.set_defaults(func=prekey_cmd, all='all')
args = parser.parse_args()
sys.exit(args.func(args))

View file

@ -0,0 +1,91 @@
import os, sys
import ctypes as C
import argparse
import binascii
import numpy as np
import timeit
import statistics
lib = C.CDLL('rslib.so')
lib.rslib_encode.argtypes = [C.c_int, C.c_size_t, C.POINTER(C.c_char), C.POINTER(C.c_char)]
lib.rslib_decode.argtypes = [C.c_int, C.c_size_t, C.POINTER(C.c_char)]
lib.rslib_gexp.argtypes = [C.c_int, C.c_int]
lib.rslib_gexp.restype = C.c_int
lib.rslib_decode.restype = C.c_int
lib.rslib_npar.restype = C.c_size_t
def npar():
return lib.rslib_npar()
def encode(data: bytes, nbits=8):
out = C.create_string_buffer(len(data) + lib.rslib_npar())
lib.rslib_encode(nbits, len(data), data, out)
return out.raw
def decode(data: bytes, nbits=8):
inout = C.create_string_buffer(data)
lib.rslib_decode(nbits, len(data), inout)
return inout.raw[:-lib.rslib_npar() - 1]
def cmdline_func_test(args, print=lambda *args, **kwargs: None, benchmark=False):
st = np.random.RandomState(seed=args.seed)
lfsr = [lib.rslib_gexp(i, args.bits) for i in range(2**args.bits - 1)]
print('LFSR', len(set(lfsr)), lfsr)
assert all(0 < x < 2**args.bits for x in lfsr)
assert len(set(lfsr)) == 2**args.bits - 1
print('Seed', args.seed)
for i in range(args.repeat):
print(f'Run {i}')
test_data = bytes(st.randint(2**args.bits, size=args.message_length, dtype=np.uint8))
print(' Raw:', binascii.hexlify(test_data).decode())
encoded = encode(test_data, nbits=args.bits)
print(' Encoded:', binascii.hexlify(encoded).decode())
indices = st.permutation(len(encoded))
encoded = list(encoded)
for pos in indices[:args.errors]:
encoded[pos] = st.randint(2**args.bits)
encoded = bytes(encoded)
print(' Modified:', ''.join(f'\033[91m{b:02x}\033[0m' if pos in indices[:args.errors] else f'{b:02x}' for pos, b in enumerate(encoded)))
if benchmark:
rpt = 10000
delta = timeit.timeit('decode(encoded, nbits=args.bits)',
globals={'args': args, 'decode': decode, 'encoded': encoded},
number=rpt)/rpt
print(f'Decoding runtime: {delta*1e6:.3f}μs')
decoded = decode(encoded, nbits=args.bits)
print(' Decoded:', binascii.hexlify(decoded).decode())
print(' Delta:', binascii.hexlify(
bytes(x^y for x, y in zip(test_data, decoded))
).decode().replace('0', '.'))
assert test_data == decoded
def cmdline_func_encode(args, **kwargs):
data = np.frombuffer(binascii.unhexlify(args.hex_str), dtype=np.uint8)
# Map 8 bit input to 6 bit symbol string
data = np.packbits(np.pad(np.unpackbits(data).reshape((-1, 6)), ((0,0),(2, 0))).flatten())
encoded = encode(data.tobytes(), nbits=args.bits)
print('symbol array:', ', '.join(f'0x{x:02x}' for x in encoded))
print('hex string:', binascii.hexlify(encoded).decode())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
cmd_parser = parser.add_subparsers(required=True)
test_parser = cmd_parser.add_parser('test', help='Test reed-solomon implementation')
test_parser.add_argument('-m', '--message-length', type=int, default=6, help='Test message (plaintext) length in bytes')
test_parser.add_argument('-e', '--errors', type=int, default=2, help='Number of byte errors to insert into simulation')
test_parser.add_argument('-r', '--repeat', type=int, default=1000, help='Repeat experiment -r times')
test_parser.add_argument('-b', '--bits', type=int, default=8, help='Symbol bit size')
test_parser.add_argument('-s', '--seed', type=int, default=0, help='Random seed')
test_parser.set_defaults(func=cmdline_func_test)
enc_parser = cmd_parser.add_parser('encode', help='RS-Encode given hex string')
enc_parser.set_defaults(func=cmdline_func_encode)
enc_parser.add_argument('-b', '--bits', type=int, default=8, help='Symbol bit size')
enc_parser.add_argument('hex_str', type=str, help='Input data as hex string')
args = parser.parse_args()
args.func(args)