Source code for trs_interface.protocol.decoder

# Copyright 2021 Patrick C. Tapping
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program.  If not, see <http://www.gnu.org/licenses/>.

"""
The bulk of the ``decoder`` module contains functions to decode raw byte strings to dictionaries
containing the message type and data fields. The majority of these are automatically generated,
and consequently not documented.
Fortunately, a user of the ``protocol`` should not need to know anything about these methods,
and instead simply rely on the :class:`StreamDecoder` to split messages out from a continuous
byte stream (such as serial port data), decode them, and serve the decoded dictionaries.

For example:

.. code-block:: python

    from serial import Serial
    from trs_interface.protocol.decoder import StreamDecoder

    # Open a serial port connection
    serial_port = Serial("/dev/ttyACM0", timeout=0.1, write_timeout=0.1)
    # Create the decoder for received messages
    decoder = StreamDecoder(serial_port, on_error="warn")

    # Split out messages waiting on the serial port and return them
    for msg in decoder:
        print(f"Received message: id={msg.id}={msg.msg}, data={msg.data}, payload={msg.payload}")

Note though that an end user of the ``trs_interface`` should not even need to deal with anything
at the ``protocol`` layer, and instead interact with the device purely through the
:class:`TRSInterface <trs_interface.TRSInterface>` class. 
"""

import io
import struct
import functools
import logging
from collections import namedtuple
from typing import Dict, Any, Optional, Sequence

import numpy as np

from . import HEADER_SIZE, LONG_FORM, ID, Polarity

decoder_for_id = {}
"""Dictionary for looking up a function which can decode data corresponding to a given message ID value."""

_log = logging.getLogger(__name__)

def _decoder(msgid):
    """
    Decorator to indicate a function which decodes a message packet.

    :param msgid: Identification code corresponding to the message.
    """
    def decoder_decorator(func):

        @functools.wraps(func)
        def decoder_wrapper(data_raw: bytes) -> Dict[str, Any]:
            """
            Decode a message header, then pass ``data`` on to the wrapped function.
            
            Messages start with the magic string ``"MSG:"``, followed by a two-byte message ID code,
            and then a 4-byte data field.
            """
            # Decode the message header and ensure it looks sensible
            msg_magic, msg_id, msg_data = struct.unpack_from("<4sHi", data_raw)
            if not msg_magic == b"MSG:":
                raise RuntimeError("Decoded message does not start with 'MSG:' prefix.")
            if not msg_id == msgid:
                raise RuntimeError(f"Decoded message id={msg_id} does not match expected value={id} for '{func.__name__}' messages.")
            # Look at any payload data and ensure it seems formatted correctly
            msg_payload = data_raw[HEADER_SIZE:]
            if len(msg_payload) and not (msg_id & LONG_FORM):
                _log.warn(f"Message packet for '{func.__name__}' id={msg_id:#06x} contains {len(msg_payload)} bytes of unexpected payload data, ignoring.")
                msg_payload = b""
            if len(msg_payload):
                # Long form message, check if fixed length or terminated version
                if msg_data <= 0:
                    # Terminated long form message type
                    if not struct.unpack_from("<i", msg_payload[-4:])[0] == msg_data:
                        _log.warn(f"Message payload for '{func.__name__}' id={msg_id} is not terminated by msg_data={msg_data}, appending.")
                        msg_payload += struct.pack("<i", msg_data)
                else:
                    # Fixed-length long form message type
                    if not len(msg_payload) == msg_data:
                        _log.warn(f"Message payload size={len(msg_payload)} for '{func.__name__}' id={msg_id} doesn't match expected msg_data={msg_data:#010x}, correcting.")
                        msg_data = len(msg_payload)
            msg = {"msg": func.__name__, "id": msg_id, "data": msg_data, "payload": msg_payload}
            # Let the wrapped function decode any further info from the byte data
            msg.update(func(data_raw))
            return msg

        # Add the message id and corresponding decode function to the lookup dictionary
        if msgid in decoder_for_id:
            raise ValueError(f"Duplicated message definition '{func.__name__}' for id={msgid:#x}='{decoder_for_id[msgid].__name__}'.")
        decoder_for_id[msgid] = decoder_wrapper
        return decoder_wrapper

    return decoder_decorator


[docs]class StreamDecoder: """ Create a StreamDecoder to decode a byte stream into messages to or from the TRSInterface. The ``stream`` parameter should be an object which data can be sourced from. It should support the ``read()`` method. The ``on_error`` parameter selects the action to take if invalid data is detected. If set to ``"continue"`` (the default), bytes will be discarded if the byte sequence does not appear to be a valid message. If set to ``"warn"``, the behaviour is identical, but a warning message will be emitted. To instead immediately abort the stream decoding and raise a ``RuntimeError``, set to ``"raise"``. :param stream: A data stream from which data can be ``read()`` from. :param on_error: Action to take if invalid data is detected. """ def __init__(self, stream=None, on_error="warn", max_message_size=2**22): if stream is None: self._file = io.BytesIO() else: self._file = stream self.buffer = b"" self.max_message_size = max_message_size self.on_error = on_error def __iter__(self): return self def _decoding_error(self, message="Error decoding message from buffer."): """ Take appropriate action if parsing of data stream fails. :param message: Warning or error message string. """ if self.on_error == "raise": raise RuntimeError(message) if self.on_error == "warn": _log.warn(message) # Discard first byte of buffer, it might decode better now... self.buffer = self.buffer[1:] def __next__(self): # Basic message packet is MSG_HEADER_SIZE bytes, try to fill buffer to at least that size if len(self.buffer) < HEADER_SIZE: self.buffer += self._file.read(HEADER_SIZE - len(self.buffer)) # Hopefully enough data in buffer now to try to decode a message while len(self.buffer) >= HEADER_SIZE: # Ensure the data follows the message format msg_magic, msg_id, msg_data = struct.unpack_from("<4sHi", self.buffer) if not msg_magic == b"MSG:": self._decoding_error(f"Invalid message prefix='{msg_magic}'") continue if not msg_id in decoder_for_id: self._decoding_error(f"Invalid message with id={msg_id:#x}") continue # MSB of id indicates message is a long form type long_form = bool(msg_id & LONG_FORM) # Message header looks OK, break from loop and proceed break # If we got here, either the buffer was/shrank too small, # or we have the start of something that looks like a valid message if len(self.buffer) < HEADER_SIZE: # Not enough data to form a message packet raise StopIteration # Buffer contains enough for a short message, but maybe not a long form one length = 0 if long_form and msg_data > 0: # Long form message, and data field encodes fixed payload length length = msg_data # Check if message would exceed limit if (HEADER_SIZE + length) > self.max_message_size: self.buffer = self.buffer[HEADER_SIZE:] raise BufferError("Expected message length exceeds maximum message size.") if len(self.buffer) < HEADER_SIZE + length: # Not enough data in buffer to decode long form message, attempt to read some more data self.buffer += self._file.read(length - len(self.buffer) + HEADER_SIZE) if len(self.buffer) < HEADER_SIZE + length: # Still didn't receive enough data to decode message raise StopIteration elif long_form and msg_data <= 0: # Long form message, and data field encodes (4-byte) terminator for variable length payload while (len(self.buffer) < HEADER_SIZE + 4) or not (struct.unpack_from("<i", self.buffer[-4:])[0] == msg_data): # Not enough data in buffer or terminator not found yet self.buffer += self._file.read(1) if (len(self.buffer) <= HEADER_SIZE + length): # Another byte wasn't available raise StopIteration length = len(self.buffer) - HEADER_SIZE # Check if message size exceeds limit if (HEADER_SIZE + length + 4) > self.max_message_size: self.buffer = self.buffer[HEADER_SIZE + length:] raise BufferError("Variable length message has exceeded maximum message size.") # Have enough data in buffer to decode the full message data = self.buffer[:length + HEADER_SIZE] # Can now remove the message data from the buffer self.buffer = self.buffer[length + HEADER_SIZE:] # Decode the message contents msg_dict = decoder_for_id[msg_id](data) return namedtuple(msg_dict["msg"], msg_dict.keys())(**msg_dict) def _reset(self): """ Reset the receive buffer. """ self.buffer = b""
[docs] def feed(self, data: bytes): """ Add byte data to the input stream. The input stream must support random access (using the ``seek()`` method), and thus is not applicable to sources such as serial port input. :param data: Byte array containing data to add. """ pos = self._file.tell() self._file.seek(0, 2) self._file.write(data) self._file.seek(pos)
[docs]@_decoder(ID.GET_PING) # id=0x0000 def get_ping(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_PING) # id=0x0001 def got_ping(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_UNKNOWN_MSG) # id=0x0011 def got_unknown_msg(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GET_VERSION) # id=0x0020 def get_version(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_VERSION) # id=0x0021 | LONG_FORM def got_version(data_raw: bytes) -> Dict[str, Any]: # Long form message, decode payload data _, _, msg_data = struct.unpack_from("<4sHi", data_raw) # Omit terminator if terminated style message msg_payload = data_raw[HEADER_SIZE:(-4 if msg_data <= 0 else None)] return { "version": msg_payload.decode("ascii"), }
[docs]@_decoder(ID.STORE_SETTINGS) # id=0x08002 def store_settings(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GET_LASER_SYNC_POLARITY) # id=0x1000 def get_laser_sync_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_LASER_SYNC_POLARITY) # id=0x1001 def got_laser_sync_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.SET_LASER_SYNC_POLARITY) # id=0x1002 def set_laser_sync_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GET_CHOPPER_SYNCIN_POLARITY) # id=0x1010 def get_chopper_syncin_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_CHOPPER_SYNCIN_POLARITY) # id=0x1011 def got_chopper_syncin_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.SET_CHOPPER_SYNCIN_POLARITY) # id=0x1012 def set_chopper_syncin_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GET_CHOPPER_SYNCOUT_POLARITY) # id=0x1020 def get_chopper_syncout_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_CHOPPER_SYNCOUT_POLARITY) # id=0x1021 def got_chopper_syncout_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.SET_CHOPPER_SYNCOUT_POLARITY) # id=0x1022 def set_chopper_syncout_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GET_DELAY_TRIG_POLARITY) # id=0x1030 def get_delay_trig_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_DELAY_TRIG_POLARITY) # id=0x1031 def got_delay_trig_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.SET_DELAY_TRIG_POLARITY) # id=0x1032 def set_delay_trig_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GET_CAMERA_TRIG_POLARITY) # id=0x1050 def get_camera_trig_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_CAMERA_TRIG_POLARITY) # id=0x1051 def got_camera_trig_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.SET_CAMERA_TRIG_POLARITY) # id=0x1052 def set_camera_trig_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GET_QUADRATURE_POLARITY) # id=0x1060 def get_quadrature_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_QUADRATURE_POLARITY) # id=0x1061 def got_quadrature_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.SET_QUADRATURE_POLARITY) # id=0x1062 def set_quadrature_polarity(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GET_QUADRATURE_DIRECTION) # id=0x1070 def get_quadrature_direction(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_QUADRATURE_DIRECTION) # id=0x1071 def got_quadrature_direction(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.SET_QUADRATURE_DIRECTION) # id=0x1072 def set_quadrature_direction(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GET_CHOPPER_SYNC_DELAY) # id=0x1100 def get_chopper_sync_delay(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_CHOPPER_SYNC_DELAY) # id=0x1101 def got_chopper_sync_delay(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.SET_CHOPPER_SYNC_DELAY) # id=0x1102 def set_chopper_sync_delay(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GET_CHOPPER_SYNC_DURATION) # id=0x1110 def get_chopper_sync_duration(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_CHOPPER_SYNC_DURATION) # id=0x1111 def got_chopper_sync_duration(data_raw: bytes) -> Dict[str, Any]: # Data is unsigned int return { "data" : struct.unpack_from("I", data_raw, offset=6)[0] }
[docs]@_decoder(ID.SET_CHOPPER_SYNC_DURATION) # id=0x1112 def set_chopper_sync_duration(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GET_CAMERA_SYNC_DELAY) # id=0x1120 def get_camera_sync_delay(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_CAMERA_SYNC_DELAY) # id=0x1121 def got_camera_sync_delay(data_raw: bytes) -> Dict[str, Any]: # Data is unsigned int return { "data" : struct.unpack_from("I", data_raw, offset=6)[0] }
[docs]@_decoder(ID.SET_CAMERA_SYNC_DELAY) # id=0x1122 def set_camera_sync_delay(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GET_CAMERA_SYNC_DURATION) # id=0x1130 def get_camera_sync_duration(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_CAMERA_SYNC_DURATION) # id=0x1131 def got_camera_sync_duration(data_raw: bytes) -> Dict[str, Any]: # Data is unsigned int return { "data" : struct.unpack_from("I", data_raw, offset=6)[0] }
[docs]@_decoder(ID.SET_CAMERA_SYNC_DURATION) # id=0x1132 def set_camera_sync_duration(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GET_CHOPPER_DIVIDER) # id=0x1200 def get_chopper_divider(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_CHOPPER_DIVIDER) # id=0x1201 def got_chopper_divider(data_raw: bytes) -> Dict[str, Any]: # Data is unsigned int return { "data" : struct.unpack_from("I", data_raw, offset=6)[0] }
[docs]@_decoder(ID.SET_CHOPPER_DIVIDER) # id=0x1202 def set_chopper_divider(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GET_QUADRATURE_VALUE) # id=0x1210 def get_quadrature_value(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_QUADRATURE_VALUE) # id=0x1211 def got_quadrature_value(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.SET_QUADRATURE_VALUE) # id=0x1212 def set_quadrature_value(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.TRIGGER) # id=0x2004 def trigger(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.ARM) # id=0x2008 def arm(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.START) # id=0x2018 def start(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.STOP) # id=0x2019 def stop(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_DATA) # id=0x4001 def got_data(data_raw: bytes) -> Dict[str, Any]: # Long form message, decode payload data _, _, msg_data = struct.unpack_from("<4sHi", data_raw) # Omit terminator if terminated style message msg_payload = np.frombuffer(data_raw[HEADER_SIZE:(-4 if msg_data <= 0 else None)], dtype=np.int32) # Quadrature value is top 31 bits (2x), chopper on/off is last bit (even/odd) return { "quadrature": msg_payload >> 1, "chopper": (msg_payload & 0x1).astype(bool) }
[docs]@_decoder(ID.GET_LASER_SYNC_PERIOD) # id=0x4100 def get_laser_sync_period(data_raw: bytes) -> Dict[str, Any]: return {}
[docs]@_decoder(ID.GOT_LASER_SYNC_PERIOD) # id=0x4101 def got_laser_sync_period(data_raw: bytes) -> Dict[str, Any]: # Data is unsigned int return { "data" : struct.unpack_from("I", data_raw, offset=6)[0] }