Skip to content

Commit b6472b2

Browse files
dandanseo123hseo36claude
authored
Fix amqps:// SSL config and celery_config_options bypass (#64392)
* Fix broker_use_ssl not applied for amqps:// broker URLs * addressed copilot comments Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix import sort order --------- Co-authored-by: hseo36 <hseo36@bloomberg.net> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 87b3611 commit b6472b2

3 files changed

Lines changed: 162 additions & 2 deletions

File tree

providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from collections.abc import Collection, Mapping, MutableMapping, Sequence
3434
from concurrent.futures import ProcessPoolExecutor
3535
from functools import cache
36+
from importlib import import_module
3637
from typing import TYPE_CHECKING, Any
3738

3839
from celery import Celery, states as celery_states
@@ -124,7 +125,10 @@ def create_celery_app(team_conf: ExecutorConf | AirflowConfigParser) -> Celery:
124125
:param team_conf: ExecutorConf instance with team-specific configuration, or global conf
125126
:return: Celery app instance
126127
"""
127-
from airflow.providers.celery.executors.default_celery import get_default_celery_config
128+
from airflow.providers.celery.executors.default_celery import (
129+
DEFAULT_CELERY_CONFIG,
130+
get_default_celery_config,
131+
)
128132

129133
celery_app_name = team_conf.get("celery", "CELERY_APP_NAME")
130134

@@ -138,6 +142,15 @@ def create_celery_app(team_conf: ExecutorConf | AirflowConfigParser) -> Celery:
138142

139143
config = get_default_celery_config(team_conf)
140144

145+
# Apply user-provided celery_config_options on top of team config.
146+
# Skip if it resolves to DEFAULT_CELERY_CONFIG (built from global conf, not team-aware).
147+
configured_path = team_conf.get("celery", "celery_config_options", fallback=None)
148+
if configured_path:
149+
module_path, _, attr_name = configured_path.rpartition(".")
150+
user_config = getattr(import_module(module_path), attr_name)
151+
if user_config is not DEFAULT_CELERY_CONFIG and isinstance(user_config, dict):
152+
config.update(user_config)
153+
141154
celery_app = Celery(celery_app_name, config_source=config)
142155

143156
# Register tasks with this app

providers/celery/src/airflow/providers/celery/executors/default_celery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def get_default_celery_config(team_conf) -> dict[str, Any]:
141141

142142
try:
143143
if celery_ssl_active:
144-
if broker_url and "amqp://" in broker_url:
144+
if broker_url and re.search(r"amqps?://", broker_url):
145145
broker_use_ssl = {
146146
"keyfile": team_conf.get("celery", "SSL_KEY"),
147147
"certfile": team_conf.get("celery", "SSL_CERT"),

providers/celery/tests/unit/celery/executors/test_celery_executor.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454

5555
if AIRFLOW_V_3_0_PLUS:
5656
from airflow.models.dag_version import DagVersion
57+
if AIRFLOW_V_3_2_PLUS:
58+
from airflow.executors.base_executor import ExecutorConf
5759
if AIRFLOW_V_3_1_PLUS:
5860
from airflow.sdk import BaseOperator, timezone
5961
else:
@@ -814,3 +816,148 @@ def test_execute_workload_ignores_already_running_task():
814816
"""
815817
with pytest.raises(Ignore):
816818
execute_workload_unwrapped(workload_json)
819+
820+
821+
class TestAmqpsSslConfig:
822+
"""Tests for amqps:// broker URL SSL configuration (Fix for substring match bug)."""
823+
824+
@conf_vars(
825+
{
826+
("celery", "BROKER_URL"): "amqps://guest:guest@rabbitmq:5671//",
827+
("celery", "SSL_ACTIVE"): "True",
828+
("celery", "SSL_KEY"): "/path/to/key.pem",
829+
("celery", "SSL_CERT"): "/path/to/cert.pem",
830+
("celery", "SSL_CACERT"): "/path/to/ca.pem",
831+
}
832+
)
833+
def test_amqps_broker_url_builds_ssl_config(self):
834+
"""Test that amqps:// broker URLs correctly build broker_use_ssl with AMQP param names."""
835+
import importlib
836+
import ssl
837+
838+
importlib.reload(default_celery)
839+
840+
config = default_celery.DEFAULT_CELERY_CONFIG
841+
assert "broker_use_ssl" in config, "broker_use_ssl should be set for amqps:// URLs"
842+
broker_ssl = config["broker_use_ssl"]
843+
assert broker_ssl["keyfile"] == "/path/to/key.pem"
844+
assert broker_ssl["certfile"] == "/path/to/cert.pem"
845+
assert broker_ssl["ca_certs"] == "/path/to/ca.pem"
846+
assert broker_ssl["cert_reqs"] == ssl.CERT_REQUIRED
847+
# Must NOT have ssl_ prefixed keys (those are for Redis)
848+
assert "ssl_keyfile" not in broker_ssl
849+
assert "ssl_certfile" not in broker_ssl
850+
851+
@conf_vars(
852+
{
853+
("celery", "BROKER_URL"): "amqp://guest:guest@rabbitmq:5672//",
854+
("celery", "SSL_ACTIVE"): "True",
855+
("celery", "SSL_KEY"): "/path/to/key.pem",
856+
("celery", "SSL_CERT"): "/path/to/cert.pem",
857+
("celery", "SSL_CACERT"): "/path/to/ca.pem",
858+
}
859+
)
860+
def test_amqp_broker_url_still_builds_ssl_config(self):
861+
"""Test that amqp:// (non-TLS) broker URLs still build SSL config correctly (no regression)."""
862+
import importlib
863+
import ssl
864+
865+
importlib.reload(default_celery)
866+
867+
config = default_celery.DEFAULT_CELERY_CONFIG
868+
assert "broker_use_ssl" in config
869+
broker_ssl = config["broker_use_ssl"]
870+
assert broker_ssl["keyfile"] == "/path/to/key.pem"
871+
assert broker_ssl["cert_reqs"] == ssl.CERT_REQUIRED
872+
873+
@conf_vars(
874+
{
875+
("celery", "BROKER_URL"): "amqps://guest:guest@rabbitmq:5671//",
876+
("celery", "SSL_ACTIVE"): "False",
877+
}
878+
)
879+
def test_amqps_broker_url_no_ssl_when_inactive(self):
880+
"""Test that amqps:// broker URLs don't get SSL config when SSL_ACTIVE is False."""
881+
import importlib
882+
883+
importlib.reload(default_celery)
884+
885+
config = default_celery.DEFAULT_CELERY_CONFIG
886+
assert "broker_use_ssl" not in config
887+
888+
889+
class TestCreateCeleryAppTeamIsolation:
890+
"""Tests for create_celery_app() multi-team config isolation."""
891+
892+
@pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecutorConf requires Airflow 3.2+")
893+
def test_custom_celery_config_options_applied(self):
894+
"""User-provided celery_config_options (non-default) should be merged into team config."""
895+
custom_config = {"worker_concurrency": 42, "broker_url": "redis://custom:6379/0"}
896+
custom_path = "my_custom_module.CELERY_CONFIG"
897+
898+
team_conf = ExecutorConf(team_name="team_alpha")
899+
original_get = team_conf.get
900+
901+
def mock_get(section, key, **kwargs):
902+
if section == "celery" and key == "celery_config_options":
903+
return custom_path
904+
return original_get(section, key, **kwargs)
905+
906+
mock_module = mock.MagicMock()
907+
mock_module.CELERY_CONFIG = custom_config
908+
909+
with (
910+
mock.patch.object(team_conf, "get", side_effect=mock_get),
911+
mock.patch.object(celery_executor_utils, "import_module", return_value=mock_module),
912+
):
913+
celery_app = celery_executor_utils.create_celery_app(team_conf)
914+
assert celery_app.conf.worker_concurrency == 42
915+
assert celery_app.conf.broker_url == "redis://custom:6379/0"
916+
917+
def test_default_celery_config_options_skipped_via_identity_check(self):
918+
"""When celery_config_options resolves to DEFAULT_CELERY_CONFIG (same object),
919+
it must be skipped — re-applying it would overwrite team-specific config
920+
since DEFAULT_CELERY_CONFIG is built from global conf."""
921+
original_get = conf.get
922+
# Path just needs a dot for rpartition and attr name matching DEFAULT_CELERY_CONFIG.
923+
# import_module is mocked to return default_celery module regardless of path.
924+
celery_config_path = "any.module.DEFAULT_CELERY_CONFIG"
925+
926+
def mock_get(section, key, **kwargs):
927+
if section == "celery" and key == "celery_config_options":
928+
return celery_config_path
929+
return original_get(section, key, **kwargs)
930+
931+
with (
932+
mock.patch.object(conf, "get", side_effect=mock_get),
933+
mock.patch.object(celery_executor_utils, "import_module") as mock_import,
934+
):
935+
mock_import.return_value = default_celery
936+
celery_app = celery_executor_utils.create_celery_app(conf)
937+
# import_module called (path is non-None), but override skipped (same object)
938+
mock_import.assert_called_once()
939+
default_config = default_celery.get_default_celery_config(conf)
940+
assert celery_app.conf.broker_url == default_config["broker_url"]
941+
942+
@pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecutorConf requires Airflow 3.2+")
943+
def test_team_specific_broker_not_overwritten(self):
944+
"""Team-specific BROKER_URL set via ExecutorConf must survive create_celery_app()."""
945+
team_conf = ExecutorConf(team_name="team_alpha")
946+
947+
original_get = team_conf.get
948+
949+
def mock_team_get(section, key, **kwargs):
950+
if section == "celery" and key == "BROKER_URL":
951+
return "amqps://team-alpha-rabbit:5671//"
952+
return original_get(section, key, **kwargs)
953+
954+
with mock.patch.object(team_conf, "get", side_effect=mock_team_get):
955+
celery_app = celery_executor_utils.create_celery_app(team_conf)
956+
assert celery_app.conf.broker_url == "amqps://team-alpha-rabbit:5671//"
957+
958+
@pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecutorConf requires Airflow 3.2+")
959+
def test_team_app_name_includes_team_name(self):
960+
"""Each team gets a unique Celery app name for broker isolation."""
961+
team_conf = ExecutorConf(team_name="team_beta")
962+
celery_app = celery_executor_utils.create_celery_app(team_conf)
963+
assert "team_beta" in celery_app.main

0 commit comments

Comments
 (0)