-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathcli.py
More file actions
executable file
·1553 lines (1284 loc) · 49.6 KB
/
cli.py
File metadata and controls
executable file
·1553 lines (1284 loc) · 49.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# stack-pr: a tool for working with stacked PRs on github.
#
# ---------------
# stack-pr submit
# ---------------
#
# Semantics:
# 1. Find merge-base (the most recent commit from 'main' in the current branch)
# 2. For each commit since merge base do:
# a. If it doesnt have stack info:
# - create a new head branch for it
# - create a new PR for it
# - base branch will be the previous commit in the stack
# b. If it has stack info: verify its correctness.
# 3. Make sure all commits in the stack are annotated with stack info
# 4. Push all the head branches
#
# If 'submit' succeeds, you'll get all commits annotated with links to the
# corresponding PRs and names of the head branches. All the branches will be
# pushed to remote, and PRs are properly created and interconnected. Base
# branch of each PR will be the head branch of the previous PR, or 'main' for
# the first PR in the stack.
#
# -------------
# stack-pr land
# -------------
#
# Semantics:
# 1. Find merge-base (the most recent commit from 'main' in the current branch)
# 2. Check that all commits in the stack have stack info. If not, bail.
# 3. Check that the stack info is valid. If not, bail.
# 4. For each commit in the stack, from oldest to newest:
# - set base branch to point to main
# - merge the corresponding PR
#
# If 'land' succeeds, all the PRs from the stack will be merged into 'main',
# all the corresponding remote and local branches deleted.
#
# ----------------
# stack-pr abandon
# ----------------
#
# Semantics:
# For all commits in the stack that have valid stack-info:
# Close the corresponding PR, delete the remote and local branch, remove the
# stack-info from commit message.
#
# ===----------------------------------------------------------------------=== #
from __future__ import annotations
import argparse
import configparser
import contextlib
import json
import logging
import os
import re
import sys
from dataclasses import dataclass
from functools import cache
from logging import getLogger
from pathlib import Path
from re import Pattern
from subprocess import SubprocessError
from stack_pr.git import (
branch_exists,
check_gh_installed,
get_current_branch_name,
get_gh_username,
get_uncommitted_changes,
)
from stack_pr.shell_commands import (
get_command_output,
run_shell_command,
)
logger = getLogger(__name__)
# A bunch of regexps for parsing commit messages and PR descriptions
RE_RAW_COMMIT_ID = re.compile(r"^(?P<commit>[a-f0-9]+)$", re.MULTILINE)
RE_RAW_AUTHOR = re.compile(
r"^author (?P<author>(?P<name>[^<]+?) <(?P<email>[^>]+)>)", re.MULTILINE
)
RE_RAW_PARENT = re.compile(r"^parent (?P<commit>[a-f0-9]+)$", re.MULTILINE)
RE_RAW_TREE = re.compile(r"^tree (?P<tree>.+)$", re.MULTILINE)
RE_RAW_COMMIT_MSG_LINE = re.compile(r"^ (?P<line>.*)$", re.MULTILINE)
# stack-info: PR: https://github.com/modularml/test-ghstack/pull/30, branch: mvz/stack/7
RE_STACK_INFO_LINE = re.compile(
r"\n^stack-info: PR: (.+), branch: (.+)\n?", re.MULTILINE
)
RE_PR_TOC = re.compile(
r"^Stacked PRs:\r?\n(^ \* (__->__)?#\d+\r?\n)*\r?\n", re.MULTILINE
)
# Delimeter for PR body
CROSS_LINKS_DELIMETER = "--- --- ---"
# ===----------------------------------------------------------------------=== #
# Error message templates
# ===----------------------------------------------------------------------=== #
ERROR_CANT_UPDATE_META = """Couldn't update stack metadata for
{e}
"""
ERROR_CANT_CREATE_PR = """Could not create a new PR for:
{e}
Failed trying to execute {cmd}
"""
ERROR_CANT_REBASE = """Could not rebase the PR on '{target}'. Failed to land PR:
{e}
Failed trying to execute {cmd}
"""
ERROR_CANT_CHECKOUT_REMOTE_BRANCH = """Could not checkout remote branch '{e.head}'. Failed to land PR:
{e}
Failed trying to execute {cmd}
"""
ERROR_STACKINFO_MISSING = """A stack entry is missing some information:
{e}
If you wanted to land a part of the stack, please use -B and -H options to
specify base and head revisions.
If you wanted to land the entire stack, please use 'submit' first.
If you hit this error trying to submit, please report a bug!
"""
ERROR_STACKINFO_BAD_LINK = """Bad PR link in stack metadata!
{e}
"""
ERROR_STACKINFO_MALFORMED_RESPONSE = """Malformed response from GH!
Returned json object is missing a field {required_field}
PR info from github: {d}
Failed verification for:
{e}
"""
ERROR_STACKINFO_PR_NOT_OPEN = """Associated PR is not in 'OPEN' state!
{e}
PR info from github: {d}
"""
ERROR_STACKINFO_PR_NUMBER_MISMATCH = """PR number on github mismatches PR number in stack metadata!
{e}
PR info from github: {d}
"""
ERROR_STACKINFO_PR_HEAD_MISMATCH = """Head branch name on github mismatches head branch name in stack metadata!
{e}
PR info from github: {d}
"""
ERROR_STACKINFO_PR_BASE_MISMATCH = """Base branch name on github mismatches base branch name in stack metadata!
{e}
If you are trying land the stack, please update it first by calling 'submit'.
PR info from github: {d}
"""
ERROR_STACKINFO_PR_NOT_MERGEABLE = """Associated PR is not mergeable on GitHub!
{e}
Please fix the issues on GitHub.
PR info from github: {d}
"""
ERROR_REPO_DIRTY = """There are uncommitted changes.
Please commit or stash them before working with stacks.
"""
UPDATE_STACK_TIP = """
If you'd like to push your local changes first, you can use the following command to update the stack:
$ stack-pr export -B {top_commit}~{stack_size} -H {top_commit}"""
EXPORT_STACK_TIP = """
You can use the following command to do that:
$ stack-pr export -B {top_commit}~{stack_size} -H {top_commit}
"""
LAND_STACK_TIP = """
To land it, you could run:
$ stack-pr land -B {top_commit}~{stack_size} -H {top_commit}
If you'd like to land stack except the top N commits, you could use the following command:
$ stack-pr land -B {top_commit}~{stack_size} -H {top_commit}~N
If you prefer to merge via the github web UI, please don't forget to edit commit message on the merge page!
If you use the default commit message filled by the web UI, links to other PRs from the stack will be included in the commit message.
"""
# ===----------------------------------------------------------------------=== #
# Class to work with git commit contents
# ===----------------------------------------------------------------------=== #
@dataclass
class CommitHeader:
"""
Represents the information extracted from `git rev-list --header`
"""
# The unparsed output from git rev-list --header
raw_header: str
def _search_group(self, regex: Pattern[str], group: str) -> str:
m = regex.search(self.raw_header)
if m is None:
raise ValueError(
f"Required field '{group}' not found in commit header: {self.raw_header}"
)
return m.group(group)
def tree(self) -> str:
return self._search_group(RE_RAW_TREE, "tree")
def title(self) -> str:
return self._search_group(RE_RAW_COMMIT_MSG_LINE, "line")
def commit_id(self) -> str:
return self._search_group(RE_RAW_COMMIT_ID, "commit")
def parents(self) -> list[str]:
return [m.group("commit") for m in RE_RAW_PARENT.finditer(self.raw_header)]
def author(self) -> str:
return self._search_group(RE_RAW_AUTHOR, "author")
def author_name(self) -> str:
return self._search_group(RE_RAW_AUTHOR, "name")
def author_email(self) -> str:
return self._search_group(RE_RAW_AUTHOR, "email")
def commit_msg(self) -> str:
return "\n".join(
m.group("line") for m in RE_RAW_COMMIT_MSG_LINE.finditer(self.raw_header)
)
# ===----------------------------------------------------------------------=== #
# Class to work with PR stack entries
# ===----------------------------------------------------------------------=== #
@dataclass
class StackEntry:
"""
Represents an entry in a stack of PRs and contains associated info, such as
linked PR, head and base branches, original git commit.
"""
commit: CommitHeader
_pr: str | None = None
_base: str | None = None
_head: str | None = None
need_update: bool = False
@property
def pr(self) -> str:
if self._pr is None:
raise ValueError("pr is not set")
return self._pr
@pr.setter
def pr(self, pr: str) -> None:
self._pr = pr
def has_pr(self) -> bool:
return self._pr is not None
@property
def head(self) -> str:
if self._head is None:
raise ValueError("head is not set")
return self._head
@head.setter
def head(self, head: str) -> None:
self._head = head
def has_head(self) -> bool:
return self._head is not None
@property
def base(self) -> str | None:
return self._base
@base.setter
def base(self, base: str | None) -> None:
self._base = base
def has_base(self) -> bool:
return self._base is not None
def has_missing_info(self) -> bool:
return None in (self._pr, self._head, self._base)
def pprint(self, *, links: bool) -> str:
s = b(self.commit.commit_id()[:8])
pr_string = None
pr_string = blue("#" + last(self.pr)) if self.has_pr() else red("no PR")
branch_string = None
if self._head or self._base:
head_str = green(self._head) if self._head else red(str(self._head))
base_str = green(self._base) if self._base else red(str(self._base))
branch_string = f"'{head_str}' -> '{base_str}'"
if pr_string or branch_string:
s += " ("
s += pr_string if pr_string else ""
if branch_string:
s += ", " if pr_string else ""
s += branch_string
if pr_string or branch_string:
s += ")"
s += ": " + self.commit.title()
if links and self.has_pr():
s = link(self.pr, s)
return s
def __repr__(self) -> str:
return self.pprint(links=False)
def read_metadata(self) -> None:
self.commit.commit_msg()
x = RE_STACK_INFO_LINE.search(self.commit.commit_msg())
if not x:
return
self.pr = x.group(1)
self.head = x.group(2)
# ===----------------------------------------------------------------------=== #
# Utils for color printing
# ===----------------------------------------------------------------------=== #
class ShellColors:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
def b(s: str) -> str:
return ShellColors.BOLD + s + ShellColors.ENDC
def h(s: str) -> str:
return ShellColors.HEADER + s + ShellColors.ENDC
def green(s: str) -> str:
return ShellColors.OKGREEN + s + ShellColors.ENDC
def blue(s: str) -> str:
return ShellColors.OKBLUE + s + ShellColors.ENDC
def red(s: str) -> str:
return ShellColors.FAIL + s + ShellColors.ENDC
# https://gist.github.com/egmontkob/eb114294efbcd5adb1944c9f3cb5feda
def link(location: str, text: str) -> str:
"""
Emits a link to the terminal using the terminal hyperlink specification.
Does not properly implement file URIs. Only use with web URIs.
"""
return f"\033]8;;{location}\033\\{text}\033]8;;\033\\"
def error(msg: str) -> None:
print(red("\nERROR: ") + msg)
def log(msg: str, *, level: int = 1) -> None:
if level <= 1:
print(msg)
elif level == 1:
logger.info(msg)
elif level >= 2: # noqa: PLR2004
logger.debug(msg)
# ===----------------------------------------------------------------------=== #
# Common utility functions
# ===----------------------------------------------------------------------=== #
def split_header(s: str) -> list[CommitHeader]:
return [CommitHeader(h) for h in s.split("\0")[:-1]]
def last(ref: str, sep: str = "/") -> str:
return ref.rsplit(sep, 1)[-1]
# TODO: Move to 'modular.utils.git'
def is_ancestor(commit1: str, commit2: str, *, verbose: bool) -> bool:
"""
Returns true if 'commit1' is an ancestor of 'commit2'.
"""
# TODO: We need to check returncode of this command more carefully, as the
# command simply might fail (rc != 0 and rc != 1).
p = run_shell_command(
["git", "merge-base", "--is-ancestor", commit1, commit2],
check=False,
quiet=not verbose,
)
return p.returncode == 0
def is_repo_clean() -> bool:
"""
Returns true if there are no uncommitted changes in the repo.
"""
changes = get_uncommitted_changes()
changes.pop("??", []) # We don't care about untracked files
return not bool(changes)
def get_stack(base: str, head: str, *, verbose: bool) -> list[StackEntry]:
if not is_ancestor(base, head, verbose=verbose):
error(
f"{base} is not an ancestor of {head}.\n"
"Could not find commits for the stack."
)
sys.exit(1)
# Find list of commits since merge base.
st: list[StackEntry] = []
stack = (
split_header(
get_command_output(["git", "rev-list", "--header", "^" + base, head])
)
)[::-1]
for i in range(len(stack)):
entry = StackEntry(stack[i])
st.append(entry)
for e in st:
e.read_metadata()
return st
def set_base_branches(st: list[StackEntry], target: str) -> None:
prev_branch: str | None = target
for e in st:
e.base, prev_branch = prev_branch, e.head
def verify(st: list[StackEntry], *, check_base: bool = False) -> None:
log(h("Verifying stack info"))
for index, e in enumerate(st):
if e.has_missing_info():
error(ERROR_STACKINFO_MISSING.format(**locals()))
raise RuntimeError
if len(e.pr.split("/")) == 0 or not last(e.pr).isnumeric():
error(ERROR_STACKINFO_BAD_LINK.format(**locals()))
raise RuntimeError
ghinfo = get_command_output(
[
"gh",
"pr",
"view",
e.pr,
"--json",
"baseRefName,headRefName,number,state,body,title,url,mergeStateStatus",
]
)
d = json.loads(ghinfo)
for required_field in ["state", "number", "baseRefName", "headRefName"]:
if required_field not in d:
error(ERROR_STACKINFO_MALFORMED_RESPONSE.format(**locals()))
raise RuntimeError
if d["state"] != "OPEN":
error(ERROR_STACKINFO_PR_NOT_OPEN.format(**locals()))
raise RuntimeError
if int(last(e.pr)) != d["number"]:
error(ERROR_STACKINFO_PR_NUMBER_MISMATCH.format(**locals()))
raise RuntimeError
if e.head != d["headRefName"]:
error(ERROR_STACKINFO_PR_HEAD_MISMATCH.format(**locals()))
raise RuntimeError
# 'Base' branch might diverge when the stack is modified (e.g. when a
# new commit is added to the middle of the stack). It is not an issue
# if we're updating the stack (i.e. in 'submit'), but it is an issue if
# we are trying to land it.
if check_base and e.base != d["baseRefName"]:
error(ERROR_STACKINFO_PR_BASE_MISMATCH.format(**locals()))
raise RuntimeError
# The first entry on the stack needs to be actually mergeable on GitHub.
if (
check_base
and index == 0
and d["mergeStateStatus"] not in ["CLEAN", "UNKNOWN", "UNSTABLE"]
):
error(ERROR_STACKINFO_PR_NOT_MERGEABLE.format(**locals()))
raise RuntimeError
def print_stack(st: list[StackEntry], *, links: bool, level: int = 1) -> None:
log(b("Stack:"), level=level)
for e in reversed(st):
log(" * " + e.pprint(links=links), level=level)
def draft_bitmask_type(value: str) -> list[bool]:
# Validate that only 0s and 1s are present
if value and not set(value).issubset({"0", "1"}):
raise argparse.ArgumentTypeError("Bitmask must only contain 0s and 1s.")
# Convert to list of booleans
return [bool(int(bit)) for bit in value]
@contextlib.contextmanager
def maybe_stash_interactive_rebase() -> Iterator[None]:
"""
If the user is in the middle of an interactive rebase, we stash the
rebase state so that we can restore it later. This is useful when
the user is trying to submit only part of their commit history.
"""
if os.path.exists(".git/rebase-merge"):
try:
assert not os.path.exists(".git/rebase-merge-stashed")
os.rename(".git/rebase-merge", ".git/rebase-merge-stashed")
yield
finally:
os.rename(".git/rebase-merge-stashed", ".git/rebase-merge")
else:
yield
# ===----------------------------------------------------------------------=== #
# SUBMIT
# ===----------------------------------------------------------------------=== #
def add_or_update_metadata(e: StackEntry, *, needs_rebase: bool, verbose: bool) -> bool:
if needs_rebase:
if not e.has_base() or not e.has_head():
error("Stack entry has no base or head branch")
raise RuntimeError
run_shell_command(
[
"git",
"rebase",
e.base or "",
e.head or "",
"--committer-date-is-author-date",
],
quiet=not verbose,
)
else:
if not e.has_head():
error("Stack entry has no head branch")
raise RuntimeError
run_shell_command(["git", "checkout", e.head], quiet=not verbose)
commit_msg = e.commit.commit_msg()
found_metadata = RE_STACK_INFO_LINE.search(commit_msg)
if found_metadata:
# Metadata is already there, skip this commit
return needs_rebase
# Add the stack info metadata to the commit message
commit_msg += f"\n\nstack-info: PR: {e.pr}, branch: {e.head}"
run_shell_command(
["git", "commit", "--amend", "-F", "-"],
input=commit_msg.encode(),
quiet=not verbose,
)
return True
def fix_branch_name_template(branch_name_template: str) -> str:
if "$ID" not in branch_name_template:
return f"{branch_name_template}/$ID"
return branch_name_template
@cache
def get_branch_name_base(branch_name_template: str) -> str:
username = get_gh_username()
current_branch_name = get_current_branch_name()
branch_name_base = branch_name_template.replace("$USERNAME", username)
return branch_name_base.replace("$BRANCH", current_branch_name)
def get_branch_id(branch_name_template: str, branch_name: str) -> str | None:
branch_name_base = get_branch_name_base(branch_name_template)
pattern = branch_name_base.replace(r"$ID", r"(\d+)")
match = re.search(pattern, branch_name)
if match:
return match.group(1)
return None
def generate_branch_name(branch_name_template: str, branch_id: int) -> str:
branch_name_base = get_branch_name_base(branch_name_template)
return branch_name_base.replace(r"$ID", str(branch_id))
def get_taken_branch_ids(refs: list[str], branch_name_template: str) -> list[int]:
branch_ids = [get_branch_id(branch_name_template, ref) for ref in refs]
return [int(branch_id) for branch_id in branch_ids if branch_id is not None]
def generate_available_branch_name(refs: list[str], branch_name_template: str) -> str:
branch_ids = get_taken_branch_ids(refs, branch_name_template)
max_ref_num = max(branch_ids) if branch_ids else 0
new_branch_id = max_ref_num + 1
return generate_branch_name(branch_name_template, new_branch_id)
def get_available_branch_name(remote: str, branch_name_template: str) -> str:
branch_name_base = get_branch_name_base(branch_name_template)
git_command_branch_template = branch_name_base.replace(r"$ID", "*")
refs = get_command_output(
[
"git",
"for-each-ref",
f"refs/remotes/{remote}/{git_command_branch_template}",
"--format='%(refname)'",
]
).split()
refs = [ref.strip("'") for ref in refs]
return generate_available_branch_name(refs, branch_name_template)
def get_next_available_branch_name(branch_name_template: str, name: str) -> str:
branch_id = get_branch_id(branch_name_template, name)
return generate_branch_name(branch_name_template, int(branch_id or 0) + 1)
def set_head_branches(
st: list[StackEntry], remote: str, *, verbose: bool, branch_name_template: str
) -> None:
"""Set the head ref for each stack entry if it doesn't already have one."""
run_shell_command(["git", "fetch", "--prune", remote], quiet=not verbose)
available_name = get_available_branch_name(remote, branch_name_template)
for e in filter(lambda e: not e.has_head(), st):
e.head = available_name
available_name = get_next_available_branch_name(
branch_name_template, available_name
)
def init_local_branches(
st: list[StackEntry], remote: str, *, verbose: bool, branch_name_template: str
) -> None:
log(h("Initializing local branches"))
set_head_branches(
st, remote, verbose=verbose, branch_name_template=branch_name_template
)
for e in st:
run_shell_command(
["git", "checkout", e.commit.commit_id(), "-B", e.head],
quiet=not verbose,
)
def push_branches(st: list[StackEntry], remote: str, *, verbose: bool) -> None:
log(h("Updating remote branches"))
cmd = ["git", "push", "-f", remote]
cmd.extend([f"{e.head}:{e.head}" for e in st])
run_shell_command(cmd, quiet=not verbose)
def print_cmd_failure_details(exc: SubprocessError) -> None:
# Test if SubprocessError subclass has stdout and stderr attributes
if hasattr(exc, "stdout") and exc.stdout:
cmd_stdout = (
exc.stdout.decode("utf-8").replace("\\n", "\n").replace("\\t", "\t")
)
else:
cmd_stdout = None
if hasattr(exc, "stderr") and exc.stderr:
cmd_stderr = (
exc.stderr.decode("utf-8").replace("\\n", "\n").replace("\\t", "\t")
)
else:
cmd_stderr = None
print(f"Exitcode: {exc.returncode if hasattr(exc, 'returncode') else 'unknown'}")
print(f"Stdout: {cmd_stdout}")
print(f"Stderr: {cmd_stderr}")
def create_pr(e: StackEntry, *, is_draft: bool, reviewer: str = "") -> None:
# Don't do anything if the PR already exists
if e.has_pr():
return
if not e.has_base() or not e.has_head():
error("Stack entry has no base or head branch")
raise RuntimeError
log(h("Creating PR " + green(f"'{e.head}' -> '{e.base}'")), level=1)
cmd = [
"gh",
"pr",
"create",
"-B",
e.base or "",
"-H",
e.head or "",
"-t",
e.commit.title(),
"-F",
"-",
]
if reviewer:
cmd.extend(["--reviewer", reviewer])
if is_draft:
cmd.append("--draft")
try:
r = get_command_output(cmd, input=e.commit.commit_msg().encode())
except Exception:
error(ERROR_CANT_CREATE_PR.format(**locals()))
raise
log(b("Created: ") + r, level=2)
e.pr = r.split()[-1]
def generate_toc(st: list[StackEntry], current: str) -> str:
def toc_entry(se: StackEntry) -> str:
pr_id = last(se.pr)
arrow = "__->__" if pr_id == current else ""
return f" * {arrow}#{pr_id}\n"
entries = (toc_entry(se) for se in st[::-1])
return f"Stacked PRs:\n{''.join(entries)}\n"
def get_current_pr_body(e: StackEntry) -> str:
out = get_command_output(
["gh", "pr", "view", e.pr, "--json", "body"],
)
return str(json.loads(out)["body"] or "").strip()
def add_cross_links(st: list[StackEntry], *, keep_body: bool, verbose: bool) -> None:
for e in st:
pr_id = last(e.pr)
pr_toc = generate_toc(st, pr_id)
title = e.commit.title()
body = e.commit.commit_msg()
# Strip title from the body - we will print it separately.
body = "\n".join(body.splitlines()[1:])
# Strip stack-info from the body, nothing interesting there.
body = RE_STACK_INFO_LINE.sub("", body)
pr_body = [
f"{pr_toc}",
f"{CROSS_LINKS_DELIMETER}\n",
]
if keep_body:
# Keep current body of the PR after the cross links component
current_pr_body = get_current_pr_body(e)
pr_body.append(current_pr_body.split(CROSS_LINKS_DELIMETER, 1)[-1].lstrip())
else:
pr_body.extend(
[
f"### {title}",
"",
f"{body}",
]
)
if e.has_base():
run_shell_command(
["gh", "pr", "edit", e.pr, "-t", title, "-F", "-", "-B", e.base or ""],
input="\n".join(pr_body).encode(),
quiet=not verbose,
)
else:
error("Stack entry has no base branch")
raise RuntimeError
# Temporarily set base branches of existing PRs to the bottom of the stack.
# This needs to be done to avoid PRs getting closed when commits are
# rearranged.
#
# For instance, if we first had
#
# Stack:
# * #2 (stack/2 -> stack/1) aaaa
# * #1 (stack/1 -> main) bbbb
#
# And then swapped the order of the commits locally and tried submitting again
# we would have:
#
# Stack:
# * #1 (stack/1 -> main) bbbb
# * #2 (stack/2 -> stack/1) aaaa
#
# Now we need to 1) change bases of the PRs, 2) push branches stack/1 and
# stack/2. If we push stack/1, then PR #2 gets automatically closed, since its
# head branch will contain all the commits from its base branch.
#
# To avoid this, we temporarily set all base branches to point to 'main' - once
# all the branches are pushed we can set the actual base branches.
def reset_remote_base_branches(
st: list[StackEntry], target: str, *, verbose: bool
) -> None:
log(h("Resetting remote base branches"), level=1)
for e in filter(lambda e: e.has_pr(), st):
run_shell_command(["gh", "pr", "edit", e.pr, "-B", target], quiet=not verbose)
# If local 'main' lags behind 'origin/main', and 'head' contains all commits
# from 'main' to 'origin/main', then we can just move 'main' forward.
#
# It is a common user mistake to not update their local branch, run 'submit',
# and end up with a huge stack of changes that are already merged.
# We could've told users to update their local branch in that scenario, but why
# not to do it for them?
# In the very unlikely case when they indeed wanted to include changes that are
# already in remote into their stack, they can use a different notation for the
# base (e.g. explicit hash of the commit) - but most probably nobody ever would
# need that.
def should_update_local_base(
head: str, base: str, remote: str, target: str, *, verbose: bool
) -> bool:
base_hash = get_command_output(["git", "rev-parse", base])
target_hash = get_command_output(["git", "rev-parse", f"{remote}/{target}"])
return (
is_ancestor(base, f"{remote}/{target}", verbose=verbose)
and is_ancestor(f"{remote}/{target}", head, verbose=verbose)
and base_hash != target_hash
)
def update_local_base(base: str, remote: str, target: str, *, verbose: bool) -> None:
log(h(f"Updating local branch {base} to {remote}/{target}"), level=1)
run_shell_command(["git", "rebase", f"{remote}/{target}", base], quiet=not verbose)
@dataclass
class CommonArgs:
"""Class to help type checkers and separate implementation for CLI args."""
base: str
head: str
remote: str
target: str
hyperlinks: bool
verbose: bool
branch_name_template: str
@classmethod
def from_args(cls, args: argparse.Namespace) -> CommonArgs:
return cls(
args.base,
args.head,
args.remote,
args.target,
args.hyperlinks,
args.verbose,
args.branch_name_template,
)
def deduce_base(args: CommonArgs) -> CommonArgs:
"""Deduce the base branch from the head and target branches.
If the base isn't explicitly specified, find the merge base between
'origin/main' and 'head'.
E.g. in the example below we want to include commits E and F into the stack,
and to do that we pick B as our base:
--> a ----> b ----> c ----> d
(main) \\ (origin/main)
\\
---> e ----> f
(head)
"""
if args.base:
return args
deduced_base = get_command_output(
["git", "merge-base", args.head, f"{args.remote}/{args.target}"]
)
return CommonArgs(
deduced_base,
args.head,
args.remote,
args.target,
args.hyperlinks,
args.verbose,
args.branch_name_template,
)
def print_tips_after_export(st: list[StackEntry], args: CommonArgs) -> None:
stack_size = len(st)
if stack_size == 0:
return
top_commit = args.head
if top_commit == "HEAD":
top_commit = get_current_branch_name()
log(b("\nOnce the stack is reviewed, it is ready to land!"), level=1)
log(LAND_STACK_TIP.format(**locals()))
# ===----------------------------------------------------------------------=== #
# Entry point for 'submit' command
# ===----------------------------------------------------------------------=== #
def command_submit(
args: CommonArgs,
*,
draft: bool,
reviewer: str,
keep_body: bool,
draft_bitmask: list[bool] | None = None,
) -> None:
"""Entry point for 'submit' command.
Args:
args: CommonArgs object containing command line arguments.
draft: Boolean flag indicating if the PRs should be created as drafts.
reviewer: String representing the reviewer of the PRs.
keep_body: Boolean flag indicating if the body of the PRs should be kept.
draft_bitmask: List of boolean values indicating if each PR should be created as
a draft.
"""
log(h("SUBMIT"), level=1)
current_branch = get_current_branch_name()
if should_update_local_base(
head=args.head,
base=args.base,
remote=args.remote,
target=args.target,
verbose=args.verbose,
):
update_local_base(
base=args.base, remote=args.remote, target=args.target, verbose=args.verbose
)
run_shell_command(["git", "checkout", current_branch], quiet=not args.verbose)
# Determine what commits belong to the stack
st = get_stack(base=args.base, head=args.head, verbose=args.verbose)
if not st:
log(h("Empty stack!"))
log(h(blue("SUCCESS!")))
return
if (draft_bitmask is not None) and (len(draft_bitmask) != len(st)):
log(h("Draft bitmask passed to 'submit' doesn't match number of PRs!"))
return
# Create local branches and initialize base and head fields in the stack
# elements
init_local_branches(
st,
args.remote,
verbose=args.verbose,
branch_name_template=args.branch_name_template,
)
set_base_branches(st, args.target)
print_stack(st, links=args.hyperlinks)
# If the current branch contains commits from the stack, we will need to
# rebase it in the end since the commits will be modified.
top_branch = st[-1].head
need_to_rebase_current = is_ancestor(
top_branch, current_branch, verbose=args.verbose
)
reset_remote_base_branches(st, target=args.target, verbose=args.verbose)
# Push local branches to remote