Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pydebeziumai/retrieval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Retrieval layer integration for LangChain and LangGraph."""

from __future__ import annotations

from pydebeziumai.retrieval.langgraph import create_retriever_node, create_retriever_tool

__all__ = ["create_retriever_node", "create_retriever_tool"]
116 changes: 116 additions & 0 deletions pydebeziumai/retrieval/langgraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""LangGraph integration for PyDebeziumAI vector store adapters."""

from __future__ import annotations

import logging
from collections.abc import Callable
from typing import Any

from langchain_core.tools import BaseTool, Tool

from pydebeziumai.adapters.base import VectorStoreAdapter

logger = logging.getLogger(__name__)


def create_retriever_tool(
adapter: VectorStoreAdapter,
name: str,
description: str,
**kwargs: Any,
) -> BaseTool:
"""Create a LangChain Tool backed by the VectorStoreAdapter.

This tool is designed to be passed to LangGraph agents or standard LangChain
agent executors.

Args:
adapter: The VectorStoreAdapter instance to query.
name: Name of the tool.
description: Description of what the tool does and when to call it.
**kwargs: Additional parameters passed to adapter.as_retriever().

Returns:
A LangChain Tool instance.
"""
retriever = adapter.as_retriever(**kwargs)

def retrieve(query: str) -> str:
"""Call the retriever with the query and format results."""
docs = retriever.invoke(query)
formatted_docs = []
for i, doc in enumerate(docs):
source = doc.metadata.get("source", "unknown")
formatted_docs.append(f"Document {i + 1} (Source: {source}):\n{doc.page_content}")
return "\n\n".join(formatted_docs)

# We try to use langchain's native tool definition structure if possible,
# fallback to manual langchain-core Tool creation.
try:
from langchain.tools.retriever import create_retriever_tool as langchain_create_tool

# Note: langchain's create_retriever_tool accepts a BaseRetriever
return langchain_create_tool(retriever, name, description)
except ImportError:
logger.debug("langchain.tools.retriever not available. Falling back to custom Tool.")
return Tool(
name=name,
description=description,
func=retrieve,
)


def create_retriever_node(
adapter: VectorStoreAdapter,
state_key: str = "documents",
query_key: str = "query",
query_extractor: Callable[[Any], str] | None = None,
**retriever_kwargs: Any,
) -> Callable[[Any], dict[str, Any]]:
"""Create a LangGraph-compatible node that queries the VectorStoreAdapter.

This node executes the retrieval query and updates the graph state with the
resulting list of Document objects under the specified state_key.

Args:
adapter: The VectorStoreAdapter instance.
state_key: The key in the returned state dictionary where retrieved documents will be stored.
query_key: The key in the input state where the search query string is stored.
Only used if query_extractor is None.
query_extractor: An optional callable that extracts the query string from the graph state.
If provided, it overrides query_key.
**retriever_kwargs: Options passed to adapter.as_retriever() (e.g. search_kwargs).

Returns:
A callable node function for LangGraph.
"""
retriever = adapter.as_retriever(**retriever_kwargs)

def node_fn(state: Any) -> dict[str, Any]:
if query_extractor is not None:
query = query_extractor(state)
elif isinstance(state, dict):
if query_key in state:
query = state[query_key]
elif "messages" in state and len(state["messages"]) > 0:
# Smart extraction from standard messages list
last_msg = state["messages"][-1]
if hasattr(last_msg, "content"):
query = last_msg.content
elif isinstance(last_msg, dict):
query = last_msg.get("content", "")
else:
query = str(last_msg)
else:
query = ""
else:
# If state is an object/Pydantic model
query = getattr(state, query_key, "")

if not isinstance(query, str):
raise ValueError(f"Extracted query must be a string, got {type(query)}")

documents = retriever.invoke(query)
return {state_key: documents}

return node_fn
148 changes: 148 additions & 0 deletions tests/unit/test_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Unit tests for retrieval layer and LangGraph node helpers."""

from __future__ import annotations

from typing import Any

import pytest
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

from pydebeziumai.adapters.base import VectorStoreAdapter
from pydebeziumai.retrieval.langgraph import create_retriever_node, create_retriever_tool


class MockRetriever(BaseRetriever):
"""Mock LangChain retriever for testing."""

mock_docs: list[Document]

def _get_relevant_documents(self, query: str, *, run_manager: Any = None) -> list[Document]:
return self.mock_docs


class MockAdapter(VectorStoreAdapter):
"""Mock vector store adapter to pass to helpers."""

def __init__(self, mock_docs: list[Document]) -> None:
self.retriever = MockRetriever(mock_docs=mock_docs)

def upsert(self, document: Document) -> None:
pass

def delete(self, doc_id: str) -> None:
pass

def as_retriever(self, **kwargs: Any) -> BaseRetriever:
return self.retriever


class MockMessage:
"""Mock message class resembling LangChain's HumanMessage/AIMessage."""

def __init__(self, content: str) -> None:
self.content = content


class MockStateObject:
"""Mock state object class to test attribute extraction."""

def __init__(self, query: str) -> None:
self.query = query


@pytest.fixture
def mock_documents() -> list[Document]:
"""Fixture returning sample documents."""
return [
Document(page_content="Apple is a fruit.", metadata={"source": "db.fruits"}),
Document(page_content="Carrot is a vegetable.", metadata={"source": "db.vegetables"}),
]


@pytest.fixture
def mock_adapter(mock_documents: list[Document]) -> MockAdapter:
"""Fixture returning a mock adapter configured with sample documents."""
return MockAdapter(mock_documents)


def test_create_retriever_tool(mock_adapter: MockAdapter) -> None:
"""Test that create_retriever_tool generates a Tool with correct behaviors."""
tool = create_retriever_tool(
adapter=mock_adapter,
name="test_tool",
description="A test tool to retrieve information",
)

assert tool.name == "test_tool"
assert tool.description == "A test tool to retrieve information"

# Invoke tool and check formatting
res = tool.invoke("test query")
assert "Apple is a fruit." in res
assert "Carrot is a vegetable." in res


def test_create_retriever_node_dict_query(mock_adapter: MockAdapter, mock_documents: list[Document]) -> None:
"""Test create_retriever_node with query key in a dict state."""
node = create_retriever_node(mock_adapter, state_key="results", query_key="search_term")

state = {"search_term": "query info"}
output = node(state)

assert isinstance(output, dict)
assert "results" in output
assert output["results"] == mock_documents


def test_create_retriever_node_dict_messages(mock_adapter: MockAdapter, mock_documents: list[Document]) -> None:
"""Test create_retriever_node with a messages key in a dict state."""
node = create_retriever_node(mock_adapter, state_key="results")

state: dict[str, Any] = {
"messages": [
MockMessage("hello"),
MockMessage("tell me about vegetables"),
]
}
output = node(state)

assert output["results"] == mock_documents


def test_create_retriever_node_object_query(mock_adapter: MockAdapter, mock_documents: list[Document]) -> None:
"""Test create_retriever_node with an object state and attribute query."""
node = create_retriever_node(mock_adapter, state_key="results", query_key="query")

state = MockStateObject("custom query text")
output = node(state)

assert output["results"] == mock_documents


def test_create_retriever_node_custom_extractor(mock_adapter: MockAdapter, mock_documents: list[Document]) -> None:
"""Test create_retriever_node with a custom query extractor function."""
node = create_retriever_node(
mock_adapter,
state_key="results",
query_extractor=lambda s: s["nested"]["query_value"],
)

state = {
"nested": {
"query_value": "extracted text",
}
}
output = node(state)

assert output["results"] == mock_documents


def test_create_retriever_node_invalid_query_type(mock_adapter: MockAdapter) -> None:
"""Test create_retriever_node raises ValueError when the extracted query is not a string."""
node = create_retriever_node(mock_adapter, query_key="search")

state = {"search": 12345} # Integer instead of string

with pytest.raises(ValueError, match="Extracted query must be a string"):
node(state)
Loading