33from collections import OrderedDict
44from dataclasses import InitVar , dataclass , field
55from datetime import datetime , timedelta
6+ from functools import partial
7+ from typing import Generic , Type , TypeVar
68
79import yaml
810from google .protobuf import duration_pb2 , field_mask_pb2 , json_format
911from grpc import ChannelConnectivity
1012from grpc .aio import Channel
11- from jumpstarter_protocol import client_pb2 , client_pb2_grpc , jumpstarter_pb2_grpc , kubernetes_pb2
13+ from jumpstarter_protocol import client_pb2 , client_pb2_grpc , jumpstarter_pb2_grpc , kubernetes_pb2 , router_pb2_grpc
1214from pydantic import BaseModel , ConfigDict , Field , field_serializer
1315
1416from jumpstarter .common .grpc import translate_grpc_exceptions
@@ -254,18 +256,19 @@ async def DeleteLease(self, *, name: str):
254256 )
255257
256258
259+ T = TypeVar ("T" )
260+
261+
257262@dataclass (frozen = True , slots = True )
258- class SmartExporterServiceStub :
263+ class SmartStub ( Generic [ T ]) :
259264 channels : InitVar [list [Channel ]]
265+ cls : InitVar [Type ]
260266
261- __stubs : dict [Channel , jumpstarter_pb2_grpc .ExporterServiceStub ] = field (
262- init = False ,
263- default_factory = OrderedDict ,
264- )
267+ __stubs : dict [Channel , T ] = field (init = False , default_factory = OrderedDict )
265268
266- def __post_init__ (self , channels ):
269+ def __post_init__ (self , channels , cls ):
267270 for channel in channels :
268- self .__stubs [channel ] = jumpstarter_pb2_grpc . ExporterServiceStub (channel )
271+ self .__stubs [channel ] = cls (channel )
269272
270273 def __getattr__ (self , name ):
271274 for channel , stub in self .__stubs .items ():
@@ -274,3 +277,13 @@ def __getattr__(self, name):
274277 return getattr (stub , name )
275278 # or fallback to the last channel (via router)
276279 return getattr (next (reversed (self .__stubs .values ())), name )
280+
281+
282+ SmartExporterServiceStub = partial (
283+ SmartStub [jumpstarter_pb2_grpc .ExporterServiceStub ],
284+ cls = jumpstarter_pb2_grpc .ExporterServiceStub ,
285+ )
286+ SmartRouterServiceStub = partial (
287+ SmartStub [router_pb2_grpc .RouterServiceStub ],
288+ cls = router_pb2_grpc .RouterServiceStub ,
289+ )
0 commit comments