Skip to content
This repository was archived by the owner on Mar 26, 2026. It is now read-only.

Commit fbad3e1

Browse files
committed
use a context manager instead of passing args
1 parent 2537d98 commit fbad3e1

5 files changed

Lines changed: 59 additions & 65 deletions

File tree

gapic/cli/generate.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from gapic import generator
2424
from gapic.schema import api
2525
from gapic.utils import Options
26-
26+
from gapic.utils.cache import generation_cache_context
2727

2828
@click.command()
2929
@click.option(
@@ -56,15 +56,16 @@ def generate(request: typing.BinaryIO, output: typing.BinaryIO) -> None:
5656
[p.package for p in req.proto_file if p.name in req.file_to_generate]
5757
).rstrip(".")
5858

59-
# Build the API model object.
60-
# This object is a frozen representation of the whole API, and is sent
61-
# to each template in the rendering step.
62-
api_schema = api.API.build(req.proto_file, opts=opts, package=package)
59+
with generation_cache_context():
60+
# Build the API model object.
61+
# This object is a frozen representation of the whole API, and is sent
62+
# to each template in the rendering step.
63+
api_schema = api.API.build(req.proto_file, opts=opts, package=package)
6364

64-
# Translate into a protobuf CodeGeneratorResponse; this reads the
65-
# individual templates and renders them.
66-
# If there are issues, error out appropriately.
67-
res = generator.Generator(opts).get_response(api_schema, opts)
65+
# Translate into a protobuf CodeGeneratorResponse; this reads the
66+
# individual templates and renders them.
67+
# If there are issues, error out appropriately.
68+
res = generator.Generator(opts).get_response(api_schema, opts)
6869

6970
# Output the serialized response.
7071
output.write(res.SerializeToString())

gapic/schema/api.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def build(
115115
prior_protos: Optional[Mapping[str, "Proto"]] = None,
116116
load_services: bool = True,
117117
all_resources: Optional[Mapping[str, wrappers.MessageType]] = None,
118-
context_cache: Optional[Dict] = None,
119118
) -> "Proto":
120119
"""Build and return a Proto instance.
121120
@@ -140,7 +139,6 @@ def build(
140139
prior_protos=prior_protos or {},
141140
load_services=load_services,
142141
all_resources=all_resources or {},
143-
context_cache=context_cache,
144142
).proto
145143

146144
@cached_property
@@ -456,7 +454,6 @@ def disambiguate_keyword_sanitize_fname(
456454
# type into the proto file that defines an LRO.
457455
# We just load all the APIs types first and then
458456
# load the services and methods with the full scope of types.
459-
context_cache = {}
460457
pre_protos: Dict[str, Proto] = dict(prior_protos or {})
461458
for fd in file_descriptors:
462459
fd.name = disambiguate_keyword_sanitize_fname(fd.name, pre_protos)
@@ -468,7 +465,6 @@ def disambiguate_keyword_sanitize_fname(
468465
prior_protos=pre_protos,
469466
# Ugly, ugly hack.
470467
load_services=False,
471-
context_cache=context_cache,
472468
)
473469

474470
# A file descriptor's file-level resources are NOT visible to any importers.
@@ -489,7 +485,6 @@ def disambiguate_keyword_sanitize_fname(
489485
opts=opts,
490486
prior_protos=pre_protos,
491487
all_resources=MappingProxyType(all_file_resources),
492-
context_cache=context_cache,
493488
)
494489
for name, proto in pre_protos.items()
495490
}
@@ -1108,7 +1103,6 @@ def __init__(
11081103
prior_protos: Optional[Mapping[str, Proto]] = None,
11091104
load_services: bool = True,
11101105
all_resources: Optional[Mapping[str, wrappers.MessageType]] = None,
1111-
context_cache: Optional[Dict] = None,
11121106
):
11131107
self.proto_messages: Dict[str, wrappers.MessageType] = {}
11141108
self.proto_enums: Dict[str, wrappers.EnumType] = {}
@@ -1117,7 +1111,6 @@ def __init__(
11171111
self.file_to_generate = file_to_generate
11181112
self.prior_protos = prior_protos or {}
11191113
self.opts = opts
1120-
self.context_cache = context_cache
11211114

11221115
# Iterate over the documentation and place it into a dictionary.
11231116
#
@@ -1227,28 +1220,20 @@ def proto(self) -> Proto:
12271220
if not self.file_to_generate:
12281221
return naive
12291222

1230-
global_collisions = frozenset(naive.names)
12311223
visited_messages: Set[wrappers.MessageType] = set()
1232-
self.context_cache = {}
12331224
# Return a context-aware proto object.
12341225
return dataclasses.replace(
12351226
naive,
12361227
all_enums=collections.OrderedDict(
1237-
(
1238-
k,
1239-
v.with_context(
1240-
collisions=global_collisions, context_cache=self.context_cache
1241-
),
1242-
)
1228+
(k, v.with_context(collisions=naive.names))
12431229
for k, v in naive.all_enums.items()
12441230
),
12451231
all_messages=collections.OrderedDict(
12461232
(
12471233
k,
12481234
v.with_context(
1249-
collisions=global_collisions,
1235+
collisions=naive.names,
12501236
visited_messages=visited_messages,
1251-
context_cache=self.context_cache,
12521237
),
12531238
)
12541239
for k, v in naive.all_messages.items()
@@ -1261,14 +1246,11 @@ def proto(self) -> Proto:
12611246
v.with_context(
12621247
collisions=v.names,
12631248
visited_messages=visited_messages,
1264-
context_cache=self.context_cache,
12651249
),
12661250
)
12671251
for k, v in naive.services.items()
12681252
),
1269-
meta=naive.meta.with_context(
1270-
collisions=naive.names, context_cache=self.context_cache
1271-
),
1253+
meta=naive.meta.with_context(collisions=naive.names),
12721254
)
12731255

12741256
@cached_property

gapic/schema/metadata.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def resolve(self, selector: str) -> str:
362362

363363
@cached_proto_context
364364
def with_context(
365-
self, *, collisions: Set[str], context_cache: Optional[Dict] = None
365+
self, *, collisions: Set[str]
366366
) -> "Address":
367367
"""Return a derivative of this address with the provided context.
368368
@@ -404,7 +404,7 @@ def doc(self):
404404

405405
@cached_proto_context
406406
def with_context(
407-
self, *, collisions: Set[str], context_cache: Optional[Dict] = None
407+
self, *, collisions: Set[str]
408408
) -> "Metadata":
409409
"""Return a derivative of this metadata with the provided context.
410410
@@ -416,7 +416,7 @@ def with_context(
416416
dataclasses.replace(
417417
self,
418418
address=self.address.with_context(
419-
collisions=collisions, context_cache=context_cache
419+
collisions=collisions
420420
),
421421
)
422422
if collisions and collisions != self.address.collisions

gapic/schema/wrappers.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,6 @@ def with_context(
417417
*,
418418
collisions: Set[str],
419419
visited_messages: Optional[Set["MessageType"]] = None,
420-
context_cache: Optional[Dict] = None,
421420
) -> "Field":
422421
"""Return a derivative of this field with the provided context.
423422
@@ -434,20 +433,19 @@ def with_context(
434433
self.message in visited_messages if visited_messages else False
435434
),
436435
visited_messages=visited_messages,
437-
context_cache=context_cache,
438436
)
439437
if self.message
440438
else None
441439
),
442440
enum=(
443441
self.enum.with_context(
444-
collisions=collisions, context_cache=context_cache
442+
collisions=collisions
445443
)
446444
if self.enum
447445
else None
448446
),
449447
meta=self.meta.with_context(
450-
collisions=collisions, context_cache=context_cache
448+
collisions=collisions,
451449
),
452450
)
453451

@@ -751,7 +749,6 @@ def path_regex_str(self) -> str:
751749
def get_field(
752750
self,
753751
*field_path: str,
754-
context_cache: Optional[Dict] = None,
755752
collisions: Optional[Set[str]] = None,
756753
) -> Field:
757754
"""Return a field arbitrarily deep in this message's structure.
@@ -796,7 +793,6 @@ def get_field(
796793
return cursor.with_context(
797794
collisions=collisions,
798795
visited_messages=set({self}),
799-
context_cache=context_cache,
800796
)
801797

802798
# Quick check: If cursor is a repeated field, then raise an exception.
@@ -828,7 +824,6 @@ def with_context(
828824
collisions: Set[str],
829825
skip_fields: bool = False,
830826
visited_messages: Optional[Set["MessageType"]] = None,
831-
context_cache: Optional[Dict] = None,
832827
) -> "MessageType":
833828
"""Return a derivative of this message with the provided context.
834829
@@ -849,28 +844,26 @@ def with_context(
849844
k: v.with_context(
850845
collisions=collisions,
851846
visited_messages=visited_messages,
852-
context_cache=context_cache,
853847
)
854848
for k, v in self.fields.items()
855849
}
856850
if not skip_fields
857851
else self.fields
858852
),
859853
nested_enums={
860-
k: v.with_context(collisions=collisions, context_cache=context_cache)
854+
k: v.with_context(collisions=collisions)
861855
for k, v in self.nested_enums.items()
862856
},
863857
nested_messages={
864858
k: v.with_context(
865859
collisions=collisions,
866860
skip_fields=skip_fields,
867861
visited_messages=visited_messages,
868-
context_cache=context_cache,
869862
)
870863
for k, v in self.nested_messages.items()
871864
},
872865
meta=self.meta.with_context(
873-
collisions=collisions, context_cache=context_cache
866+
collisions=collisions
874867
),
875868
)
876869

@@ -965,7 +958,6 @@ def with_context(
965958
self,
966959
*,
967960
collisions: Set[str],
968-
context_cache: Optional[Dict] = None,
969961
) -> "EnumType":
970962
"""Return a derivative of this enum with the provided context.
971963
@@ -977,7 +969,7 @@ def with_context(
977969
dataclasses.replace(
978970
self,
979971
meta=self.meta.with_context(
980-
collisions=collisions, context_cache=context_cache
972+
collisions=collisions
981973
),
982974
)
983975
if collisions
@@ -1095,7 +1087,6 @@ def with_context(
10951087
*,
10961088
collisions: Set[str],
10971089
visited_messages: Optional[Set["MessageType"]] = None,
1098-
context_cache: Optional[Dict] = None,
10991090
) -> "ExtendedOperationInfo":
11001091
"""Return a derivative of this OperationInfo with the provided context.
11011092
@@ -1111,12 +1102,10 @@ def with_context(
11111102
request_type=self.request_type.with_context(
11121103
collisions=collisions,
11131104
visited_messages=visited_messages,
1114-
context_cache=context_cache,
11151105
),
11161106
operation_type=self.operation_type.with_context(
11171107
collisions=collisions,
11181108
visited_messages=visited_messages,
1119-
context_cache=context_cache,
11201109
),
11211110
)
11221111
)
@@ -1168,7 +1157,6 @@ def with_context(
11681157
*,
11691158
collisions: Set[str],
11701159
visited_messages: Optional[Set["MessageType"]] = None,
1171-
context_cache: Optional[Dict] = None,
11721160
) -> "OperationInfo":
11731161
"""Return a derivative of this OperationInfo with the provided context.
11741162
@@ -1181,12 +1169,10 @@ def with_context(
11811169
response_type=self.response_type.with_context(
11821170
collisions=collisions,
11831171
visited_messages=visited_messages,
1184-
context_cache=context_cache,
11851172
),
11861173
metadata_type=self.metadata_type.with_context(
11871174
collisions=collisions,
11881175
visited_messages=visited_messages,
1189-
context_cache=context_cache,
11901176
),
11911177
)
11921178

@@ -1982,7 +1968,6 @@ def with_context(
19821968
*,
19831969
collisions: Set[str],
19841970
visited_messages: Optional[Set["MessageType"]] = None,
1985-
context_cache: Optional[Dict] = None,
19861971
) -> "Method":
19871972
"""Return a derivative of this method with the provided context.
19881973
@@ -1996,7 +1981,6 @@ def with_context(
19961981
self.lro.with_context(
19971982
collisions=collisions,
19981983
visited_messages=visited_messages,
1999-
context_cache=context_cache,
20001984
)
20011985
if collisions
20021986
else self.lro
@@ -2006,7 +1990,6 @@ def with_context(
20061990
self.extended_lro.with_context(
20071991
collisions=collisions,
20081992
visited_messages=visited_messages,
2009-
context_cache=context_cache,
20101993
)
20111994
if self.extended_lro
20121995
else None
@@ -2019,15 +2002,13 @@ def with_context(
20192002
input=self.input.with_context(
20202003
collisions=collisions,
20212004
visited_messages=visited_messages,
2022-
context_cache=context_cache,
20232005
),
20242006
output=self.output.with_context(
20252007
collisions=collisions,
20262008
visited_messages=visited_messages,
2027-
context_cache=context_cache,
20282009
),
20292010
meta=self.meta.with_context(
2030-
collisions=collisions, context_cache=context_cache
2011+
collisions=collisions,
20312012
),
20322013
)
20332014

@@ -2410,7 +2391,6 @@ def with_context(
24102391
*,
24112392
collisions: Set[str],
24122393
visited_messages: Optional[Set["MessageType"]] = None,
2413-
context_cache: Optional[Dict] = None,
24142394
) -> "Service":
24152395
"""Return a derivative of this service with the provided context.
24162396
@@ -2426,12 +2406,11 @@ def with_context(
24262406
# that may conflict with module imports.
24272407
collisions=collisions | set(v.flattened_fields.keys()),
24282408
visited_messages=visited_messages,
2429-
context_cache=context_cache,
24302409
)
24312410
for k, v in self.methods.items()
24322411
},
24332412
meta=self.meta.with_context(
2434-
collisions=collisions, context_cache=context_cache
2413+
collisions=collisions,
24352414
),
24362415
)
24372416

0 commit comments

Comments
 (0)