Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions airflow-core/src/airflow/serialization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,27 @@ def sort_dict_recursively(obj: Any) -> Any:
return tuple(sort_dict_recursively(item) for item in obj)
return obj

def handle_not_jsonable_object(obj: Any) -> Any:
try:
return obj.serialize()
except AttributeError:
if callable(obj):
full_qualified_name = qualname(obj, True)
return f"<callable {full_qualified_name}>"
return str(obj)

max_length = conf.getint("core", "max_templated_field_length")

if not is_jsonable(template_field):
try:
serialized = template_field.serialize()
except AttributeError:
if callable(template_field):
full_qualified_name = qualname(template_field, True)
serialized = f"<callable {full_qualified_name}>"
else:
serialized = str(template_field)
if len(serialized) > max_length:
rendered = redact(serialized, name)
if isinstance(template_field, (list, tuple)):
not_jsonable_serialized = [handle_not_jsonable_object(item) for item in template_field]
else:
not_jsonable_serialized = handle_not_jsonable_object(template_field)
if len(str(not_jsonable_serialized)) > max_length:
rendered = redact(not_jsonable_serialized, name)
return truncate_rendered_value(str(rendered), max_length)
return serialized
return not_jsonable_serialized

if not template_field and not isinstance(template_field, tuple):
# Avoid unnecessary serialization steps for empty fields unless they are tuples
# and need to be converted to lists
Expand All @@ -95,9 +101,9 @@ def sort_dict_recursively(obj: Any) -> Any:
# This prevents hash inconsistencies when dict ordering varies
if isinstance(template_field, dict):
template_field = sort_dict_recursively(template_field)
serialized = str(template_field)
if len(serialized) > max_length:
rendered = redact(serialized, name)
jsonable_serialized = str(template_field)
if len(jsonable_serialized) > max_length:
rendered = redact(jsonable_serialized, name)
return truncate_rendered_value(str(rendered), max_length)
return template_field

Expand Down
64 changes: 64 additions & 0 deletions airflow-core/tests/unit/dags/test_dag_decorator_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from datetime import datetime

from airflow.sdk import dag, task, task_group


@dag(
dag_id="TEST_DTM",
dag_display_name="TEST DTM",
schedule=None,
default_args={"owner": "airflow", "email": ""},
start_date=datetime(2024, 1, 25),
)
def dtm_test(
exponent: int = 2,
):

@task
def get_data():
return [20, 100, 200, 222, 242, 272]

@task
def to_exp(number: int, exponent: int) -> float:
return number**exponent

@task
def trunc(number: float, digits: int) -> float:
return round(number / 22, digits)

@task
def save(number: list[float]):
for n in number:
print(f"Got number: {n}")

@task_group # type: ignore[type-var]
def transform(number: int, exponent: int) -> float:
a = to_exp(number, exponent)
b = trunc(a, 2)
return b

data = get_data()
result = transform.partial(exponent=exponent).expand(number=data)
save(result) # type: ignore[arg-type]


instance = dtm_test()
12 changes: 12 additions & 0 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
from airflow.serialization.serialized_objects import (
BaseSerialization,
DagSerialization,
LazyDeserializedDAG,
OperatorSerialization,
_XComRef,
)
Expand Down Expand Up @@ -114,6 +115,7 @@
cron_timetable,
delta_timetable,
)
from unit.models import TEST_DAGS_FOLDER

if TYPE_CHECKING:
from airflow.sdk.definitions.context import Context
Expand Down Expand Up @@ -702,6 +704,16 @@ def test_deserialization_across_process(self):
for dag_id in stringified_dags:
self.validate_deserialized_dag(stringified_dags[dag_id], dags[dag_id])

@conf_vars({("core", "load_examples"): "false"})
def test_reserialize_should_make_equal_hash_with_dag_processor(self):
dagbag1 = DagBag(TEST_DAGS_FOLDER / "test_dag_decorator_version.py")
hash_result1 = LazyDeserializedDAG.from_dag(next(iter(dagbag1.dags.values()))).hash

dagbag2 = DagBag(TEST_DAGS_FOLDER / "test_dag_decorator_version.py")
hash_result2 = LazyDeserializedDAG.from_dag(next(iter(dagbag2.dags.values()))).hash

Comment thread
wjddn279 marked this conversation as resolved.
assert hash_result1 == hash_result2

@skip_if_force_lowest_dependencies_marker
@pytest.mark.db_test
def test_roundtrip_provider_example_dags(self):
Expand Down
Loading