diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..466fadb --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,37 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + version: "latest" + + - name: Set up Python + run: uv python install 3.12 + + - name: Install dependencies + run: uv sync + + - name: Lint with ruff + run: uv run ruff check . + + - name: Type check with mypy + run: uv run mypy src + + - name: Run tests + run: uv run pytest + env: + OPENAI_API_KEY: "test-key" + ANTHROPIC_API_KEY: "test-key" diff --git a/pyproject.toml b/pyproject.toml index 3092169..5afbf46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ strict = true [tool.pytest.ini_options] testpaths = ["tests"] +pythonpath = ["src"] [tool.uv] override-dependencies = [ @@ -53,4 +54,6 @@ override-dependencies = [ [dependency-groups] dev = [ "pre-commit>=4.6.0", + "pytest>=9.0.3", + "pytest-mock>=3.15.1", ] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1602276 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,16 @@ +import pytest + +from raglab.gateway.base import LLMResponse + + +@pytest.fixture +def fake_llm_response() -> LLMResponse: + return LLMResponse( + text="This is a test answer.", + input_tokens=10, + output_tokens=20, + cost_usd=0.000005, + latency_ms=123.45, + model="gpt-4o-mini", + provider="openai", + ) diff --git a/tests/test_experiments.py b/tests/test_experiments.py new file mode 100644 index 0000000..5271261 --- /dev/null +++ b/tests/test_experiments.py @@ -0,0 +1,27 @@ +from raglab.experiments.runner import ExperimentConfig + + +def test_experiment_config_defaults(): + config = ExperimentConfig( + models=["gpt-4o-mini"], + prompt_versions=["v1"], + questions=["What is RAG?"], + ) + assert config.top_k == 5 + assert config.provider == "openai" + assert config.retriever == "chroma" + + +def test_experiment_config_matrix_size(): + config = ExperimentConfig( + models=["gpt-4o-mini", "claude-haiku-4-5-20251001"], + prompt_versions=["v1", "v2"], + questions=["Q1", "Q2", "Q3"], + ) + import itertools + + combinations = list( + itertools.product(config.models, config.prompt_versions, config.questions) + ) + # 2 models x 2 prompts x 3 questions = 12 + assert len(combinations) == 12 diff --git a/tests/test_gateway.py b/tests/test_gateway.py new file mode 100644 index 0000000..86b9882 --- /dev/null +++ b/tests/test_gateway.py @@ -0,0 +1,89 @@ +from unittest.mock import MagicMock, patch + +from raglab.gateway.anthropic import AnthropicProvider +from raglab.gateway.openai import OpenAIProvider + + +def test_openai_provider_returns_normalized_response(): + # build a fake OpenAI response object + mock_response = MagicMock() + mock_response.choices[0].message.content = "Hello from OpenAI" + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 20 + mock_response.model = "gpt-4o-mini" + + with patch("raglab.gateway.openai.OpenAI") as mock_client_class: + mock_client_class.return_value.chat.completions.create.return_value = ( + mock_response + ) + + provider = OpenAIProvider(api_key="test-key") + result = provider.generate( + messages=[{"role": "user", "content": "hello"}], + model="gpt-4o-mini", + ) + + assert result.text == "Hello from OpenAI" + assert result.provider == "openai" + assert result.input_tokens == 10 + assert result.output_tokens == 20 + assert result.cost_usd >= 0 + assert result.latency_ms > 0 + + +def test_anthropic_provider_returns_normalized_response(): + mock_response = MagicMock() + mock_response.content[0].text = "Hello from Anthropic" + mock_response.usage.input_tokens = 10 + mock_response.usage.output_tokens = 20 + + with patch("raglab.gateway.anthropic.Anthropic") as mock_client_class: + mock_client_class.return_value.messages.create.return_value = mock_response + + provider = AnthropicProvider(api_key="test-key") + result = provider.generate( + messages=[ + {"role": "system", "content": "you are helpful"}, + {"role": "user", "content": "hello"}, + ], + model="claude-haiku-4-5-20251001", + ) + + assert result.text == "Hello from Anthropic" + assert result.provider == "anthropic" + assert result.input_tokens == 10 + assert result.output_tokens == 20 + + +def test_gateway_normalizes_both_providers_to_same_shape(): + # both providers must return LLMResponse with identical fields + from raglab.gateway.base import LLMResponse + + mock_openai = MagicMock() + mock_openai.choices[0].message.content = "answer" + mock_openai.usage.prompt_tokens = 5 + mock_openai.usage.completion_tokens = 5 + + mock_anthropic = MagicMock() + mock_anthropic.content[0].text = "answer" + mock_anthropic.usage.input_tokens = 5 + mock_anthropic.usage.output_tokens = 5 + + with patch("raglab.gateway.openai.OpenAI") as oai: + oai.return_value.chat.completions.create.return_value = mock_openai + openai_result = OpenAIProvider("test").generate( + [{"role": "user", "content": "hi"}], "gpt-4o-mini" + ) + + with patch("raglab.gateway.anthropic.Anthropic") as ant: + ant.return_value.messages.create.return_value = mock_anthropic + anthropic_result = AnthropicProvider("test").generate( + [{"role": "user", "content": "hi"}], "claude-haiku-4-5-20251001" + ) + + # both results must be LLMResponse instances with same fields + assert isinstance(openai_result, LLMResponse) + assert isinstance(anthropic_result, LLMResponse) + assert set(openai_result.__class__.model_fields) == set( + anthropic_result.__class__.model_fields + ) diff --git a/tests/test_ingestion.py b/tests/test_ingestion.py new file mode 100644 index 0000000..910f562 --- /dev/null +++ b/tests/test_ingestion.py @@ -0,0 +1,30 @@ +from raglab.ingestion.chunkers import recursive_chunk + + +def test_chunk_splits_text_into_correct_sizes(): + text = "a" * 1000 + chunks = recursive_chunk(text, chunk_size=100, overlap=10) + + assert len(chunks) > 1 + # no chunk should exceed chunk_size + assert all(len(c) <= 100 for c in chunks) + + +def test_chunk_overlap_exists_between_consecutive_chunks(): + text = "abcdefghij" * 100 + chunks = recursive_chunk(text, chunk_size=50, overlap=10) + + # the end of chunk N should appear at the start of chunk N+1 + assert chunks[0][-10:] == chunks[1][:10] + + +def test_chunk_empty_text_returns_empty_list(): + assert recursive_chunk("") == [] + assert recursive_chunk(" ") == [] + + +def test_chunk_short_text_returns_single_chunk(): + text = "short text" + chunks = recursive_chunk(text, chunk_size=100, overlap=10) + assert len(chunks) == 1 + assert chunks[0] == text diff --git a/tests/test_retrieval.py b/tests/test_retrieval.py new file mode 100644 index 0000000..9a629b7 --- /dev/null +++ b/tests/test_retrieval.py @@ -0,0 +1,50 @@ +from unittest.mock import MagicMock, patch + +from raglab.retrieval.chroma import ChromaRetriever + + +def test_chroma_retriever_returns_retrieved_chunks(): + mock_collection = MagicMock() + mock_collection.query.return_value = { + "documents": [["chunk text 1", "chunk text 2"]], + "metadatas": [ + [ + {"source": "test.pdf", "page": 1, "chunk_index": 0}, + {"source": "test.pdf", "page": 1, "chunk_index": 1}, + ] + ], + "distances": [[0.1, 0.3]], + } + + with patch("raglab.retrieval.chroma.chromadb.PersistentClient") as mock_client: + with patch("raglab.retrieval.chroma.embed_batch", return_value=[[0.1] * 1536]): + mock_client.return_value.get_collection.return_value = mock_collection + + retriever = ChromaRetriever(collection_name="test") + results = retriever.retrieve("what is RAG?", top_k=2) + + assert len(results) == 2 + assert results[0].text == "chunk text 1" + assert results[0].source == "test.pdf" + assert results[0].page == 1 + # higher score = more similar (1 - distance) + assert results[0].score > results[1].score + + +def test_chroma_retriever_score_is_inverted_distance(): + mock_collection = MagicMock() + mock_collection.query.return_value = { + "documents": [["chunk"]], + "metadatas": [[{"source": "f.pdf", "page": 1, "chunk_index": 0}]], + "distances": [[0.2]], + } + + with patch("raglab.retrieval.chroma.chromadb.PersistentClient") as mock_client: + with patch("raglab.retrieval.chroma.embed_batch", return_value=[[0.1] * 1536]): + mock_client.return_value.get_collection.return_value = mock_collection + + retriever = ChromaRetriever(collection_name="test") + results = retriever.retrieve("query") + + # score should be 1 - 0.2 = 0.8 + assert results[0].score == round(1 - 0.2, 4) diff --git a/uv.lock b/uv.lock index f605c4f..3d4b437 100644 --- a/uv.lock +++ b/uv.lock @@ -1086,6 +1086,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/db/55a262f3606bebcae07cc14095338471ad7c0bbcaa37707e6f0ee49725b7/importlib_resources-7.1.0-py3-none-any.whl", hash = "sha256:1bd7b48b4088eddb2cd16382150bb515af0bd2c70128194392725f82ad2c96a1", size = 37232, upload-time = "2026-04-12T16:36:08.219Z" }, ] +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + [[package]] name = "instructor" version = "1.15.1" @@ -2212,6 +2221,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/75/a6/a0a304dc33b49145b21f4808d763822111e67d1c3a32b524a1baf947b6e1/platformdirs-4.9.6-py3-none-any.whl", hash = "sha256:e61adb1d5e5cb3441b4b7710bea7e4c12250ca49439228cc1021c00dcfac0917", size = 21348, upload-time = "2026-04-09T00:04:09.463Z" }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + [[package]] name = "pre-commit" version = "4.6.0" @@ -2646,6 +2664,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bd/24/12818598c362d7f300f18e74db45963dbcb85150324092410c8b49405e42/pyproject_hooks-1.2.0-py3-none-any.whl", hash = "sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913", size = 10216, upload-time = "2024-09-29T09:24:11.978Z" }, ] +[[package]] +name = "pytest" +version = "9.0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" }, +] + +[[package]] +name = "pytest-mock" +version = "3.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -2788,6 +2834,8 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "pre-commit" }, + { name = "pytest" }, + { name = "pytest-mock" }, ] [package.metadata] @@ -2808,7 +2856,11 @@ requires-dist = [ ] [package.metadata.requires-dev] -dev = [{ name = "pre-commit", specifier = ">=4.6.0" }] +dev = [ + { name = "pre-commit", specifier = ">=4.6.0" }, + { name = "pytest", specifier = ">=9.0.3" }, + { name = "pytest-mock", specifier = ">=3.15.1" }, +] [[package]] name = "rank-bm25"