Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.

Commit 81fa7db

Browse files
committed
Generic smart stub
1 parent 71dc421 commit 81fa7db

2 files changed

Lines changed: 24 additions & 11 deletions

File tree

packages/jumpstarter/jumpstarter/client/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from google.protobuf import empty_pb2
99
from jumpstarter_protocol import jumpstarter_pb2_grpc
1010

11+
from .grpc import SmartExporterServiceStub
1112
from jumpstarter.client import DriverClient
12-
from jumpstarter.common.grpc import aio_secure_channel
1313
from jumpstarter.common.importlib import import_class
1414

1515

@@ -34,12 +34,12 @@ async def client_from_channel(
3434
reports = {}
3535
clients = OrderedDict()
3636

37-
response = await jumpstarter_pb2_grpc.ExporterServiceStub(channel).GetReport(empty_pb2.Empty())
37+
response = await SmartExporterServiceStub([channel]).GetReport(empty_pb2.Empty())
3838

3939
if use_alternative_endpoints:
4040
for endpoint in response.alternative_endpoints:
4141
if endpoint.certificate:
42-
attempted_channel = aio_secure_channel(
42+
attempted_channel = grpc.aio.secure_channel(
4343
endpoint.endpoint,
4444
grpc.ssl_channel_credentials(
4545
root_certificates=endpoint.certificate.encode(),

packages/jumpstarter/jumpstarter/client/grpc.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from collections import OrderedDict
44
from dataclasses import InitVar, dataclass, field
55
from datetime import datetime, timedelta
6+
from functools import partial
7+
from typing import Generic, Type, TypeVar
68

79
import yaml
810
from google.protobuf import duration_pb2, field_mask_pb2, json_format
911
from grpc import ChannelConnectivity
1012
from grpc.aio import Channel
11-
from jumpstarter_protocol import client_pb2, client_pb2_grpc, jumpstarter_pb2_grpc, kubernetes_pb2
13+
from jumpstarter_protocol import client_pb2, client_pb2_grpc, jumpstarter_pb2_grpc, kubernetes_pb2, router_pb2_grpc
1214
from pydantic import BaseModel, ConfigDict, Field, field_serializer
1315

1416
from jumpstarter.common.grpc import translate_grpc_exceptions
@@ -254,18 +256,19 @@ async def DeleteLease(self, *, name: str):
254256
)
255257

256258

259+
T = TypeVar("T")
260+
261+
257262
@dataclass(frozen=True, slots=True)
258-
class SmartExporterServiceStub:
263+
class SmartStub(Generic[T]):
259264
channels: InitVar[list[Channel]]
265+
cls: InitVar[Type]
260266

261-
__stubs: dict[Channel, jumpstarter_pb2_grpc.ExporterServiceStub] = field(
262-
init=False,
263-
default_factory=OrderedDict,
264-
)
267+
__stubs: dict[Channel, T] = field(init=False, default_factory=OrderedDict)
265268

266-
def __post_init__(self, channels):
269+
def __post_init__(self, channels, cls):
267270
for channel in channels:
268-
self.__stubs[channel] = jumpstarter_pb2_grpc.ExporterServiceStub(channel)
271+
self.__stubs[channel] = cls(channel)
269272

270273
def __getattr__(self, name):
271274
for channel, stub in self.__stubs.items():
@@ -274,3 +277,13 @@ def __getattr__(self, name):
274277
return getattr(stub, name)
275278
# or fallback to the last channel (via router)
276279
return getattr(next(reversed(self.__stubs.values())), name)
280+
281+
282+
SmartExporterServiceStub = partial(
283+
SmartStub[jumpstarter_pb2_grpc.ExporterServiceStub],
284+
cls=jumpstarter_pb2_grpc.ExporterServiceStub,
285+
)
286+
SmartRouterServiceStub = partial(
287+
SmartStub[router_pb2_grpc.RouterServiceStub],
288+
cls=router_pb2_grpc.RouterServiceStub,
289+
)

0 commit comments

Comments
 (0)