import logging
import struct
import time
import typing

import ha_mqtt_discoverable
import ha_mqtt_discoverable.sensors
import paho.mqtt.client
import usb.core
import usb.util


class Desk:
    """
    Controls my Linak desk, which is a CBD4P controller connected via USB2LIN06
    This particular combination doesn't seem to report desk height,
    so it is estimated from the physical controller that does work.
    """

    # Source of data:

    # https://github.com/UrbanskiDawid/usb2lin06-HID-in-linux-for-LINAK-Desk-Control-Cable
    # https://github.com/monofox/python-linak-desk-control
    # https://github.com/gryf/linak-ctrl

    # Desk Control Basic Software
    # https://www.linak-us.com/products/controls/desk-control-basic-software/
    # Says it's connected but doesn't report height and buttons do nothing
    # Expected, as manual says it only works with CBD4A or CBD6
    # Decompiled with ILSpy (easy), doesn't offer much though

    # CBD4+5 Configurator
    # https://www.linak.nl/technische-ondersteuning/#/cbd4-cbd6s-configurator
    # Connects, and settings can be changed.
    # Don't think there's much that would help with our problem.
    # Tried to decompile with Ghidra (hard), didn't go super far

    VEND = 0x12D3
    PROD = 0x0002
    # Official apps use HID library, although only managed to barely make
    # pyhidapi read manufacturer and product once after device reset

    BUF_LEN = 64
    MOVE_CMD_REPEAT_INTERVAL = 0.2  # s
    STOP_CMD_INTERVAL = 1  # s
    MAX_EST_INTERVAL = 10  # s

    # Theoritical height values
    VALUE_MIN = 0x0000
    VALUE_MAX = 0x7FFE
    VALUE_DOWN = 0x7FFF
    VALUE_UP = 0x8000
    VALUE_STOP = 0x8001

    # Measured values
    VALUE_BOT = 0x0001
    VALUE_TOP = 0x1A50
    HEIGHT_BOT = 68
    HEIGHT_TOP = 135
    FULL_RISE_TIME = 17.13  # s
    FULL_FALL_TIME = 16.64  # s

    # Computed values
    HEIGHT_OFFSET = HEIGHT_BOT  # cm
    HEIGHT_MULT = VALUE_TOP / (HEIGHT_TOP - HEIGHT_BOT)  # unit / cm
    # Should be 100 in theory (1 unit = 0.1 mm)
    FULL_TIME = (FULL_FALL_TIME + FULL_RISE_TIME) / 2  # s
    SPEED_MARGIN = 0.9
    # Better estimate a bit slower
    SPEED = (VALUE_TOP - VALUE_BOT) / FULL_TIME * SPEED_MARGIN  # unit / s

    def _cmToUnit(self, height: float) -> int:
        return round((height - self.HEIGHT_OFFSET) * self.HEIGHT_MULT)

    def _unitToCm(self, height: int) -> float:
        return height / self.HEIGHT_MULT + self.HEIGHT_OFFSET

    def _get(self, typ: int, overflow_ok: bool = False) -> bytes:
        # Magic numbers: get class interface, HID get report
        raw = self._dev.ctrl_transfer(
            0xA1, 0x01, 0x300 + typ, 0, self.BUF_LEN
        ).tobytes()
        self.log.debug(f"Received {raw.hex()}")
        assert raw[0] == typ
        size = raw[1]
        end = 2 + size
        if not overflow_ok:
            assert end < self.BUF_LEN
        return raw[2:end]
        # Non-implemented types:
        # 1, 7: some kind of stream when the device isn't initialized?
        # size reduces the faster you poll, increases when buttons are held
        # 9: unknown, always report 0

    def _set(self, typ: int, buf: bytes) -> None:
        buf = bytes([typ]) + buf
        # The official apps pad, not that it doesn't seem to work without
        buf = buf + b"\x00" * (self.BUF_LEN - len(buf))
        self.log.debug(f"Sending {buf.hex()}")
        # Magic numbers: set class interface, HID set report
        self._dev.ctrl_transfer(0x21, 0x09, 0x300 + typ, 0, buf)
        # Non-implemented types:
        # Some stuff < 10

    def _reset_estimations(self) -> None:
        self.est_value: None | int = None
        self.est_value_bot = float(self.VALUE_BOT)
        self.est_value_top = float(self.VALUE_TOP)
        self.last_est: float = 0.0

    def _initialize(self) -> None:
        """
        Seems to take the USB2LIN06 out of "boot mode"
        (name according to CBD4 Controller) which it is after reset.
        Permits control and reading the report.
        """
        buf = bytes([0x04, 0x00, 0xFB])
        self._set(3, buf)
        time.sleep(0.5)

    def __init__(self) -> None:
        self.log = logging.getLogger("Desk")
        self._dev = usb.core.find(idVendor=Desk.VEND, idProduct=Desk.PROD)
        if not self._dev:
            raise ValueError(
                f"Device {Desk.VEND}:" f"{Desk.PROD:04d} " f"not found!"
            )

        if self._dev.is_kernel_driver_active(0):
            self._dev.detach_kernel_driver(0)

        self._initialize()
        self._reset_estimations()
        self.last_destination = None

        self.fetch_callback: typing.Callable[["Desk"], None] | None = None

    def _get_report(self) -> bytes:
        raw = self._get(4)
        assert len(raw) == 0x38
        return raw

    def _update_estimations(self) -> None:
        now = time.time()
        delta_s = now - self.last_est

        if delta_s > self.MAX_EST_INTERVAL:
            # Attempt at fixing the issue of
            # the service not working after the night
            self._initialize()
            self.log.warning(
                "Too long without getting a report, "
                "assuming the desk might be anywhere now."
            )
            self._reset_estimations()
        else:
            delta_u = delta_s * self.SPEED

            if self.destination == self.VALUE_STOP:
                pass
            elif self.destination == self.VALUE_UP:
                self.est_value_bot += delta_u
                self.est_value_top += delta_u
            elif self.destination == self.VALUE_DOWN:
                self.est_value_bot -= delta_u
                self.est_value_top -= delta_u
            else:

                def move_closer(start_val: float) -> float:
                    if start_val < self.destination:
                        end_val = start_val + delta_u
                        return min(end_val, self.destination)
                    else:
                        end_val = start_val - delta_u
                        return max(end_val, self.destination)

                self.est_value_bot = move_closer(self.est_value_bot)
                self.est_value_top = move_closer(self.est_value_top)

            # Clamp
            self.est_value_bot = max(self.VALUE_BOT, self.est_value_bot)
            self.est_value_top = min(self.VALUE_TOP, self.est_value_top)

            if self.est_value_top == self.est_value_bot:
                if self.est_value is None:
                    self.log.info("Height estimation converged")
                self.est_value = int(self.est_value_top)

        self.last_est = now

    def fetch(self) -> None:
        for _ in range(3):
            try:
                raw = self._get_report()
                break
            except usb.USBError as e:
                self.log.error(e)
        else:
            raw = self._get_report()

        # Allegedly, from decompiling:
        # https://www.linak-us.com/products/controls/desk-control-basic-software/
        # Never reports anything in practice
        self.value = struct.unpack("<H", raw[0:2])[0]
        unk = struct.unpack("<H", raw[2:4])[0]
        self.initalized = (unk & 0xF) != 0

        # From observation. Reliable
        self.destination = (struct.unpack("<H", raw[18:20])[0],)[0]

        if self.destination != self.last_destination:
            self.log.info(f"Destination changed to {self.destination:04x}")
            self.last_destination = self.destination

        self._update_estimations()
        if self.fetch_callback is not None:
            self.fetch_callback(self)

    def _move(self, position: int) -> None:
        buf = struct.pack("<H", position) * 4
        self._set(5, buf)

    def _move_to(self, position: int) -> None:
        # Clamp
        position = max(self.VALUE_BOT, position)
        position = min(self.VALUE_TOP, position)

        self.log.info(f"Start moving to {position:04x}")
        self.fetch()
        while self.est_value != position:
            self._move(position)
            time.sleep(self.MOVE_CMD_REPEAT_INTERVAL)
            self.fetch()
        self.stop()

    def move_to(self, position: float) -> None:
        """
        If any button is held during movement, the desk will stop moving,
        yet this will think it's still moving, throwing off the estimates.
        It's not a bug, it's a safety feature.
        Also if you try to make it move when it's already moving,
        it's going to keep moving while desyncing.
        That one is a bug.
        """
        # Would to stop for a while before reversing course, without being able
        # to read the actual height it's just too annoying to implement
        return self._move_to(self._cmToUnit(position))

    def stop(self) -> None:
        self.log.info("Stop moving")
        self._move(self.VALUE_STOP)
        time.sleep(0.5)

    def get_height_bounds(self) -> tuple[float, float]:
        return (
            self._unitToCm(int(self.est_value_bot)),
            self._unitToCm(int(self.est_value_top)),
        )

    def get_height(self) -> float | None:
        if self.est_value is None:
            return None
        else:
            return self._unitToCm(self.est_value)


if __name__ == "__main__":
    logging.basicConfig()
    log = logging.getLogger(__name__)

    desk = Desk()
    serial = "000C-34E7"

    # Configure the required parameters for the MQTT broker
    mqtt_settings = ha_mqtt_discoverable.Settings.MQTT(host="192.168.7.53")
    ndigits = 1
    target_height: float | None = None

    device_info = ha_mqtt_discoverable.DeviceInfo(
        name="Desk",
        identifiers=["Linak", serial],
        manufacturer="Linak",
        model="CBD4P",
        suggested_area="Desk",
        hw_version="77402",
        sw_version="1.91",
        serial_number=serial,
    )

    common_opts = {
        "device": device_info,
        "icon": "mdi:desk",
        "unit_of_measurement": "cm",
        "device_class": "distance",
        "expire_after": 10,
    }
    # TODO Implement proper availability in hq-mqtt-discoverable

    height_info = ha_mqtt_discoverable.sensors.NumberInfo(
        name="Height ",
        min=desk.HEIGHT_BOT,
        max=desk.HEIGHT_TOP,
        mode="slider",
        step=10 ** (-ndigits),
        unique_id="desk_height",
        **common_opts,
    )
    height_settings = ha_mqtt_discoverable.Settings(
        mqtt=mqtt_settings, entity=height_info
    )

    def height_callback(
        client: paho.mqtt.client.Client,
        user_data: None,
        message: paho.mqtt.client.MQTTMessage,
    ) -> None:
        global target_height
        target_height = float(message.payload.decode())
        log.info(f"Requested height to {target_height:.1f}")

    height = ha_mqtt_discoverable.sensors.Number(
        height_settings, height_callback
    )

    height_max_info = ha_mqtt_discoverable.sensors.SensorInfo(
        name="Estimated height max",
        unique_id="desk_height_max",
        entity_category="diagnostic",
        **common_opts,
    )
    height_max_settings = ha_mqtt_discoverable.Settings(
        mqtt=mqtt_settings, entity=height_max_info
    )
    height_max = ha_mqtt_discoverable.sensors.Sensor(height_max_settings)

    height_min_info = ha_mqtt_discoverable.sensors.SensorInfo(
        name="Estimated height min",
        unique_id="desk_height_min",
        entity_category="diagnostic",
        **common_opts,
    )
    height_min_settings = ha_mqtt_discoverable.Settings(
        mqtt=mqtt_settings, entity=height_min_info
    )
    height_min = ha_mqtt_discoverable.sensors.Sensor(height_min_settings)

    last_published_state = None

    def fetch_callback(desk: Desk) -> None:
        hcur = desk.get_height()
        hmin, hmax = desk.get_height_bounds()
        global last_published_state

        state = hcur, hmin, hmax
        if state == last_published_state:
            return
        last_published_state = state

        # If none this will set as unknown
        # Also readings can be a bit outside the boundaries,
        # so this skips verification
        if isinstance(hcur, float):
            hcur = round(hcur, ndigits=ndigits)
        height._update_state(hcur)

        height_max._update_state(round(hmax, ndigits=ndigits))
        height_min._update_state(round(hmin, ndigits=ndigits))

    desk.fetch_callback = fetch_callback

    interval = 0.2
    # Need to be rective to catch
    while True:
        if target_height:
            temp_target_height = target_height
            # Allows queuing of other instructions while moving
            target_height = None
            desk.move_to(temp_target_height)
        else:
            time.sleep(interval)
            desk.fetch()