Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 66 additions & 73 deletions muninn/store/sqlite_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
and consolidation state. SQLite provides ACID guarantees and zero-config operation.
"""

import sqlite3
import json
import time
import math
import logging
import math
import sqlite3
import time
from pathlib import Path
from typing import Optional, List, Dict, Any
from typing import Any, Dict, List, Optional

from muninn.core.types import MemoryRecord, MemoryType, Provenance
from muninn.store.lock import get_store_lock
Expand Down Expand Up @@ -218,41 +218,30 @@ def _initialize(self):
self._ensure_column_exists(conn, "memories", "scope", "TEXT NOT NULL DEFAULT 'project'")
self._ensure_column_exists(conn, "memories", "media_type", "TEXT DEFAULT 'text'")
self._ensure_column_exists(conn, "memories", "archived", "INTEGER DEFAULT 0")
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_memories_scope ON memories(scope);"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_memories_media_type ON memories(media_type);"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_memories_archived ON memories(archived);"
)
conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_scope ON memories(scope);")
conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_media_type ON memories(media_type);")
conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_archived ON memories(archived);")
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_feedback_scope_time ON retrieval_feedback(user_id, namespace, project, created_at DESC);"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_profile_policy_events_created ON profile_policy_events(created_at DESC);"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_user_profiles_updated_at ON user_profiles(updated_at DESC);"
)
conn.execute("CREATE INDEX IF NOT EXISTS idx_user_profiles_updated_at ON user_profiles(updated_at DESC);")
# Mimir interop relay tables — imported locally to prevent circular
# imports at module load time (muninn.mimir.store → muninn.mimir →
# muninn.mimir.relay → ... none of which import sqlite_metadata).
# The try/except allows the base store to initialise cleanly in
# environments where the mimir sub-package is absent.
try:
from muninn.mimir.store import MIMIR_DDL_STATEMENTS as _mimir_ddl
for _stmt in _mimir_ddl:
conn.execute(_stmt)

conn.executescript("\n".join(_mimir_ddl))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid partial commits during Mimir DDL initialization

Replacing the per-statement loop with executescript changes failure semantics: executescript implicitly commits pending work before running the script, and statements executed before an error are not rolled back automatically. In _initialize(), this means a single bad Mimir statement (for example on a drifted/partially-corrupt existing DB) can persist earlier schema/index changes and some Mimir objects, then raise before writing schema_meta version, leaving the database in a partially-initialized state that did not occur with the previous transactional execute() loop.

Useful? React with 👍 / 👎.

logger.debug("Mimir DDL applied (%d statements).", len(_mimir_ddl))
except ImportError:
logger.debug("muninn.mimir not available; skipping Mimir DDL.")

conn.execute(
"INSERT OR IGNORE INTO schema_meta (key, value) VALUES (?, ?)",
("version", str(SCHEMA_VERSION))
)
conn.execute("INSERT OR IGNORE INTO schema_meta (key, value) VALUES (?, ?)", ("version", str(SCHEMA_VERSION)))
conn.commit()
logger.info(f"SQLite metadata store initialized at {self.db_path}")

Expand Down Expand Up @@ -293,7 +282,7 @@ def _user_id_param(self, user_id: str) -> str:
"""Return parameter value matching `_user_id_condition`."""
if self._json1_available:
return user_id
return f'%\"user_id\": \"{user_id}\"%'
return f'%"user_id": "{user_id}"%'

def _row_to_record(self, row: sqlite3.Row) -> MemoryRecord:
d = dict(row)
Expand All @@ -311,11 +300,12 @@ def _row_to_record(self, row: sqlite3.Row) -> MemoryRecord:
# v3.11.0: normalize scope — treat NULL or unknown values as "project" (backward compat)
if d.get("scope") not in ("project", "global"):
d["scope"] = "project"

# v3.20.0: normalize media_type
if d.get("media_type"):
try:
from muninn.core.types import MediaType

d["media_type"] = MediaType(d.get("media_type"))
except ValueError:
d["media_type"] = "text"
Expand All @@ -324,7 +314,6 @@ def _row_to_record(self, row: sqlite3.Row) -> MemoryRecord:

return MemoryRecord(**d)


def set_meta(self, key: str, value: str) -> None:
conn = self._get_conn()
conn.execute(
Expand Down Expand Up @@ -425,9 +414,7 @@ def get_profile_policy_event_stats_since(self, *, since_epoch: float) -> Dict[st
""",
(float(since_epoch),),
).fetchone()
distinct_sources = (
int(distinct_row["n"]) if distinct_row and distinct_row["n"] is not None else 0
)
distinct_sources = int(distinct_row["n"]) if distinct_row and distinct_row["n"] is not None else 0

top_row = conn.execute(
"""
Expand Down Expand Up @@ -843,26 +830,26 @@ def get_memory_retrieval_utility(
""",
(memory_id, cutoff_ts),
).fetchall()

if not rows:
return 0.0

estimator_mode = (estimator or "snips").strip().lower()
use_snips = estimator_mode == "snips"

safe_propensity_floor = max(1e-4, min(1.0, float(propensity_floor)))
safe_default_sampling_prob = max(safe_propensity_floor, min(1.0, float(default_sampling_prob)))

total_outcome = 0.0
total_weight = 0.0
snips_positive = 0.0
snips_sum_w = 0.0

for row in rows:
outcome = max(0.0, min(1.0, float(row["outcome"])))
rank = row["rank"]
sampling_prob = row["sampling_prob"]

rank_propensity = 1.0
if rank is not None:
try:
Expand All @@ -872,7 +859,7 @@ def get_memory_retrieval_utility(
rank_propensity = 1.0 / math.log2(rank_value + 1.0)
except (TypeError, ValueError):
pass

base_prob = safe_default_sampling_prob
if sampling_prob is not None:
try:
Expand All @@ -881,15 +868,15 @@ def get_memory_retrieval_utility(
base_prob = min(1.0, parsed_prob)
except (TypeError, ValueError):
pass

propensity = max(safe_propensity_floor, min(1.0, base_prob * rank_propensity))
ipw = 1.0 / propensity

total_outcome += outcome
total_weight += 1.0
snips_positive += outcome * ipw
snips_sum_w += ipw

if use_snips and snips_sum_w > 0:
return max(0.0, min(1.0, snips_positive / snips_sum_w))
elif total_weight > 0:
Expand Down Expand Up @@ -933,6 +920,7 @@ def get_batch_retrieval_utility(

# Accumulate per-memory stats
from collections import defaultdict

pos_w: Dict[str, float] = defaultdict(float)
sum_w: Dict[str, float] = defaultdict(float)
plain_pos: Dict[str, float] = defaultdict(float)
Expand Down Expand Up @@ -961,7 +949,7 @@ def get_batch_retrieval_utility(
base = min(1.0, p)
except (TypeError, ValueError):
pass

propensity = max(safe_floor, min(1.0, base * rank_prop))
ipw = 1.0 / propensity
pos_w[mid] += outcome * ipw
Expand Down Expand Up @@ -1014,7 +1002,6 @@ def count_user_scope_backfill_failures(self) -> int:
row = conn.execute("SELECT COUNT(*) FROM user_scope_backfill_failures").fetchone()
return row[0] if row else 0


def get_missing_user_id_records(self, limit: int = 500) -> List[MemoryRecord]:
"""Fetch a batch of records that do not have metadata.user_id set."""
conn = self._get_conn()
Expand Down Expand Up @@ -1048,9 +1035,7 @@ def count_missing_user_id(self) -> int:
"SELECT COUNT(*) FROM memories WHERE json_extract(metadata, '$.user_id') IS NULL"
).fetchone()
else:
row = conn.execute(
"SELECT COUNT(*) FROM memories WHERE metadata NOT LIKE '%\"user_id\"%'"
).fetchone()
row = conn.execute("SELECT COUNT(*) FROM memories WHERE metadata NOT LIKE '%\"user_id\"%'").fetchone()
return row[0] if row else 0

# --- Legacy Sources Cache ---
Expand Down Expand Up @@ -1120,7 +1105,7 @@ def get_legacy_sources_cache(
placeholders = ",".join("?" for _ in providers)
conditions.append(f"provider IN ({placeholders})")
params.extend(providers)

if not include_ignored:
conditions.append("ignored = 0")

Expand All @@ -1134,29 +1119,27 @@ def get_legacy_sources_cache(
def get_legacy_sources_stats(self) -> Dict[str, Any]:
"""Return statistics about cached legacy sources."""
conn = self._get_conn()

# Count total
row = conn.execute("SELECT COUNT(*) FROM legacy_sources_cache").fetchone()
total = row[0] if row else 0

# Count "new" (seen for the first time in the last 24h)
cutoff = time.time() - 86400
row = conn.execute(
"SELECT COUNT(*) FROM legacy_sources_cache WHERE first_seen_at > ?",
(cutoff,),
).fetchone()
new_24h = row[0] if row else 0

# Count non-ignored
row = conn.execute(
"SELECT COUNT(*) FROM legacy_sources_cache WHERE ignored = 0"
).fetchone()
row = conn.execute("SELECT COUNT(*) FROM legacy_sources_cache WHERE ignored = 0").fetchone()
active = row[0] if row else 0

# MAX(last_seen_at) as sync timestamp
row = conn.execute("SELECT MAX(last_seen_at) FROM legacy_sources_cache").fetchone()
last_sync = row[0] if row and row[0] else 0

return {
"total_cached": total,
"new_last_24h": new_24h,
Expand All @@ -1175,6 +1158,7 @@ def add(self, record: MemoryRecord) -> str:
# Initialize Elo rating if not present
if "elo_rating" not in record.metadata:
from muninn.scoring.elo import INITIAL_ELO

record.metadata["elo_rating"] = INITIAL_ELO

conn.execute(
Expand All @@ -1186,17 +1170,31 @@ def add(self, record: MemoryRecord) -> str:
consolidation_gen, metadata, scope, media_type
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
record.id, record.content, record.memory_type.value,
record.importance, record.recency_score, record.access_count,
record.novelty_score, record.created_at, record.ingested_at,
record.last_accessed, record.expires_at,
record.source_agent, record.project, record.branch,
record.namespace, record.provenance.value,
record.vector_id, record.embedding_model,
int(record.consolidated), record.parent_id,
record.consolidation_gen, json.dumps(record.metadata),
record.scope, record.media_type.value
)
record.id,
record.content,
record.memory_type.value,
record.importance,
record.recency_score,
record.access_count,
record.novelty_score,
record.created_at,
record.ingested_at,
record.last_accessed,
record.expires_at,
record.source_agent,
record.project,
record.branch,
record.namespace,
record.provenance.value,
record.vector_id,
record.embedding_model,
int(record.consolidated),
record.parent_id,
record.consolidation_gen,
json.dumps(record.metadata),
record.scope,
record.media_type.value,
),
)
conn.commit()
return record.id
Expand Down Expand Up @@ -1233,7 +1231,7 @@ def update_elo_rating(self, memory_id: str, new_rating: float) -> bool:
record = self.get(memory_id)
if not record:
return False

metadata = record.metadata or {}
metadata["elo_rating"] = float(new_rating)
return self.update(memory_id, metadata=metadata)
Expand Down Expand Up @@ -1316,7 +1314,7 @@ def get_all(
if media_type is not None:
conditions.append("media_type = ?")
params.append(media_type)

# v3.24.0: Cognitive Optimization filter
if archived is not None:
conditions.append("archived = ?")
Expand All @@ -1343,9 +1341,7 @@ def get_by_ids(self, memory_ids: List[str]) -> List[MemoryRecord]:
for i in range(0, len(ids), self._SQLITE_MAX_VARS):
chunk = ids[i : i + self._SQLITE_MAX_VARS]
placeholders = ",".join("?" for _ in chunk)
rows = conn.execute(
f"SELECT * FROM memories WHERE id IN ({placeholders})", chunk
).fetchall()
rows = conn.execute(f"SELECT * FROM memories WHERE id IN ({placeholders})", chunk).fetchall()
records.extend(self._row_to_record(row) for row in rows)
return records

Expand All @@ -1367,7 +1363,7 @@ def record_access(self, memory_id: str):
conn = self._get_conn()
conn.execute(
"UPDATE memories SET access_count = access_count + 1, last_accessed = ? WHERE id = ?",
(time.time(), memory_id)
(time.time(), memory_id),
)
conn.commit()

Expand All @@ -1380,7 +1376,7 @@ def record_access_batch(self, memory_ids: List[str]):
placeholders = ",".join("?" for _ in memory_ids)
conn.execute(
f"UPDATE memories SET access_count = access_count + 1, last_accessed = ? WHERE id IN ({placeholders})",
[now] + memory_ids
[now] + memory_ids,
)
conn.commit()

Expand Down Expand Up @@ -1438,13 +1434,10 @@ def search_content(self, query: str, limit: int = 20) -> List[MemoryRecord]:
def get_random(self, limit: int = 10) -> List[MemoryRecord]:
"""Fetch a random sample of memory records."""
conn = self._get_conn()
rows = conn.execute(
"SELECT * FROM memories ORDER BY RANDOM() LIMIT ?",
(limit,)
).fetchall()
rows = conn.execute("SELECT * FROM memories ORDER BY RANDOM() LIMIT ?", (limit,)).fetchall()
return [self._row_to_record(row) for row in rows]

def close(self):
if self._conn:
self._conn.close()
self._conn = None
self._conn = None
Loading