Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/model_ledger/backends/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ def flush(self) -> None:
def _flush_models(self) -> None:
if not self._model_buffer:
return
# Dedup by model_hash (last write wins). A single Ledger.add() pass can
# buffer the same new model twice — register() saves it, then
# update_model() saves it again. The MERGE is idempotent only once the
# target row exists; for a brand-new model the empty-target INSERT fires
# per source row, so an undeduped buffer produces duplicate rows.
deduped: dict[str, ModelRef] = {}
for model in self._model_buffer:
deduped[model.model_hash] = model
self._model_buffer = list(deduped.values())
if self._flush_models_pandas():
self._model_buffer.clear()
return
Expand Down
37 changes: 37 additions & 0 deletions tests/test_backends/test_snowflake_ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,40 @@ def fake_exec(session, sql: str, *args, **kwargs):

with pytest.raises(RuntimeError, match="insufficient privileges"):
SnowflakeLedgerBackend(connection=object(), schema="TEST.LEDGER")


def test_flush_dedups_model_buffer_by_hash():
"""register() and update_model() both buffer the same new model in one
Ledger.add() pass. Without dedup, the MERGE inserts both copies because the
target row doesn't exist yet (empty-target INSERT fires per source row),
producing duplicate rows. _flush_models must collapse the buffer by
model_hash so each model reaches the MERGE exactly once.
"""
from model_ledger.backends.snowflake import SnowflakeLedgerBackend

seen_hashes: list[str] = []

class RecordingSession:
"""Captures every model_hash that appears in a MODELS MERGE source."""

def sql(self, query: str, params: Any = None) -> MockCollectResult:
if "MERGE INTO" in query.upper() and ".MODELS " in query.upper():
for m in re.finditer(r"'([^']+)'\s+AS\s+model_hash", query):
seen_hashes.append(m.group(1))
return MockCollectResult([])

backend = SnowflakeLedgerBackend(schema="TEST_SCHEMA", connection=RecordingSession())
model = ModelRef(
name="fraud_scorer",
owner="risk-team",
model_type="scoring_model",
tier="unclassified",
purpose="",
)
backend.save_model(model) # register() path
backend.save_model(model) # update_model() path (same hash)
backend.flush()

assert seen_hashes.count(model.model_hash) == 1, (
f"model written {seen_hashes.count(model.model_hash)}x to MERGE source, expected 1"
)
Loading