diff --git a/serve/src/pixelrag_serve/api.py b/serve/src/pixelrag_serve/api.py index 65b6d40..b73ad5b 100644 --- a/serve/src/pixelrag_serve/api.py +++ b/serve/src/pixelrag_serve/api.py @@ -147,6 +147,7 @@ class SearchRequest(BaseModel): instruction: str | None = None # override query embedding instruction include_images: bool = False # return base64-encoded tile images articles_only: bool = False # drop Wikipedia meta pages (Portal:, List_of_, …) + hybrid: bool = False # enable hybrid text+visual search (RRF fusion with BM25) # Wikipedia meta/aggregator pages that pollute "find the article" results. @@ -431,6 +432,33 @@ async def search(req: SearchRequest): index.nprobe = default_nprobe t_search = time.time() - t0 - t_encode + # Hybrid: fuse FAISS visual results with BM25 text results via RRF + if req.hybrid and _state.get("bm25") and _state["bm25"].loaded: + from .hybrid import reciprocal_rank_fusion + + for qi in range(len(req.queries)): + query_text = req.queries[qi].text + if not query_text: + continue + # Get FAISS ranked list + faiss_ranked = [ + (int(indices[qi, j]), float(distances[qi, j])) + for j in range(fetch_k) + if int(indices[qi, j]) != -1 + ] + # Get BM25 ranked list + bm25_ranked = _state["bm25"].search(query_text, k=fetch_k) + # Fuse + fused = reciprocal_rank_fusion([faiss_ranked, bm25_ranked]) + # Rewrite indices/distances for this query with fused order + for j, (vid, rrf_score) in enumerate(fused[:fetch_k]): + indices[qi, j] = vid + distances[qi, j] = rrf_score + # Pad remaining with -1 + for j in range(len(fused), fetch_k): + indices[qi, j] = -1 + distances[qi, j] = 0.0 + # Build results meta = _state["metadata"] article_ids = meta["article_ids"] @@ -670,6 +698,12 @@ def load(args): cache, ) + # Optional: BM25 text index for hybrid search (RRF fusion) + from .hybrid import load_or_build_bm25 + + bm25 = load_or_build_bm25(args.index_dir, args.articles_json) + _state["bm25"] = bm25 + def _derive_kiwix_book(kiwix_url: str) -> str: """Read the kiwix-serve catalog and return the /content/ id.""" diff --git a/serve/src/pixelrag_serve/hybrid.py b/serve/src/pixelrag_serve/hybrid.py new file mode 100644 index 0000000..251eb4a --- /dev/null +++ b/serve/src/pixelrag_serve/hybrid.py @@ -0,0 +1,193 @@ +"""Hybrid search: BM25 text index + Reciprocal Rank Fusion with FAISS visual results. + +Builds an inverted text index from article text (extracted during indexing or from +OCR/page content). At query time, fuses BM25 text results with FAISS visual results +using Reciprocal Rank Fusion (RRF), improving precision on text-heavy queries while +retaining visual retrieval for tables/charts/diagrams. + +Reference: Cormack, Clarke & Buettcher (2009) — "Reciprocal Rank Fusion outperforms +Condorcet and individual Rank Learning Methods" +""" + +import json +import logging +import math +import os +import re +from collections import defaultdict + +logger = logging.getLogger(__name__) + +# BM25 parameters (tuned for short document chunks) +_K1 = 1.2 +_B = 0.75 + + +def _tokenize(text: str) -> list[str]: + """Simple whitespace + punctuation tokenizer with lowercasing.""" + return re.findall(r"[a-z0-9]+", text.lower()) + + +class BM25Index: + """In-memory BM25 inverted index over article/chunk text. + + Each document is identified by its vector_id (matching FAISS metadata). + """ + + def __init__(self): + self.doc_count = 0 + self.avg_dl = 0.0 + self.doc_lens: dict[int, int] = {} # vector_id → doc length + self.df: dict[str, int] = defaultdict(int) # term → doc frequency + self.tf: dict[int, dict[str, int]] = {} # vector_id → {term: freq} + self._loaded = False + + @property + def loaded(self) -> bool: + return self._loaded + + def build_from_texts(self, texts: dict[int, str]): + """Build index from {vector_id: text_content} mapping.""" + self.doc_count = len(texts) + total_len = 0 + + for vid, text in texts.items(): + tokens = _tokenize(text) + self.doc_lens[vid] = len(tokens) + total_len += len(tokens) + + term_freqs: dict[str, int] = defaultdict(int) + for token in tokens: + term_freqs[token] += 1 + self.tf[vid] = dict(term_freqs) + + for term in term_freqs: + self.df[term] += 1 + + self.avg_dl = total_len / max(self.doc_count, 1) + self._loaded = True + logger.info("BM25 index built: %d docs, %d terms", self.doc_count, len(self.df)) + + def search(self, query: str, k: int = 100) -> list[tuple[int, float]]: + """Return top-k (vector_id, score) pairs ranked by BM25 score.""" + if not self._loaded: + return [] + + tokens = _tokenize(query) + if not tokens: + return [] + + scores: dict[int, float] = defaultdict(float) + + for term in tokens: + if term not in self.df: + continue + idf = math.log( + (self.doc_count - self.df[term] + 0.5) / (self.df[term] + 0.5) + 1.0 + ) + for vid, term_freqs in self.tf.items(): + if term not in term_freqs: + continue + tf = term_freqs[term] + dl = self.doc_lens[vid] + numerator = tf * (_K1 + 1) + denominator = tf + _K1 * (1 - _B + _B * dl / self.avg_dl) + scores[vid] += idf * numerator / denominator + + ranked = sorted(scores.items(), key=lambda x: -x[1]) + return ranked[:k] + + def save(self, path: str): + """Persist to JSON for fast reload.""" + data = { + "doc_count": self.doc_count, + "avg_dl": self.avg_dl, + "doc_lens": self.doc_lens, + "df": dict(self.df), + "tf": {str(k): v for k, v in self.tf.items()}, + } + with open(path, "w") as f: + json.dump(data, f) + logger.info("BM25 index saved: %s (%.1f MB)", path, os.path.getsize(path) / 1e6) + + def load(self, path: str) -> bool: + """Load from JSON. Returns True on success.""" + if not os.path.exists(path): + return False + with open(path) as f: + data = json.load(f) + self.doc_count = data["doc_count"] + self.avg_dl = data["avg_dl"] + self.doc_lens = {int(k): v for k, v in data["doc_lens"].items()} + self.df = defaultdict(int, data["df"]) + self.tf = {int(k): v for k, v in data["tf"].items()} + self._loaded = True + logger.info("BM25 index loaded: %d docs, %d terms", self.doc_count, len(self.df)) + return True + + +def reciprocal_rank_fusion( + ranked_lists: list[list[tuple[int, float]]], + k: int = 60, +) -> list[tuple[int, float]]: + """Fuse multiple ranked result lists using RRF. + + Args: + ranked_lists: List of ranked results, each is [(vector_id, score), ...] + k: RRF constant (default 60, from the original paper) + + Returns: + Fused ranked list of (vector_id, rrf_score) sorted descending. + """ + fused_scores: dict[int, float] = defaultdict(float) + + for ranked in ranked_lists: + for rank, (vid, _score) in enumerate(ranked): + fused_scores[vid] += 1.0 / (k + rank + 1) + + return sorted(fused_scores.items(), key=lambda x: -x[1]) + + +def build_text_index_from_articles( + articles_json: str, + metadata_path: str | None = None, +) -> BM25Index: + """Build BM25 index from articles.json titles/URLs. + + For a richer index, pass article page text via metadata. This minimal + version uses article titles as the text representation — still useful + for entity/name queries like "Albert Einstein" or "Python programming". + """ + index = BM25Index() + texts: dict[int, str] = {} + + with open(articles_json) as f: + articles = json.load(f) + + for vid, article in enumerate(articles): + # Use title + URL path as searchable text + title = article.get("title", "") + url = article.get("url", "") + # Extract meaningful text from URL path + url_text = url.split("/")[-1].replace("_", " ").replace("%20", " ") if url else "" + text = f"{title} {url_text}" + if text.strip(): + texts[vid] = text + + index.build_from_texts(texts) + return index + + +def load_or_build_bm25(index_dir: str, articles_json: str) -> BM25Index: + """Load cached BM25 index or build from articles.json.""" + bm25_path = os.path.join(index_dir, "bm25_index.json") + index = BM25Index() + + if index.load(bm25_path): + return index + + # Build and cache + index = build_text_index_from_articles(articles_json) + if index.loaded: + index.save(bm25_path) + return index diff --git a/tests/test_hybrid_search.py b/tests/test_hybrid_search.py new file mode 100644 index 0000000..3f406e4 --- /dev/null +++ b/tests/test_hybrid_search.py @@ -0,0 +1,117 @@ +"""Tests for hybrid search: BM25 index + Reciprocal Rank Fusion.""" + +import json +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).parents[1] / "serve" / "src")) +from pixelrag_serve.hybrid import BM25Index, reciprocal_rank_fusion, load_or_build_bm25 + + +@pytest.fixture +def bm25(): + idx = BM25Index() + idx.build_from_texts({ + 0: "Python programming language created by Guido van Rossum", + 1: "Machine learning neural networks deep learning", + 2: "Albert Einstein theory of relativity physics Nobel Prize", + 3: "JavaScript React frontend web development", + 4: "Python snake reptile biology animal", + }) + return idx + + +def test_bm25_basic_search(bm25): + """BM25 should rank relevant docs highest.""" + results = bm25.search("Python programming", k=5) + # Doc 0 (Python programming) should rank first + assert results[0][0] == 0 + assert results[0][1] > 0 + + +def test_bm25_disambiguates(bm25): + """'Python snake' should prefer the reptile doc over the programming one.""" + results = bm25.search("python snake reptile", k=5) + vids = [vid for vid, _ in results] + # Doc 4 (snake) should rank above doc 0 (programming) + assert vids.index(4) < vids.index(0) + + +def test_bm25_empty_query(bm25): + """Empty or punctuation-only queries should return empty.""" + assert bm25.search("") == [] + assert bm25.search("!!! ???") == [] + + +def test_bm25_unknown_terms(bm25): + """Query with no matching terms returns empty.""" + assert bm25.search("xyzzy frobnicator") == [] + + +def test_bm25_save_load(bm25, tmp_path): + """Save and reload should produce identical search results.""" + path = str(tmp_path / "bm25.json") + bm25.save(path) + + loaded = BM25Index() + assert loaded.load(path) + assert loaded.doc_count == bm25.doc_count + + # Same results + r1 = bm25.search("Einstein physics") + r2 = loaded.search("Einstein physics") + assert [vid for vid, _ in r1] == [vid for vid, _ in r2] + + +def test_rrf_fuses_two_lists(): + """RRF should combine rankings from two sources.""" + # Visual search ranks: doc 5, doc 3, doc 1 + visual = [(5, 0.9), (3, 0.7), (1, 0.5)] + # Text search ranks: doc 1, doc 5, doc 7 + text = [(1, 3.2), (5, 2.1), (7, 1.0)] + + fused = reciprocal_rank_fusion([visual, text]) + fused_vids = [vid for vid, _ in fused] + + # Doc 5 appears at rank 0 in visual and rank 1 in text — strong signal + # Doc 1 appears at rank 2 in visual and rank 0 in text — also strong + assert 5 in fused_vids[:2] + assert 1 in fused_vids[:2] + # Doc 7 only in text — should rank lower + assert fused_vids.index(7) > fused_vids.index(5) + + +def test_rrf_single_list(): + """RRF with one list should preserve original order.""" + ranked = [(10, 0.9), (20, 0.8), (30, 0.7)] + fused = reciprocal_rank_fusion([ranked]) + assert [vid for vid, _ in fused] == [10, 20, 30] + + +def test_load_or_build_from_articles(tmp_path): + """Integration: build BM25 from articles.json and search.""" + articles = [ + {"title": "Python (programming language)", "url": "https://en.wikipedia.org/wiki/Python_(programming_language)"}, + {"title": "Albert Einstein", "url": "https://en.wikipedia.org/wiki/Albert_Einstein"}, + {"title": "Machine learning", "url": "https://en.wikipedia.org/wiki/Machine_learning"}, + ] + articles_path = tmp_path / "articles.json" + articles_path.write_text(json.dumps(articles)) + index_dir = str(tmp_path) + + bm25 = load_or_build_bm25(index_dir, str(articles_path)) + assert bm25.loaded + assert bm25.doc_count == 3 + + # Should find Einstein + results = bm25.search("Einstein") + assert results[0][0] == 1 # article index 1 + + # Cache file should exist + assert (tmp_path / "bm25_index.json").exists() + + # Reload from cache + bm25_cached = load_or_build_bm25(index_dir, str(articles_path)) + assert bm25_cached.loaded