diff --git a/pydebeziumai/retrieval/__init__.py b/pydebeziumai/retrieval/__init__.py new file mode 100755 index 0000000..a2678a2 --- /dev/null +++ b/pydebeziumai/retrieval/__init__.py @@ -0,0 +1,7 @@ +"""Retrieval layer integration for LangChain and LangGraph.""" + +from __future__ import annotations + +from pydebeziumai.retrieval.langgraph import create_retriever_node, create_retriever_tool + +__all__ = ["create_retriever_node", "create_retriever_tool"] diff --git a/pydebeziumai/retrieval/langgraph.py b/pydebeziumai/retrieval/langgraph.py new file mode 100755 index 0000000..d5ce612 --- /dev/null +++ b/pydebeziumai/retrieval/langgraph.py @@ -0,0 +1,114 @@ +"""LangGraph integration for PyDebeziumAI vector store adapters.""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from typing import Any + +from langchain_core.tools import BaseTool, Tool + +from pydebeziumai.adapters.base import VectorStoreAdapter + +logger = logging.getLogger(__name__) + + +def create_retriever_tool( + adapter: VectorStoreAdapter, + name: str, + description: str, + **kwargs: Any, +) -> BaseTool: + """Create a LangChain Tool backed by the VectorStoreAdapter. + + This tool is designed to be passed to LangGraph agents or standard LangChain + agent executors. + + Args: + adapter: The VectorStoreAdapter instance to query. + name: Name of the tool. + description: Description of what the tool does and when to call it. + **kwargs: Additional parameters passed to adapter.as_retriever(). + + Returns: + A LangChain Tool instance. + """ + retriever = adapter.as_retriever(**kwargs) + try: + from langchain.tools.retriever import create_retriever_tool as langchain_create_tool + + # Note: langchain's create_retriever_tool accepts a BaseRetriever + return langchain_create_tool(retriever, name, description) + except ImportError: + logger.debug("langchain.tools.retriever not available. Falling back to custom Tool.") + + def retrieve(query: str) -> str: + """Call the retriever with the query and format results.""" + docs = retriever.invoke(query) + formatted_docs = [] + for i, doc in enumerate(docs): + source = doc.metadata.get("source", "unknown") + formatted_docs.append(f"Document {i + 1} (Source: {source}):\n{doc.page_content}") + return "\n\n".join(formatted_docs) + + return Tool( + name=name, + description=description, + func=retrieve, + ) + + +def create_retriever_node( + adapter: VectorStoreAdapter, + state_key: str = "documents", + query_key: str = "query", + query_extractor: Callable[[Any], str] | None = None, + **retriever_kwargs: Any, +) -> Callable[[Any], dict[str, Any]]: + """Create a LangGraph-compatible node that queries the VectorStoreAdapter. + + This node executes the retrieval query and updates the graph state with the + resulting list of Document objects under the specified state_key. + + Args: + adapter: The VectorStoreAdapter instance. + state_key: The key in the returned state dictionary where retrieved documents will be stored. + query_key: The key in the input state where the search query string is stored. + Only used if query_extractor is None. + query_extractor: An optional callable that extracts the query string from the graph state. + If provided, it overrides query_key. + **retriever_kwargs: Options passed to adapter.as_retriever() (e.g. search_kwargs). + + Returns: + A callable node function for LangGraph. + """ + retriever = adapter.as_retriever(**retriever_kwargs) + + def node_fn(state: Any) -> dict[str, Any]: + if query_extractor is not None: + query = query_extractor(state) + elif isinstance(state, dict): + if query_key in state: + query = state[query_key] + elif "messages" in state and len(state["messages"]) > 0: + # Smart extraction from standard messages list + last_msg = state["messages"][-1] + if hasattr(last_msg, "content"): + query = last_msg.content + elif isinstance(last_msg, dict): + query = last_msg.get("content", "") + else: + query = str(last_msg) + else: + query = "" + else: + # If state is an object/Pydantic model + query = getattr(state, query_key, "") + + if not isinstance(query, str): + raise ValueError(f"Extracted query must be a string, got {type(query)}") + + documents = retriever.invoke(query) + return {state_key: documents} + + return node_fn diff --git a/tests/unit/test_retrieval.py b/tests/unit/test_retrieval.py new file mode 100755 index 0000000..90b75b6 --- /dev/null +++ b/tests/unit/test_retrieval.py @@ -0,0 +1,148 @@ +"""Unit tests for retrieval layer and LangGraph node helpers.""" + +from __future__ import annotations + +from typing import Any + +import pytest +from langchain_core.documents import Document +from langchain_core.retrievers import BaseRetriever + +from pydebeziumai.adapters.base import VectorStoreAdapter +from pydebeziumai.retrieval.langgraph import create_retriever_node, create_retriever_tool + + +class MockRetriever(BaseRetriever): + """Mock LangChain retriever for testing.""" + + mock_docs: list[Document] + + def _get_relevant_documents(self, query: str, *, run_manager: Any = None) -> list[Document]: + return self.mock_docs + + +class MockAdapter(VectorStoreAdapter): + """Mock vector store adapter to pass to helpers.""" + + def __init__(self, mock_docs: list[Document]) -> None: + self.retriever = MockRetriever(mock_docs=mock_docs) + + def upsert(self, document: Document) -> None: + pass + + def delete(self, doc_id: str) -> None: + pass + + def as_retriever(self, **kwargs: Any) -> BaseRetriever: + return self.retriever + + +class MockMessage: + """Mock message class resembling LangChain's HumanMessage/AIMessage.""" + + def __init__(self, content: str) -> None: + self.content = content + + +class MockStateObject: + """Mock state object class to test attribute extraction.""" + + def __init__(self, query: str) -> None: + self.query = query + + +@pytest.fixture +def mock_documents() -> list[Document]: + """Fixture returning sample documents.""" + return [ + Document(page_content="Apple is a fruit.", metadata={"source": "db.fruits"}), + Document(page_content="Carrot is a vegetable.", metadata={"source": "db.vegetables"}), + ] + + +@pytest.fixture +def mock_adapter(mock_documents: list[Document]) -> MockAdapter: + """Fixture returning a mock adapter configured with sample documents.""" + return MockAdapter(mock_documents) + + +def test_create_retriever_tool(mock_adapter: MockAdapter) -> None: + """Test that create_retriever_tool generates a Tool with correct behaviors.""" + tool = create_retriever_tool( + adapter=mock_adapter, + name="test_tool", + description="A test tool to retrieve information", + ) + + assert tool.name == "test_tool" + assert tool.description == "A test tool to retrieve information" + + # Invoke tool and check formatting + res = tool.invoke("test query") + assert "Apple is a fruit." in res + assert "Carrot is a vegetable." in res + + +def test_create_retriever_node_dict_query(mock_adapter: MockAdapter, mock_documents: list[Document]) -> None: + """Test create_retriever_node with query key in a dict state.""" + node = create_retriever_node(mock_adapter, state_key="results", query_key="search_term") + + state = {"search_term": "query info"} + output = node(state) + + assert isinstance(output, dict) + assert "results" in output + assert output["results"] == mock_documents + + +def test_create_retriever_node_dict_messages(mock_adapter: MockAdapter, mock_documents: list[Document]) -> None: + """Test create_retriever_node with a messages key in a dict state.""" + node = create_retriever_node(mock_adapter, state_key="results") + + state: dict[str, Any] = { + "messages": [ + MockMessage("hello"), + MockMessage("tell me about vegetables"), + ] + } + output = node(state) + + assert output["results"] == mock_documents + + +def test_create_retriever_node_object_query(mock_adapter: MockAdapter, mock_documents: list[Document]) -> None: + """Test create_retriever_node with an object state and attribute query.""" + node = create_retriever_node(mock_adapter, state_key="results", query_key="query") + + state = MockStateObject("custom query text") + output = node(state) + + assert output["results"] == mock_documents + + +def test_create_retriever_node_custom_extractor(mock_adapter: MockAdapter, mock_documents: list[Document]) -> None: + """Test create_retriever_node with a custom query extractor function.""" + node = create_retriever_node( + mock_adapter, + state_key="results", + query_extractor=lambda s: s["nested"]["query_value"], + ) + + state = { + "nested": { + "query_value": "extracted text", + } + } + output = node(state) + + assert output["results"] == mock_documents + + +def test_create_retriever_node_invalid_query_type(mock_adapter: MockAdapter) -> None: + """Test create_retriever_node raises ValueError when the extracted query is not a string.""" + node = create_retriever_node(mock_adapter, query_key="search") + + state = {"search": 12345} # Integer instead of string + + with pytest.raises(ValueError, match="Extracted query must be a string"): + node(state)