Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.

Commit 1d42653

Browse files
committed
Parallel SSL cert retrieval
When using insecure TLS, if the first address of a hostname DNS fails, our code does not try with the other IPs as gRPC does later, so we cannot connect. This patch attempts to retrieve the insecure cert in parallel over all the available IP addresses.
1 parent 2863bcd commit 1d42653

1 file changed

Lines changed: 116 additions & 23 deletions

File tree

  • packages/jumpstarter/jumpstarter/common

packages/jumpstarter/jumpstarter/common/grpc.py

Lines changed: 116 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import base64
3+
import logging
34
import os
45
import socket
56
import ssl
@@ -12,32 +13,124 @@
1213

1314
from jumpstarter.common.exceptions import ConfigurationError, ConnectionError
1415

16+
logger = logging.getLogger(__name__)
17+
18+
19+
async def _try_connect_and_extract_cert(
20+
ip_address: str, port: int, ssl_context: ssl.SSLContext, hostname: str, timeout: float
21+
) -> bytes:
22+
"""
23+
Try to connect to a single IP and extract its certificate chain.
24+
25+
Returns the certificate chain in PEM format as bytes.
26+
Raises exception on failure.
27+
"""
28+
logger.debug(f"Attempting TLS connection to {ip_address}:{port} (timeout={timeout}s)")
29+
_, writer = await asyncio.wait_for(
30+
asyncio.open_connection(ip_address, port, ssl=ssl_context, server_hostname=hostname),
31+
timeout=timeout,
32+
)
33+
logger.debug(f"Successfully connected to {ip_address}:{port}")
34+
try:
35+
# Extract certificates
36+
cert_chain = writer.get_extra_info("ssl_object")._sslobj.get_unverified_chain()
37+
root_certificates = ""
38+
for cert in cert_chain:
39+
root_certificates += cert.public_bytes()
40+
logger.debug(f"Successfully extracted {len(cert_chain)} certificate(s) from {ip_address}:{port}")
41+
42+
return root_certificates.encode()
43+
finally:
44+
writer.close()
45+
46+
47+
async def _ssl_channel_credentials_insecure(target: str, timeout: float) -> grpc.ChannelCredentials: # noqa: C901
48+
"""
49+
Extract TLS certificates from server without verification (insecure mode).
50+
51+
Tries to connect to all resolved IPs in parallel and returns credentials
52+
from the first successful connection.
53+
"""
54+
try:
55+
parsed = urlparse(f"//{target}")
56+
port = parsed.port if parsed.port else 443
57+
except ValueError as e:
58+
raise ConfigurationError(f"Failed parsing {target}") from e
59+
60+
try:
61+
with fail_after(timeout):
62+
ssl_context = ssl.create_default_context()
63+
ssl_context.check_hostname = False
64+
ssl_context.verify_mode = ssl.CERT_NONE
65+
66+
# Resolve all IP addresses for the hostname
67+
loop = asyncio.get_running_loop()
68+
addr_info = await loop.getaddrinfo(
69+
parsed.hostname, port, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
70+
)
71+
72+
# Log resolved IPs
73+
resolved_ips = [sockaddr[0] for _, _, _, _, sockaddr in addr_info]
74+
logger.debug(
75+
f"Resolved {parsed.hostname} to {len(resolved_ips)} IP(s): {', '.join(resolved_ips)}"
76+
)
77+
78+
# Try all IPs in parallel - race for first success
79+
# Wrap tasks to include IP info with results/exceptions
80+
async def try_with_ip(ip_address: str):
81+
"""Wrapper that returns (ip, result) on success or (ip, exception) on failure."""
82+
try:
83+
result = await _try_connect_and_extract_cert(
84+
ip_address, port, ssl_context, parsed.hostname, timeout
85+
)
86+
return (ip_address, result, None)
87+
except Exception as e:
88+
return (ip_address, None, e)
89+
90+
tasks = []
91+
for _family, _type, _proto, _canonname, sockaddr in addr_info:
92+
ip_address = sockaddr[0]
93+
task = asyncio.create_task(try_with_ip(ip_address))
94+
tasks.append(task)
95+
96+
# Process tasks as they complete
97+
errors = {}
98+
99+
try:
100+
for future in asyncio.as_completed(tasks):
101+
ip_address, root_certificates, error = await future
102+
103+
if error is None:
104+
# Success! Return immediately (cleanup in finally)
105+
logger.debug(f"Using certificates from {ip_address}:{port}")
106+
return grpc.ssl_channel_credentials(root_certificates=root_certificates)
107+
108+
# This IP failed - log and continue trying other IPs
109+
if isinstance(error, ssl.SSLError):
110+
logger.error(f"SSL error on {ip_address}:{port}: {error}")
111+
else:
112+
logger.warning(f"Failed to connect to {ip_address}:{port}: {type(error).__name__}: {error}")
113+
errors[ip_address] = error
114+
115+
# All IPs failed
116+
raise ConnectionError(
117+
f"Failed connecting to {parsed.hostname}:{port} - all IPs exhausted. Errors: {errors}"
118+
)
119+
finally:
120+
# Cancel any remaining tasks
121+
for task in tasks:
122+
if not task.done():
123+
task.cancel()
124+
except socket.gaierror as e:
125+
raise ConnectionError(f"Failed resolving {parsed.hostname}") from e
126+
except TimeoutError as e:
127+
raise ConnectionError(f"Timeout connecting to {parsed.hostname}:{port}") from e
128+
15129

16130
async def ssl_channel_credentials(target: str, tls_config, timeout=5):
131+
"""Get SSL channel credentials for gRPC connection."""
17132
if tls_config.insecure or os.getenv("JUMPSTARTER_GRPC_INSECURE") == "1":
18-
try:
19-
parsed = urlparse(f"//{target}")
20-
port = parsed.port if parsed.port else 443
21-
except ValueError as e:
22-
raise ConfigurationError(f"Failed parsing {target}") from e
23-
24-
try:
25-
with fail_after(timeout):
26-
ssl_context = ssl.create_default_context()
27-
ssl_context.check_hostname = False
28-
ssl_context.verify_mode = ssl.CERT_NONE
29-
_, writer = await asyncio.open_connection(parsed.hostname, port, ssl=ssl_context)
30-
root_certificates = ""
31-
for cert in writer.get_extra_info("ssl_object")._sslobj.get_unverified_chain():
32-
root_certificates += cert.public_bytes()
33-
return grpc.ssl_channel_credentials(root_certificates=root_certificates.encode())
34-
except socket.gaierror as e:
35-
raise ConnectionError(f"Failed resolving {parsed.hostname}") from e
36-
except ConnectionRefusedError as e:
37-
raise ConnectionError(f"Failed connecting to {parsed.hostname}:{port}") from e
38-
except TimeoutError as e:
39-
raise ConnectionError(f"Timeout connecting to {parsed.hostname}:{port}") from e
40-
133+
return await _ssl_channel_credentials_insecure(target, timeout)
41134
elif tls_config.ca != "":
42135
ca_certificate = base64.b64decode(tls_config.ca)
43136
return grpc.ssl_channel_credentials(ca_certificate)

0 commit comments

Comments
 (0)