99import pprint
1010import subprocess
1111import sys
12-
13- # Used indirectly in the below Jinja2 block
14- from collections import OrderedDict # pylint: disable=unused-import
12+ from collections import OrderedDict
1513from logging import basicConfig , getLogger
1614from pathlib import Path
1715
18- import git
1916import yaml
20- from cookiecutter .repository import expand_abbreviations
2117
2218LOG_FORMAT = json .dumps (
2319 {
3531
3632def get_context () -> dict :
3733 """Return the context as a dict"""
34+ import git
35+ from cookiecutter .repository import expand_abbreviations
36+
3837 cookiecutter = None
3938 timestamp = datetime .datetime .now (datetime .UTC ).isoformat (timespec = "seconds" )
4039
4140 ##############
4241 # This section leverages cookiecutter's jinja interpolation
43- # pylint: disable-next=unhashable-member
4442 cookiecutter_context_ordered : OrderedDict [str , str ] = {{cookiecutter | pprint }} # type: ignore
4543 cookiecutter_context : dict [str , str ] = dict (cookiecutter_context_ordered )
46-
47- project_name = cookiecutter_context ["project_slug" ] # pylint: disable=unsubscriptable-object
48- project_description = cookiecutter_context ["project_short_description" ] # pylint: disable=unsubscriptable-object
49- template = cookiecutter_context ["_template" ] # pylint: disable=unsubscriptable-object
50- output = cookiecutter_context ["_output_dir" ] # pylint: disable=unsubscriptable-object
5144 ##############
5245
53- try :
54- if Path (template ).is_absolute ():
55- template_path : Path = Path (template ).resolve ()
56- else :
57- output_path : Path = Path (output ).resolve ()
58- template_path : Path = output_path .joinpath (template )
59-
60- # IMPORTANT: If the specified template is remote (http/git/ssh) this SHOULD raise an exception. The remote logic is in the except block
61- repo : git .Repo = git .Repo (template_path )
46+ project_name = cookiecutter_context ["project_slug" ]
47+ project_description = cookiecutter_context ["project_short_description" ]
48+ template = cookiecutter_context ["_template" ]
49+ output = cookiecutter_context ["_output_dir" ]
50+ # Get the branch specified via --checkout, but fall back to main
51+ branch = cookiecutter_context .get ("_checkout" ) or "main"
6252
63- # Expect this is a local template
64- branch : str = str (repo .active_branch )
65- dirty : bool = repo .is_dirty (untracked_files = True )
66- template_commit_hash = git .cmd .Git ().ls_remote (template_path , "HEAD" )[:40 ]
67- except (git .exc .InvalidGitRepositoryError , git .exc .NoSuchPathError ):
68- # This exception handling occurs every time the template repo is remote
53+ # Check if template is a remote URL or abbreviation
54+ is_remote_template = any (
55+ template .startswith (prefix ) for prefix in ["http://" , "https://" , "git@" , "gh:" , "gl:" , "bb:" ]
56+ )
6957
58+ if is_remote_template :
7059 # From https://github.com/cookiecutter/cookiecutter/blob/b4451231809fb9e4fc2a1e95d433cb030e4b9e06/cookiecutter/config.py#L22
7160 abbreviations : dict [str , str ] = {
7261 "gh" : "https://github.com/{0}.git" ,
@@ -75,11 +64,36 @@ def get_context() -> dict:
7564 }
7665 template_repo : str = expand_abbreviations (template , abbreviations )
7766
78- # This currently assumes main until https://github.com/cookiecutter/cookiecutter/issues/1759 is resolved
79- branch : str = "main"
8067 dirty : bool = False
8168
69+ # For remote templates, get the commit hash from the remote
8270 template_commit_hash = git .cmd .Git ().ls_remote (template_repo , branch )[:40 ]
71+ # Store the expanded URL as the template location
72+ template_location = template_repo
73+ else :
74+ # This is a local template path
75+ if Path (template ).is_absolute ():
76+ template_path : Path = Path (template ).resolve ()
77+ else :
78+ output_path : Path = Path (output ).resolve ()
79+ template_path : Path = output_path .joinpath (template ).resolve ()
80+
81+ try :
82+ repo : git .Repo = git .Repo (template_path )
83+
84+ # Get info from the local repository
85+ branch : str = str (repo .active_branch )
86+ dirty : bool = repo .is_dirty (untracked_files = True )
87+ # Get the actual commit hash from the local repository
88+ template_commit_hash = repo .head .commit .hexsha
89+ # Store the fully qualified template path for local templates
90+ template_location = str (template_path )
91+ except (git .exc .InvalidGitRepositoryError , git .exc .NoSuchPathError ):
92+ # Not a git repository, fall back to unknown values
93+ branch = "unknown"
94+ dirty = False
95+ template_commit_hash = "unknown"
96+ template_location = str (template_path )
8397
8498 context : dict [str , str | dict [str , str | bool | dict [str , str | bool | dict [str , str ]]]] = {}
8599 context ["name" ] = project_name
@@ -91,12 +105,12 @@ def get_context() -> dict:
91105 context ["origin" ]["template" ]["branch" ] = branch
92106 context ["origin" ]["template" ]["commit hash" ] = template_commit_hash
93107 context ["origin" ]["template" ]["dirty" ] = dirty
94- context ["origin" ]["template" ]["location" ] = template
108+ context ["origin" ]["template" ]["location" ] = template_location
95109 context ["origin" ]["template" ]["cookiecutter" ] = {}
96110 context ["origin" ]["template" ]["cookiecutter" ] = cookiecutter_context
97111
98112 # Filter out unwanted cookiecutter context
99- del cookiecutter_context ["_output_dir" ] # pylint: disable=unsubscriptable-object
113+ del cookiecutter_context ["_output_dir" ]
100114
101115 return context
102116
0 commit comments