safety-reset/controller/fw/tools/dsss_demod_test_runner.py
2020-04-17 17:59:08 +02:00

207 lines
9 KiB
Python

#!/usr/bin/env python3
import os
from os import path
import subprocess
import json
from collections import namedtuple, defaultdict
from tqdm import tqdm
import uuid
import multiprocessing
import sqlite3
import time
from urllib.parse import urlparse
import functools
import tempfile
import itertools
import numpy as np
np.set_printoptions(linewidth=240)
from dsss_demod_test_waveform_gen import load_noise_meas_params, load_noise_synth_params,\
mains_noise_measured, mains_noise_synthetic, modulate as dsss_modulate
def build_test_binary(nbits, thf, decimation, symbols, cachedir):
build_id = str(uuid.uuid4())
builddir = path.join(cachedir, build_id)
os.mkdir(builddir)
cwd = path.join(path.dirname(__file__), '..')
env = os.environ.copy()
env['BUILDDIR'] = path.abspath(builddir)
env['DSSS_GOLD_CODE_NBITS'] = str(nbits)
env['DSSS_DECIMATION'] = str(decimation)
env['DSSS_THRESHOLD_FACTOR'] = str(thf)
env['DSSS_WAVELET_WIDTH'] = str(0.73 * decimation)
env['DSSS_WAVELET_LUT_SIZE'] = str(10 * decimation)
env['TRANSMISSION_SYMBOLS'] = str(symbols)
with open(path.join(builddir, 'make_stdout.txt'), 'w') as stdout,\
open(path.join(builddir, 'make_stderr.txt'), 'w') as stderr:
subprocess.run(['make', 'clean', os.path.abspath(path.join(builddir, 'tools/dsss_demod_test'))],
env=env, cwd=cwd, check=True, stdout=stdout, stderr=stderr)
return build_id
@functools.lru_cache()
def load_noise_gen(url):
schema, refpath = url.split('://')
if not path.isabs(refpath):
refpath = path.abspath(path.join(path.dirname(__file__), refpath))
if schema == 'meas':
return mains_noise_measured, load_noise_meas_params(refpath)
elif schema == 'synth':
return mains_noise_synthetic, load_noise_synth_params(refpath)
else:
raise ValueError('Invalid schema', schema)
def sequence_matcher(test_data, decoded, max_shift=3):
match_result = []
for shift in range(-max_shift, max_shift):
failures = -shift if shift < 0 else 0 # we're skipping the first $shift symbols
a = test_data if shift > 0 else test_data[-shift:]
b = decoded if shift < 0 else decoded[shift:]
for i, (ref, found) in enumerate(itertools.zip_longest(a, b)):
if ref is None: # end of signal
break
if ref != found:
failures += 1
match_result.append(failures)
failures = min(match_result)
return failures/len(test_data)
ResultParams = namedtuple('ResultParams', ['nbits', 'thf', 'decimation', 'symbols', 'seed', 'amplitude', 'background'])
def run_test(seed, amplitude_spec, background, nbits, decimation, symbols, thfs, lookup_binary, cachedir):
noise_gen, noise_params = load_noise_gen(background)
test_data = np.random.RandomState(seed=seed).randint(0, 2 * (2**nbits), symbols)
signal = np.repeat(dsss_modulate(test_data, nbits) * 2.0 - 1, decimation)
# We're re-using the seed here. This is not a problem.
noise = noise_gen(seed, len(signal), *noise_params)
amplitudes = amplitude_spec[0] * 10 ** np.linspace(0, amplitude_spec[1], amplitude_spec[2])
output = []
for amp in amplitudes:
with tempfile.NamedTemporaryFile(dir=cachedir) as f:
waveform = signal*amp + noise
f.write(waveform.astype('float').tobytes())
f.flush()
for thf in thfs:
cmdline = [lookup_binary(nbits, thf, decimation, symbols), f.name]
proc = subprocess.Popen(cmdline, stdout=subprocess.PIPE, text=True)
stdout, _stderr = proc.communicate()
if proc.returncode != 0:
raise SystemError(f'Subprocess signalled error: {proc.returncode=}')
lines = stdout.splitlines()
matched = [ l.partition('[')[2].partition(']')[0]
for l in lines if l.strip().startswith('data sequence received:') ]
matched = [ [ int(elem) for elem in l.split(',') ] for l in matched ]
ser = min(sequence_matcher(test_data, match) for match in matched) if matched else None
rpars = ResultParams(nbits, thf, decimation, symbols, seed, amp, background)
output.append((rpars, ser))
print(f'ran {rpars} {ser=} {" ".join(cmdline)}')
return output
def parallel_generator(db, table, columns, builder, param_list, desc, context={}, params_mapper=lambda *args: args):
with multiprocessing.Pool(multiprocessing.cpu_count()) as pool:
with db as conn:
jobs = []
for params in param_list:
found_res = conn.execute(
f'SELECT result FROM {table} WHERE ({",".join(columns)}) = ({",".join("?"*len(columns))})',
params_mapper(*params)).fetchone()
if found_res:
yield params, json.loads(*found_res)
else:
jobs.append((params, pool.apply_async(builder, params, context)))
pool.close()
print('Using', len(param_list) - len(jobs), 'cached jobs', flush=True)
with tqdm(total=len(jobs), desc=desc) as tq:
for params, res in jobs:
tq.update(1)
result = res.get()
with db as conn:
conn.execute(f'INSERT INTO {table} VALUES ({"?,"*len(params)}?,?)',
(*params_mapper(*params), json.dumps(result), timestamp()))
yield params, result
pool.join()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--dump', help='Write results to JSON file')
parser.add_argument('-c', '--cachedir', default='dsss_test_cache', help='Directory to store build output and data in')
args = parser.parse_args()
DecoderParams = namedtuple('DecoderParams', ['nbits', 'thf', 'decimation', 'symbols'])
dec_paramses = [ DecoderParams(nbits=nbits, thf=thf, decimation=decimation, symbols=20)
for nbits in [5, 6]
for thf in [4.5, 4.0, 5.0]
for decimation in [10, 5, 22] ]
# dec_paramses = [ DecoderParams(nbits=nbits, thf=thf, decimation=decimation, symbols=100)
# for nbits in [5, 6, 7, 8]
# for thf in [1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0]
# for decimation in [1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 16, 22, 30, 40, 50] ]
build_cache_dir = path.join(args.cachedir, 'builds')
data_cache_dir = path.join(args.cachedir, 'data')
os.makedirs(build_cache_dir, exist_ok=True)
os.makedirs(data_cache_dir, exist_ok=True)
build_db = sqlite3.connect(path.join(args.cachedir, 'build_db.sqlite3'))
build_db.execute('CREATE TABLE IF NOT EXISTS builds (nbits, thf, decimation, symbols, result, timestamp)')
timestamp = lambda: int(time.time()*1000)
builds = dict(parallel_generator(build_db, table='builds', columns=['nbits', 'thf', 'decimation', 'symbols'],
builder=build_test_binary, param_list=dec_paramses, desc='Building decoders',
context=dict(cachedir=build_cache_dir)))
print('Done building decoders.')
GeneratorParams = namedtuple('GeneratorParams', ['seed', 'amplitude_spec', 'background'])
gen_params = [ GeneratorParams(rep, (5e-3, 1, 5), background)
#GeneratorParams(rep, (0.05e-3, 3.5, 50), background)
for rep in range(30)
for background in ['meas://fmeas_export_ocxo_2day.bin', 'synth://grid_freq_psd_spl_108pt.json'] ]
data_db = sqlite3.connect(path.join(args.cachedir, 'data_db.sqlite3'))
data_db.execute('CREATE TABLE IF NOT EXISTS waveforms'
'(seed, amplitude_spec, background, nbits, decimation, symbols, thresholds, result, timestamp)')
dec_param_groups = defaultdict(lambda: [])
for nbits, thf, decimation, symbols in dec_paramses:
dec_param_groups[(nbits, decimation, symbols)].append(thf)
waveform_params = [ (*gp, *dp, thfs) for gp in gen_params for dp, thfs in dec_param_groups.items() ]
print(f'Generated {len(waveform_params)} parameter sets')
def lookup_binary(*params):
return path.join(build_cache_dir, builds[tuple(params)], 'tools/dsss_demod_test')
def params_mapper(seed, amplitude_spec, background, nbits, decimation, symbols, thresholds):
amplitude_spec = ','.join(str(x) for x in amplitude_spec)
thresholds = ','.join(str(x) for x in thresholds)
return seed, amplitude_spec, background, nbits, decimation, symbols, thresholds
results = []
for _params, chunk in parallel_generator(data_db, 'waveforms',
['seed', 'amplitude_spec', 'background', 'nbits', 'decimation', 'symbols', 'thresholds'],
params_mapper=params_mapper,
builder=run_test,
param_list=waveform_params, desc='Generating waveforms',
context=dict(cachedir=data_cache_dir, lookup_binary=lookup_binary)):
results += chunk
if args.dump:
with open(args.dump, 'w') as f:
json.dump(results, f)