dotfiles/hm/scripts/jlab
2024-10-31 20:21:52 +01:00

266 lines
7.6 KiB
Python
Executable file

#!/usr/bin/env cached-nix-shell
#! nix-shell -i python3
#! nix-shell -p python3 python3Packages.pydantic
# vim: filetype=python
"""
glab wrapper for jujutsu,
with some convinience features.
"""
import re
import subprocess
import sys
import typing
import pydantic
import typing_extensions
class GitLabMR(pydantic.BaseModel):
"""
Represents a GitLab Merge Request.
"""
title: str
source_branch: str
target_branch: str
project_id: int
source_project_id: int
target_project_id: int
@pydantic.model_validator(mode="after")
def same_project(self) -> typing_extensions.Self:
if not (self.project_id == self.source_project_id == self.target_project_id):
raise NotImplementedError("Different project ids")
return self
def glab_get_mr(branch: str) -> GitLabMR:
"""
Get details about a GitLab MR.
"""
sp = subprocess.run(
["glab", "mr", "view", branch, "--output", "json"], stdout=subprocess.PIPE
)
sp.check_returncode()
return GitLabMR.model_validate_json(sp.stdout)
class JujutsuType:
"""
Utility to work with Template types.
https://martinvonz.github.io/jj/latest/templates/
"""
FIELD_SEPARATOR: typing.ClassVar[str] = "\0"
ESCAPED_SEPARATOR: typing.ClassVar[str] = r"\0"
@staticmethod
def template(base: str, type_: typing.Type) -> str:
"""
Generates a --template string that is machine-parseable for a given type.
"""
if typing.get_origin(type_) == list:
# If we have a list, prepend it with the number of items
# so we know how many fields we should consume.
(subtype,) = typing.get_args(type_)
subtype = typing.cast(typing.Type, subtype)
return (
f'{base}.len()++"{JujutsuType.ESCAPED_SEPARATOR}"'
f'++{base}.map(|l| {JujutsuType.template("l", subtype)})'
)
elif issubclass(type_, JujutsuObject):
return type_.template(base)
else:
return f'{base}++"{JujutsuType.ESCAPED_SEPARATOR}"'
@staticmethod
def parse(stack: list[str], type_: typing.Type) -> typing.Any:
"""
Unserialize the result of a template to a given type.
Needs to be provided the template as a list splitted by the field separator.
It will consume the fields it needs.
"""
if typing.get_origin(type_) == list:
(subtype,) = typing.get_args(type_)
subtype = typing.cast(typing.Type, subtype)
len = int(stack.pop(0))
return [JujutsuType.parse(stack, subtype) for i in range(len)]
elif issubclass(type_, JujutsuObject):
return type_.parse(stack)
else:
return stack.pop(0)
class JujutsuObject(pydantic.BaseModel):
@classmethod
def template(cls, base: str) -> str:
temp = []
for k, v in cls.model_fields.items():
key = f"{base}.{k}()"
temp.append(JujutsuType.template(key, v.annotation))
return "++".join(temp)
@classmethod
def parse(cls, stack: list[str]) -> typing_extensions.Self:
ddict = dict()
for k, v in cls.model_fields.items():
ddict[k] = JujutsuType.parse(stack, v.annotation)
return cls(**ddict)
class JujutsuShortestIdPrefix(JujutsuObject):
prefix: str
rest: str
@property
def full(self) -> str:
return self.prefix + self.rest
class JujutsuChangeId(JujutsuObject):
shortest: JujutsuShortestIdPrefix
@property
def full(self) -> str:
return self.shortest.full
class JujutsuRefName(JujutsuObject):
name: str
class JujutsuCommit(JujutsuObject):
change_id: JujutsuChangeId
bookmarks: list[JujutsuRefName]
class Jujutsu:
"""
Represents a Jujutsu repo.
Since there's no need for multi-repo, this is just the one in the current directory.
"""
def run(self, *args: str, **kwargs: typing.Any) -> subprocess.CompletedProcess:
cmd = ["jj"]
cmd.extend(args)
sp = subprocess.run(cmd, stdout=subprocess.PIPE)
sp.check_returncode()
return sp
def log(self, revset: str = "@") -> list[JujutsuCommit]:
cmd = [
"log",
"-r",
revset,
"--no-graph",
"-T",
JujutsuCommit.template("self"),
]
sp = self.run(*cmd, stdout=subprocess.PIPE)
stack = sp.stdout.decode().split(JujutsuType.FIELD_SEPARATOR)
assert stack[-1] == "", "No trailing NUL byte"
stack.pop()
commits = []
while len(stack):
commits.append(JujutsuCommit.parse(stack))
return commits
jj = Jujutsu()
def current_bookmark() -> JujutsuRefName | None:
"""
Replacement of git's current branch concept working with jj.
Needed for commodity features, such as not requiring to type the MR mumber / branch
for `glab mr`, or automatically advance the bookmark to the head before pushing.
"""
bookmarks = []
for commit in jj.log("reachable(@, trunk()..)"):
bookmarks.extend(commit.bookmarks)
if len(bookmarks) > 1:
raise NotImplementedError("Multiple bookmarks on trunk branch") # TODO
# If there's a split in the tree: TBD
# If there's no bookmark ahead: the bookmark behind
# If there's a bookmark ahead: that one
# (needs adjusting of push so it doesn't advance anything then)
if bookmarks:
return bookmarks[0]
else:
return None
def to_glab() -> None:
"""
Pass the remaining arguments to glab.
"""
sp = subprocess.run(["glab"] + sys.argv[1:])
sys.exit(sp.returncode)
if len(sys.argv) <= 1:
to_glab()
elif sys.argv[1] in ("merge", "mr"):
if len(sys.argv) <= 2:
to_glab()
elif sys.argv[2] == "checkout":
# Bypass the original checkout command so it doesn't run git commands.
# If there's no commit on the branch, add one with the MR title
# so jj has a current bookmark.
mr = glab_get_mr(sys.argv[3])
jj.run("git", "fetch")
if len(JujutsuCommit.log(f"{mr.source_branch} | {mr.target_branch}")) == 1:
title = re.sub(r"^(WIP|Draft): ", "", mr.title)
jj.run("new", mr.source_branch)
jj.run("describe", "-m", title)
jj.run("bookmark", "move", mr.source_branch)
else:
jj.run("bookmark", "edit", mr.source_branch)
elif sys.argv[2] in (
# If no MR number/branch is given, insert the current bookmark,
# as the current branch concept doesn't exist in jj
"approve",
"approvers",
"checkout",
"close",
"delete",
"diff",
"issues",
"merge",
"note",
"rebase",
"revoke",
"subscribe",
"todo",
"unsubscribe",
"update",
"view",
):
if len(sys.argv) <= 3 or sys.argv[3].startswith("-"):
bookmark = current_bookmark()
if bookmark:
sys.argv.insert(3, bookmark.name)
to_glab()
else:
to_glab()
elif sys.argv[1] == "push":
# Advance the current branch to the head and push
bookmark = current_bookmark()
if not bookmark:
raise RuntimeError("Couldn't find a current branch")
heads = jj.log("heads(@::)")
if len(heads) != 1:
raise RuntimeError("Multiple heads") # Or none if something goes horribly wrong
head = heads[0]
jj.run("bookmark", "set", bookmark.name, "-r", head.change_id.full)
jj.run("git", "push", "--bookmark", bookmark.name)
# TODO Sign https://github.com/martinvonz/jj/issues/4712
else:
to_glab()
# TODO Autocomplete