188 lines
4.9 KiB
Python
188 lines
4.9 KiB
Python
#!/usr/bin/env python
|
|
|
|
import time
|
|
from pprint import pprint
|
|
from enum import Enum
|
|
from functools import cache
|
|
from dataclasses import dataclass, fields, astuple
|
|
import struct
|
|
import binascii
|
|
|
|
import numpy as np
|
|
import click
|
|
import serial
|
|
from cobs import cobs
|
|
|
|
class CobsSerial:
|
|
def __init__(self, port, timeout):
|
|
self.ser = serial.Serial(port, timeout=timeout)
|
|
self.ser.flushOutput()
|
|
self.ser.flushInput()
|
|
self.ser.write(bytes([0])) # synchronize
|
|
self.ser.flushOutput()
|
|
|
|
def write_packet(self, data):
|
|
self.ser.write(cobs.encode(data))
|
|
self.ser.write(bytes([0]))
|
|
self.ser.flushOutput()
|
|
|
|
def read_packet(self):
|
|
data = b''
|
|
while (b := self.ser.read(1)):
|
|
if b[0] == 0:
|
|
break
|
|
data += b
|
|
|
|
if data:
|
|
return parse_packet(cobs.decode(data))
|
|
else:
|
|
return None
|
|
|
|
def command(self, command, args=b''):
|
|
self.write_packet(bytes([command.value]) + args)
|
|
return self.read_packet()
|
|
|
|
|
|
class SerializableEnum(Enum):
|
|
def __int__(self):
|
|
return self.value
|
|
|
|
class PacketType(SerializableEnum):
|
|
USBP_GET_STATUS = 0
|
|
USBP_GET_MEASUREMENTS = 1
|
|
USBP_SET_MOTOR = 2
|
|
|
|
class ErrorCode(Enum):
|
|
ERR_SUCCESS = 0
|
|
ERR_TIMEOUT = 1
|
|
ERR_PHYSICAL_LAYER = 2
|
|
ERR_FRAMING = 3
|
|
ERR_PROTOCOL = 4
|
|
ERR_DMA = 5
|
|
ERR_BUSY = 6
|
|
ERR_BUFFER_OVERFLOW = 7
|
|
ERR_RX_OVERRUN = 8
|
|
ERR_TX_OVERRUN = 9
|
|
|
|
class BoardConfig(Enum):
|
|
BCFG_UNCONFIGURED = 0
|
|
BCFG_DISPLAY = 1
|
|
BCFG_MOTOR = 2
|
|
BCFG_MEAS = 3
|
|
|
|
class Serialized:
|
|
@classmethod
|
|
def deserialize(kls, data):
|
|
fields = struct.unpack(kls._struct_format(), data)
|
|
mapped = [cast(val) for cast, val in zip(kls._struct_casts(), fields)]
|
|
return kls(*mapped)
|
|
|
|
def serialize(self):
|
|
mapped = [uncast(val) for uncast, val in zip(self._struct_uncasts(), astuple(self))]
|
|
return struct.pack(self._struct_format(), *mapped)
|
|
|
|
@classmethod
|
|
@cache
|
|
def _struct_format(kls):
|
|
return kls._parse_fields()[0]
|
|
|
|
@classmethod
|
|
@cache
|
|
def _struct_casts(kls):
|
|
return kls._parse_fields()[1]
|
|
|
|
@classmethod
|
|
@cache
|
|
def _struct_uncasts(kls):
|
|
return kls._parse_fields()[2]
|
|
|
|
@classmethod
|
|
def _parse_fields(kls):
|
|
fmt = '<'
|
|
casts = []
|
|
uncasts = []
|
|
for field in fields(kls):
|
|
if isinstance(field.type, tuple):
|
|
struct_type, cast, uncast, *_ = *field.type, int
|
|
else:
|
|
struct_type, cast, uncast = field.type, int, int
|
|
fmt += struct_type
|
|
casts.append(cast)
|
|
uncasts.append(uncast)
|
|
return fmt, casts, uncasts
|
|
|
|
def timestamp(value):
|
|
return float(value) / 1e6
|
|
|
|
@dataclass
|
|
class StatusPacket(Serialized):
|
|
packet_type: ('B', PacketType)
|
|
sys_time_us: ('Q', timestamp)
|
|
has_lcd: ('B', bool)
|
|
has_adc: ('B', bool)
|
|
board_config: ('B', BoardConfig)
|
|
bus_addr: 'B'
|
|
last_uart_error: ('B', ErrorCode)
|
|
last_uart_error_timestamp: ('Q', timestamp)
|
|
last_uart_rx: ('Q', timestamp)
|
|
last_uart_tx: ('Q', timestamp)
|
|
last_bus_error: ('B', ErrorCode)
|
|
last_bus_error_timestamp: ('Q', timestamp)
|
|
|
|
@dataclass
|
|
class MotorPacket(Serialized):
|
|
packet_type: ('B', PacketType)
|
|
speed_rpm: 'i'
|
|
|
|
def parse_packet(data):
|
|
packet_type = PacketType(data[0])
|
|
if packet_type == PacketType.USBP_GET_STATUS:
|
|
return StatusPacket.deserialize(data)
|
|
if packet_type == PacketType.USBP_GET_MEASUREMENTS:
|
|
return MeasurementPacket.deserialize(data)
|
|
else:
|
|
raise ValueError(f'Unsupported packet type {packet_type}')
|
|
|
|
@dataclass
|
|
class MeasurementPacket(Serialized):
|
|
packet_type: ('B', PacketType)
|
|
num_channels: 'B'
|
|
_num_samples_a: 'I'
|
|
_num_samples_b: 'I'
|
|
_measurements_raw: ('240s', bytes)
|
|
|
|
@property
|
|
def measurements(self):
|
|
return np.frombuffer(self._measurements_raw, np.dtype(np.int32).newbyteorder('<')).reshape([2, 2, -1])
|
|
|
|
@property
|
|
def num_samples(self):
|
|
return [self._num_samples_a, self._num_samples_b]
|
|
|
|
@click.group()
|
|
def cli():
|
|
pass
|
|
|
|
@cli.command()
|
|
@click.argument('port')
|
|
@click.option('--timeout', type=float, default=1)
|
|
def probe(port, timeout):
|
|
ser = CobsSerial(port, timeout)
|
|
pprint(ser.command(PacketType.USBP_GET_STATUS))
|
|
while True:
|
|
time.sleep(0.01)
|
|
packet = ser.command(PacketType.USBP_GET_MEASUREMENTS)
|
|
for i in range(packet.num_samples[1]):
|
|
print(packet.measurements[1,1,i], packet.num_samples[1])
|
|
|
|
@cli.command()
|
|
@click.argument('port')
|
|
@click.argument('speed_rpm', type=int, default=0)
|
|
@click.option('--timeout', type=float, default=1)
|
|
def motor(port, speed_rpm, timeout):
|
|
ser = CobsSerial(port, timeout)
|
|
packet = MotorPacket(PacketType.USBP_SET_MOTOR, speed_rpm)
|
|
ser.write_packet(packet.serialize())
|
|
|
|
if __name__ == '__main__':
|
|
cli()
|