Skip to content
This repository was archived by the owner on Mar 6, 2026. It is now read-only.

Commit 8269318

Browse files
feat: Add mTLS configuration for async session in google-auth
Signed-off-by: Radhika Agrawal <agrawalradhika@google.com>
1 parent 8110a6f commit 8269318

3 files changed

Lines changed: 251 additions & 0 deletions

File tree

google/auth/aio/transport/mtls.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,13 @@
1717
"""
1818

1919
import asyncio
20+
import contextlib
2021
import logging
22+
import os
2123
from os import getenv, path
24+
import ssl
25+
import tempfile
26+
from typing import Optional
2227

2328
from google.auth import exceptions
2429
import google.auth.transport._mtls_helper
@@ -27,6 +32,61 @@
2732
_LOGGER = logging.getLogger(__name__)
2833

2934

35+
@contextlib.contextmanager
36+
def _create_temp_file(content: bytes):
37+
"""Creates a temporary file with the given content.
38+
39+
Args:
40+
content (bytes): The content to write to the file.
41+
42+
Yields:
43+
str: The path to the temporary file.
44+
"""
45+
# Create a temporary file that is readable only by the owner.
46+
fd, path = tempfile.mkstemp()
47+
try:
48+
with os.fdopen(fd, "wb") as f:
49+
f.write(content)
50+
yield path
51+
finally:
52+
# Securely delete the file after use.
53+
if os.path.exists(path):
54+
os.remove(path)
55+
56+
57+
def make_client_cert_ssl_context(
58+
cert_bytes: bytes, key_bytes: bytes, passphrase: Optional[bytes] = None
59+
) -> ssl.SSLContext:
60+
"""Creates an SSLContext with the given client certificate and key.
61+
This function writes the certificate and key to temporary files so that
62+
ssl.create_default_context can load them, as the ssl module requires
63+
file paths for client certificates.
64+
Args:
65+
cert_bytes (bytes): The client certificate content in PEM format.
66+
key_bytes (bytes): The client private key content in PEM format.
67+
passphrase (Optional[bytes]): The passphrase for the private key, if any.
68+
Returns:
69+
ssl.SSLContext: The configured SSL context with client certificate.
70+
71+
Raises:
72+
google.auth.exceptions.TransportError: If there is an error loading the certificate.
73+
"""
74+
try:
75+
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
76+
77+
# Write cert and key to temp files because ssl.load_cert_chain requires paths
78+
with _create_temp_file(cert_bytes) as cert_path:
79+
with _create_temp_file(key_bytes) as key_path:
80+
context.load_cert_chain(
81+
certfile=cert_path, keyfile=key_path, password=passphrase
82+
)
83+
return context
84+
except (ssl.SSLError, OSError) as exc:
85+
raise exceptions.TransportError(
86+
"Failed to load client certificate and key for mTLS."
87+
) from exc
88+
89+
3090
def _check_config_path(config_path):
3191
"""Checks for config file path. If it exists, returns the absolute path with user expansion;
3292
otherwise returns None.

google/auth/aio/transport/sessions.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121
from google.auth import _exponential_backoff, exceptions
2222
from google.auth.aio import transport
2323
from google.auth.aio.credentials import Credentials
24+
from google.auth.aio.transport import mtls
2425
from google.auth.exceptions import TimeoutError
26+
import google.auth.transport._mtls_helper
2527

2628
try:
29+
import aiohttp
2730
from google.auth.aio.transport.aiohttp import Request as AiohttpRequest
2831

2932
AIOHTTP_INSTALLED = True
@@ -124,12 +127,70 @@ def __init__(
124127
_auth_request = auth_request
125128
if not _auth_request and AIOHTTP_INSTALLED:
126129
_auth_request = AiohttpRequest()
130+
self._is_mtls = False
131+
self._cached_Cert = None
127132
if _auth_request is None:
128133
raise exceptions.TransportError(
129134
"`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value."
130135
)
131136
self._auth_request = _auth_request
132137

138+
async def configure_mtls_channel(self, client_cert_callback=None):
139+
"""Configure the client certificate and key for SSL connection.
140+
141+
The function does nothing unless `GOOGLE_API_USE_CLIENT_CERTIFICATE` is
142+
explicitly set to `true`. In this case if client certificate and key are
143+
successfully obtained (from the given client_cert_callback or from application
144+
default SSL credentials), the underlying transport will be reconfigured
145+
to use mTLS.
146+
147+
Args:
148+
client_cert_callback (Optional[Callable[[], (bytes, bytes)]]):
149+
The optional callback returns the client certificate and private
150+
key bytes both in PEM format.
151+
If the callback is None, application default SSL credentials
152+
will be used.
153+
154+
Raises:
155+
google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel
156+
creation failed for any reason.
157+
"""
158+
# Run the blocking check in an executor
159+
use_client_cert = await mtls._run_in_executor(
160+
google.auth.transport._mtls_helper.check_use_client_cert
161+
)
162+
if not use_client_cert:
163+
self._is_mtls = False
164+
return
165+
166+
try:
167+
(
168+
self._is_mtls,
169+
cert,
170+
key,
171+
) = await mtls.get_client_cert_and_key(client_cert_callback)
172+
173+
if self._is_mtls:
174+
self._cached_cert = cert
175+
ssl_context = await mtls._run_in_executor(
176+
mtls.make_client_cert_ssl_context, cert, key
177+
)
178+
179+
# Re-create the auth request with the new SSL context
180+
if isinstance(self._auth_request, AiohttpRequest):
181+
connector = aiohttp.TCPConnector(ssl=ssl_context)
182+
new_session = aiohttp.ClientSession(connector=connector)
183+
await self._auth_request.close()
184+
self._auth_request = AiohttpRequest(session=new_session)
185+
186+
except (
187+
exceptions.ClientCertError,
188+
ImportError,
189+
OSError,
190+
) as caught_exc:
191+
new_exc = exceptions.MutualTLSChannelError(caught_exc)
192+
raise new_exc from caught_exc
193+
133194
async def request(
134195
self,
135196
method: str,
@@ -174,6 +235,8 @@ async def request(
174235
retries = _exponential_backoff.AsyncExponentialBackoff(
175236
total_attempts=transport.DEFAULT_MAX_RETRY_ATTEMPTS
176237
)
238+
if headers is None:
239+
headers = {}
177240
async with timeout_guard(max_allowed_time) as with_timeout:
178241
await with_timeout(
179242
# Note: before_request will attempt to refresh credentials if expired.
@@ -261,6 +324,11 @@ async def delete(
261324
"DELETE", url, data, headers, max_allowed_time, timeout, **kwargs
262325
)
263326

327+
@property
328+
def is_mtls(self):
329+
"""Indicates if mutual TLS is enabled."""
330+
return self._is_mtls
331+
264332
async def close(self) -> None:
265333
"""
266334
Close the underlying auth request session.
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import ssl
17+
from unittest import mock
18+
19+
import pytest
20+
21+
from google.auth import exceptions
22+
from google.auth.aio import credentials
23+
from google.auth.aio.transport import sessions
24+
25+
# This is the valid "workload" format the library expects
26+
VALID_WORKLOAD_CONFIG = {
27+
"version": 1,
28+
"cert_configs": {
29+
"workload": {"cert_path": "/tmp/mock_cert.pem", "key_path": "/tmp/mock_key.pem"}
30+
},
31+
}
32+
33+
34+
class TestSessionsMtls:
35+
@mock.patch("os.path.exists")
36+
@mock.patch(
37+
"builtins.open",
38+
new_callable=mock.mock_open,
39+
read_data=json.dumps(VALID_WORKLOAD_CONFIG),
40+
)
41+
@mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key")
42+
@mock.patch("ssl.create_default_context")
43+
@pytest.mark.asyncio
44+
async def test_configure_mtls_channel(
45+
self, mock_ssl, mock_helper, mock_file, mock_exists
46+
):
47+
"""
48+
Tests that the mTLS channel configures correctly when a
49+
valid workload config is mocked.
50+
"""
51+
mock_exists.return_value = True
52+
mock_helper.return_value = (b"fake_cert_data", b"fake_key_data")
53+
54+
mock_context = mock.Mock(spec=ssl.SSLContext)
55+
mock_ssl.return_value = mock_context
56+
57+
mock_creds = mock.Mock(spec=credentials.Credentials)
58+
session = sessions.AsyncAuthorizedSession(mock_creds)
59+
60+
await session.configure_mtls_channel()
61+
62+
assert session._is_mtls is True
63+
assert mock_context.load_cert_chain.called
64+
65+
@mock.patch("os.path.exists")
66+
@pytest.mark.asyncio
67+
async def test_configure_mtls_channel_disabled(self, mock_exists):
68+
"""
69+
Tests behavior when the config file does not exist.
70+
"""
71+
mock_exists.return_value = False
72+
mock_creds = mock.Mock(spec=credentials.Credentials)
73+
74+
try:
75+
session = sessions.AsyncAuthorizedSession(mock_creds)
76+
except AttributeError:
77+
session = sessions.Session()
78+
await session.configure_mtls_channel()
79+
80+
# If the file doesn't exist, it shouldn't error; it just won't use mTLS
81+
assert session._is_mtls is False
82+
83+
@mock.patch("os.path.exists")
84+
@mock.patch(
85+
"builtins.open", new_callable=mock.mock_open, read_data='{"invalid": "format"}'
86+
)
87+
@pytest.mark.asyncio
88+
async def test_configure_mtls_channel_invalid_format(self, mock_file, mock_exists):
89+
"""
90+
Verifies that the MutualTLSChannelError is raised for bad formats.
91+
"""
92+
mock_exists.return_value = True
93+
mock_creds = mock.Mock(spec=credentials.Credentials)
94+
95+
try:
96+
session = sessions.AsyncAuthorizedSession(mock_creds)
97+
except AttributeError:
98+
session = sessions.Session()
99+
with pytest.raises(
100+
exceptions.MutualTLSChannelError, match="is in an invalid format"
101+
):
102+
await session.configure_mtls_channel()
103+
104+
@pytest.mark.asyncio
105+
@mock.patch(
106+
"google.auth.aio.transport.mtls.has_default_client_cert_source",
107+
return_value=True,
108+
)
109+
async def test_configure_mtls_channel_mock_callback(self, mock_has_cert):
110+
"""
111+
Tests mTLS configuration using bytes-returning callback.
112+
"""
113+
114+
def mock_callback():
115+
return (b"fake_cert_bytes", b"fake_key_bytes")
116+
117+
mock_creds = mock.Mock(spec=credentials.Credentials)
118+
119+
with mock.patch("ssl.SSLContext.load_cert_chain"):
120+
session = sessions.AsyncAuthorizedSession(mock_creds)
121+
await session.configure_mtls_channel(client_cert_callback=mock_callback)
122+
123+
assert session._is_mtls is True

0 commit comments

Comments
 (0)