|
| 1 | +from datetime import datetime, timedelta |
| 2 | +from ipaddress import IPv4Address, IPv6Address, ip_address |
| 3 | + |
| 4 | +import grpc |
| 5 | +from cryptography import x509 |
| 6 | +from cryptography.hazmat.backends import default_backend |
| 7 | +from cryptography.hazmat.primitives import hashes, serialization |
| 8 | +from cryptography.hazmat.primitives.asymmetric import rsa |
| 9 | +from jumpstarter_protocol import jumpstarter_pb2 |
| 10 | + |
| 11 | + |
| 12 | +def parse_endpoint(endpoint): |
| 13 | + host, sep, port = endpoint.rpartition(":") |
| 14 | + |
| 15 | + if sep == "": |
| 16 | + raise ValueError("port not specified in endpoint {}".format(endpoint)) |
| 17 | + |
| 18 | + host = host.strip("[]") # strip brackets from ipv6 addresses |
| 19 | + |
| 20 | + try: |
| 21 | + port = int(port) |
| 22 | + if port < 0 or port > 65535: |
| 23 | + raise ValueError("port number {} out of range".format(port)) |
| 24 | + except ValueError as e: |
| 25 | + raise ValueError("invalid port {} in endpoint {}".format(port, endpoint)) from e |
| 26 | + |
| 27 | + try: |
| 28 | + return ip_address(host), port |
| 29 | + except ValueError: |
| 30 | + return host, port |
| 31 | + |
| 32 | + |
| 33 | +def with_alternative_endpoints(server, endpoints: list[str]): |
| 34 | + sans = [] |
| 35 | + for endpoint in endpoints: |
| 36 | + host, port = parse_endpoint(endpoint) |
| 37 | + match host: |
| 38 | + case str(): |
| 39 | + sans.append(x509.DNSName(host)) |
| 40 | + case IPv4Address() | IPv6Address(): |
| 41 | + sans.append(x509.IPAddress(host)) |
| 42 | + |
| 43 | + key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) |
| 44 | + |
| 45 | + crt = ( |
| 46 | + x509.CertificateBuilder() |
| 47 | + .subject_name(x509.Name([])) |
| 48 | + .issuer_name(x509.Name([])) |
| 49 | + .public_key(key.public_key()) |
| 50 | + .serial_number(x509.random_serial_number()) |
| 51 | + .not_valid_before(datetime.now()) |
| 52 | + .not_valid_after(datetime.now() + timedelta(days=365)) |
| 53 | + .add_extension(x509.SubjectAlternativeName(sans), critical=False) |
| 54 | + .sign(private_key=key, algorithm=hashes.SHA256(), backend=default_backend()) |
| 55 | + ) |
| 56 | + |
| 57 | + pem_crt = crt.public_bytes(serialization.Encoding.PEM) |
| 58 | + pem_key = key.private_bytes( |
| 59 | + encoding=serialization.Encoding.PEM, |
| 60 | + format=serialization.PrivateFormat.TraditionalOpenSSL, |
| 61 | + encryption_algorithm=serialization.NoEncryption(), |
| 62 | + ) |
| 63 | + |
| 64 | + server_credentials = grpc.ssl_server_credentials([(pem_key, pem_crt)]) |
| 65 | + |
| 66 | + endpoints_pb = [] |
| 67 | + for endpoint in endpoints: |
| 68 | + server.add_secure_port(endpoint, server_credentials) |
| 69 | + # FIXME: generate and check token |
| 70 | + endpoints_pb.append( |
| 71 | + jumpstarter_pb2.Endpoint(endpoint=endpoint, token="", certificate=pem_crt), |
| 72 | + ) |
| 73 | + |
| 74 | + return endpoints_pb |
0 commit comments