Skip to content

Commit 60271c5

Browse files
authored
fix(post-gen-hook): populate project.yml with more accurate details (#26)
1 parent a4709d5 commit 60271c5

1 file changed

Lines changed: 44 additions & 30 deletions

File tree

hooks/post_gen_project.py

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,11 @@
99
import pprint
1010
import subprocess
1111
import sys
12-
13-
# Used indirectly in the below Jinja2 block
14-
from collections import OrderedDict # pylint: disable=unused-import
12+
from collections import OrderedDict
1513
from logging import basicConfig, getLogger
1614
from pathlib import Path
1715

18-
import git
1916
import yaml
20-
from cookiecutter.repository import expand_abbreviations
2117

2218
LOG_FORMAT = json.dumps(
2319
{
@@ -35,38 +31,31 @@
3531

3632
def get_context() -> dict:
3733
"""Return the context as a dict"""
34+
import git
35+
from cookiecutter.repository import expand_abbreviations
36+
3837
cookiecutter = None
3938
timestamp = datetime.datetime.now(datetime.UTC).isoformat(timespec="seconds")
4039

4140
##############
4241
# This section leverages cookiecutter's jinja interpolation
43-
# pylint: disable-next=unhashable-member
4442
cookiecutter_context_ordered: OrderedDict[str, str] = {{cookiecutter | pprint}} # type: ignore
4543
cookiecutter_context: dict[str, str] = dict(cookiecutter_context_ordered)
46-
47-
project_name = cookiecutter_context["project_slug"] # pylint: disable=unsubscriptable-object
48-
project_description = cookiecutter_context["project_short_description"] # pylint: disable=unsubscriptable-object
49-
template = cookiecutter_context["_template"] # pylint: disable=unsubscriptable-object
50-
output = cookiecutter_context["_output_dir"] # pylint: disable=unsubscriptable-object
5144
##############
5245

53-
try:
54-
if Path(template).is_absolute():
55-
template_path: Path = Path(template).resolve()
56-
else:
57-
output_path: Path = Path(output).resolve()
58-
template_path: Path = output_path.joinpath(template)
59-
60-
# IMPORTANT: If the specified template is remote (http/git/ssh) this SHOULD raise an exception. The remote logic is in the except block
61-
repo: git.Repo = git.Repo(template_path)
46+
project_name = cookiecutter_context["project_slug"]
47+
project_description = cookiecutter_context["project_short_description"]
48+
template = cookiecutter_context["_template"]
49+
output = cookiecutter_context["_output_dir"]
50+
# Get the branch specified via --checkout, but fall back to main
51+
branch = cookiecutter_context.get("_checkout") or "main"
6252

63-
# Expect this is a local template
64-
branch: str = str(repo.active_branch)
65-
dirty: bool = repo.is_dirty(untracked_files=True)
66-
template_commit_hash = git.cmd.Git().ls_remote(template_path, "HEAD")[:40]
67-
except (git.exc.InvalidGitRepositoryError, git.exc.NoSuchPathError):
68-
# This exception handling occurs every time the template repo is remote
53+
# Check if template is a remote URL or abbreviation
54+
is_remote_template = any(
55+
template.startswith(prefix) for prefix in ["http://", "https://", "git@", "gh:", "gl:", "bb:"]
56+
)
6957

58+
if is_remote_template:
7059
# From https://github.com/cookiecutter/cookiecutter/blob/b4451231809fb9e4fc2a1e95d433cb030e4b9e06/cookiecutter/config.py#L22
7160
abbreviations: dict[str, str] = {
7261
"gh": "https://github.com/{0}.git",
@@ -75,11 +64,36 @@ def get_context() -> dict:
7564
}
7665
template_repo: str = expand_abbreviations(template, abbreviations)
7766

78-
# This currently assumes main until https://github.com/cookiecutter/cookiecutter/issues/1759 is resolved
79-
branch: str = "main"
8067
dirty: bool = False
8168

69+
# For remote templates, get the commit hash from the remote
8270
template_commit_hash = git.cmd.Git().ls_remote(template_repo, branch)[:40]
71+
# Store the expanded URL as the template location
72+
template_location = template_repo
73+
else:
74+
# This is a local template path
75+
if Path(template).is_absolute():
76+
template_path: Path = Path(template).resolve()
77+
else:
78+
output_path: Path = Path(output).resolve()
79+
template_path: Path = output_path.joinpath(template).resolve()
80+
81+
try:
82+
repo: git.Repo = git.Repo(template_path)
83+
84+
# Get info from the local repository
85+
branch: str = str(repo.active_branch)
86+
dirty: bool = repo.is_dirty(untracked_files=True)
87+
# Get the actual commit hash from the local repository
88+
template_commit_hash = repo.head.commit.hexsha
89+
# Store the fully qualified template path for local templates
90+
template_location = str(template_path)
91+
except (git.exc.InvalidGitRepositoryError, git.exc.NoSuchPathError):
92+
# Not a git repository, fall back to unknown values
93+
branch = "unknown"
94+
dirty = False
95+
template_commit_hash = "unknown"
96+
template_location = str(template_path)
8397

8498
context: dict[str, str | dict[str, str | bool | dict[str, str | bool | dict[str, str]]]] = {}
8599
context["name"] = project_name
@@ -91,12 +105,12 @@ def get_context() -> dict:
91105
context["origin"]["template"]["branch"] = branch
92106
context["origin"]["template"]["commit hash"] = template_commit_hash
93107
context["origin"]["template"]["dirty"] = dirty
94-
context["origin"]["template"]["location"] = template
108+
context["origin"]["template"]["location"] = template_location
95109
context["origin"]["template"]["cookiecutter"] = {}
96110
context["origin"]["template"]["cookiecutter"] = cookiecutter_context
97111

98112
# Filter out unwanted cookiecutter context
99-
del cookiecutter_context["_output_dir"] # pylint: disable=unsubscriptable-object
113+
del cookiecutter_context["_output_dir"]
100114

101115
return context
102116

0 commit comments

Comments
 (0)