Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions sqlmesh/core/model/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,16 @@ def column_hashes(self) -> t.Dict[str, str]:
}

def read(self, batch_size: t.Optional[int] = None) -> t.Generator[pd.DataFrame, None, None]:
import pandas as pd

df = self._get_df()

batch_size = batch_size or df.size
batch_start = 0
while batch_start < df.shape[0]:
yield df.iloc[batch_start : batch_start + batch_size, :].copy()
batch_start += batch_size
with pd.option_context("mode.copy_on_write", True):
while batch_start < df.shape[0]:
yield df.iloc[batch_start : batch_start + batch_size, :]
batch_start += batch_size

def _get_df(self) -> pd.DataFrame:
import pandas as pd
Expand Down
16 changes: 12 additions & 4 deletions tests/core/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,18 @@ def test_read_returns_independent_batches():
seed = Seed(content=content)
seed_reader = seed.reader()

batches = list(seed_reader.read(batch_size=1))
batches[0].at[0, "value"] = "changed"

assert [df["value"].tolist() for df in batches] == [["changed"], ["two"]]
# Keep the generator open so the copy_on_write context inside read() stays active.
gen = seed_reader.read(batch_size=1)
first_batch = next(gen)
# Mutate while the generator (and therefore the CoW context) is still open.
# CoW ensures only first_batch gets a private copy; the cached _df is unchanged.
first_batch.at[0, "value"] = "changed"
# second_batch is fetched while CoW is still active, so it still sees the original data.
second_batch = next(gen)

assert first_batch["value"].tolist() == ["changed"]
assert second_batch["value"].tolist() == ["two"]
# CoW prevented the mutation from reaching the cached _df, so a fresh read returns original data.
assert next(seed_reader.read())["value"].tolist() == ["one", "two"]


Expand Down
Loading