diff --git a/src/snowflake/snowpark/modin/plugin/extensions/utils.py b/src/snowflake/snowpark/modin/plugin/extensions/utils.py index 86d702a90a..57cc78b03b 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/utils.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/utils.py @@ -823,6 +823,25 @@ def pandas_to_snowflake( "to_snowflake", ", ".join(type(t).__name__ for t in unsupported_types) ) + name_parts = [name] if isinstance(name, str) else list(name) + if len(name_parts) > 3: + raise ValueError( + f"name must have at most 3 parts (database, schema, table), got {len(name_parts)}: {name_parts}" + ) + table_name_converted = _convert_to_snowflake_table_name_to_write_pandas_table_name( + name_parts[-1] + ) + schema_converted = ( + _convert_to_snowflake_table_name_to_write_pandas_table_name(name_parts[-2]) + if len(name_parts) >= 2 + else None + ) + database_converted = ( + _convert_to_snowflake_table_name_to_write_pandas_table_name(name_parts[0]) + if len(name_parts) == 3 + else None + ) + pd.session.write_pandas( # use set_axis() this way so that we can also flatten the tuple column # labels of a column multi-index, e.g. if `pandas_frame` has columns @@ -836,7 +855,9 @@ def pandas_to_snowflake( # column identifiers ourselves, we get the correct column names and we # don't have to modify the table name, but the snowflake connector seems # to incorrectly insert null data. - table_name=_convert_to_snowflake_table_name_to_write_pandas_table_name(name), + table_name=table_name_converted, + database=database_converted, + schema=schema_converted, auto_create_table=True, overwrite=if_exists != "append", table_type=table_type, diff --git a/tests/integ/modin/frame/test_to_snowflake.py b/tests/integ/modin/frame/test_to_snowflake.py index 569d323325..d95301f597 100644 --- a/tests/integ/modin/frame/test_to_snowflake.py +++ b/tests/integ/modin/frame/test_to_snowflake.py @@ -8,6 +8,7 @@ import modin.pandas as pd import pandas as native_pd import pytest +from modin.config import context as config_context import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.utils import ( @@ -386,3 +387,65 @@ def test_special_chars_unquoted(self, valid_unquoted_identifier_table_name): assert_snowpark_pandas_equals_to_pandas_without_dtypecheck( written, native_df.rename(str, axis=1) ) + + +class TestWritePandasParquetPath: + @pytest.fixture(autouse=True) + def use_starting_backend_and_parquet_threshold(self): + with config_context(Backend="Pandas", PandasToSnowflakeParquetThresholdBytes=0): + yield + + def test_table_name_only(self, session, test_table_name): + native_df = native_pd.DataFrame({"a": [1]}) + df = pd.DataFrame(native_df) + with to_snowflake_counter(dataset=df, if_exists="replace"): + session.write_pandas( + df, test_table_name, auto_create_table=True, overwrite=True + ) + written = pd.read_snowflake(test_table_name) + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck( + written, native_df.rename(str, axis=1) + ) + + def test_table_name_with_schema(self, session, test_table_name): + native_df = native_pd.DataFrame({"a": [1]}) + df = pd.DataFrame(native_df) + schema = session.get_current_schema().strip('"') + with to_snowflake_counter(dataset=df, if_exists="replace"): + session.write_pandas( + df, test_table_name, schema=schema, auto_create_table=True, overwrite=True + ) + written = pd.read_snowflake(test_table_name) + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck( + written, native_df.rename(str, axis=1) + ) + + def test_table_name_with_database_and_schema(self, session, test_table_name): + native_df = native_pd.DataFrame({"a": [1]}) + df = pd.DataFrame(native_df) + database = session.get_current_database().strip('"') + schema = session.get_current_schema().strip('"') + with to_snowflake_counter(dataset=df, if_exists="replace"): + session.write_pandas( + df, + test_table_name, + database=database, + schema=schema, + auto_create_table=True, + overwrite=True, + ) + written = pd.read_snowflake(test_table_name) + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck( + written, native_df.rename(str, axis=1) + ) + + def test_too_many_parts_raises(self, test_table_name): + native_df = native_pd.DataFrame({"a": [1]}) + df = pd.DataFrame(native_df) + with SqlCounter(query_count=0): + with pytest.raises(ValueError, match="at most 3 parts"): + df.to_snowflake( + ["extra", "db", "schema", test_table_name], + if_exists="replace", + index=False, + )