Skip to content

Commit 294c92e

Browse files
authored
Fix Azure provider hooks ignoring cloud_environment connection extra (#65320)
1 parent 8a4984e commit 294c92e

4 files changed

Lines changed: 165 additions & 2 deletions

File tree

providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/base_azure.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from azure.common.client_factory import get_client_from_auth_file, get_client_from_json_dict
2222
from azure.common.credentials import ServicePrincipalCredentials
23-
from azure.identity import ClientSecretCredential, DefaultAzureCredential
23+
from azure.identity import AzureAuthorityHosts, ClientSecretCredential, DefaultAzureCredential
2424

2525
from airflow.providers.common.compat.sdk import AirflowException, BaseHook
2626
from airflow.providers.microsoft.azure.utils import (
@@ -34,6 +34,24 @@
3434

3535
from airflow.sdk import Connection
3636

37+
_AZURE_CLOUD_ENVIRONMENTS: dict[str, dict[str, Any]] = {
38+
"AzurePublicCloud": {
39+
"authority": AzureAuthorityHosts.AZURE_PUBLIC_CLOUD,
40+
"base_url": "https://management.azure.com",
41+
"credential_scopes": ["https://management.azure.com/.default"],
42+
},
43+
"AzureUSGovernment": {
44+
"authority": AzureAuthorityHosts.AZURE_GOVERNMENT,
45+
"base_url": "https://management.usgovcloudapi.net",
46+
"credential_scopes": ["https://management.usgovcloudapi.net/.default"],
47+
},
48+
"AzureChinaCloud": {
49+
"authority": AzureAuthorityHosts.AZURE_CHINA,
50+
"base_url": "https://management.chinacloudapi.cn",
51+
"credential_scopes": ["https://management.chinacloudapi.cn/.default"],
52+
},
53+
}
54+
3755

3856
class AzureBaseHook(BaseHook):
3957
"""
@@ -63,6 +81,9 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
6381
return {
6482
"tenantId": StringField(lazy_gettext("Azure Tenant ID"), widget=BS3TextFieldWidget()),
6583
"subscriptionId": StringField(lazy_gettext("Azure Subscription ID"), widget=BS3TextFieldWidget()),
84+
"cloud_environment": StringField(
85+
lazy_gettext("Azure Cloud Environment"), widget=BS3TextFieldWidget()
86+
),
6687
}
6788

6889
@classmethod
@@ -88,6 +109,7 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
88109
"password": "secret (token credentials auth)",
89110
"tenantId": "tenantId (token credentials auth)",
90111
"subscriptionId": "subscriptionId (token credentials auth)",
112+
"cloud_environment": "AzurePublicCloud (default) | AzureUSGovernment | AzureChinaCloud",
91113
},
92114
}
93115

@@ -163,11 +185,18 @@ def _get_client_secret_credential(
163185
extra_dejson = conn.extra_dejson
164186
tenant = extra_dejson.get("tenantId")
165187
use_azure_identity_object = extra_dejson.get("use_azure_identity_object", False)
188+
189+
cloud_env_name = extra_dejson.get("cloud_environment", "AzurePublicCloud")
190+
cloud_env = _AZURE_CLOUD_ENVIRONMENTS.get(
191+
cloud_env_name, _AZURE_CLOUD_ENVIRONMENTS["AzurePublicCloud"]
192+
)
193+
166194
if use_azure_identity_object:
167195
return ClientSecretCredential(
168196
client_id=conn.login, # type: ignore[arg-type]
169197
client_secret=conn.password, # type: ignore[arg-type]
170198
tenant_id=tenant, # type: ignore[arg-type]
199+
authority=cloud_env["authority"],
171200
)
172201
return ServicePrincipalCredentials(client_id=conn.login, secret=conn.password, tenant=tenant)
173202

providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/container_instance.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
2727

2828
from airflow.providers.common.compat.sdk import AirflowException
29-
from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
29+
from airflow.providers.microsoft.azure.hooks.base_azure import _AZURE_CLOUD_ENVIRONMENTS, AzureBaseHook
3030
from airflow.providers.microsoft.azure.utils import get_sync_default_azure_credential
3131

3232
if TYPE_CHECKING:
@@ -82,13 +82,19 @@ def get_conn(self) -> Any:
8282
self.log.info("Getting connection using a JSON config.")
8383
return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json)
8484

85+
cloud_env_name = conn.extra_dejson.get("cloud_environment", "AzurePublicCloud")
86+
cloud_env = _AZURE_CLOUD_ENVIRONMENTS.get(
87+
cloud_env_name, _AZURE_CLOUD_ENVIRONMENTS["AzurePublicCloud"]
88+
)
89+
8590
credential: ClientSecretCredential | DefaultAzureCredential
8691
if all([conn.login, conn.password, tenant]):
8792
self.log.info("Getting connection using specific credentials and subscription_id.")
8893
credential = ClientSecretCredential(
8994
client_id=cast("str", conn.login),
9095
client_secret=cast("str", conn.password),
9196
tenant_id=cast("str", tenant),
97+
authority=cloud_env["authority"],
9298
)
9399
else:
94100
self.log.info("Using DefaultAzureCredential as credential")
@@ -103,6 +109,8 @@ def get_conn(self) -> Any:
103109
return ContainerInstanceManagementClient(
104110
credential=credential,
105111
subscription_id=subscription_id,
112+
base_url=cloud_env["base_url"],
113+
credential_scopes=cloud_env["credential_scopes"],
106114
)
107115

108116
def create_or_update(self, resource_group: str, name: str, container_group: ContainerGroup) -> None:

providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_base_azure.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from unittest.mock import MagicMock, Mock, patch
2121

2222
import pytest
23+
from azure.identity import AzureAuthorityHosts
2324

2425
from airflow.models import Connection
2526
from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
@@ -156,6 +157,7 @@ def test_get_credential_with_client_secret(self, mock_spc, mocked_connection):
156157
client_id=mocked_connection.login,
157158
client_secret=mocked_connection.password,
158159
tenant_id=mocked_connection.extra_dejson["tenantId"],
160+
authority=AzureAuthorityHosts.AZURE_PUBLIC_CLOUD,
159161
)
160162
assert cred == "foo-bar"
161163

@@ -221,3 +223,41 @@ def test_get_token_with_azure_default_credential(self, mock_spc, mocked_connecti
221223

222224
mock_spc.assert_called_once_with()
223225
assert token == "new-token"
226+
227+
@patch(f"{MODULE}.ClientSecretCredential")
228+
@pytest.mark.parametrize(
229+
("cloud_env", "expected_authority"),
230+
[
231+
pytest.param(None, AzureAuthorityHosts.AZURE_PUBLIC_CLOUD, id="default_public_cloud"),
232+
pytest.param(
233+
"AzurePublicCloud", AzureAuthorityHosts.AZURE_PUBLIC_CLOUD, id="explicit_public_cloud"
234+
),
235+
pytest.param("AzureUSGovernment", AzureAuthorityHosts.AZURE_GOVERNMENT, id="us_government"),
236+
pytest.param("AzureChinaCloud", AzureAuthorityHosts.AZURE_CHINA, id="china_cloud"),
237+
],
238+
)
239+
def test_get_credential_cloud_environment(
240+
self, mock_csc, cloud_env, expected_authority, create_mock_connection
241+
):
242+
extras = {"tenantId": "my_tenant", "use_azure_identity_object": True}
243+
if cloud_env is not None:
244+
extras["cloud_environment"] = cloud_env
245+
246+
create_mock_connection(
247+
Connection(
248+
conn_id="azure_default",
249+
login="my_login",
250+
password="my_password",
251+
extra=extras,
252+
)
253+
)
254+
mock_csc.return_value = "credential"
255+
cred = AzureBaseHook().get_credential()
256+
257+
mock_csc.assert_called_once_with(
258+
client_id="my_login",
259+
client_secret="my_password",
260+
tenant_id="my_tenant",
261+
authority=expected_authority,
262+
)
263+
assert cred == "credential"

providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_container_instance.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import pytest
2323
from azure.core.exceptions import ResourceNotFoundError
24+
from azure.identity import AzureAuthorityHosts
2425
from azure.mgmt.containerinstance.models import (
2526
Logs,
2627
ResourceRequests,
@@ -165,4 +166,89 @@ def test_get_conn_fallback_to_default_azure_credential(
165166
mock_client_cls.assert_called_once_with(
166167
credential=mock_credential,
167168
subscription_id="subscription_id",
169+
base_url="https://management.azure.com",
170+
credential_scopes=["https://management.azure.com/.default"],
171+
)
172+
173+
174+
class TestAzureContainerInstanceHookCloudEnvironment:
175+
@pytest.mark.parametrize(
176+
("cloud_env", "expected_authority", "expected_base_url", "expected_scopes"),
177+
[
178+
pytest.param(
179+
None,
180+
AzureAuthorityHosts.AZURE_PUBLIC_CLOUD,
181+
"https://management.azure.com",
182+
["https://management.azure.com/.default"],
183+
id="default_public_cloud",
184+
),
185+
pytest.param(
186+
"AzurePublicCloud",
187+
AzureAuthorityHosts.AZURE_PUBLIC_CLOUD,
188+
"https://management.azure.com",
189+
["https://management.azure.com/.default"],
190+
id="explicit_public_cloud",
191+
),
192+
pytest.param(
193+
"AzureUSGovernment",
194+
AzureAuthorityHosts.AZURE_GOVERNMENT,
195+
"https://management.usgovcloudapi.net",
196+
["https://management.usgovcloudapi.net/.default"],
197+
id="us_government",
198+
),
199+
pytest.param(
200+
"AzureChinaCloud",
201+
AzureAuthorityHosts.AZURE_CHINA,
202+
"https://management.chinacloudapi.cn",
203+
["https://management.chinacloudapi.cn/.default"],
204+
id="china_cloud",
205+
),
206+
],
207+
)
208+
@patch("airflow.providers.microsoft.azure.hooks.container_instance.ContainerInstanceManagementClient")
209+
@patch("airflow.providers.microsoft.azure.hooks.container_instance.ClientSecretCredential")
210+
def test_get_conn_cloud_environment(
211+
self,
212+
mock_credential_cls,
213+
mock_client_cls,
214+
cloud_env,
215+
expected_authority,
216+
expected_base_url,
217+
expected_scopes,
218+
create_mock_connection,
219+
):
220+
extras = {
221+
"tenantId": "my-tenant",
222+
"subscriptionId": "my-subscription",
223+
}
224+
if cloud_env is not None:
225+
extras["cloud_environment"] = cloud_env
226+
227+
mock_connection = create_mock_connection(
228+
Connection(
229+
conn_id="azure_container_instance_cloud_test",
230+
conn_type="azure_container_instances",
231+
login="my-client-id",
232+
password="my-secret",
233+
extra=extras,
234+
)
235+
)
236+
237+
mock_credential_cls.return_value = MagicMock()
238+
mock_client_cls.return_value = MagicMock()
239+
240+
hook = AzureContainerInstanceHook(azure_conn_id=mock_connection.conn_id)
241+
hook.get_conn()
242+
243+
mock_credential_cls.assert_called_once_with(
244+
client_id="my-client-id",
245+
client_secret="my-secret",
246+
tenant_id="my-tenant",
247+
authority=expected_authority,
248+
)
249+
mock_client_cls.assert_called_once_with(
250+
credential=mock_credential_cls.return_value,
251+
subscription_id="my-subscription",
252+
base_url=expected_base_url,
253+
credential_scopes=expected_scopes,
168254
)

0 commit comments

Comments
 (0)