#!/usr/bin/env cached-nix-shell
#! nix-shell -i python3
#! nix-shell -p python3 python3Packages.pydantic

"""
glab wrapper for jujutsu,
with some convinience features.
"""

import re
import subprocess
import sys
import typing

import pydantic
import typing_extensions


class GitLabMR(pydantic.BaseModel):
    """
    Represents a GitLab Merge Request.
    """

    title: str
    source_branch: str
    target_branch: str
    project_id: int
    source_project_id: int
    target_project_id: int

    @pydantic.model_validator(mode="after")
    def same_project(self) -> typing_extensions.Self:
        if not (self.project_id == self.source_project_id == self.target_project_id):
            raise NotImplementedError("Different project ids")
        return self


def glab_get_mr(branch: str) -> GitLabMR:
    """
    Get details about a GitLab MR.
    """
    sp = subprocess.run(
        ["glab", "mr", "view", branch, "--output", "json"], stdout=subprocess.PIPE
    )
    sp.check_returncode()
    return GitLabMR.model_validate_json(sp.stdout)


class JujutsuType:
    """
    Utility to work with Template types.
    https://martinvonz.github.io/jj/latest/templates/
    """

    FIELD_SEPARATOR: typing.ClassVar[str] = "\0"
    ESCAPED_SEPARATOR: typing.ClassVar[str] = r"\0"

    @staticmethod
    def template(base: str, type_: typing.Type) -> str:
        """
        Generates a --template string that is machine-parseable for a given type.
        """
        if typing.get_origin(type_) == list:
            # If we have a list, prepend it with the number of items
            # so we know how many fields we should consume.
            (subtype,) = typing.get_args(type_)
            subtype = typing.cast(typing.Type, subtype)
            return (
                f'{base}.len()++"{JujutsuType.ESCAPED_SEPARATOR}"'
                f'++{base}.map(|l| {JujutsuType.template("l", subtype)})'
            )
        elif issubclass(type_, JujutsuObject):
            return type_.template(base)
        else:
            return f'{base}++"{JujutsuType.ESCAPED_SEPARATOR}"'

    @staticmethod
    def parse(stack: list[str], type_: typing.Type) -> typing.Any:
        """
        Unserialize the result of a template to a given type.
        Needs to be provided the template as a list splitted by the field separator.
        It will consume the fields it needs.
        """
        if typing.get_origin(type_) == list:
            (subtype,) = typing.get_args(type_)
            subtype = typing.cast(typing.Type, subtype)
            len = int(stack.pop(0))
            return [JujutsuType.parse(stack, subtype) for i in range(len)]
        elif issubclass(type_, JujutsuObject):
            return type_.parse(stack)
        else:
            return stack.pop(0)


class JujutsuObject(pydantic.BaseModel):
    @classmethod
    def template(cls, base: str) -> str:
        temp = []
        for k, v in cls.model_fields.items():
            key = f"{base}.{k}()"
            temp.append(JujutsuType.template(key, v.annotation))
        return "++".join(temp)

    @classmethod
    def parse(cls, stack: list[str]) -> typing_extensions.Self:
        ddict = dict()
        for k, v in cls.model_fields.items():
            ddict[k] = JujutsuType.parse(stack, v.annotation)
        return cls(**ddict)


class JujutsuShortestIdPrefix(JujutsuObject):
    prefix: str
    rest: str

    @property
    def full(self) -> str:
        return self.prefix + self.rest


class JujutsuChangeId(JujutsuObject):
    shortest: JujutsuShortestIdPrefix

    @property
    def full(self) -> str:
        return self.shortest.full


class JujutsuRefName(JujutsuObject):
    name: str


class JujutsuCommit(JujutsuObject):
    change_id: JujutsuChangeId
    bookmarks: list[JujutsuRefName]


class Jujutsu:
    """
    Represents a Jujutsu repo.
    Since there's no need for multi-repo, this is just the one in the current directory.
    """

    def run(self, *args: str, **kwargs: typing.Any) -> subprocess.CompletedProcess:
        cmd = ["jj"]
        cmd.extend(args)
        sp = subprocess.run(cmd, stdout=subprocess.PIPE)
        sp.check_returncode()
        return sp

    def log(self, revset: str = "@") -> list[JujutsuCommit]:
        cmd = [
            "log",
            "-r",
            revset,
            "--no-graph",
            "-T",
            JujutsuCommit.template("self"),
        ]
        sp = self.run(*cmd, stdout=subprocess.PIPE)
        stack = sp.stdout.decode().split(JujutsuType.FIELD_SEPARATOR)
        assert stack[-1] == "", "No trailing NUL byte"
        stack.pop()
        commits = []
        while len(stack):
            commits.append(JujutsuCommit.parse(stack))
        return commits


jj = Jujutsu()


def current_bookmark() -> JujutsuRefName | None:
    """
    Replacement of git's current branch concept working with jj.
    Needed for commodity features, such as not requiring to type the MR mumber / branch
    for `glab mr`, or automatically advance the bookmark to the head before pushing.
    """
    bookmarks = []
    for commit in jj.log("reachable(@, trunk()..)"):
        bookmarks.extend(commit.bookmarks)

    if len(bookmarks) > 1:
        raise NotImplementedError("Multiple bookmarks on trunk branch")  # TODO
        # If there's a split in the tree: TBD
        # If there's no bookmark ahead: the bookmark behind
        # If there's a bookmark ahead: that one
        # (needs adjusting of push so it doesn't advance anything then)

    if bookmarks:
        return bookmarks[0]
    else:
        return None


def to_glab() -> None:
    """
    Pass the remaining arguments to glab.
    """
    sp = subprocess.run(["glab"] + sys.argv[1:])
    sys.exit(sp.returncode)


if len(sys.argv) <= 1:
    to_glab()
elif sys.argv[1] in ("merge", "mr"):
    if len(sys.argv) <= 2:
        to_glab()
    elif sys.argv[2] == "checkout":
        # Bypass the original checkout command so it doesn't run git commands.
        # If there's no commit on the branch, add one with the MR title
        # so jj has a current bookmark.
        mr = glab_get_mr(sys.argv[3])
        jj.run("git", "fetch")
        if len(JujutsuCommit.log(f"{mr.source_branch} | {mr.target_branch}")) == 1:
            title = re.sub(r"^(WIP|Draft): ", "", mr.title)
            jj.run("new", mr.source_branch)
            jj.run("describe", "-m", title)
            jj.run("bookmark", "move", mr.source_branch)
        else:
            jj.run("bookmark", "edit", mr.source_branch)
    elif sys.argv[2] in (
        # If no MR number/branch is given, insert the current bookmark,
        # as the current branch concept doesn't exist in jj
        "approve",
        "approvers",
        "checkout",
        "close",
        "delete",
        "diff",
        "issues",
        "merge",
        "note",
        "rebase",
        "revoke",
        "subscribe",
        "todo",
        "unsubscribe",
        "update",
        "view",
    ):
        if len(sys.argv) <= 3 or sys.argv[3].startswith("-"):
            bookmark = current_bookmark()
            if bookmark:
                sys.argv.insert(3, bookmark.name)
        to_glab()
    else:
        to_glab()
elif sys.argv[1] == "push":
    # Advance the current branch to the head and push
    bookmark = current_bookmark()
    if not bookmark:
        raise RuntimeError("Couldn't find a current branch")
    heads = jj.log("heads(@::)")
    if len(heads) != 1:
        raise RuntimeError("Multiple heads")  # Or none if something goes horribly wrong
    head = heads[0]
    jj.run("bookmark", "set", bookmark.name, "-r", head.change_id.full)
    jj.run("git", "push", "--bookmark", bookmark.name)
    # TODO Sign https://github.com/martinvonz/jj/issues/4712
else:
    to_glab()

# TODO Autocomplete