Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: CI

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
test:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"

- name: Set up Python
run: uv python install 3.12

- name: Install dependencies
run: uv sync

- name: Lint with ruff
run: uv run ruff check .

- name: Type check with mypy
run: uv run mypy src

- name: Run tests
run: uv run pytest
env:
OPENAI_API_KEY: "test-key"
ANTHROPIC_API_KEY: "test-key"
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ strict = true

[tool.pytest.ini_options]
testpaths = ["tests"]
pythonpath = ["src"]

[tool.uv]
override-dependencies = [
Expand All @@ -53,4 +54,6 @@ override-dependencies = [
[dependency-groups]
dev = [
"pre-commit>=4.6.0",
"pytest>=9.0.3",
"pytest-mock>=3.15.1",
]
Empty file added tests/__init__.py
Empty file.
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from raglab.gateway.base import LLMResponse


@pytest.fixture
def fake_llm_response() -> LLMResponse:
return LLMResponse(
text="This is a test answer.",
input_tokens=10,
output_tokens=20,
cost_usd=0.000005,
latency_ms=123.45,
model="gpt-4o-mini",
provider="openai",
)
27 changes: 27 additions & 0 deletions tests/test_experiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from raglab.experiments.runner import ExperimentConfig


def test_experiment_config_defaults():
config = ExperimentConfig(
models=["gpt-4o-mini"],
prompt_versions=["v1"],
questions=["What is RAG?"],
)
assert config.top_k == 5
assert config.provider == "openai"
assert config.retriever == "chroma"


def test_experiment_config_matrix_size():
config = ExperimentConfig(
models=["gpt-4o-mini", "claude-haiku-4-5-20251001"],
prompt_versions=["v1", "v2"],
questions=["Q1", "Q2", "Q3"],
)
import itertools

combinations = list(
itertools.product(config.models, config.prompt_versions, config.questions)
)
# 2 models x 2 prompts x 3 questions = 12
assert len(combinations) == 12
89 changes: 89 additions & 0 deletions tests/test_gateway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from unittest.mock import MagicMock, patch

from raglab.gateway.anthropic import AnthropicProvider
from raglab.gateway.openai import OpenAIProvider


def test_openai_provider_returns_normalized_response():
# build a fake OpenAI response object
mock_response = MagicMock()
mock_response.choices[0].message.content = "Hello from OpenAI"
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 20
mock_response.model = "gpt-4o-mini"

with patch("raglab.gateway.openai.OpenAI") as mock_client_class:
mock_client_class.return_value.chat.completions.create.return_value = (
mock_response
)

provider = OpenAIProvider(api_key="test-key")
result = provider.generate(
messages=[{"role": "user", "content": "hello"}],
model="gpt-4o-mini",
)

assert result.text == "Hello from OpenAI"
assert result.provider == "openai"
assert result.input_tokens == 10
assert result.output_tokens == 20
assert result.cost_usd >= 0
assert result.latency_ms > 0


def test_anthropic_provider_returns_normalized_response():
mock_response = MagicMock()
mock_response.content[0].text = "Hello from Anthropic"
mock_response.usage.input_tokens = 10
mock_response.usage.output_tokens = 20

with patch("raglab.gateway.anthropic.Anthropic") as mock_client_class:
mock_client_class.return_value.messages.create.return_value = mock_response

provider = AnthropicProvider(api_key="test-key")
result = provider.generate(
messages=[
{"role": "system", "content": "you are helpful"},
{"role": "user", "content": "hello"},
],
model="claude-haiku-4-5-20251001",
)

assert result.text == "Hello from Anthropic"
assert result.provider == "anthropic"
assert result.input_tokens == 10
assert result.output_tokens == 20


def test_gateway_normalizes_both_providers_to_same_shape():
# both providers must return LLMResponse with identical fields
from raglab.gateway.base import LLMResponse

mock_openai = MagicMock()
mock_openai.choices[0].message.content = "answer"
mock_openai.usage.prompt_tokens = 5
mock_openai.usage.completion_tokens = 5

mock_anthropic = MagicMock()
mock_anthropic.content[0].text = "answer"
mock_anthropic.usage.input_tokens = 5
mock_anthropic.usage.output_tokens = 5

with patch("raglab.gateway.openai.OpenAI") as oai:
oai.return_value.chat.completions.create.return_value = mock_openai
openai_result = OpenAIProvider("test").generate(
[{"role": "user", "content": "hi"}], "gpt-4o-mini"
)

with patch("raglab.gateway.anthropic.Anthropic") as ant:
ant.return_value.messages.create.return_value = mock_anthropic
anthropic_result = AnthropicProvider("test").generate(
[{"role": "user", "content": "hi"}], "claude-haiku-4-5-20251001"
)

# both results must be LLMResponse instances with same fields
assert isinstance(openai_result, LLMResponse)
assert isinstance(anthropic_result, LLMResponse)
assert set(openai_result.__class__.model_fields) == set(
anthropic_result.__class__.model_fields
)
30 changes: 30 additions & 0 deletions tests/test_ingestion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from raglab.ingestion.chunkers import recursive_chunk


def test_chunk_splits_text_into_correct_sizes():
text = "a" * 1000
chunks = recursive_chunk(text, chunk_size=100, overlap=10)

assert len(chunks) > 1
# no chunk should exceed chunk_size
assert all(len(c) <= 100 for c in chunks)


def test_chunk_overlap_exists_between_consecutive_chunks():
text = "abcdefghij" * 100
chunks = recursive_chunk(text, chunk_size=50, overlap=10)

# the end of chunk N should appear at the start of chunk N+1
assert chunks[0][-10:] == chunks[1][:10]


def test_chunk_empty_text_returns_empty_list():
assert recursive_chunk("") == []
assert recursive_chunk(" ") == []


def test_chunk_short_text_returns_single_chunk():
text = "short text"
chunks = recursive_chunk(text, chunk_size=100, overlap=10)
assert len(chunks) == 1
assert chunks[0] == text
50 changes: 50 additions & 0 deletions tests/test_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from unittest.mock import MagicMock, patch

from raglab.retrieval.chroma import ChromaRetriever


def test_chroma_retriever_returns_retrieved_chunks():
mock_collection = MagicMock()
mock_collection.query.return_value = {
"documents": [["chunk text 1", "chunk text 2"]],
"metadatas": [
[
{"source": "test.pdf", "page": 1, "chunk_index": 0},
{"source": "test.pdf", "page": 1, "chunk_index": 1},
]
],
"distances": [[0.1, 0.3]],
}

with patch("raglab.retrieval.chroma.chromadb.PersistentClient") as mock_client:
with patch("raglab.retrieval.chroma.embed_batch", return_value=[[0.1] * 1536]):
mock_client.return_value.get_collection.return_value = mock_collection

retriever = ChromaRetriever(collection_name="test")
results = retriever.retrieve("what is RAG?", top_k=2)

assert len(results) == 2
assert results[0].text == "chunk text 1"
assert results[0].source == "test.pdf"
assert results[0].page == 1
# higher score = more similar (1 - distance)
assert results[0].score > results[1].score


def test_chroma_retriever_score_is_inverted_distance():
mock_collection = MagicMock()
mock_collection.query.return_value = {
"documents": [["chunk"]],
"metadatas": [[{"source": "f.pdf", "page": 1, "chunk_index": 0}]],
"distances": [[0.2]],
}

with patch("raglab.retrieval.chroma.chromadb.PersistentClient") as mock_client:
with patch("raglab.retrieval.chroma.embed_batch", return_value=[[0.1] * 1536]):
mock_client.return_value.get_collection.return_value = mock_collection

retriever = ChromaRetriever(collection_name="test")
results = retriever.retrieve("query")

# score should be 1 - 0.2 = 0.8
assert results[0].score == round(1 - 0.2, 4)
54 changes: 53 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading