#!/usr/bin/env python3

import datetime
import ipaddress
import json
import logging
import random
import socket
import subprocess

import coloredlogs
import mpd
import notmuch
import psutil
import pulsectl

from frobar.display import *
from frobar.updaters import *

coloredlogs.install(level="DEBUG", fmt="%(levelname)s %(message)s")
log = logging.getLogger()

# TODO Generator class (for I3WorkspacesProvider, NetworkProvider and later
# PulseaudioProvider and MpdProvider)


def humanSize(num):
    """
    Returns a string of width 3+3
    """
    for unit in ("B  ", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB"):
        if abs(num) < 1000:
            if num >= 10:
                return "{:3d}{}".format(int(num), unit)
            else:
                return "{:.1f}{}".format(num, unit)
        num /= 1024.0
    return "{:d}YiB".format(num)


def randomColor(seed=0):
    random.seed(seed)
    return "#{:02x}{:02x}{:02x}".format(*[random.randint(0, 255) for _ in range(3)])


class TimeProvider(StatefulSection, PeriodicUpdater):
    FORMATS = ["%H:%M", "%m-%d %H:%M:%S", "%a %y-%m-%d %H:%M:%S"]
    NUMBER_STATES = len(FORMATS)
    DEFAULT_STATE = 1

    def fetcher(self):
        now = datetime.datetime.now()
        return now.strftime(self.FORMATS[self.state])

    def __init__(self, theme=None):
        PeriodicUpdater.__init__(self)
        StatefulSection.__init__(self, theme)
        self.changeInterval(1)  # TODO OPTI When state < 1


class AlertLevel(enum.Enum):
    NORMAL = 0
    WARNING = 1
    DANGER = 2


class AlertingSection(StatefulSection):
    # TODO EASE Correct settings for themes
    THEMES = {AlertLevel.NORMAL: 2, AlertLevel.WARNING: 3, AlertLevel.DANGER: 1}
    PERSISTENT = True

    def getLevel(self, quantity):
        if quantity > self.dangerThresold:
            return AlertLevel.DANGER
        elif quantity > self.warningThresold:
            return AlertLevel.WARNING
        else:
            return AlertLevel.NORMAL

    def updateLevel(self, quantity):
        self.level = self.getLevel(quantity)
        self.updateTheme(self.THEMES[self.level])
        if self.level == AlertLevel.NORMAL:
            return
        # TODO Temporary update state

    def __init__(self, theme):
        StatefulSection.__init__(self, theme)
        self.dangerThresold = 0.90
        self.warningThresold = 0.75


class CpuProvider(AlertingSection, PeriodicUpdater):
    NUMBER_STATES = 3
    ICON = ""

    def fetcher(self):
        percent = psutil.cpu_percent(percpu=False)
        self.updateLevel(percent / 100)
        if self.state >= 2:
            percents = psutil.cpu_percent(percpu=True)
            return "".join([Section.ramp(p / 100) for p in percents])
        elif self.state >= 1:
            return Section.ramp(percent / 100)

    def __init__(self, theme=None):
        AlertingSection.__init__(self, theme)
        PeriodicUpdater.__init__(self)
        self.changeInterval(1)


class RamProvider(AlertingSection, PeriodicUpdater):
    """
    Shows free RAM
    """

    NUMBER_STATES = 4
    ICON = ""

    def fetcher(self):
        mem = psutil.virtual_memory()
        freePerc = mem.percent / 100
        self.updateLevel(freePerc)

        if self.state < 1:
            return None

        text = Text(Section.ramp(freePerc))
        if self.state >= 2:
            freeStr = humanSize(mem.total - mem.available)
            text.append(freeStr)
        if self.state >= 3:
            totalStr = humanSize(mem.total)
            text.append("/", totalStr)

        return text

    def __init__(self, theme=None):
        AlertingSection.__init__(self, theme)
        PeriodicUpdater.__init__(self)
        self.changeInterval(1)


class TemperatureProvider(AlertingSection, PeriodicUpdater):
    NUMBER_STATES = 2
    RAMP = ""

    def fetcher(self):
        allTemp = psutil.sensors_temperatures()
        if "coretemp" not in allTemp:
            # TODO Opti Remove interval
            return ""
        temp = allTemp["coretemp"][0]

        self.warningThresold = temp.high
        self.dangerThresold = temp.critical
        self.updateLevel(temp.current)

        self.icon = Section.ramp(temp.current / temp.high, self.RAMP)
        if self.state >= 1:
            return "{:.0f}°C".format(temp.current)

    def __init__(self, theme=None):
        AlertingSection.__init__(self, theme)
        PeriodicUpdater.__init__(self)
        self.changeInterval(5)


class BatteryProvider(AlertingSection, PeriodicUpdater):
    # TODO Support ACPID for events
    NUMBER_STATES = 3
    RAMP = ""

    def fetcher(self):
        bat = psutil.sensors_battery()
        if not bat:
            self.icon = None
            return None

        self.icon = ("" if bat.power_plugged else "") + Section.ramp(
            bat.percent / 100, self.RAMP
        )

        self.updateLevel(1 - bat.percent / 100)

        if self.state < 1:
            return

        t = Text("{:.0f}%".format(bat.percent))

        if self.state < 2:
            return t

        h = int(bat.secsleft / 3600)
        m = int((bat.secsleft - h * 3600) / 60)
        t.append(" ({:d}:{:02d})".format(h, m))
        return t

    def __init__(self, theme=None):
        AlertingSection.__init__(self, theme)
        PeriodicUpdater.__init__(self)
        self.changeInterval(5)


class PulseaudioProvider(StatefulSection, ThreadedUpdater):
    NUMBER_STATES = 3
    DEFAULT_STATE = 1

    def __init__(self, theme=None):
        ThreadedUpdater.__init__(self)
        StatefulSection.__init__(self, theme)
        self.pulseEvents = pulsectl.Pulse("event-handler")

        self.pulseEvents.event_mask_set(pulsectl.PulseEventMaskEnum.sink)
        self.pulseEvents.event_callback_set(self.handleEvent)
        self.start()
        self.refreshData()

    def fetcher(self):
        sinks = []
        with pulsectl.Pulse("list-sinks") as pulse:
            for sink in pulse.sink_list():
                if sink.port_active.name == "analog-output-headphones":
                    icon = ""
                elif sink.port_active.name == "analog-output-speaker":
                    icon = "" if sink.mute else ""
                elif sink.port_active.name == "headset-output":
                    icon = ""
                else:
                    icon = "?"
                vol = pulse.volume_get_all_chans(sink)
                fg = (sink.mute and "#333333") or (vol > 1 and "#FF0000") or None

                t = Text(icon, fg=fg)
                sinks.append(t)

                if self.state < 1:
                    continue

                if self.state < 2:
                    if not sink.mute:
                        ramp = " "
                        while vol >= 0:
                            ramp += self.ramp(vol if vol < 1 else 1)
                            vol -= 1
                        t.append(ramp)
                else:
                    t.append(" {:2.0f}%".format(vol * 100))

        return Text(*sinks)

    def loop(self):
        self.pulseEvents.event_listen()

    def handleEvent(self, ev):
        self.refreshData()


class NetworkProviderSection(StatefulSection, Updater):
    NUMBER_STATES = 5
    DEFAULT_STATE = 1

    def actType(self):
        self.ssid = None
        if self.iface.startswith("eth") or self.iface.startswith("enp"):
            if "u" in self.iface:
                self.icon = ""
            else:
                self.icon = ""
        elif self.iface.startswith("wlan") or self.iface.startswith("wl"):
            self.icon = ""
            if self.showSsid:
                cmd = ["iwgetid", self.iface, "--raw"]
                p = subprocess.run(cmd, stdout=subprocess.PIPE)
                self.ssid = p.stdout.strip().decode()
        elif self.iface.startswith("tun") or self.iface.startswith("tap"):
            self.icon = ""
        elif self.iface.startswith("docker"):
            self.icon = ""
        elif self.iface.startswith("veth"):
            self.icon = ""
        elif self.iface.startswith("vboxnet"):
            self.icon = ""
        else:
            self.icon = "?"

    def getAddresses(self):
        ipv4 = None
        ipv6 = None
        for address in self.parent.addrs[self.iface]:
            if address.family == socket.AF_INET:
                ipv4 = address
            elif address.family == socket.AF_INET6:
                ipv6 = address
        return ipv4, ipv6

    def fetcher(self):
        self.icon = None
        self.persistent = False
        if (
            self.iface not in self.parent.stats
            or not self.parent.stats[self.iface].isup
            or self.iface.startswith("lo")
        ):
            return None

        # Get addresses
        ipv4, ipv6 = self.getAddresses()
        if ipv4 is None and ipv6 is None:
            return None

        text = []
        self.persistent = True
        self.actType()

        if self.showSsid and self.ssid:
            text.append(self.ssid)

        if self.showAddress:
            if ipv4:
                netStrFull = "{}/{}".format(ipv4.address, ipv4.netmask)
                addr = ipaddress.IPv4Network(netStrFull, strict=False)
                addrStr = "{}/{}".format(ipv4.address, addr.prefixlen)
                text.append(addrStr)
            # TODO IPV6
            # if ipv6:
            #     text += ' ' + ipv6.address

        if self.showSpeed:
            recvDiff = (
                self.parent.IO[self.iface].bytes_recv
                - self.parent.prevIO[self.iface].bytes_recv
            )
            sentDiff = (
                self.parent.IO[self.iface].bytes_sent
                - self.parent.prevIO[self.iface].bytes_sent
            )
            recvDiff /= self.parent.dt
            sentDiff /= self.parent.dt
            text.append("↓{}↑{}".format(humanSize(recvDiff), humanSize(sentDiff)))

        if self.showTransfer:
            text.append(
                "⇓{}⇑{}".format(
                    humanSize(self.parent.IO[self.iface].bytes_recv),
                    humanSize(self.parent.IO[self.iface].bytes_sent),
                )
            )

        return " ".join(text)

    def onChangeState(self, state):
        self.showSsid = state >= 1
        self.showAddress = state >= 2
        self.showSpeed = state >= 3
        self.showTransfer = state >= 4

    def __init__(self, iface, parent):
        Updater.__init__(self)
        StatefulSection.__init__(self, theme=parent.theme)
        self.iface = iface
        self.parent = parent


class NetworkProvider(Section, PeriodicUpdater):
    def fetchData(self):
        self.prev = self.last
        self.prevIO = self.IO

        self.stats = psutil.net_if_stats()
        self.addrs = psutil.net_if_addrs()
        self.IO = psutil.net_io_counters(pernic=True)
        self.ifaces = self.stats.keys()

        self.last = time.perf_counter()
        self.dt = self.last - self.prev

    def fetcher(self):
        self.fetchData()

        # Add missing sections
        lastSection = self
        for iface in sorted(list(self.ifaces)):
            if iface not in self.sections.keys():
                section = NetworkProviderSection(iface, self)
                lastSection.appendAfter(section)
                self.sections[iface] = section
            else:
                section = self.sections[iface]
            lastSection = section

        # Refresh section text
        for section in self.sections.values():
            section.refreshData()

        return None

    def addParent(self, parent):
        self.parents.add(parent)
        self.refreshData()

    def __init__(self, theme=None):
        PeriodicUpdater.__init__(self)
        Section.__init__(self, theme)

        self.sections = dict()
        self.last = 0
        self.IO = dict()
        self.fetchData()
        self.changeInterval(5)


class RfkillProvider(Section, PeriodicUpdater):
    # TODO FEAT rfkill doesn't seem to indicate that the hardware switch is
    # toggled
    PATH = "/sys/class/rfkill"

    def fetcher(self):
        t = Text()
        for device in os.listdir(self.PATH):
            with open(os.path.join(self.PATH, device, "soft"), "rb") as f:
                softBlocked = f.read().strip() != b"0"
            with open(os.path.join(self.PATH, device, "hard"), "rb") as f:
                hardBlocked = f.read().strip() != b"0"

            if not hardBlocked and not softBlocked:
                continue

            with open(os.path.join(self.PATH, device, "type"), "rb") as f:
                typ = f.read().strip()

            fg = (hardBlocked and "#CCCCCC") or (softBlocked and "#FF0000")
            if typ == b"wlan":
                icon = ""
            elif typ == b"bluetooth":
                icon = ""
            else:
                icon = "?"

            t.append(Text(icon, fg=fg))
        return t

    def __init__(self, theme=None):
        PeriodicUpdater.__init__(self)
        Section.__init__(self, theme)
        self.changeInterval(5)


class SshAgentProvider(PeriodicUpdater):
    def fetcher(self):
        cmd = ["ssh-add", "-l"]
        proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
        if proc.returncode != 0:
            return None
        text = Text()
        for line in proc.stdout.split(b"\n"):
            if not len(line):
                continue
            fingerprint = line.split()[1]
            text.append(Text("", fg=randomColor(seed=fingerprint)))
        return text

    def __init__(self):
        PeriodicUpdater.__init__(self)
        self.changeInterval(5)


class GpgAgentProvider(PeriodicUpdater):
    def fetcher(self):
        cmd = ["gpg-connect-agent", "keyinfo --list", "/bye"]
        proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
        # proc = subprocess.run(cmd)
        if proc.returncode != 0:
            return None
        text = Text()
        for line in proc.stdout.split(b"\n"):
            if not len(line) or line == b"OK":
                continue
            spli = line.split()
            if spli[6] != b"1":
                continue
            keygrip = spli[2]
            text.append(Text("", fg=randomColor(seed=keygrip)))
        return text

    def __init__(self):
        PeriodicUpdater.__init__(self)
        self.changeInterval(5)


class KeystoreProvider(Section, MergedUpdater):
    # TODO OPTI+FEAT Use ColorCountsSection and not MergedUpdater, this is useless
    ICON = ""

    def __init__(self, theme=None):
        MergedUpdater.__init__(self, SshAgentProvider(), GpgAgentProvider())
        Section.__init__(self, theme)


class NotmuchUnreadProvider(ColorCountsSection, InotifyUpdater):
    COLORABLE_ICON = ""

    def subfetcher(self):
        db = notmuch.Database(mode=notmuch.Database.MODE.READ_ONLY, path=self.dir)
        counts = []
        for account in self.accounts:
            queryStr = "folder:/{}/ and tag:unread".format(account)
            query = notmuch.Query(db, queryStr)
            nbMsgs = query.count_messages()
            if account == "frogeye":
                global q
                q = query
            if nbMsgs < 1:
                continue
            counts.append((nbMsgs, self.colors[account]))
        # db.close()
        return counts

    def __init__(self, dir="~/.mail/", theme=None):
        PeriodicUpdater.__init__(self)
        ColorCountsSection.__init__(self, theme)

        self.dir = os.path.realpath(os.path.expanduser(dir))
        assert os.path.isdir(self.dir)

        # Fetching account list
        self.accounts = sorted(
            [a for a in os.listdir(self.dir) if not a.startswith(".")]
        )
        # Fetching colors
        self.colors = dict()
        for account in self.accounts:
            filename = os.path.join(self.dir, account, "color")
            with open(filename, "r") as f:
                color = f.read().strip()
            self.colors[account] = color

        self.addPath(os.path.join(self.dir, ".notmuch", "xapian"))


class TodoProvider(ColorCountsSection, InotifyUpdater):
    # TODO OPT/UX Maybe we could get more data from the todoman python module
    # TODO OPT Specific callback for specific directory
    COLORABLE_ICON = ""

    def updateCalendarList(self):
        calendars = sorted(os.listdir(self.dir))
        for calendar in calendars:
            # If the calendar wasn't in the list
            if calendar not in self.calendars:
                self.addPath(os.path.join(self.dir, calendar), refresh=False)

                # Fetching name
                path = os.path.join(self.dir, calendar, "displayname")
                with open(path, "r") as f:
                    self.names[calendar] = f.read().strip()

                # Fetching color
                path = os.path.join(self.dir, calendar, "color")
                with open(path, "r") as f:
                    self.colors[calendar] = f.read().strip()
        self.calendars = calendars

    def __init__(self, dir, theme=None):
        """
        :parm str dir: [main]path value in todoman.conf
        """
        InotifyUpdater.__init__(self)
        ColorCountsSection.__init__(self, theme=theme)
        self.dir = os.path.realpath(os.path.expanduser(dir))
        assert os.path.isdir(self.dir)

        self.calendars = []
        self.colors = dict()
        self.names = dict()
        self.updateCalendarList()
        self.refreshData()

    def countUndone(self, calendar):
        cmd = ["todo", "--porcelain", "list"]
        if calendar:
            cmd.append(self.names[calendar])
        proc = subprocess.run(cmd, stdout=subprocess.PIPE)
        data = json.loads(proc.stdout)
        return len(data)

    def subfetcher(self):
        counts = []

        # TODO This an ugly optimisation that cuts on features, but todoman
        # calls are very expensive so we keep that in the meanwhile
        if self.state < 2:
            c = self.countUndone(None)
            if c > 0:
                counts.append((c, "#00000"))
                counts.append((0, "#FFFFF"))
            return counts
        # Optimisation ends here

        for calendar in self.calendars:
            c = self.countUndone(calendar)
            if c <= 0:
                continue
            counts.append((c, self.colors[calendar]))
        return counts


class I3WindowTitleProvider(Section, I3Updater):
    # TODO FEAT To make this available from start, we need to find the
    # `focused=True` element following the `focus` array
    # TODO Feat Make this output dependant if wanted
    def on_window(self, i3, e):
        self.updateText(e.container.name)

    def __init__(self, theme=None):
        I3Updater.__init__(self)
        Section.__init__(self, theme=theme)
        self.on("window", self.on_window)


class I3WorkspacesProviderSection(Section):
    def selectTheme(self):
        if self.urgent:
            return self.parent.themeUrgent
        elif self.focused:
            return self.parent.themeFocus
        else:
            return self.parent.themeNormal

    # TODO On mode change the state (shown / hidden) gets overriden so every
    # tab is shown

    def show(self):
        self.updateTheme(self.selectTheme())
        self.updateText(self.fullName if self.focused else self.shortName)

    def changeState(self, focused, urgent):
        self.focused = focused
        self.urgent = urgent
        self.show()

    def setName(self, name):
        self.shortName = name
        self.fullName = (
            self.parent.customNames[name] if name in self.parent.customNames else name
        )

    def switchTo(self):
        self.parent.i3.command("workspace {}".format(self.shortName))

    def __init__(self, name, parent):
        Section.__init__(self)
        self.parent = parent
        self.setName(name)
        self.setDecorators(clickLeft=self.switchTo)
        self.tempText = None

    def empty(self):
        self.updateTheme(self.parent.themeNormal)
        self.updateText(None)

    def tempShow(self):
        self.updateText(self.tempText)

    def tempEmpty(self):
        self.tempText = self.dstText[1]
        self.updateText(None)


class I3WorkspacesProvider(Section, I3Updater):
    # TODO FEAT Multi-screen

    def initialPopulation(self, parent):
        """
        Called on init
        Can't reuse addWorkspace since i3.get_workspaces() gives dict and not
        ConObjects
        """
        workspaces = self.i3.get_workspaces()
        lastSection = self.modeSection
        for workspace in workspaces:
            # if parent.display != workspace["display"]:
            #     continue

            section = I3WorkspacesProviderSection(workspace.name, self)
            section.focused = workspace.focused
            section.urgent = workspace.urgent
            section.show()
            parent.addSectionAfter(lastSection, section)
            self.sections[workspace.num] = section

            lastSection = section

    def on_workspace_init(self, i3, e):
        workspace = e.current
        i = workspace.num
        if i in self.sections:
            section = self.sections[i]
        else:
            # Find the section just before
            while i not in self.sections.keys() and i > 0:
                i -= 1
            prevSection = self.sections[i] if i != 0 else self.modeSection

            section = I3WorkspacesProviderSection(workspace.name, self)
            prevSection.appendAfter(section)
            self.sections[workspace.num] = section
        section.focused = workspace.focused
        section.urgent = workspace.urgent
        section.show()

    def on_workspace_empty(self, i3, e):
        self.sections[e.current.num].empty()

    def on_workspace_focus(self, i3, e):
        self.sections[e.old.num].focused = False
        self.sections[e.old.num].show()
        self.sections[e.current.num].focused = True
        self.sections[e.current.num].show()

    def on_workspace_urgent(self, i3, e):
        self.sections[e.current.num].urgent = e.current.urgent
        self.sections[e.current.num].show()

    def on_workspace_rename(self, i3, e):
        self.sections[e.current.num].setName(e.name)
        self.sections[e.current.num].show()

    def on_mode(self, i3, e):
        if e.change == "default":
            self.modeSection.updateText(None)
            for section in self.sections.values():
                section.tempShow()
        else:
            self.modeSection.updateText(e.change)
            for section in self.sections.values():
                section.tempEmpty()

    def __init__(
        self, theme=0, themeFocus=3, themeUrgent=1, themeMode=2, customNames=dict()
    ):
        I3Updater.__init__(self)
        Section.__init__(self)
        self.themeNormal = theme
        self.themeFocus = themeFocus
        self.themeUrgent = themeUrgent
        self.customNames = customNames

        self.sections = dict()
        self.on("workspace::init", self.on_workspace_init)
        self.on("workspace::focus", self.on_workspace_focus)
        self.on("workspace::empty", self.on_workspace_empty)
        self.on("workspace::urgent", self.on_workspace_urgent)
        self.on("workspace::rename", self.on_workspace_rename)
        # TODO Un-handled/tested: reload, rename, restored, move

        self.on("mode", self.on_mode)
        self.modeSection = Section(theme=themeMode)

    def addParent(self, parent):
        self.parents.add(parent)
        parent.addSection(self.modeSection)
        self.initialPopulation(parent)


class MpdProvider(Section, ThreadedUpdater):
    # TODO FEAT More informations and controls

    MAX_LENGTH = 50

    def connect(self):
        self.mpd.connect("localhost", 6600)

    def __init__(self, theme=None):
        ThreadedUpdater.__init__(self)
        Section.__init__(self, theme)

        self.mpd = mpd.MPDClient()
        self.connect()
        self.refreshData()
        self.start()

    def fetcher(self):
        stat = self.mpd.status()
        if not len(stat) or stat["state"] == "stop":
            return None

        cur = self.mpd.currentsong()
        if not len(cur):
            return None

        infos = []

        def tryAdd(field):
            if field in cur:
                infos.append(cur[field])

        tryAdd("title")
        tryAdd("album")
        tryAdd("artist")

        infosStr = " - ".join(infos)
        if len(infosStr) > MpdProvider.MAX_LENGTH:
            infosStr = infosStr[: MpdProvider.MAX_LENGTH - 1] + "…"

        return " {}".format(infosStr)

    def loop(self):
        try:
            self.mpd.idle("player")
            self.refreshData()
        except mpd.base.ConnectionError as e:
            log.warn(e, exc_info=True)
            self.connect()
        except BaseException as e:
            log.error(e, exc_info=True)