Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions serve/src/pixelrag_serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class SearchRequest(BaseModel):
instruction: str | None = None # override query embedding instruction
include_images: bool = False # return base64-encoded tile images
articles_only: bool = False # drop Wikipedia meta pages (Portal:, List_of_, …)
hybrid: bool = False # enable hybrid text+visual search (RRF fusion with BM25)


# Wikipedia meta/aggregator pages that pollute "find the article" results.
Expand Down Expand Up @@ -431,6 +432,33 @@ async def search(req: SearchRequest):
index.nprobe = default_nprobe
t_search = time.time() - t0 - t_encode

# Hybrid: fuse FAISS visual results with BM25 text results via RRF
if req.hybrid and _state.get("bm25") and _state["bm25"].loaded:
from .hybrid import reciprocal_rank_fusion

for qi in range(len(req.queries)):
query_text = req.queries[qi].text
if not query_text:
continue
# Get FAISS ranked list
faiss_ranked = [
(int(indices[qi, j]), float(distances[qi, j]))
for j in range(fetch_k)
if int(indices[qi, j]) != -1
]
# Get BM25 ranked list
bm25_ranked = _state["bm25"].search(query_text, k=fetch_k)
# Fuse
fused = reciprocal_rank_fusion([faiss_ranked, bm25_ranked])
# Rewrite indices/distances for this query with fused order
for j, (vid, rrf_score) in enumerate(fused[:fetch_k]):
indices[qi, j] = vid
distances[qi, j] = rrf_score
# Pad remaining with -1
for j in range(len(fused), fetch_k):
indices[qi, j] = -1
distances[qi, j] = 0.0

# Build results
meta = _state["metadata"]
article_ids = meta["article_ids"]
Expand Down Expand Up @@ -670,6 +698,12 @@ def load(args):
cache,
)

# Optional: BM25 text index for hybrid search (RRF fusion)
from .hybrid import load_or_build_bm25

bm25 = load_or_build_bm25(args.index_dir, args.articles_json)
_state["bm25"] = bm25


def _derive_kiwix_book(kiwix_url: str) -> str:
"""Read the kiwix-serve catalog and return the /content/<book> id."""
Expand Down
193 changes: 193 additions & 0 deletions serve/src/pixelrag_serve/hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""Hybrid search: BM25 text index + Reciprocal Rank Fusion with FAISS visual results.

Builds an inverted text index from article text (extracted during indexing or from
OCR/page content). At query time, fuses BM25 text results with FAISS visual results
using Reciprocal Rank Fusion (RRF), improving precision on text-heavy queries while
retaining visual retrieval for tables/charts/diagrams.

Reference: Cormack, Clarke & Buettcher (2009) — "Reciprocal Rank Fusion outperforms
Condorcet and individual Rank Learning Methods"
"""

import json
import logging
import math
import os
import re
from collections import defaultdict

logger = logging.getLogger(__name__)

# BM25 parameters (tuned for short document chunks)
_K1 = 1.2
_B = 0.75


def _tokenize(text: str) -> list[str]:
"""Simple whitespace + punctuation tokenizer with lowercasing."""
return re.findall(r"[a-z0-9]+", text.lower())


class BM25Index:
"""In-memory BM25 inverted index over article/chunk text.

Each document is identified by its vector_id (matching FAISS metadata).
"""

def __init__(self):
self.doc_count = 0
self.avg_dl = 0.0
self.doc_lens: dict[int, int] = {} # vector_id → doc length
self.df: dict[str, int] = defaultdict(int) # term → doc frequency
self.tf: dict[int, dict[str, int]] = {} # vector_id → {term: freq}
self._loaded = False

@property
def loaded(self) -> bool:
return self._loaded

def build_from_texts(self, texts: dict[int, str]):
"""Build index from {vector_id: text_content} mapping."""
self.doc_count = len(texts)
total_len = 0

for vid, text in texts.items():
tokens = _tokenize(text)
self.doc_lens[vid] = len(tokens)
total_len += len(tokens)

term_freqs: dict[str, int] = defaultdict(int)
for token in tokens:
term_freqs[token] += 1
self.tf[vid] = dict(term_freqs)

for term in term_freqs:
self.df[term] += 1

self.avg_dl = total_len / max(self.doc_count, 1)
self._loaded = True
logger.info("BM25 index built: %d docs, %d terms", self.doc_count, len(self.df))

def search(self, query: str, k: int = 100) -> list[tuple[int, float]]:
"""Return top-k (vector_id, score) pairs ranked by BM25 score."""
if not self._loaded:
return []

tokens = _tokenize(query)
if not tokens:
return []

scores: dict[int, float] = defaultdict(float)

for term in tokens:
if term not in self.df:
continue
idf = math.log(
(self.doc_count - self.df[term] + 0.5) / (self.df[term] + 0.5) + 1.0
)
for vid, term_freqs in self.tf.items():
if term not in term_freqs:
continue
tf = term_freqs[term]
dl = self.doc_lens[vid]
numerator = tf * (_K1 + 1)
denominator = tf + _K1 * (1 - _B + _B * dl / self.avg_dl)
scores[vid] += idf * numerator / denominator

ranked = sorted(scores.items(), key=lambda x: -x[1])
return ranked[:k]

def save(self, path: str):
"""Persist to JSON for fast reload."""
data = {
"doc_count": self.doc_count,
"avg_dl": self.avg_dl,
"doc_lens": self.doc_lens,
"df": dict(self.df),
"tf": {str(k): v for k, v in self.tf.items()},
}
with open(path, "w") as f:
json.dump(data, f)
logger.info("BM25 index saved: %s (%.1f MB)", path, os.path.getsize(path) / 1e6)

def load(self, path: str) -> bool:
"""Load from JSON. Returns True on success."""
if not os.path.exists(path):
return False
with open(path) as f:
data = json.load(f)
self.doc_count = data["doc_count"]
self.avg_dl = data["avg_dl"]
self.doc_lens = {int(k): v for k, v in data["doc_lens"].items()}
self.df = defaultdict(int, data["df"])
self.tf = {int(k): v for k, v in data["tf"].items()}
self._loaded = True
logger.info("BM25 index loaded: %d docs, %d terms", self.doc_count, len(self.df))
return True


def reciprocal_rank_fusion(
ranked_lists: list[list[tuple[int, float]]],
k: int = 60,
) -> list[tuple[int, float]]:
"""Fuse multiple ranked result lists using RRF.

Args:
ranked_lists: List of ranked results, each is [(vector_id, score), ...]
k: RRF constant (default 60, from the original paper)

Returns:
Fused ranked list of (vector_id, rrf_score) sorted descending.
"""
fused_scores: dict[int, float] = defaultdict(float)

for ranked in ranked_lists:
for rank, (vid, _score) in enumerate(ranked):
fused_scores[vid] += 1.0 / (k + rank + 1)

return sorted(fused_scores.items(), key=lambda x: -x[1])


def build_text_index_from_articles(
articles_json: str,
metadata_path: str | None = None,
) -> BM25Index:
"""Build BM25 index from articles.json titles/URLs.

For a richer index, pass article page text via metadata. This minimal
version uses article titles as the text representation — still useful
for entity/name queries like "Albert Einstein" or "Python programming".
"""
index = BM25Index()
texts: dict[int, str] = {}

with open(articles_json) as f:
articles = json.load(f)

for vid, article in enumerate(articles):
# Use title + URL path as searchable text
title = article.get("title", "")
url = article.get("url", "")
# Extract meaningful text from URL path
url_text = url.split("/")[-1].replace("_", " ").replace("%20", " ") if url else ""
text = f"{title} {url_text}"
if text.strip():
texts[vid] = text

index.build_from_texts(texts)
return index


def load_or_build_bm25(index_dir: str, articles_json: str) -> BM25Index:
"""Load cached BM25 index or build from articles.json."""
bm25_path = os.path.join(index_dir, "bm25_index.json")
index = BM25Index()

if index.load(bm25_path):
return index

# Build and cache
index = build_text_index_from_articles(articles_json)
if index.loaded:
index.save(bm25_path)
return index
117 changes: 117 additions & 0 deletions tests/test_hybrid_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Tests for hybrid search: BM25 index + Reciprocal Rank Fusion."""

import json
import sys
from pathlib import Path

import pytest

sys.path.insert(0, str(Path(__file__).parents[1] / "serve" / "src"))
from pixelrag_serve.hybrid import BM25Index, reciprocal_rank_fusion, load_or_build_bm25


@pytest.fixture
def bm25():
idx = BM25Index()
idx.build_from_texts({
0: "Python programming language created by Guido van Rossum",
1: "Machine learning neural networks deep learning",
2: "Albert Einstein theory of relativity physics Nobel Prize",
3: "JavaScript React frontend web development",
4: "Python snake reptile biology animal",
})
return idx


def test_bm25_basic_search(bm25):
"""BM25 should rank relevant docs highest."""
results = bm25.search("Python programming", k=5)
# Doc 0 (Python programming) should rank first
assert results[0][0] == 0
assert results[0][1] > 0


def test_bm25_disambiguates(bm25):
"""'Python snake' should prefer the reptile doc over the programming one."""
results = bm25.search("python snake reptile", k=5)
vids = [vid for vid, _ in results]
# Doc 4 (snake) should rank above doc 0 (programming)
assert vids.index(4) < vids.index(0)


def test_bm25_empty_query(bm25):
"""Empty or punctuation-only queries should return empty."""
assert bm25.search("") == []
assert bm25.search("!!! ???") == []


def test_bm25_unknown_terms(bm25):
"""Query with no matching terms returns empty."""
assert bm25.search("xyzzy frobnicator") == []


def test_bm25_save_load(bm25, tmp_path):
"""Save and reload should produce identical search results."""
path = str(tmp_path / "bm25.json")
bm25.save(path)

loaded = BM25Index()
assert loaded.load(path)
assert loaded.doc_count == bm25.doc_count

# Same results
r1 = bm25.search("Einstein physics")
r2 = loaded.search("Einstein physics")
assert [vid for vid, _ in r1] == [vid for vid, _ in r2]


def test_rrf_fuses_two_lists():
"""RRF should combine rankings from two sources."""
# Visual search ranks: doc 5, doc 3, doc 1
visual = [(5, 0.9), (3, 0.7), (1, 0.5)]
# Text search ranks: doc 1, doc 5, doc 7
text = [(1, 3.2), (5, 2.1), (7, 1.0)]

fused = reciprocal_rank_fusion([visual, text])
fused_vids = [vid for vid, _ in fused]

# Doc 5 appears at rank 0 in visual and rank 1 in text — strong signal
# Doc 1 appears at rank 2 in visual and rank 0 in text — also strong
assert 5 in fused_vids[:2]
assert 1 in fused_vids[:2]
# Doc 7 only in text — should rank lower
assert fused_vids.index(7) > fused_vids.index(5)


def test_rrf_single_list():
"""RRF with one list should preserve original order."""
ranked = [(10, 0.9), (20, 0.8), (30, 0.7)]
fused = reciprocal_rank_fusion([ranked])
assert [vid for vid, _ in fused] == [10, 20, 30]


def test_load_or_build_from_articles(tmp_path):
"""Integration: build BM25 from articles.json and search."""
articles = [
{"title": "Python (programming language)", "url": "https://en.wikipedia.org/wiki/Python_(programming_language)"},
{"title": "Albert Einstein", "url": "https://en.wikipedia.org/wiki/Albert_Einstein"},
{"title": "Machine learning", "url": "https://en.wikipedia.org/wiki/Machine_learning"},
]
articles_path = tmp_path / "articles.json"
articles_path.write_text(json.dumps(articles))
index_dir = str(tmp_path)

bm25 = load_or_build_bm25(index_dir, str(articles_path))
assert bm25.loaded
assert bm25.doc_count == 3

# Should find Einstein
results = bm25.search("Einstein")
assert results[0][0] == 1 # article index 1

# Cache file should exist
assert (tmp_path / "bm25_index.json").exists()

# Reload from cache
bm25_cached = load_or_build_bm25(index_dir, str(articles_path))
assert bm25_cached.loaded