Skip to content
This repository was archived by the owner on Mar 6, 2026. It is now read-only.

Commit 7f23594

Browse files
fix: Update based on gemini-assit comments to make robust callback by handling async and removing encrypted_key complication
Signed-off-by: Radhika Agrawal <agrawalradhika@google.com>
1 parent be10a50 commit 7f23594

2 files changed

Lines changed: 107 additions & 17 deletions

File tree

google/auth/aio/transport/mtls.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
"""
1818

1919
import asyncio
20+
import inspect
2021
import logging
2122
from os import getenv, path
2223

24+
from google.auth import exceptions
2325
import google.auth.transport._mtls_helper
2426

2527
CERTIFICATE_CONFIGURATION_DEFAULT_PATH = "~/.config/gcloud/certificate_config.json"
@@ -71,8 +73,35 @@ def has_default_client_cert_source():
7173
return False
7274

7375

76+
async def default_client_cert_source():
77+
"""Get a callback which returns the default client SSL credentials.
78+
79+
Returns:
80+
Callable[[], [bytes, bytes]]: A callback which returns the default
81+
client certificate bytes and private key bytes, both in PEM format.
82+
83+
Raises:
84+
google.auth.exceptions.DefaultClientCertSourceError: If the default
85+
client SSL credentials don't exist or are malformed.
86+
"""
87+
if not has_default_client_cert_source():
88+
raise exceptions.MutualTLSChannelError(
89+
"Default client cert source doesn't exist"
90+
)
91+
92+
async def callback():
93+
try:
94+
_, cert_bytes, key_bytes = await get_client_cert_and_key()
95+
except (OSError, RuntimeError, ValueError) as caught_exc:
96+
new_exc = exceptions.MutualTLSChannelError(caught_exc)
97+
raise new_exc from caught_exc
98+
99+
return cert_bytes, key_bytes
100+
101+
return callback
102+
103+
74104
async def get_client_ssl_credentials(
75-
generate_encrypted_key=False,
76105
certificate_config_path=None,
77106
):
78107
"""Returns the client side certificate, private key and passphrase.
@@ -82,10 +111,6 @@ async def get_client_ssl_credentials(
82111
Currently, only X.509 workload certificates are supported.
83112
84113
Args:
85-
generate_encrypted_key (bool): If set to True, encrypted private key
86-
and passphrase will be generated; otherwise, unencrypted private key
87-
will be generated and passphrase will be None. This option only
88-
affects keys obtained via context_aware_metadata.json.
89114
certificate_config_path (str): The certificate_config.json file path.
90115
91116
Returns:
@@ -131,10 +156,12 @@ async def get_client_cert_and_key(client_cert_callback=None):
131156
the cert and key.
132157
"""
133158
if client_cert_callback:
134-
cert, key = client_cert_callback()
159+
result = client_cert_callback()
160+
if inspect.isawaitable(result):
161+
cert, key = await result
162+
else:
163+
cert, key = result
135164
return True, cert, key
136165

137-
has_cert, cert, key, _ = await get_client_ssl_credentials(
138-
generate_encrypted_key=False
139-
)
166+
has_cert, cert, key, _ = await get_client_ssl_credentials()
140167
return has_cert, cert, key

tests/transport/test_aio_mtls_helper.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,71 @@ def test__check_config_path_not_found(self, mock_exists):
4545
@mock.patch("google.auth.aio.transport.mtls._check_config_path")
4646
@mock.patch("google.auth.aio.transport.mtls.getenv")
4747
def test_has_default_client_cert_source_env_var(self, mock_getenv, mock_check):
48-
# Mocking so the default path fails but the env var path succeeds
4948
custom_path = "/custom/path.json"
5049
mock_check.side_effect = lambda x: custom_path if x == custom_path else None
5150
mock_getenv.return_value = custom_path
5251

5352
assert mtls.has_default_client_cert_source() is True
5453

54+
@mock.patch("google.auth.aio.transport.mtls._check_config_path")
55+
@mock.patch("google.auth.aio.transport.mtls.getenv")
56+
def test_has_default_client_cert_source_check_priority(
57+
self, mock_getenv, mock_check
58+
):
59+
mock_check.return_value = "/default/path.json"
60+
61+
assert mtls.has_default_client_cert_source() is True
62+
mock_getenv.assert_not_called()
63+
64+
@pytest.mark.asyncio
65+
@mock.patch(
66+
"google.auth.aio.transport.mtls.get_client_cert_and_key",
67+
new_callable=mock.AsyncMock,
68+
)
69+
@mock.patch("google.auth.aio.transport.mtls.has_default_client_cert_source")
70+
async def test_default_client_cert_source_success(
71+
self, mock_has_default, mock_get_cert_key
72+
):
73+
mock_has_default.return_value = True
74+
mock_get_cert_key.return_value = (True, CERT_DATA, KEY_DATA)
75+
76+
callback = await mtls.default_client_cert_source()
77+
78+
cert, key = await callback()
79+
80+
assert cert == CERT_DATA
81+
assert key == KEY_DATA
82+
mock_has_default.assert_called_once()
83+
mock_get_cert_key.assert_called_once()
84+
85+
@pytest.mark.asyncio
86+
@mock.patch(
87+
"google.auth.aio.transport.mtls.has_default_client_cert_source",
88+
return_value=False,
89+
)
90+
async def test_default_client_cert_source_not_found(self, mock_has_default):
91+
with pytest.raises(exceptions.MutualTLSChannelError, match="doesn't exist"):
92+
await mtls.default_client_cert_source()
93+
94+
@pytest.mark.asyncio
95+
@mock.patch(
96+
"google.auth.aio.transport.mtls.get_client_cert_and_key",
97+
new_callable=mock.AsyncMock,
98+
)
99+
@mock.patch(
100+
"google.auth.aio.transport.mtls.has_default_client_cert_source",
101+
return_value=True,
102+
)
103+
async def test_default_client_cert_source_callback_wraps_exception(
104+
self, mock_has, mock_get
105+
):
106+
mock_get.side_effect = ValueError("Format error")
107+
callback = await mtls.default_client_cert_source()
108+
109+
with pytest.raises(exceptions.MutualTLSChannelError) as excinfo:
110+
await callback()
111+
assert "Format error" in str(excinfo.value)
112+
55113
@pytest.mark.asyncio
56114
@mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key")
57115
async def test_get_client_ssl_credentials_success(self, mock_workload):
@@ -64,6 +122,17 @@ async def test_get_client_ssl_credentials_success(self, mock_workload):
64122
assert key == KEY_DATA
65123
assert passphrase is None
66124

125+
@pytest.mark.asyncio
126+
@mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials")
127+
async def test_get_client_cert_and_key_no_credentials_found(self, mock_get_ssl):
128+
mock_get_ssl.return_value = (False, None, None, None)
129+
130+
success, cert, key = await mtls.get_client_cert_and_key(None)
131+
132+
assert success is False
133+
assert cert is None
134+
assert key is None
135+
67136
@pytest.mark.asyncio
68137
async def test_get_client_cert_and_key_callback(self):
69138
# The callback should be tried first and return immediately
@@ -79,39 +148,33 @@ async def test_get_client_cert_and_key_callback(self):
79148
@pytest.mark.asyncio
80149
@mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials")
81150
async def test_get_client_cert_and_key_default(self, mock_get_ssl):
82-
# If no callback, it should call get_client_ssl_credentials
83151
mock_get_ssl.return_value = (True, CERT_DATA, KEY_DATA, None)
84152

85153
success, cert, key = await mtls.get_client_cert_and_key(None)
86154

87155
assert success is True
88156
assert cert == CERT_DATA
89157
assert key == KEY_DATA
90-
mock_get_ssl.assert_called_with(generate_encrypted_key=False)
158+
mock_get_ssl.assert_called_once()
91159

92160
@pytest.mark.asyncio
93161
@mock.patch("google.auth.transport._mtls_helper._get_workload_cert_and_key")
94162
async def test_get_client_ssl_credentials_error(self, mock_workload):
95-
"""Tests that ClientCertError is propagated correctly."""
96-
# Setup the mock to raise the specific google-auth exception
97163
mock_workload.side_effect = exceptions.ClientCertError(
98164
"Failed to read metadata"
99165
)
100166

101-
# Verify that calling our function raises the same exception
102167
with pytest.raises(exceptions.ClientCertError, match="Failed to read metadata"):
103168
await mtls.get_client_ssl_credentials()
104169

105170
@pytest.mark.asyncio
106171
@mock.patch("google.auth.aio.transport.mtls.get_client_ssl_credentials")
107172
async def test_get_client_cert_and_key_exception_propagation(self, mock_get_ssl):
108-
"""Tests that get_client_cert_and_key propagates errors from its internal calls."""
109173
mock_get_ssl.side_effect = exceptions.ClientCertError(
110174
"Underlying credentials failed"
111175
)
112176

113177
with pytest.raises(
114178
exceptions.ClientCertError, match="Underlying credentials failed"
115179
):
116-
# Pass None for callback so it attempts to call get_client_ssl_credentials
117180
await mtls.get_client_cert_and_key(client_cert_callback=None)

0 commit comments

Comments
 (0)