Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.

Commit f2cea90

Browse files
committed
Allow serving alternative endpoints on exporter
1 parent 892b71a commit f2cea90

9 files changed

Lines changed: 175 additions & 24 deletions

File tree

packages/jumpstarter/jumpstarter/client/client.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import grpc
77
from anyio.from_thread import BlockingPortal
88
from google.protobuf import empty_pb2
9-
from jumpstarter_protocol import jumpstarter_pb2_grpc
109

10+
from .grpc import SmartExporterStub
1111
from jumpstarter.client import DriverClient
1212
from jumpstarter.common.importlib import import_class
1313

@@ -26,13 +26,31 @@ async def client_from_channel(
2626
stack: ExitStack,
2727
allow: list[str],
2828
unsafe: bool,
29+
use_alternative_endpoints: bool = True,
2930
) -> DriverClient:
3031
topo = defaultdict(list)
3132
last_seen = {}
3233
reports = {}
3334
clients = OrderedDict()
3435

35-
response = await jumpstarter_pb2_grpc.ExporterServiceStub(channel).GetReport(empty_pb2.Empty())
36+
response = await SmartExporterStub([channel]).GetReport(empty_pb2.Empty())
37+
38+
channels = [channel]
39+
if use_alternative_endpoints:
40+
for endpoint in response.alternative_endpoints:
41+
if endpoint.certificate:
42+
channels.append(
43+
grpc.aio.secure_channel(
44+
endpoint.endpoint,
45+
grpc.ssl_channel_credentials(
46+
root_certificates=endpoint.certificate.encode(),
47+
private_key=endpoint.client_private_key.encode(),
48+
certificate_chain=endpoint.client_certificate.encode(),
49+
),
50+
)
51+
)
52+
53+
stub = SmartExporterStub(list(reversed(channels)))
3654

3755
for index, report in enumerate(response.reports):
3856
topo[index] = []
@@ -52,7 +70,7 @@ async def client_from_channel(
5270
client = client_class(
5371
uuid=UUID(report.uuid),
5472
labels=report.labels,
55-
channel=channel,
73+
stub=stub,
5674
portal=portal,
5775
stack=stack.enter_context(ExitStack()),
5876
children={reports[k].labels["jumpstarter.dev/name"]: clients[k] for k in topo[index]},

packages/jumpstarter/jumpstarter/client/core.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
import logging
66
from contextlib import asynccontextmanager
77
from dataclasses import dataclass, field
8+
from typing import Any
89

910
from anyio import create_task_group
1011
from google.protobuf import empty_pb2
1112
from grpc import StatusCode
12-
from grpc.aio import AioRpcError, Channel
13+
from grpc.aio import AioRpcError
1314
from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc, router_pb2_grpc
1415

1516
from jumpstarter.common import Metadata
@@ -60,16 +61,14 @@ class AsyncDriverClient(
6061
Backing implementation of blocking driver client.
6162
"""
6263

63-
channel: Channel
64+
stub: Any
6465

6566
log_level: str = "INFO"
6667
logger: logging.Logger = field(init=False)
6768

6869
def __post_init__(self):
6970
if hasattr(super(), "__post_init__"):
7071
super().__post_init__()
71-
jumpstarter_pb2_grpc.ExporterServiceStub.__init__(self, self.channel)
72-
router_pb2_grpc.RouterServiceStub.__init__(self, self.channel)
7372
self.logger = logging.getLogger(self.__class__.__name__)
7473
self.logger.setLevel(self.log_level)
7574

@@ -89,7 +88,7 @@ async def call_async(self, method, *args):
8988
)
9089

9190
try:
92-
response = await self.DriverCall(request)
91+
response = await self.stub.DriverCall(request)
9392
except AioRpcError as e:
9493
match e.code():
9594
case StatusCode.UNIMPLEMENTED:
@@ -113,7 +112,7 @@ async def streamingcall_async(self, method, *args):
113112
)
114113

115114
try:
116-
async for response in self.StreamingDriverCall(request):
115+
async for response in self.stub.StreamingDriverCall(request):
117116
yield decode_value(response.result)
118117
except AioRpcError as e:
119118
match e.code():
@@ -128,7 +127,7 @@ async def streamingcall_async(self, method, *args):
128127

129128
@asynccontextmanager
130129
async def stream_async(self, method):
131-
context = self.Stream(
130+
context = self.stub.Stream(
132131
metadata=StreamRequestMetadata.model_construct(request=DriverStreamRequest(uuid=self.uuid, method=method))
133132
.model_dump(mode="json", round_trip=True)
134133
.items(),
@@ -142,7 +141,7 @@ async def resource_async(
142141
self,
143142
stream,
144143
):
145-
context = self.Stream(
144+
context = self.stub.Stream(
146145
metadata=StreamRequestMetadata.model_construct(request=ResourceStreamRequest(uuid=self.uuid))
147146
.model_dump(mode="json", round_trip=True)
148147
.items(),
@@ -160,7 +159,7 @@ def __log(self, level: int, msg: str):
160159
@asynccontextmanager
161160
async def log_stream_async(self):
162161
async def log_stream():
163-
async for response in self.LogStream(empty_pb2.Empty()):
162+
async for response in self.stub.LogStream(empty_pb2.Empty()):
164163
self.__log(logging.getLevelName(response.severity), response.message)
165164

166165
async with create_task_group() as tg:

packages/jumpstarter/jumpstarter/client/grpc.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from __future__ import annotations
22

3-
from dataclasses import dataclass, field
3+
from collections import OrderedDict
4+
from dataclasses import InitVar, dataclass, field
45
from datetime import datetime, timedelta
6+
from types import SimpleNamespace
7+
from typing import Any
58

69
import yaml
710
from google.protobuf import duration_pb2, field_mask_pb2, json_format
11+
from grpc import ChannelConnectivity
812
from grpc.aio import Channel
9-
from jumpstarter_protocol import client_pb2, client_pb2_grpc, kubernetes_pb2
13+
from jumpstarter_protocol import client_pb2, client_pb2_grpc, jumpstarter_pb2_grpc, kubernetes_pb2, router_pb2_grpc
1014
from pydantic import BaseModel, ConfigDict, Field, field_serializer
1115

1216
from jumpstarter.common.grpc import translate_grpc_exceptions
@@ -250,3 +254,25 @@ async def DeleteLease(self, *, name: str):
250254
name="namespaces/{}/leases/{}".format(self.namespace, name),
251255
)
252256
)
257+
258+
259+
@dataclass(frozen=True, slots=True)
260+
class SmartExporterStub:
261+
channels: InitVar[list[Channel]]
262+
263+
__stubs: dict[Channel, Any] = field(init=False, default_factory=OrderedDict)
264+
265+
def __post_init__(self, channels):
266+
for channel in channels:
267+
stub = SimpleNamespace()
268+
jumpstarter_pb2_grpc.ExporterServiceStub.__init__(stub, channel)
269+
router_pb2_grpc.RouterServiceStub.__init__(stub, channel)
270+
self.__stubs[channel] = stub
271+
272+
def __getattr__(self, name):
273+
for channel, stub in self.__stubs.items():
274+
# find the first channel that's ready
275+
if channel.get_state(try_to_connect=True) == ChannelConnectivity.READY:
276+
return getattr(stub, name)
277+
# or fallback to the last channel (via router)
278+
return getattr(next(reversed(self.__stubs.values())), name)

packages/jumpstarter/jumpstarter/config/exporter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from contextlib import asynccontextmanager, contextmanager, suppress
44
from pathlib import Path
5-
from typing import Any, ClassVar, Literal, Optional, Self
5+
from typing import Any, ClassVar, List, Literal, Optional, Self
66

77
import grpc
88
import yaml
@@ -83,6 +83,8 @@ class ExporterConfigV1Alpha1(BaseModel):
8383
token: str
8484
grpcOptions: dict[str, str | int] | None = Field(default_factory=dict)
8585

86+
alternative_endpoints: List[str] = Field(default_factory=list)
87+
8688
export: dict[str, ExporterConfigV1Alpha1DriverInstance] = Field(default_factory=dict)
8789

8890
path: Path | None = Field(default=None)
@@ -171,6 +173,7 @@ def channel_factory():
171173
device_factory=ExporterConfigV1Alpha1DriverInstance(children=self.export).instantiate,
172174
tls=self.tls,
173175
grpc_options=self.grpcOptions,
176+
alternative_endpoints=self.alternative_endpoints,
174177
) as exporter:
175178
await exporter.serve()
176179

packages/jumpstarter/jumpstarter/exporter/exporter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class Exporter(AbstractAsyncContextManager, Metadata):
2525
channel_factory: Callable[[], grpc.aio.Channel]
2626
device_factory: Callable[[], Driver]
2727
lease_name: str = field(init=False, default="")
28+
alternative_endpoints: list[str] = field(default_factory=list)
2829
tls: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1)
2930
grpc_options: dict[str, str] = field(default_factory=dict)
3031

@@ -50,7 +51,7 @@ async def session(self):
5051
labels=self.labels,
5152
root_device=self.device_factory(),
5253
) as session:
53-
async with session.serve_unix_async() as path:
54+
async with session.serve_unix_async(alternative_endpoints=self.alternative_endpoints) as path:
5455
async with grpc.aio.secure_channel(
5556
f"unix://{path}", grpc.local_channel_credentials(grpc.LocalConnectionType.UDS)
5657
) as channel:

packages/jumpstarter/jumpstarter/exporter/session.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616

1717
from .logging import LogHandler
18+
from .tls import with_alternative_endpoints
1819
from jumpstarter.common import Metadata, TemporarySocket
1920
from jumpstarter.common.streams import StreamRequestMetadata
2021
from jumpstarter.driver import Driver
@@ -53,12 +54,16 @@ def __init__(self, *args, root_device, **kwargs):
5354

5455
self._logging_queue = deque(maxlen=32)
5556
self._logging_handler = LogHandler(self._logging_queue)
57+
self._alternative_endpoints = []
5658

5759
@asynccontextmanager
58-
async def serve_port_async(self, port):
60+
async def serve_ports_async(self, port, alternative_endpoints: list[str] | None = None):
5961
server = grpc.aio.server()
6062
server.add_insecure_port(port)
6163

64+
if alternative_endpoints is not None:
65+
self._alternative_endpoints = with_alternative_endpoints(server, alternative_endpoints)
66+
6267
jumpstarter_pb2_grpc.add_ExporterServiceServicer_to_server(self, server)
6368
router_pb2_grpc.add_RouterServiceServicer_to_server(self, server)
6469

@@ -69,15 +74,15 @@ async def serve_port_async(self, port):
6974
await server.stop(grace=None)
7075

7176
@asynccontextmanager
72-
async def serve_unix_async(self):
77+
async def serve_unix_async(self, alternative_endpoints: list[str] | None = None):
7378
with TemporarySocket() as path:
74-
async with self.serve_port_async(f"unix://{path}"):
79+
async with self.serve_ports_async(f"unix://{path}", alternative_endpoints):
7580
yield path
7681

7782
@contextmanager
78-
def serve_unix(self):
83+
def serve_unix(self, alternative_endpoints: list[str] | None = None):
7984
with start_blocking_portal() as portal:
80-
with portal.wrap_async_context_manager(self.serve_unix_async()) as path:
85+
with portal.wrap_async_context_manager(self.serve_unix_async(alternative_endpoints)) as path:
8186
yield path
8287

8388
def __getitem__(self, key: UUID):
@@ -92,6 +97,7 @@ async def GetReport(self, request, context):
9297
instance.report(parent=parent, name=name)
9398
for (_, parent, name, instance) in self.root_device.enumerate()
9499
],
100+
alternative_endpoints=self._alternative_endpoints,
95101
)
96102

97103
async def DriverCall(self, request, context):
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from datetime import datetime, timedelta
2+
from ipaddress import IPv4Address, IPv6Address, ip_address
3+
4+
import grpc
5+
from cryptography import x509
6+
from cryptography.hazmat.backends import default_backend
7+
from cryptography.hazmat.primitives import hashes, serialization
8+
from cryptography.hazmat.primitives.asymmetric import rsa
9+
from jumpstarter_protocol import jumpstarter_pb2
10+
11+
12+
def parse_endpoint(endpoint):
13+
host, sep, port = endpoint.rpartition(":")
14+
15+
if sep == "":
16+
raise ValueError("port not specified in endpoint {}".format(endpoint))
17+
18+
host = host.strip("[]") # strip brackets from ipv6 addresses
19+
20+
try:
21+
port = int(port)
22+
if port < 0 or port > 65535:
23+
raise ValueError("port number {} out of range".format(port))
24+
except ValueError as e:
25+
raise ValueError("invalid port {} in endpoint {}".format(port, endpoint)) from e
26+
27+
try:
28+
return ip_address(host), port
29+
except ValueError:
30+
return host, port
31+
32+
33+
def with_alternative_endpoints(server, endpoints: list[str]):
34+
sans = []
35+
for endpoint in endpoints:
36+
host, port = parse_endpoint(endpoint)
37+
match host:
38+
case str():
39+
sans.append(x509.DNSName(host))
40+
case IPv4Address() | IPv6Address():
41+
sans.append(x509.IPAddress(host))
42+
43+
key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend())
44+
client_key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend())
45+
46+
crt = (
47+
x509.CertificateBuilder()
48+
.subject_name(x509.Name([]))
49+
.issuer_name(x509.Name([]))
50+
.public_key(key.public_key())
51+
.serial_number(x509.random_serial_number())
52+
.not_valid_before(datetime.now())
53+
.not_valid_after(datetime.now() + timedelta(days=365))
54+
.add_extension(x509.SubjectAlternativeName(sans), critical=False)
55+
.sign(private_key=key, algorithm=hashes.SHA256(), backend=default_backend())
56+
)
57+
client_crt = (
58+
x509.CertificateBuilder()
59+
.subject_name(x509.Name([]))
60+
.issuer_name(x509.Name([]))
61+
.public_key(client_key.public_key())
62+
.serial_number(x509.random_serial_number())
63+
.not_valid_before(datetime.now())
64+
.not_valid_after(datetime.now() + timedelta(days=365))
65+
.sign(private_key=client_key, algorithm=hashes.SHA256(), backend=default_backend())
66+
)
67+
68+
pem_crt = crt.public_bytes(serialization.Encoding.PEM)
69+
pem_key = key.private_bytes(
70+
encoding=serialization.Encoding.PEM,
71+
format=serialization.PrivateFormat.TraditionalOpenSSL,
72+
encryption_algorithm=serialization.NoEncryption(),
73+
)
74+
75+
pem_client_crt = client_crt.public_bytes(serialization.Encoding.PEM)
76+
pem_client_key = client_key.private_bytes(
77+
encoding=serialization.Encoding.PEM,
78+
format=serialization.PrivateFormat.TraditionalOpenSSL,
79+
encryption_algorithm=serialization.NoEncryption(),
80+
)
81+
82+
server_credentials = grpc.ssl_server_credentials(
83+
[(pem_key, pem_crt)], root_certificates=pem_client_crt, require_client_auth=True
84+
)
85+
86+
endpoints_pb = []
87+
for endpoint in endpoints:
88+
server.add_secure_port(endpoint, server_credentials)
89+
endpoints_pb.append(
90+
jumpstarter_pb2.Endpoint(
91+
endpoint=endpoint,
92+
certificate=pem_crt,
93+
client_certificate=pem_client_crt,
94+
client_private_key=pem_client_key,
95+
),
96+
)
97+
98+
return endpoints_pb

packages/jumpstarter/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies = [
1616
"anyio>=4.4.0,!=4.6.2",
1717
"aiohttp>=3.10.5",
1818
"tqdm>=4.66.5",
19+
"cryptography>=43.0.3",
1920
"pydantic>=2.8.2"
2021
]
2122

@@ -25,7 +26,6 @@ dev = [
2526
"pytest-cov>=6.0.0",
2627
"pytest-anyio>=0.0.0",
2728
"pytest-asyncio>=0.0.0",
28-
"cryptography>=43.0.3",
2929
"jumpstarter-driver-power",
3030
"jumpstarter-driver-network",
3131
"jumpstarter-driver-composite"

0 commit comments

Comments
 (0)