From 4590e9b3148d6e199fc6391251ba4b7c51d6f733 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Wed, 13 May 2026 23:05:07 +0000 Subject: [PATCH] perf: Use executescript for Mimir DDL initialization Replaced the loop executing Mimir DDL statements individually with a single `conn.executescript("\n".join(_mimir_ddl))` call in `muninn/store/sqlite_metadata.py` to improve initialization performance. Co-authored-by: wjohns989 <56205870+wjohns989@users.noreply.github.com> --- muninn/store/sqlite_metadata.py | 139 +++++++++++++++----------------- 1 file changed, 66 insertions(+), 73 deletions(-) diff --git a/muninn/store/sqlite_metadata.py b/muninn/store/sqlite_metadata.py index fc451af..407fc19 100644 --- a/muninn/store/sqlite_metadata.py +++ b/muninn/store/sqlite_metadata.py @@ -5,13 +5,13 @@ and consolidation state. SQLite provides ACID guarantees and zero-config operation. """ -import sqlite3 import json -import time -import math import logging +import math +import sqlite3 +import time from pathlib import Path -from typing import Optional, List, Dict, Any +from typing import Any, Dict, List, Optional from muninn.core.types import MemoryRecord, MemoryType, Provenance from muninn.store.lock import get_store_lock @@ -218,24 +218,16 @@ def _initialize(self): self._ensure_column_exists(conn, "memories", "scope", "TEXT NOT NULL DEFAULT 'project'") self._ensure_column_exists(conn, "memories", "media_type", "TEXT DEFAULT 'text'") self._ensure_column_exists(conn, "memories", "archived", "INTEGER DEFAULT 0") - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_memories_scope ON memories(scope);" - ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_memories_media_type ON memories(media_type);" - ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_memories_archived ON memories(archived);" - ) + conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_scope ON memories(scope);") + conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_media_type ON memories(media_type);") + conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_archived ON memories(archived);") conn.execute( "CREATE INDEX IF NOT EXISTS idx_feedback_scope_time ON retrieval_feedback(user_id, namespace, project, created_at DESC);" ) conn.execute( "CREATE INDEX IF NOT EXISTS idx_profile_policy_events_created ON profile_policy_events(created_at DESC);" ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_user_profiles_updated_at ON user_profiles(updated_at DESC);" - ) + conn.execute("CREATE INDEX IF NOT EXISTS idx_user_profiles_updated_at ON user_profiles(updated_at DESC);") # Mimir interop relay tables — imported locally to prevent circular # imports at module load time (muninn.mimir.store → muninn.mimir → # muninn.mimir.relay → ... none of which import sqlite_metadata). @@ -243,16 +235,13 @@ def _initialize(self): # environments where the mimir sub-package is absent. try: from muninn.mimir.store import MIMIR_DDL_STATEMENTS as _mimir_ddl - for _stmt in _mimir_ddl: - conn.execute(_stmt) + + conn.executescript("\n".join(_mimir_ddl)) logger.debug("Mimir DDL applied (%d statements).", len(_mimir_ddl)) except ImportError: logger.debug("muninn.mimir not available; skipping Mimir DDL.") - conn.execute( - "INSERT OR IGNORE INTO schema_meta (key, value) VALUES (?, ?)", - ("version", str(SCHEMA_VERSION)) - ) + conn.execute("INSERT OR IGNORE INTO schema_meta (key, value) VALUES (?, ?)", ("version", str(SCHEMA_VERSION))) conn.commit() logger.info(f"SQLite metadata store initialized at {self.db_path}") @@ -293,7 +282,7 @@ def _user_id_param(self, user_id: str) -> str: """Return parameter value matching `_user_id_condition`.""" if self._json1_available: return user_id - return f'%\"user_id\": \"{user_id}\"%' + return f'%"user_id": "{user_id}"%' def _row_to_record(self, row: sqlite3.Row) -> MemoryRecord: d = dict(row) @@ -311,11 +300,12 @@ def _row_to_record(self, row: sqlite3.Row) -> MemoryRecord: # v3.11.0: normalize scope — treat NULL or unknown values as "project" (backward compat) if d.get("scope") not in ("project", "global"): d["scope"] = "project" - + # v3.20.0: normalize media_type if d.get("media_type"): try: from muninn.core.types import MediaType + d["media_type"] = MediaType(d.get("media_type")) except ValueError: d["media_type"] = "text" @@ -324,7 +314,6 @@ def _row_to_record(self, row: sqlite3.Row) -> MemoryRecord: return MemoryRecord(**d) - def set_meta(self, key: str, value: str) -> None: conn = self._get_conn() conn.execute( @@ -425,9 +414,7 @@ def get_profile_policy_event_stats_since(self, *, since_epoch: float) -> Dict[st """, (float(since_epoch),), ).fetchone() - distinct_sources = ( - int(distinct_row["n"]) if distinct_row and distinct_row["n"] is not None else 0 - ) + distinct_sources = int(distinct_row["n"]) if distinct_row and distinct_row["n"] is not None else 0 top_row = conn.execute( """ @@ -843,26 +830,26 @@ def get_memory_retrieval_utility( """, (memory_id, cutoff_ts), ).fetchall() - + if not rows: return 0.0 estimator_mode = (estimator or "snips").strip().lower() use_snips = estimator_mode == "snips" - + safe_propensity_floor = max(1e-4, min(1.0, float(propensity_floor))) safe_default_sampling_prob = max(safe_propensity_floor, min(1.0, float(default_sampling_prob))) - + total_outcome = 0.0 total_weight = 0.0 snips_positive = 0.0 snips_sum_w = 0.0 - + for row in rows: outcome = max(0.0, min(1.0, float(row["outcome"]))) rank = row["rank"] sampling_prob = row["sampling_prob"] - + rank_propensity = 1.0 if rank is not None: try: @@ -872,7 +859,7 @@ def get_memory_retrieval_utility( rank_propensity = 1.0 / math.log2(rank_value + 1.0) except (TypeError, ValueError): pass - + base_prob = safe_default_sampling_prob if sampling_prob is not None: try: @@ -881,15 +868,15 @@ def get_memory_retrieval_utility( base_prob = min(1.0, parsed_prob) except (TypeError, ValueError): pass - + propensity = max(safe_propensity_floor, min(1.0, base_prob * rank_propensity)) ipw = 1.0 / propensity - + total_outcome += outcome total_weight += 1.0 snips_positive += outcome * ipw snips_sum_w += ipw - + if use_snips and snips_sum_w > 0: return max(0.0, min(1.0, snips_positive / snips_sum_w)) elif total_weight > 0: @@ -933,6 +920,7 @@ def get_batch_retrieval_utility( # Accumulate per-memory stats from collections import defaultdict + pos_w: Dict[str, float] = defaultdict(float) sum_w: Dict[str, float] = defaultdict(float) plain_pos: Dict[str, float] = defaultdict(float) @@ -961,7 +949,7 @@ def get_batch_retrieval_utility( base = min(1.0, p) except (TypeError, ValueError): pass - + propensity = max(safe_floor, min(1.0, base * rank_prop)) ipw = 1.0 / propensity pos_w[mid] += outcome * ipw @@ -1014,7 +1002,6 @@ def count_user_scope_backfill_failures(self) -> int: row = conn.execute("SELECT COUNT(*) FROM user_scope_backfill_failures").fetchone() return row[0] if row else 0 - def get_missing_user_id_records(self, limit: int = 500) -> List[MemoryRecord]: """Fetch a batch of records that do not have metadata.user_id set.""" conn = self._get_conn() @@ -1048,9 +1035,7 @@ def count_missing_user_id(self) -> int: "SELECT COUNT(*) FROM memories WHERE json_extract(metadata, '$.user_id') IS NULL" ).fetchone() else: - row = conn.execute( - "SELECT COUNT(*) FROM memories WHERE metadata NOT LIKE '%\"user_id\"%'" - ).fetchone() + row = conn.execute("SELECT COUNT(*) FROM memories WHERE metadata NOT LIKE '%\"user_id\"%'").fetchone() return row[0] if row else 0 # --- Legacy Sources Cache --- @@ -1120,7 +1105,7 @@ def get_legacy_sources_cache( placeholders = ",".join("?" for _ in providers) conditions.append(f"provider IN ({placeholders})") params.extend(providers) - + if not include_ignored: conditions.append("ignored = 0") @@ -1134,11 +1119,11 @@ def get_legacy_sources_cache( def get_legacy_sources_stats(self) -> Dict[str, Any]: """Return statistics about cached legacy sources.""" conn = self._get_conn() - + # Count total row = conn.execute("SELECT COUNT(*) FROM legacy_sources_cache").fetchone() total = row[0] if row else 0 - + # Count "new" (seen for the first time in the last 24h) cutoff = time.time() - 86400 row = conn.execute( @@ -1146,17 +1131,15 @@ def get_legacy_sources_stats(self) -> Dict[str, Any]: (cutoff,), ).fetchone() new_24h = row[0] if row else 0 - + # Count non-ignored - row = conn.execute( - "SELECT COUNT(*) FROM legacy_sources_cache WHERE ignored = 0" - ).fetchone() + row = conn.execute("SELECT COUNT(*) FROM legacy_sources_cache WHERE ignored = 0").fetchone() active = row[0] if row else 0 # MAX(last_seen_at) as sync timestamp row = conn.execute("SELECT MAX(last_seen_at) FROM legacy_sources_cache").fetchone() last_sync = row[0] if row and row[0] else 0 - + return { "total_cached": total, "new_last_24h": new_24h, @@ -1175,6 +1158,7 @@ def add(self, record: MemoryRecord) -> str: # Initialize Elo rating if not present if "elo_rating" not in record.metadata: from muninn.scoring.elo import INITIAL_ELO + record.metadata["elo_rating"] = INITIAL_ELO conn.execute( @@ -1186,17 +1170,31 @@ def add(self, record: MemoryRecord) -> str: consolidation_gen, metadata, scope, media_type ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( - record.id, record.content, record.memory_type.value, - record.importance, record.recency_score, record.access_count, - record.novelty_score, record.created_at, record.ingested_at, - record.last_accessed, record.expires_at, - record.source_agent, record.project, record.branch, - record.namespace, record.provenance.value, - record.vector_id, record.embedding_model, - int(record.consolidated), record.parent_id, - record.consolidation_gen, json.dumps(record.metadata), - record.scope, record.media_type.value - ) + record.id, + record.content, + record.memory_type.value, + record.importance, + record.recency_score, + record.access_count, + record.novelty_score, + record.created_at, + record.ingested_at, + record.last_accessed, + record.expires_at, + record.source_agent, + record.project, + record.branch, + record.namespace, + record.provenance.value, + record.vector_id, + record.embedding_model, + int(record.consolidated), + record.parent_id, + record.consolidation_gen, + json.dumps(record.metadata), + record.scope, + record.media_type.value, + ), ) conn.commit() return record.id @@ -1233,7 +1231,7 @@ def update_elo_rating(self, memory_id: str, new_rating: float) -> bool: record = self.get(memory_id) if not record: return False - + metadata = record.metadata or {} metadata["elo_rating"] = float(new_rating) return self.update(memory_id, metadata=metadata) @@ -1316,7 +1314,7 @@ def get_all( if media_type is not None: conditions.append("media_type = ?") params.append(media_type) - + # v3.24.0: Cognitive Optimization filter if archived is not None: conditions.append("archived = ?") @@ -1343,9 +1341,7 @@ def get_by_ids(self, memory_ids: List[str]) -> List[MemoryRecord]: for i in range(0, len(ids), self._SQLITE_MAX_VARS): chunk = ids[i : i + self._SQLITE_MAX_VARS] placeholders = ",".join("?" for _ in chunk) - rows = conn.execute( - f"SELECT * FROM memories WHERE id IN ({placeholders})", chunk - ).fetchall() + rows = conn.execute(f"SELECT * FROM memories WHERE id IN ({placeholders})", chunk).fetchall() records.extend(self._row_to_record(row) for row in rows) return records @@ -1367,7 +1363,7 @@ def record_access(self, memory_id: str): conn = self._get_conn() conn.execute( "UPDATE memories SET access_count = access_count + 1, last_accessed = ? WHERE id = ?", - (time.time(), memory_id) + (time.time(), memory_id), ) conn.commit() @@ -1380,7 +1376,7 @@ def record_access_batch(self, memory_ids: List[str]): placeholders = ",".join("?" for _ in memory_ids) conn.execute( f"UPDATE memories SET access_count = access_count + 1, last_accessed = ? WHERE id IN ({placeholders})", - [now] + memory_ids + [now] + memory_ids, ) conn.commit() @@ -1438,13 +1434,10 @@ def search_content(self, query: str, limit: int = 20) -> List[MemoryRecord]: def get_random(self, limit: int = 10) -> List[MemoryRecord]: """Fetch a random sample of memory records.""" conn = self._get_conn() - rows = conn.execute( - "SELECT * FROM memories ORDER BY RANDOM() LIMIT ?", - (limit,) - ).fetchall() + rows = conn.execute("SELECT * FROM memories ORDER BY RANDOM() LIMIT ?", (limit,)).fetchall() return [self._row_to_record(row) for row in rows] def close(self): if self._conn: self._conn.close() - self._conn = None \ No newline at end of file + self._conn = None