diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index c32134054..44135a56b 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -93,6 +93,13 @@ fn array_cat(exprs: Vec) -> PyExpr { array_concat(exprs) } +#[pyfunction] +fn make_map(keys: Vec, values: Vec) -> PyExpr { + let keys = keys.into_iter().map(|x| x.into()).collect(); + let values = values.into_iter().map(|x| x.into()).collect(); + datafusion::functions_nested::map::map(keys, values).into() +} + #[pyfunction] #[pyo3(signature = (array, element, index=None))] fn array_position(array: PyExpr, element: PyExpr, index: Option) -> PyExpr { @@ -665,6 +672,12 @@ array_fn!(cardinality, array); array_fn!(flatten, array); array_fn!(range, start stop step); +// Map Functions +array_fn!(map_keys, map); +array_fn!(map_values, map); +array_fn!(map_extract, map key); +array_fn!(map_entries, map); + aggregate_function!(array_agg); aggregate_function!(max); aggregate_function!(min); @@ -1124,6 +1137,13 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(flatten))?; m.add_wrapped(wrap_pyfunction!(cardinality))?; + // Map Functions + m.add_wrapped(wrap_pyfunction!(make_map))?; + m.add_wrapped(wrap_pyfunction!(map_keys))?; + m.add_wrapped(wrap_pyfunction!(map_values))?; + m.add_wrapped(wrap_pyfunction!(map_extract))?; + m.add_wrapped(wrap_pyfunction!(map_entries))?; + // Window Functions m.add_wrapped(wrap_pyfunction!(lead))?; m.add_wrapped(wrap_pyfunction!(lag))?; diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index f062cbfce..71a6140d1 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -137,6 +137,7 @@ "degrees", "dense_rank", "digest", + "element_at", "empty", "encode", "ends_with", @@ -200,6 +201,12 @@ "make_array", "make_date", "make_list", + "make_map", + "map", + "map_entries", + "map_extract", + "map_keys", + "map_values", "max", "md5", "mean", @@ -3338,6 +3345,157 @@ def empty(array: Expr) -> Expr: return array_empty(array) +# map functions + + +def map(*args: Any) -> Expr: + """Returns a map expression. + + Supports three calling conventions: + + - ``map({"a": 1, "b": 2})`` — from a Python dictionary. + - ``map([keys], [values])`` — from a list of keys and a list of + their associated values. Both lists must be the same length. + - ``map(k1, v1, k2, v2, ...)`` — from alternating keys and their + associated values. + + Keys and values that are not already :py:class:`~datafusion.expr.Expr` + are automatically converted to literal expressions. + + Examples: + From a dictionary: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> result = df.select( + ... dfn.functions.map({"a": 1, "b": 2}).alias("m")) + >>> result.collect_column("m")[0].as_py() + [('a', 1), ('b', 2)] + + From two lists: + + >>> df = ctx.from_pydict({"key": ["x", "y"], "val": [10, 20]}) + >>> df = df.select( + ... dfn.functions.map( + ... [dfn.col("key")], [dfn.col("val")] + ... ).alias("m")) + >>> df.collect_column("m")[0].as_py() + [('x', 10)] + + From alternating keys and values: + + >>> df = ctx.from_pydict({"a": [1]}) + >>> result = df.select( + ... dfn.functions.map("x", 1, "y", 2).alias("m")) + >>> result.collect_column("m")[0].as_py() + [('x', 1), ('y', 2)] + """ + if len(args) == 1 and isinstance(args[0], dict): + key_list = list(args[0].keys()) + value_list = list(args[0].values()) + elif ( + len(args) == 2 # noqa: PLR2004 + and isinstance(args[0], list) + and isinstance(args[1], list) + ): + key_list = args[0] + value_list = args[1] + elif len(args) >= 2 and len(args) % 2 == 0: # noqa: PLR2004 + key_list = list(args[0::2]) + value_list = list(args[1::2]) + else: + msg = "map expects a dict, two lists, or an even number of key-value arguments" + raise ValueError(msg) + + key_exprs = [k if isinstance(k, Expr) else Expr.literal(k) for k in key_list] + val_exprs = [v if isinstance(v, Expr) else Expr.literal(v) for v in value_list] + return Expr(f.make_map([k.expr for k in key_exprs], [v.expr for v in val_exprs])) + + +def make_map(*args: Any) -> Expr: + """Returns a map expression. + + See Also: + This is an alias for :py:func:`map`. + """ + return map(*args) + + +def map_keys(map: Expr) -> Expr: + """Returns a list of all keys in the map. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> df = df.select( + ... dfn.functions.map({"x": 1, "y": 2}).alias("m")) + >>> result = df.select( + ... dfn.functions.map_keys(dfn.col("m")).alias("keys")) + >>> result.collect_column("keys")[0].as_py() + ['x', 'y'] + """ + return Expr(f.map_keys(map.expr)) + + +def map_values(map: Expr) -> Expr: + """Returns a list of all values in the map. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> df = df.select( + ... dfn.functions.map({"x": 1, "y": 2}).alias("m")) + >>> result = df.select( + ... dfn.functions.map_values(dfn.col("m")).alias("vals")) + >>> result.collect_column("vals")[0].as_py() + [1, 2] + """ + return Expr(f.map_values(map.expr)) + + +def map_extract(map: Expr, key: Expr) -> Expr: + """Returns the value for the given key in the map, or an empty list if absent. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> df = df.select( + ... dfn.functions.map({"x": 1, "y": 2}).alias("m")) + >>> result = df.select( + ... dfn.functions.map_extract( + ... dfn.col("m"), dfn.lit("x") + ... ).alias("val")) + >>> result.collect_column("val")[0].as_py() + [1] + """ + return Expr(f.map_extract(map.expr, key.expr)) + + +def map_entries(map: Expr) -> Expr: + """Returns a list of all entries (key-value struct pairs) in the map. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> df = df.select( + ... dfn.functions.map({"x": 1, "y": 2}).alias("m")) + >>> result = df.select( + ... dfn.functions.map_entries(dfn.col("m")).alias("entries")) + >>> result.collect_column("entries")[0].as_py() + [{'key': 'x', 'value': 1}, {'key': 'y', 'value': 2}] + """ + return Expr(f.map_entries(map.expr)) + + +def element_at(map: Expr, key: Expr) -> Expr: + """Returns the value for the given key in the map, or an empty list if absent. + + See Also: + This is an alias for :py:func:`map_extract`. + """ + return map_extract(map, key) + + # aggregate functions def approx_distinct( expression: Expr, diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 37d349c58..8dec52c1a 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -668,6 +668,154 @@ def test_array_function_obj_tests(stmt, py_expr): assert a == b +def test_map_from_dict(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + result = df.select(f.map({"x": 1, "y": 2}).alias("m")).collect()[0].column(0) + assert result[0].as_py() == [("x", 1), ("y", 2)] + + +def test_map_from_dict_with_expr_values(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + result = ( + df.select(f.map({"x": literal(1), "y": literal(2)}).alias("m")) + .collect()[0] + .column(0) + ) + assert result[0].as_py() == [("x", 1), ("y", 2)] + + +def test_map_from_two_lists(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays( + [ + pa.array(["k1", "k2", "k3"]), + pa.array([10, 20, 30]), + ], + names=["keys", "vals"], + ) + df = ctx.create_dataframe([[batch]]) + + m = f.map([column("keys")], [column("vals")]) + result = df.select(f.map_keys(m).alias("k")).collect()[0].column(0) + for i, expected in enumerate(["k1", "k2", "k3"]): + assert result[i].as_py() == [expected] + + result = df.select(f.map_values(m).alias("v")).collect()[0].column(0) + for i, expected in enumerate([10, 20, 30]): + assert result[i].as_py() == [expected] + + +def test_map_from_variadic_pairs(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + result = df.select(f.map("x", 1, "y", 2).alias("m")).collect()[0].column(0) + assert result[0].as_py() == [("x", 1), ("y", 2)] + + +def test_map_variadic_with_exprs(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + result = ( + df.select(f.map(literal("x"), literal(1), literal("y"), literal(2)).alias("m")) + .collect()[0] + .column(0) + ) + assert result[0].as_py() == [("x", 1), ("y", 2)] + + +def test_map_odd_args_raises(): + with pytest.raises(ValueError, match="map expects"): + f.map("x", 1, "y") + + +def test_make_map_is_alias(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + result = df.select(f.make_map({"x": 1, "y": 2}).alias("m")).collect()[0].column(0) + assert result[0].as_py() == [("x", 1), ("y", 2)] + + +def test_map_keys(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + m = f.map({"x": 1, "y": 2}) + result = df.select(f.map_keys(m).alias("keys")).collect()[0].column(0) + assert result[0].as_py() == ["x", "y"] + + +def test_map_values(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + m = f.map({"x": 1, "y": 2}) + result = df.select(f.map_values(m).alias("vals")).collect()[0].column(0) + assert result[0].as_py() == [1, 2] + + +def test_map_extract(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + m = f.map({"x": 1, "y": 2}) + result = ( + df.select(f.map_extract(m, literal("x")).alias("val")).collect()[0].column(0) + ) + assert result[0].as_py() == [1] + + +def test_map_extract_missing_key(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + m = f.map({"x": 1}) + result = ( + df.select(f.map_extract(m, literal("z")).alias("val")).collect()[0].column(0) + ) + assert result[0].as_py() == [None] + + +def test_map_entries(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + m = f.map({"x": 1, "y": 2}) + result = df.select(f.map_entries(m).alias("entries")).collect()[0].column(0) + assert result[0].as_py() == [ + {"key": "x", "value": 1}, + {"key": "y", "value": 2}, + ] + + +def test_element_at(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + m = f.map({"a": 10, "b": 20}) + result = ( + df.select(f.element_at(m, literal("b")).alias("val")).collect()[0].column(0) + ) + assert result[0].as_py() == [20] + + @pytest.mark.parametrize( ("function", "expected_result"), [