Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1955,6 +1955,7 @@ class TimeTravelConfig(NamedTuple):
timestamp: Optional[str] = None
timestamp_type: Optional[str] = None
stream: Optional[str] = None
version: Optional[int] = None

@staticmethod
def validate_and_normalize_params(
Expand All @@ -1964,6 +1965,7 @@ def validate_and_normalize_params(
timestamp: Optional[Union[str, datetime.datetime]] = None,
timestamp_type: Optional[Union[str, "TimestampTimeZone"]] = None,
stream: Optional[str] = None,
version: Optional[int] = None,
) -> Optional["TimeTravelConfig"]:
"""
Validates and normalizes time travel parameters.
Expand All @@ -1986,7 +1988,7 @@ def validate_and_normalize_params(
ValueError: If parameters are invalid.
"""
time_travel_arg_count = sum(
arg is not None for arg in (statement, offset, timestamp, stream)
arg is not None for arg in (statement, offset, timestamp, stream, version)
)

# Validate mode
Expand All @@ -2003,10 +2005,28 @@ def validate_and_normalize_params(
f"Invalid time travel mode: {time_travel_mode}. Must be 'at' or 'before'."
)

# version (Iceberg snapshot id) only works with 'at' mode — matches
# Snowflake's ``AT(VERSION => <id>)`` grammar and Spark Iceberg's
# ``snapshot-id`` option semantics ("read snapshot N", not "before N").
if version is not None and time_travel_mode.lower() != "at":
raise ValueError(
"Iceberg snapshot version time travel can only be used with "
"time_travel_mode='at', not 'before'."
)

# Validate version type — snapshot IDs are 64-bit integers in Iceberg.
# Reject bool explicitly because ``isinstance(True, int)`` is True in Python.
if version is not None and (
not isinstance(version, int) or isinstance(version, bool)
):
raise ValueError(
f"'version' must be an int Iceberg snapshot id, got {type(version).__name__}."
)

# Validate exactly one parameter is provided
if time_travel_arg_count != 1:
raise ValueError(
"Exactly one of 'statement', 'offset', 'timestamp', or 'stream' must be provided."
"Exactly one of 'statement', 'offset', 'timestamp', 'stream', or 'version' must be provided."
)

# Normalize timestamp
Expand Down Expand Up @@ -2040,6 +2060,7 @@ def validate_and_normalize_params(
timestamp=normalized_timestamp,
timestamp_type=timestamp_type,
stream=stream,
version=version,
)

def generate_sql_clause(self) -> str:
Expand All @@ -2048,7 +2069,8 @@ def generate_sql_clause(self) -> str:
Args:
config: Time travel configuration.
Returns:
SQL clause like " AT (TIMESTAMP => TO_TIMESTAMP_NTZ('...'))"
SQL clause like " AT (TIMESTAMP => TO_TIMESTAMP_NTZ('...'))" or
" AT (VERSION => 1234567890)" for Iceberg snapshot id time travel.
"""
clause = f" {self.time_travel_mode.upper()} "

Expand All @@ -2058,6 +2080,8 @@ def generate_sql_clause(self) -> str:
clause += f"(OFFSET => {self.offset})"
elif self.stream is not None:
clause += f"(STREAM => '{self.stream}')"
elif self.version is not None:
clause += f"(VERSION => {self.version})"
elif self.timestamp is not None:
if self.timestamp_type is not None:
timestamp_type = self.timestamp_type.upper()
Expand Down
58 changes: 56 additions & 2 deletions src/snowflake/snowpark/dataframe_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ def _extract_time_travel_from_options(options: dict) -> dict:
- Automatically sets time_travel_mode to 'at'
- Cannot be used with time_travel_mode='before' (raises error)
- Cannot be mixed with regular 'timestamp' option (raises error)

Special handling for 'SNAPSHOT-ID' / 'SNAPSHOT_ID' (Spark Iceberg
compatibility) — both aliases map to the internal ``version`` time
travel parameter:
- Automatically set time_travel_mode to 'at'
(Iceberg snapshot ids only support ``AT(VERSION => N)``, not ``BEFORE``)
- Cannot be used with time_travel_mode='before' (raises error)
"""
result = {}
excluded_keys = set()
Expand All @@ -183,6 +190,35 @@ def _extract_time_travel_from_options(options: dict) -> dict:
result["timestamp"] = options["AS-OF-TIMESTAMP"]
excluded_keys.add("TIMESTAMP")

# Handle Iceberg snapshot id (Spark ``snapshot-id`` / ``snapshot_id``).
# Auto-sets mode='at' since ``AT(VERSION => N)`` is the only valid form.
snapshot_id_value = options.get("SNAPSHOT-ID")
snapshot_id_source = "snapshot-id"
if snapshot_id_value is None:
snapshot_id_value = options.get("SNAPSHOT_ID")
snapshot_id_source = "snapshot_id"
Comment thread
sfc-gh-igarish marked this conversation as resolved.
if snapshot_id_value is not None:
if (
"TIME_TRAVEL_MODE" in options
and options["TIME_TRAVEL_MODE"].lower() == "before"
):
raise ValueError(
f"Cannot use '{snapshot_id_source}' option with "
"time_travel_mode='before'. Iceberg snapshot id time travel "
"only supports time_travel_mode='at'."
)
# Coerce string snapshot ids (Spark accepts both string and long
# literals via .option(); we normalize to int so the SQL emits an
# unquoted long).
try:
result["version"] = int(snapshot_id_value)
except (TypeError, ValueError):
raise ValueError(
f"'{snapshot_id_source}' must be a 64-bit integer Iceberg "
f"snapshot id, got {snapshot_id_value!r}."
)
result["time_travel_mode"] = "at"

for option_key, param_name in _TIME_TRAVEL_OPTIONS_PARAMS_MAP.items():
if option_key in options and option_key not in excluded_keys:
result[param_name] = options[option_key]
Expand Down Expand Up @@ -549,6 +585,7 @@ def table(
timestamp: Optional[Union[str, datetime]] = None,
timestamp_type: Optional[Union[str, TimestampTimeZone]] = None,
stream: Optional[str] = None,
**kwargs,
) -> Table:
"""Returns a Table that points to the specified table.

Expand Down Expand Up @@ -605,6 +642,15 @@ def table(
... .option("offset", -60) # This will be IGNORED
... .table("my_table", time_travel_mode="at", offset=-3600)) # Only this is used
"""
# ``version`` (Iceberg snapshot id) is intentionally not in the public
# signature — it's consumed by Snowpark Connect and may be removed
# once a first-class API lands. Accept it through **kwargs so direct
# callers can still pass it without us advertising it.
version = kwargs.pop("version", None)
if kwargs:
raise TypeError(
f"table() got unexpected keyword arguments: {sorted(kwargs)}"
)

# AST.
stmt = None
Expand All @@ -626,14 +672,22 @@ def table(
if stream is not None:
ast.stream.value = stream

if time_travel_mode is not None:
if time_travel_mode is not None or version is not None:
# If version is provided without mode, default to 'at' (snapshot ids
# only make sense with AT — symmetric with iceberg_tag handling).
effective_mode = (
Comment thread
sfc-gh-aling marked this conversation as resolved.
time_travel_mode
if time_travel_mode
else ("at" if version is not None else None)
)
time_travel_params = {
"time_travel_mode": time_travel_mode,
"time_travel_mode": effective_mode,
"statement": statement,
"offset": offset,
"timestamp": timestamp,
"timestamp_type": timestamp_type,
"stream": stream,
"version": version,
}
else:
# if time_travel_mode is not provided, extract time travel config from options
Expand Down
12 changes: 12 additions & 0 deletions src/snowflake/snowpark/session.py
Comment thread
sfc-gh-igarish marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -2728,6 +2728,7 @@ def table(
timestamp: Optional[Union[str, datetime.datetime]] = None,
timestamp_type: Optional[Union[str, TimestampTimeZone]] = None,
stream: Optional[str] = None,
**kwargs,
) -> Table:
"""
Returns a Table that points the specified table.
Expand Down Expand Up @@ -2775,6 +2776,16 @@ def table(
# timestamp_type remains "NTZ" (user's explicit choice respected)
>>> table2 = session.read.table("my_table", time_travel_mode="at", timestamp=tz_aware, timestamp_type="NTZ") # doctest: +SKIP
"""
# ``version`` (Iceberg snapshot id) is intentionally not in the public
# signature — it's consumed by Snowpark Connect and may be removed
# once a first-class API lands. Accept it through **kwargs so direct
# callers can still pass it without us advertising it.
version = kwargs.pop("version", None)
if kwargs:
raise TypeError(
f"table() got unexpected keyword arguments: {sorted(kwargs)}"
)

if _emit_ast:
stmt = self._ast_batch.bind()
ast = with_src_position(stmt.expr.table, stmt)
Expand Down Expand Up @@ -2811,6 +2822,7 @@ def table(
timestamp=timestamp,
timestamp_type=timestamp_type,
stream=stream,
version=version,
)
# Replace API call origin for table
set_api_call_source(t, "Session.table")
Expand Down
12 changes: 12 additions & 0 deletions src/snowflake/snowpark/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,18 @@ def __init__(
timestamp: Optional[Union[str, datetime.datetime]] = None,
timestamp_type: Optional[Union[str, TimestampTimeZone]] = None,
stream: Optional[str] = None,
**kwargs,
) -> None:
# ``version`` (Iceberg snapshot id) is intentionally not in the public
# signature — it's consumed by Snowpark Connect and may be removed
# once a first-class API lands. Accept it through **kwargs so direct
# callers can still pass it without us advertising it.
version = kwargs.pop("version", None)
if kwargs:
raise TypeError(
f"Table() got unexpected keyword arguments: {sorted(kwargs)}"
)

if _ast_stmt is None and session is not None and _emit_ast:
_ast_stmt = session._ast_batch.bind()
ast = with_src_position(_ast_stmt.expr.table, _ast_stmt)
Expand All @@ -328,6 +339,7 @@ def __init__(
timestamp=timestamp,
timestamp_type=timestamp_type,
stream=stream,
version=version,
)

snowflake_table_plan = SnowflakeTable(
Expand Down
74 changes: 74 additions & 0 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8326,3 +8326,77 @@ def test_time_travel_comprehensive_coverage(session):
finally:
Utils.drop_table(session, table1_name)
Utils.drop_table(session, table2_name)


# ----------------------------------------------------------------------
# Iceberg snapshot id (``version=``) time travel.
#
# TODO(SNOW-3525585): Wire these up to a CI test account that has:
# * a Catalog-Linked Database (CLD) such as cldUnity / cldglue, AND
# * an unmanaged Iceberg table inside it with at least two snapshots
# readable through ``INFORMATION_SCHEMA.GET_TABLE_VERSIONS(...)``.
#
# Snowflake's ``AT(VERSION => N)`` syntax requires the
# ``FEATURE_ICEBERG_TIME_TRAVEL`` server parameter to be enabled on the
# account and is currently scoped to **unmanaged** Iceberg tables in CLDs.
# Because the existing snowpark-python integ accounts don't have a CLD with
# a multi-snapshot Iceberg table provisioned, these tests are skipped by
# default and run manually against ``sfctest0`` (see
# ``tests/sas_tests/test_iceberg_snapshot_id_sample.py`` in the
# snowflake-eng/sas repo for the manual reproducer).
# ----------------------------------------------------------------------
@pytest.mark.skip(
reason=(
"Requires a CLD-linked unmanaged Iceberg table with multiple "
"snapshots and FEATURE_ICEBERG_TIME_TRAVEL enabled on the account. "
"Tested manually; see TODO above."
)
)
def test_iceberg_snapshot_id_time_travel_session_table(session):
"""End-to-end: ``Session.table(..., version=<snapshot_id>)`` returns the
table state at the requested Iceberg snapshot."""
table_fqn = "CLDUNITY.scosschema.snapshot_demo"

snapshot_ids = [
row["SNAPSHOT_ID"]
for row in session.sql(
f"SELECT SNAPSHOT_ID FROM "
f"TABLE(INFORMATION_SCHEMA.GET_TABLE_VERSIONS('{table_fqn}')) "
"ORDER BY SNAPSHOT_TIMESTAMP"
).collect()
]
assert len(snapshot_ids) >= 2, "Demo table needs at least 2 snapshots"

first_snapshot = session.table(
table_fqn, time_travel_mode="at", version=snapshot_ids[0]
).collect()
latest = session.table(table_fqn).collect()
assert len(first_snapshot) <= len(latest)


@pytest.mark.skip(
reason=(
"Requires a CLD-linked unmanaged Iceberg table with multiple "
"snapshots and FEATURE_ICEBERG_TIME_TRAVEL enabled on the account. "
"Tested manually; see TODO above."
)
)
def test_iceberg_snapshot_id_time_travel_dataframe_reader_option(session):
"""End-to-end: ``session.read.option('snapshot-id', N).table(...)``
routes through the Spark Iceberg-compat alias and produces the same
result as the explicit ``version=`` kwarg."""
table_fqn = "CLDUNITY.scosschema.snapshot_demo"

snapshot_id = session.sql(
f"SELECT SNAPSHOT_ID FROM "
f"TABLE(INFORMATION_SCHEMA.GET_TABLE_VERSIONS('{table_fqn}')) "
"ORDER BY SNAPSHOT_TIMESTAMP LIMIT 1"
).collect()[0]["SNAPSHOT_ID"]

via_kwarg = session.read.table(
table_fqn, time_travel_mode="at", version=snapshot_id
).collect()
via_option = (
session.read.option("snapshot-id", snapshot_id).table(table_fqn).collect()
)
assert via_kwarg == via_option
Loading
Loading