From 3a364d566c9c2feacabd72ec5b514e78acdf0d00 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 31 Mar 2026 12:09:20 -0400 Subject: [PATCH 1/2] Add missing scalar functions: get_field, union_extract, union_tag, arrow_metadata, version, row Expose upstream DataFusion scalar functions that were not yet available in the Python API. Closes #1453. - get_field: extracts a field from a struct or map by name - union_extract: extracts a value from a union type by field name - union_tag: returns the active field name of a union type - arrow_metadata: returns Arrow field metadata (all or by key) - version: returns the DataFusion version string - row: alias for the struct constructor Note: arrow_try_cast was listed in the issue but does not exist in DataFusion 53, so it is not included. Co-Authored-By: Claude Opus 4.6 (1M context) --- crates/core/src/functions.rs | 26 ++++++++++ python/datafusion/functions.py | 86 ++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index c32134054..2c1c7f1fc 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -631,8 +631,29 @@ expr_fn_vec!(named_struct); expr_fn!(from_unixtime, unixtime); expr_fn!(arrow_typeof, arg_1); expr_fn!(arrow_cast, arg_1 datatype); +expr_fn_vec!(arrow_metadata); +expr_fn!(union_tag, arg1); expr_fn!(random); +#[pyfunction] +fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr { + functions::core::get_field() + .call(vec![expr.into(), name.into()]) + .into() +} + +#[pyfunction] +fn union_extract(union_expr: PyExpr, field_name: PyExpr) -> PyExpr { + functions::core::union_extract() + .call(vec![union_expr.into(), field_name.into()]) + .into() +} + +#[pyfunction] +fn version() -> PyExpr { + functions::core::version().call(vec![]).into() +} + // Array Functions array_fn!(array_append, array element); array_fn!(array_to_string, array delimiter); @@ -940,6 +961,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(array_agg))?; m.add_wrapped(wrap_pyfunction!(arrow_typeof))?; m.add_wrapped(wrap_pyfunction!(arrow_cast))?; + m.add_wrapped(wrap_pyfunction!(arrow_metadata))?; m.add_wrapped(wrap_pyfunction!(ascii))?; m.add_wrapped(wrap_pyfunction!(asin))?; m.add_wrapped(wrap_pyfunction!(asinh))?; @@ -1063,6 +1085,10 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(trim))?; m.add_wrapped(wrap_pyfunction!(trunc))?; m.add_wrapped(wrap_pyfunction!(upper))?; + m.add_wrapped(wrap_pyfunction!(get_field))?; + m.add_wrapped(wrap_pyfunction!(union_extract))?; + m.add_wrapped(wrap_pyfunction!(union_tag))?; + m.add_wrapped(wrap_pyfunction!(version))?; m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision m.add_wrapped(wrap_pyfunction!(var_pop))?; m.add_wrapped(wrap_pyfunction!(var_sample))?; diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index f062cbfce..46e69be8a 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -90,6 +90,7 @@ "array_to_string", "array_union", "arrow_cast", + "arrow_metadata", "arrow_typeof", "ascii", "asin", @@ -149,6 +150,7 @@ "floor", "from_unixtime", "gcd", + "get_field", "in_list", "initcap", "isnan", @@ -242,6 +244,7 @@ "reverse", "right", "round", + "row", "row_number", "rpad", "rtrim", @@ -282,12 +285,15 @@ "translate", "trim", "trunc", + "union_extract", + "union_tag", "upper", "uuid", "var", "var_pop", "var_samp", "var_sample", + "version", "when", # Window Functions "window", @@ -2492,6 +2498,86 @@ def arrow_cast(expr: Expr, data_type: Expr) -> Expr: return Expr(f.arrow_cast(expr.expr, data_type.expr)) +def arrow_metadata(*args: Expr) -> Expr: + """Returns the metadata of the input expression. + + If called with one argument, returns a Map of all metadata key-value pairs. + If called with two arguments, returns the value for the specified metadata key. + + Args: + args: An expression, optionally followed by a metadata key string. + + Returns: + A Map of metadata or a specific metadata value. + """ + args = [arg.expr for arg in args] + return Expr(f.arrow_metadata(*args)) + + +def get_field(expr: Expr, name: Expr) -> Expr: + """Extracts a field from a struct or map by name. + + Args: + expr: A struct or map expression. + name: The field name to extract. + + Returns: + The value of the named field. + """ + return Expr(f.get_field(expr.expr, name.expr)) + + +def union_extract(union_expr: Expr, field_name: Expr) -> Expr: + """Extracts a value from a union type by field name. + + Returns the value of the named field if it is the currently selected + variant, otherwise returns NULL. + + Args: + union_expr: A union-typed expression. + field_name: The name of the field to extract. + + Returns: + The extracted value or NULL. + """ + return Expr(f.union_extract(union_expr.expr, field_name.expr)) + + +def union_tag(union_expr: Expr) -> Expr: + """Returns the tag (active field name) of a union type. + + Args: + union_expr: A union-typed expression. + + Returns: + The name of the currently selected field in the union. + """ + return Expr(f.union_tag(union_expr.expr)) + + +def version() -> Expr: + """Returns the DataFusion version string. + + Returns: + A string describing the DataFusion version. + """ + return Expr(f.version()) + + +def row(*args: Expr) -> Expr: + """Returns a struct with the given arguments. + + This is an alias for :py:func:`struct`. + + Args: + args: The expressions to include in the struct. + + Returns: + A struct expression. + """ + return struct(*args) + + def random() -> Expr: """Returns a random value in the range ``0.0 <= x < 1.0``. From 192593f297ebd37554198c4b5f04f98422a84d2c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 31 Mar 2026 12:14:47 -0400 Subject: [PATCH 2/2] Add tests for new scalar functions Tests for get_field, arrow_metadata, version, row, union_tag, and union_extract. Also fix codespell skip paths in pyproject.toml. Co-Authored-By: Claude Opus 4.6 (1M context) --- pyproject.toml | 6 +-- python/tests/test_functions.py | 70 ++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d05a64083..256bbc61a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -172,10 +172,10 @@ extend-allowed-calls = ["datafusion.lit", "lit"] [tool.codespell] skip = [ - "./python/tests/test_functions.py", - "./target", + "python/tests/test_functions.py", + "target", "uv.lock", - "./examples/tpch/answers_sf1/*", + "examples/tpch/answers_sf1/*", ] count = true ignore-words-list = ["IST", "ans"] diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 37d349c58..556dfdc48 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -1435,3 +1435,73 @@ def test_coalesce(df): assert result.column(0) == pa.array( ["Hello", "fallback", "!"], type=pa.string_view() ) + + +def test_get_field(df): + df = df.with_column( + "s", + f.named_struct( + [ + ("x", column("a")), + ("y", column("b")), + ] + ), + ) + result = df.select( + f.get_field(column("s"), string_literal("x")).alias("x_val"), + f.get_field(column("s"), string_literal("y")).alias("y_val"), + ).collect()[0] + + assert result.column(0) == pa.array(["Hello", "World", "!"], type=pa.string_view()) + assert result.column(1) == pa.array([4, 5, 6]) + + +def test_arrow_metadata(df): + result = df.select( + f.arrow_metadata(column("a")).alias("meta"), + ).collect()[0] + # The metadata column should be returned as a map type (possibly empty) + assert result.column(0).type == pa.map_(pa.utf8(), pa.utf8()) + + +def test_version(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1]}) + result = df.select(f.version().alias("v")).collect()[0] + version_str = result.column(0)[0].as_py() + assert "Apache DataFusion" in version_str + + +def test_row(df): + result = df.select( + f.row(column("a"), column("b")).alias("r"), + f.struct(column("a"), column("b")).alias("s"), + ).collect()[0] + # row is an alias for struct, so they should produce the same output + assert result.column(0) == result.column(1) + + +def test_union_tag(): + ctx = SessionContext() + types = pa.array([0, 1, 0], type=pa.int8()) + offsets = pa.array([0, 0, 1], type=pa.int32()) + children = [pa.array([1, 2]), pa.array(["hello"])] + arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1]) + df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]]) + + result = df.select(f.union_tag(column("u")).alias("tag")).collect()[0] + assert result.column(0).to_pylist() == ["int", "str", "int"] + + +def test_union_extract(): + ctx = SessionContext() + types = pa.array([0, 1, 0], type=pa.int8()) + offsets = pa.array([0, 0, 1], type=pa.int32()) + children = [pa.array([1, 2]), pa.array(["hello"])] + arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1]) + df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]]) + + result = df.select( + f.union_extract(column("u"), string_literal("int")).alias("val") + ).collect()[0] + assert result.column(0).to_pylist() == [1, None, 2]