#!/usr/bin/env nix-shell
#! nix-shell -i python3
#! nix-shell -p python3 python3Packages.coloredlogs python3Packages.configargparse python3Packages.filelock python3Packages.filelock python3Packages.requests python3Packages.yt-dlp ffmpeg
# Also needs mpv but if I put it there it's not using the configured one


"""
Script that download videos that are linked as an article
in a RSS feed.
The common use case would be a feed from an RSS aggregator
with the unread items (non-video links are ignored).
"""

import datetime
import functools
import logging
import os
import pickle
import random
import re
import subprocess
import sys
import time
import typing

import coloredlogs
import configargparse
import filelock
import requests
import yt_dlp

log = logging.getLogger(__name__)


def configure_logging(args: configargparse.Namespace) -> None:
    # Configure logging
    if args.verbosity:
        coloredlogs.install(
            level=args.verbosity,
        )
    else:
        coloredlogs.install(
            fmt="%(message)s",
            logger=log,
        )


class SaveInfoPP(yt_dlp.postprocessor.common.PostProcessor):
    """
    yt_dlp.process_ie_result() doesn't return a completely updated info dict,
    notably the extension is still the one before it realizes the files cannot
    be merged. So we use this PostProcessor to catch the info dict in its final
    form and save what we need from it (it's not serializable in this state).
    """

    def __init__(self, rvelement: "RVElement") -> None:
        self.rvelement = rvelement
        super().__init__()

    def run(self, info: dict) -> tuple[list, dict]:
        self.rvelement.update_post_download(info)
        return [], info


def parse_duration(string: str) -> int:
    DURATION_MULTIPLIERS = {"s": 1, "m": 60, "h": 3600, "": 1}

    mult_index = string[-1].lower()
    if mult_index.isdigit():
        mult_index = ""
    else:
        string = string[:-1]
    try:
        multiplier = DURATION_MULTIPLIERS[mult_index]
    except IndexError:
        raise ValueError(f"Unknown duration multiplier: {mult_index}")

    return int(string) * multiplier


def compare_duration(compstr: str) -> typing.Callable[[int], bool]:
    DURATION_COMPARATORS = {
        "<": int.__lt__,
        "-": int.__lt__,
        ">": int.__gt__,
        "+": int.__gt__,
        "=": int.__eq__,
        "": int.__le__,
    }

    comp_index = compstr[0]
    if comp_index.isdigit():
        comp_index = ""
    else:
        compstr = compstr[1:]
    try:
        comparator = DURATION_COMPARATORS[comp_index]
    except IndexError:
        raise ValueError(f"Unknown duration comparator: {comp_index}")

    duration = parse_duration(compstr)

    return lambda d: comparator(d, duration)


def format_duration(duration: int) -> str:
    return time.strftime("%H:%M:%S", time.gmtime(duration))


class RVElement:
    parent: "RVDatabase"
    item: dict

    RERESEARCH_AFTER = datetime.timedelta(hours=1)

    def __init__(self, parent: "RVDatabase", item: dict) -> None:
        self.parent = parent
        self.item = item

    @property
    def id(self) -> str:
        return self.item["id"]

    @property
    def sid(self) -> str:
        return self.id.split("/")[-1]

    def metafile(self, extension: str) -> str:
        return os.path.join(self.parent.METADATA_FOLDER, f"{self.sid}.{extension}")

    def metafile_read(self, extension: str) -> typing.Any:
        return self.parent.metafile_read(f"{self.sid}.{extension}")

    def metafile_write(self, extension: str, data: typing.Any) -> None:
        return self.parent.metafile_write(f"{self.sid}.{extension}", data)

    def save(self) -> None:
        self.metafile_write("item", self.item)

    @property
    def title(self) -> str:
        return self.item["title"]

    @property
    def link(self) -> str:
        return self.item["canonical"][0]["href"]

    @property
    def creator(self) -> str:
        return self.item["origin"]["title"]

    @property
    def date(self) -> datetime.datetime:
        timestamp = (
            int(self.item.get("timestampUsec", "0")) / 1000000
            or int(self.item.get("crawlTimeMsec", "0")) / 1000
            or self.item["published"]
        )
        return datetime.datetime.fromtimestamp(timestamp)

    @property
    def is_researched(self) -> bool:
        metafile = self.metafile("ytdl")
        return os.path.isfile(metafile)

    def __str__(self) -> str:
        str = f"{self.date.strftime('%y-%m-%d %H:%M')} ("
        if self.is_researched:
            if self.is_video:
                str += format_duration(self.duration)
            else:
                str += "--:--:--"
        else:
            str += "??:??:??"
        str += (
            f") {self.creator if self.creator else '?'} "
            f"– {self.title} "
            f"– {self.link}"
        )
        return str

    @property
    def downloaded(self) -> bool:
        if not self.is_researched:
            return False
        return os.path.isfile(self.filepath)

    @functools.cached_property
    def ytdl_infos(self) -> typing.Optional[dict]:
        try:
            return self.metafile_read("ytdl")
        except (FileNotFoundError, TypeError, AttributeError, EOFError):
            infos = self._ytdl_infos()
            self.metafile_write("ytdl", infos)
        return infos

    def _ytdl_infos(self) -> typing.Optional[dict]:
        log.info(f"Researching: {self}")
        try:
            infos = self.parent.ytdl_dry.extract_info(self.link, download=False)
        except KeyboardInterrupt as e:
            raise e
        except yt_dlp.utils.DownloadError as e:
            # TODO Still raise in case of temporary network issue
            log.warning(e)
            infos = None
        if infos:
            infos = self.parent.ytdl_dry.sanitize_info(infos)
        return infos

    @property
    def duration(self) -> int:
        assert self.is_video
        assert self.ytdl_infos
        return int(self.ytdl_infos["duration"])

    @property
    def is_video(self) -> bool:
        # Duration might be missing in playlists and stuff
        return self.ytdl_infos is not None and "duration" in self.ytdl_infos

    @functools.cached_property
    def downloaded_filepath(self) -> typing.Optional[str]:
        try:
            return self.metafile_read("path")
        except FileNotFoundError:
            return None

    @property
    def was_downloaded(self) -> bool:
        metafile = self.metafile("path")
        return os.path.exists(metafile)

    @property
    def filepath(self) -> str:
        assert self.is_video
        if self.downloaded_filepath:
            return self.downloaded_filepath
        return self.parent.ytdl_dry.prepare_filename(self.ytdl_infos)

    @property
    def basename(self) -> str:
        assert self.is_video
        return os.path.splitext(self.filepath)[0]

    def expire_info(self) -> None:
        metafile = self.metafile("ytdl")
        if os.path.isfile(metafile):
            stat = os.stat(metafile)
            mtime = datetime.datetime.fromtimestamp(stat.st_mtime)
            diff = datetime.datetime.now() - mtime
            if diff > self.RERESEARCH_AFTER:
                os.unlink(metafile)
                del self.ytdl_infos

    def download(self) -> None:
        assert self.is_video
        if self.downloaded:
            return
        self.expire_info()
        log.info(f"Downloading: {self}")
        lockfile = self.metafile("lock")
        with filelock.FileLock(lockfile):
            if not self.parent.args.dryrun:
                with yt_dlp.YoutubeDL(self.parent.ytdl_opts) as ydl:
                    ydl.add_post_processor(SaveInfoPP(self))
                    ydl.process_ie_result(self.ytdl_infos, download=True)

    def update_post_download(self, info: dict) -> None:
        self.downloaded_filepath = self.parent.ytdl_dry.prepare_filename(info)
        assert self.downloaded_filepath
        assert self.downloaded_filepath.startswith(self.basename)
        self.metafile_write("path", self.downloaded_filepath)

    @property
    def watched(self) -> bool:
        if not self.is_researched:
            return False
        return self.was_downloaded and not self.downloaded

    def matches_filter(self, args: configargparse.Namespace) -> bool:
        # Inexpensive filters
        if args.seen != "any" and (args.seen == "seen") != self.watched:
            log.debug(f"Not {args.seen}: {self}")
            return False
        if args.title and not re.search(args.title, self.title):
            log.debug(f"Title not matching {args.title}: {self}")
            return False
        if args.link and not re.search(args.link, self.link):
            log.debug(f"Link not matching {args.link}: {self}")
            return False
        if args.creator and (
            not self.creator or not re.search(args.creator, self.creator)
        ):
            log.debug(f"Creator not matching {args.creator}: {self}")
            return False

        # Expensive filters
        if not self.is_video:
            log.debug(f"Not a video: {self}")
            return False
        if args.duration and not compare_duration(args.duration)(self.duration):
            log.debug(f"Duration {self.duration} not matching {args.duration}: {self}")
            return False

        return True

    def watch(self) -> None:
        self.download()

        cmd = ["mpv", self.filepath]
        log.debug(f"Running {cmd}")
        if not self.parent.args.dryrun:
            proc = subprocess.run(cmd)
            proc.check_returncode()

        self.undownload()
        self.try_mark_read()

    def clean_file(self, folder: str, basename: str) -> None:
        for file in os.listdir(folder):
            if file.startswith(basename):
                path = os.path.join(folder, file)
                log.debug(f"Removing file: {path}")
                if not self.parent.args.dryrun:
                    os.unlink(path)

    def undownload(self) -> None:
        assert self.is_video
        log.info(f"Removing gone video: {self.basename}*")
        self.clean_file(".", self.basename)

    def clean(self) -> None:
        if self.is_researched and self.is_video:
            self.undownload()
        log.info(f"Removing gone metadata: {self.sid}*")
        self.clean_file(self.parent.METADATA_FOLDER, self.sid)

    def mark_read(self) -> None:
        log.debug(f"Marking {self} read")
        if self.parent.args.dryrun:
            return
        r = requests.post(
            f"{self.parent.args.url}/reader/api/0/edit-tag",
            data={
                "i": self.id,
                "a": "user/-/state/com.google/read",
                "ac": "edit",
                "token": self.parent.feed_token,
            },
            headers=self.parent.auth_headers,
        )
        r.raise_for_status()
        if r.text.strip() != "OK":
            raise RuntimeError(f"Couldn't mark {self} as read: {r.text}")
        log.info(f"Marked {self} as read")
        self.clean()

    def try_mark_read(self) -> None:
        try:
            self.mark_read()
        except requests.ConnectionError:
            log.warning(f"Couldn't mark {self} as read")


class RVDatabase:
    METADATA_FOLDER = ".metadata"

    args: configargparse.Namespace
    elements: list[RVElement]

    def __init__(self, args: configargparse.Namespace) -> None:
        self.args = args

    def metafile_read(self, name: str) -> typing.Any:
        path = os.path.join(self.METADATA_FOLDER, name)
        log.debug(f"Reading {path}")
        with open(path, "rb") as mf:
            return pickle.load(mf)

    def metafile_write(self, name: str, data: typing.Any) -> None:
        path = os.path.join(self.METADATA_FOLDER, name)
        log.debug(f"Writing {path}")
        if not self.args.dryrun:
            with open(path, "wb") as mf:
                pickle.dump(data, mf)

    def clean_cache(self, cache: "RVDatabase") -> None:
        log.debug("Cleaning cache")
        fresh_ids = set(el.id for el in self.elements)
        for el in cache.elements:
            if el.id not in fresh_ids:
                el.clean()

    def _auth_headers(self) -> dict[str, str]:
        r = requests.get(
            f"{self.args.url}/accounts/ClientLogin",
            params={"Email": self.args.email, "Passwd": self.args.passwd},
        )
        r.raise_for_status()
        for line in r.text.split("\n"):
            if line.lower().startswith("auth="):
                val = "=".join(line.split("=")[1:])
                return {"Authorization": f"GoogleLogin auth={val}"}
        raise RuntimeError("Couldn't find auth= key")

    @functools.cached_property
    def auth_headers(self) -> dict[str, str]:
        try:
            return self.metafile_read(".auth_headers")
        except FileNotFoundError:
            headers = self._auth_headers()
            self.metafile_write(".auth_headers", headers)
            return headers

    def fetch_feed_elements(self) -> typing.Generator[dict, None, None]:
        log.info("Fetching RSS feed")
        continuation: typing.Optional[str] = None
        with requests.Session() as s:

            def next_page() -> typing.Generator[dict, None, None]:
                nonlocal continuation
                r = s.get(
                    f"{self.args.url}/reader/api/0/stream/contents",
                    params={
                        "xt": "user/-/state/com.google/read",
                        "c": continuation,
                    },
                    headers=self.auth_headers,
                )
                r.raise_for_status()
                json = r.json()
                yield from json["items"]
                continuation = json.get("continuation")

            yield from next_page()
            while continuation:
                yield from next_page()

    def fetch_cache_elements(self) -> typing.Generator[dict, None, None]:
        log.info("Fetching from cache")
        for file in os.listdir(self.METADATA_FOLDER):
            if not file.endswith(".item"):
                continue
            yield self.metafile_read(file)

    def build_list(self, items: typing.Iterable[dict], save: bool = False) -> None:
        self.elements = []
        for item in items:
            element = RVElement(self, item)
            self.elements.insert(0, element)
            log.debug(f"Known: {element}")
            if save:
                element.save()

    def read_feed(self) -> None:
        self.build_list(self.fetch_feed_elements(), save=True)

    def read_cache(self) -> None:
        self.build_list(self.fetch_cache_elements())

    def clean_folder(self, folder: str, basenames: set[str]) -> None:
        for file in os.listdir(folder):
            path = os.path.join(folder, file)
            if not os.path.isfile(path) or file[0] == ".":
                continue
            for basename in basenames:
                if file.startswith(basename):
                    break
            else:
                log.info(f"Removing unknown file: {path}")
                if not self.args.dryrun:
                    os.unlink(path)

    def clean(self) -> None:
        log.debug("Cleaning")
        filenames = set(el.basename for el in self.elements if el.is_video)
        self.clean_folder(".", filenames)
        ids = set(el.sid for el in self.elements)
        self.clean_folder(self.METADATA_FOLDER, ids)

    @property
    def ytdl_opts(self) -> dict:
        # Get user/system options
        prev_argv = sys.argv
        sys.argv = ["yt-dlp"]
        _, _, _, ydl_opts = yt_dlp.parse_options()
        sys.argv = prev_argv
        return ydl_opts

    @property
    def ytdl_dry_opts(self) -> dict:
        opts = self.ytdl_opts.copy()
        opts.update({"quiet": True})
        return opts

    @property
    def ytdl_dry(self) -> yt_dlp.YoutubeDL:
        return yt_dlp.YoutubeDL(self.ytdl_dry_opts)

    def filter(self, args: configargparse.Namespace) -> typing.Iterable[RVElement]:
        elements_src = self.elements.copy()
        elements: typing.Iterable[RVElement]
        # Inexpensive sort
        if args.order == "new":
            elements = sorted(elements_src, key=lambda el: el.date, reverse=True)
        elif args.order == "old":
            elements = sorted(elements_src, key=lambda el: el.date)
        elif args.order == "title":
            elements = sorted(elements_src, key=lambda el: el.title)
        elif args.order == "creator":
            elements = sorted(elements_src, key=lambda el: el.creator or "")
        elif args.order == "link":
            elements = sorted(elements_src, key=lambda el: el.link)
        elif args.order == "random":
            elements = elements_src
            random.shuffle(elements)

        # Possibly expensive filtering
        elements = filter(lambda el: el.matches_filter(args), elements)

        # Expensive sort
        if args.order == "short":
            elements = sorted(
                elements, key=lambda el: el.duration if el.is_video else 0
            )
        elif args.order == "long":
            elements = sorted(
                elements, key=lambda el: el.duration if el.is_video else 0, reverse=True
            )

        # Post sorting filtering
        if args.total_duration:
            rem = parse_duration(args.total_duration)
            old_els = list(elements)
            elements = list()
            while rem > 0:
                for el in old_els:
                    if el.duration < rem:
                        elements.append(el)
                        rem -= el.duration
                        old_els.remove(el)
                        break
                else:
                    break

        return elements

    @functools.cached_property
    def feed_token(self) -> str:
        r = requests.get(
            f"{self.args.url}/reader/api/0/token",
            headers=self.auth_headers,
        )
        r.raise_for_status()
        return r.text.strip()

    def try_mark_watched_read(self) -> None:
        for element in self.elements:
            if element.watched:
                element.try_mark_read()


def get_args() -> configargparse.Namespace:
    defaultConfigPath = os.path.join(
        os.path.expanduser(os.getenv("XDG_CONFIG_PATH", "~/.config/")), "rssVideos"
    )

    parser = configargparse.ArgParser(
        description="Download videos in unread articles from a feed aggregator",
        default_config_files=[defaultConfigPath],
    )

    # Runtime settings
    parser.add_argument(
        "-v",
        "--verbosity",
        choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
        default=None,
        help="Verbosity of log messages",
    )
    parser.add(
        "-c", "--config", required=False, is_config_file=True, help="Configuration file"
    )
    parser.add(
        "-n",
        "--dryrun",
        help="Only pretend to do actions",
        action="store_const",
        const=True,
        default=False,
    )

    # Input/Output
    parser.add(
        "--url",
        help="URL of the Google Reader API of the aggregator",
        env_var="RSS_VIDEOS_URL",
        required=True,
    )
    parser.add(
        "--email",
        help="E-mail / user to connect to the aggregator",
        env_var="RSS_VIDEOS_EMAIL",
        required=True,
    )
    parser.add(
        "--passwd",
        help="Password to connect to the aggregator",
        env_var="RSS_VIDEOS_PASSWD",
        required=True,
    )
    parser.add(
        "--no-refresh",
        dest="refresh",
        help="Don't fetch feed",
        action="store_false",
    )
    parser.add(
        "--videos",
        help="Directory to store videos",
        env_var="RSS_VIDEOS_VIDEO_DIR",
        required=True,
    )

    # Which videos
    parser.add(
        "--order",
        choices=("old", "new", "title", "creator", "link", "short", "long", "random"),
        default="old",
        help="Sorting mechanism",
    )
    parser.add("--creator", help="Regex to filter by creator")
    parser.add("--title", help="Regex to filter by title")
    parser.add("--link", help="Regex to filter by link")
    parser.add("--duration", help="Comparative to filter by duration")
    # TODO Date selector
    parser.add(
        "--seen",
        choices=("seen", "unseen", "any"),
        default="unseen",
        help="Only include seen/unseen/any videos",
    )
    parser.add(
        "--total-duration",
        help="Use videos that fit under the total given",
    )
    # TODO Envrionment variables
    # TODO Allow to ask

    parser.add(
        "action",
        nargs="?",
        choices=(
            "download",
            "list",
            "watch",
            "binge",
        ),
        default="download",
    )

    args = parser.parse_args()
    args.videos = os.path.realpath(os.path.expanduser(args.videos))

    return args


def get_database(args: configargparse.Namespace) -> RVDatabase:
    cache = RVDatabase(args)
    cache.read_cache()
    if not args.refresh:
        return cache

    fresh = RVDatabase(args)
    fresh.read_feed()
    fresh.clean_cache(cache)
    return fresh


def main() -> None:
    args = get_args()
    configure_logging(args)

    metadata_dir = os.path.join(args.videos, RVDatabase.METADATA_FOLDER)
    for dir in (args.videos, metadata_dir):
        os.makedirs(dir, exist_ok=True)
    os.chdir(args.videos)

    database = get_database(args)

    log.debug("Running action")
    duration = 0
    for element in database.filter(args):
        duration += element.duration if element.is_video else 0
        if args.action == "download":
            element.download()
        elif args.action == "list":
            print(element)
        elif args.action in ("watch", "binge"):
            element.watch()
            if args.action == "watch":
                break
        else:
            raise NotImplementedError(f"Unimplemented action: {args.action}")
    log.info(f"Total duration: {format_duration(duration)}")
    database.try_mark_watched_read()
    database.clean()


if __name__ == "__main__":
    main()