|
54 | 54 |
|
55 | 55 | if AIRFLOW_V_3_0_PLUS: |
56 | 56 | from airflow.models.dag_version import DagVersion |
| 57 | +if AIRFLOW_V_3_2_PLUS: |
| 58 | + from airflow.executors.base_executor import ExecutorConf |
57 | 59 | if AIRFLOW_V_3_1_PLUS: |
58 | 60 | from airflow.sdk import BaseOperator, timezone |
59 | 61 | else: |
@@ -814,3 +816,148 @@ def test_execute_workload_ignores_already_running_task(): |
814 | 816 | """ |
815 | 817 | with pytest.raises(Ignore): |
816 | 818 | execute_workload_unwrapped(workload_json) |
| 819 | + |
| 820 | + |
| 821 | +class TestAmqpsSslConfig: |
| 822 | + """Tests for amqps:// broker URL SSL configuration (Fix for substring match bug).""" |
| 823 | + |
| 824 | + @conf_vars( |
| 825 | + { |
| 826 | + ("celery", "BROKER_URL"): "amqps://guest:guest@rabbitmq:5671//", |
| 827 | + ("celery", "SSL_ACTIVE"): "True", |
| 828 | + ("celery", "SSL_KEY"): "/path/to/key.pem", |
| 829 | + ("celery", "SSL_CERT"): "/path/to/cert.pem", |
| 830 | + ("celery", "SSL_CACERT"): "/path/to/ca.pem", |
| 831 | + } |
| 832 | + ) |
| 833 | + def test_amqps_broker_url_builds_ssl_config(self): |
| 834 | + """Test that amqps:// broker URLs correctly build broker_use_ssl with AMQP param names.""" |
| 835 | + import importlib |
| 836 | + import ssl |
| 837 | + |
| 838 | + importlib.reload(default_celery) |
| 839 | + |
| 840 | + config = default_celery.DEFAULT_CELERY_CONFIG |
| 841 | + assert "broker_use_ssl" in config, "broker_use_ssl should be set for amqps:// URLs" |
| 842 | + broker_ssl = config["broker_use_ssl"] |
| 843 | + assert broker_ssl["keyfile"] == "/path/to/key.pem" |
| 844 | + assert broker_ssl["certfile"] == "/path/to/cert.pem" |
| 845 | + assert broker_ssl["ca_certs"] == "/path/to/ca.pem" |
| 846 | + assert broker_ssl["cert_reqs"] == ssl.CERT_REQUIRED |
| 847 | + # Must NOT have ssl_ prefixed keys (those are for Redis) |
| 848 | + assert "ssl_keyfile" not in broker_ssl |
| 849 | + assert "ssl_certfile" not in broker_ssl |
| 850 | + |
| 851 | + @conf_vars( |
| 852 | + { |
| 853 | + ("celery", "BROKER_URL"): "amqp://guest:guest@rabbitmq:5672//", |
| 854 | + ("celery", "SSL_ACTIVE"): "True", |
| 855 | + ("celery", "SSL_KEY"): "/path/to/key.pem", |
| 856 | + ("celery", "SSL_CERT"): "/path/to/cert.pem", |
| 857 | + ("celery", "SSL_CACERT"): "/path/to/ca.pem", |
| 858 | + } |
| 859 | + ) |
| 860 | + def test_amqp_broker_url_still_builds_ssl_config(self): |
| 861 | + """Test that amqp:// (non-TLS) broker URLs still build SSL config correctly (no regression).""" |
| 862 | + import importlib |
| 863 | + import ssl |
| 864 | + |
| 865 | + importlib.reload(default_celery) |
| 866 | + |
| 867 | + config = default_celery.DEFAULT_CELERY_CONFIG |
| 868 | + assert "broker_use_ssl" in config |
| 869 | + broker_ssl = config["broker_use_ssl"] |
| 870 | + assert broker_ssl["keyfile"] == "/path/to/key.pem" |
| 871 | + assert broker_ssl["cert_reqs"] == ssl.CERT_REQUIRED |
| 872 | + |
| 873 | + @conf_vars( |
| 874 | + { |
| 875 | + ("celery", "BROKER_URL"): "amqps://guest:guest@rabbitmq:5671//", |
| 876 | + ("celery", "SSL_ACTIVE"): "False", |
| 877 | + } |
| 878 | + ) |
| 879 | + def test_amqps_broker_url_no_ssl_when_inactive(self): |
| 880 | + """Test that amqps:// broker URLs don't get SSL config when SSL_ACTIVE is False.""" |
| 881 | + import importlib |
| 882 | + |
| 883 | + importlib.reload(default_celery) |
| 884 | + |
| 885 | + config = default_celery.DEFAULT_CELERY_CONFIG |
| 886 | + assert "broker_use_ssl" not in config |
| 887 | + |
| 888 | + |
| 889 | +class TestCreateCeleryAppTeamIsolation: |
| 890 | + """Tests for create_celery_app() multi-team config isolation.""" |
| 891 | + |
| 892 | + @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecutorConf requires Airflow 3.2+") |
| 893 | + def test_custom_celery_config_options_applied(self): |
| 894 | + """User-provided celery_config_options (non-default) should be merged into team config.""" |
| 895 | + custom_config = {"worker_concurrency": 42, "broker_url": "redis://custom:6379/0"} |
| 896 | + custom_path = "my_custom_module.CELERY_CONFIG" |
| 897 | + |
| 898 | + team_conf = ExecutorConf(team_name="team_alpha") |
| 899 | + original_get = team_conf.get |
| 900 | + |
| 901 | + def mock_get(section, key, **kwargs): |
| 902 | + if section == "celery" and key == "celery_config_options": |
| 903 | + return custom_path |
| 904 | + return original_get(section, key, **kwargs) |
| 905 | + |
| 906 | + mock_module = mock.MagicMock() |
| 907 | + mock_module.CELERY_CONFIG = custom_config |
| 908 | + |
| 909 | + with ( |
| 910 | + mock.patch.object(team_conf, "get", side_effect=mock_get), |
| 911 | + mock.patch.object(celery_executor_utils, "import_module", return_value=mock_module), |
| 912 | + ): |
| 913 | + celery_app = celery_executor_utils.create_celery_app(team_conf) |
| 914 | + assert celery_app.conf.worker_concurrency == 42 |
| 915 | + assert celery_app.conf.broker_url == "redis://custom:6379/0" |
| 916 | + |
| 917 | + def test_default_celery_config_options_skipped_via_identity_check(self): |
| 918 | + """When celery_config_options resolves to DEFAULT_CELERY_CONFIG (same object), |
| 919 | + it must be skipped — re-applying it would overwrite team-specific config |
| 920 | + since DEFAULT_CELERY_CONFIG is built from global conf.""" |
| 921 | + original_get = conf.get |
| 922 | + # Path just needs a dot for rpartition and attr name matching DEFAULT_CELERY_CONFIG. |
| 923 | + # import_module is mocked to return default_celery module regardless of path. |
| 924 | + celery_config_path = "any.module.DEFAULT_CELERY_CONFIG" |
| 925 | + |
| 926 | + def mock_get(section, key, **kwargs): |
| 927 | + if section == "celery" and key == "celery_config_options": |
| 928 | + return celery_config_path |
| 929 | + return original_get(section, key, **kwargs) |
| 930 | + |
| 931 | + with ( |
| 932 | + mock.patch.object(conf, "get", side_effect=mock_get), |
| 933 | + mock.patch.object(celery_executor_utils, "import_module") as mock_import, |
| 934 | + ): |
| 935 | + mock_import.return_value = default_celery |
| 936 | + celery_app = celery_executor_utils.create_celery_app(conf) |
| 937 | + # import_module called (path is non-None), but override skipped (same object) |
| 938 | + mock_import.assert_called_once() |
| 939 | + default_config = default_celery.get_default_celery_config(conf) |
| 940 | + assert celery_app.conf.broker_url == default_config["broker_url"] |
| 941 | + |
| 942 | + @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecutorConf requires Airflow 3.2+") |
| 943 | + def test_team_specific_broker_not_overwritten(self): |
| 944 | + """Team-specific BROKER_URL set via ExecutorConf must survive create_celery_app().""" |
| 945 | + team_conf = ExecutorConf(team_name="team_alpha") |
| 946 | + |
| 947 | + original_get = team_conf.get |
| 948 | + |
| 949 | + def mock_team_get(section, key, **kwargs): |
| 950 | + if section == "celery" and key == "BROKER_URL": |
| 951 | + return "amqps://team-alpha-rabbit:5671//" |
| 952 | + return original_get(section, key, **kwargs) |
| 953 | + |
| 954 | + with mock.patch.object(team_conf, "get", side_effect=mock_team_get): |
| 955 | + celery_app = celery_executor_utils.create_celery_app(team_conf) |
| 956 | + assert celery_app.conf.broker_url == "amqps://team-alpha-rabbit:5671//" |
| 957 | + |
| 958 | + @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecutorConf requires Airflow 3.2+") |
| 959 | + def test_team_app_name_includes_team_name(self): |
| 960 | + """Each team gets a unique Celery app name for broker isolation.""" |
| 961 | + team_conf = ExecutorConf(team_name="team_beta") |
| 962 | + celery_app = celery_executor_utils.create_celery_app(team_conf) |
| 963 | + assert "team_beta" in celery_app.main |
0 commit comments