Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ class Config:
LLM_API_KEY = os.environ.get('LLM_API_KEY')
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')

# Graph memory backend configuration
GRAPH_MEMORY_BACKEND = os.environ.get('GRAPH_MEMORY_BACKEND', 'zep_cloud')
GRAPHITI_MODEL_NAME = os.environ.get('GRAPHITI_MODEL_NAME', LLM_MODEL_NAME)
GRAPHITI_EMBEDDING_MODEL_NAME = os.environ.get('GRAPHITI_EMBEDDING_MODEL_NAME', 'text-embedding-3-small')
GRAPHITI_BRIDGE_URL = os.environ.get('GRAPHITI_BRIDGE_URL', 'http://graphiti-bridge:8008')
FALKORDB_HOST = os.environ.get('FALKORDB_HOST', 'localhost')
FALKORDB_PORT = int(os.environ.get('FALKORDB_PORT', '6379'))
FALKORDB_DATABASE = os.environ.get('FALKORDB_DATABASE', 'mirofish')

# Zep配置
ZEP_API_KEY = os.environ.get('ZEP_API_KEY')
Expand Down Expand Up @@ -69,7 +78,8 @@ def validate(cls) -> list[str]:
errors: list[str] = []
if not cls.LLM_API_KEY:
errors.append("LLM_API_KEY 未配置")
if not cls.ZEP_API_KEY:
graph_backend = (cls.GRAPH_MEMORY_BACKEND or 'zep_cloud').lower()
if graph_backend in {'zep', 'zep_cloud', 'zep-cloud'} and not cls.ZEP_API_KEY:
errors.append("ZEP_API_KEY 未配置")
return errors

13 changes: 13 additions & 0 deletions backend/app/graph_memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Graph memory backend adapters."""

from .base import GraphMemoryAdapter
from .factory import create_graph_memory_adapter
from .graphiti_bridge_adapter import GraphitiBridgeGraphMemoryAdapter
from .zep_cloud_adapter import ZepCloudGraphMemoryAdapter

__all__ = [
"GraphMemoryAdapter",
"GraphitiBridgeGraphMemoryAdapter",
"ZepCloudGraphMemoryAdapter",
"create_graph_memory_adapter",
]
66 changes: 66 additions & 0 deletions backend/app/graph_memory/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Graph memory adapter contracts.

This module defines the narrow graph-memory surface Mirofish needs. Concrete
backends can implement it without leaking vendor SDK details into services.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Protocol


class GraphMemoryAdapter(ABC):
"""Backend-neutral graph memory interface used by Mirofish services."""

@abstractmethod
def create_graph(self, graph_id: str, name: str, description: str) -> Any:
"""Create a graph and return the backend response."""

@abstractmethod
def set_ontology(self, graph_id: str, ontology: dict[str, Any]) -> Any:
"""Apply ontology definitions for a graph."""

@abstractmethod
def add_text_batch(self, graph_id: str, chunks: list[str]) -> Any:
"""Add a batch of text episodes to a graph."""

@abstractmethod
def add_text(self, graph_id: str, text: str) -> Any:
"""Add a single text episode to a graph."""

@abstractmethod
def get_episode(self, episode_uuid: str) -> Any:
"""Return one episode by UUID."""

@abstractmethod
def get_all_nodes(self, graph_id: str) -> list[Any]:
"""Return all nodes for a graph."""

@abstractmethod
def get_all_edges(self, graph_id: str) -> list[Any]:
"""Return all edges for a graph."""

@abstractmethod
def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "edges", **kwargs: Any) -> Any:
"""Search graph memory."""

@abstractmethod
def get_node(self, node_uuid: str) -> Any:
"""Return one node by UUID."""

@abstractmethod
def get_node_edges(self, node_uuid: str) -> list[Any]:
"""Return edges related to one node."""

@abstractmethod
def delete_graph(self, graph_id: str) -> Any:
"""Delete a graph."""


class SupportsRawClient(Protocol):
"""Compatibility escape hatch for legacy code not yet adapter-native."""

@property
def raw_client(self) -> Any:
"""Return the underlying SDK client."""
23 changes: 23 additions & 0 deletions backend/app/graph_memory/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Graph memory adapter factory."""

from __future__ import annotations

from typing import Optional

from ..config import Config
from .base import GraphMemoryAdapter
from .zep_cloud_adapter import ZepCloudGraphMemoryAdapter


def create_graph_memory_adapter(api_key: Optional[str] = None, backend: Optional[str] = None) -> GraphMemoryAdapter:
selected_backend = (backend or Config.GRAPH_MEMORY_BACKEND).strip().lower()

if selected_backend in {"zep", "zep_cloud", "zep-cloud"}:
return ZepCloudGraphMemoryAdapter(api_key=api_key or Config.ZEP_API_KEY)

if selected_backend in {"graphiti", "graphiti_core", "graphiti-core", "graphiti_bridge", "graphiti-bridge"}:
from .graphiti_bridge_adapter import GraphitiBridgeGraphMemoryAdapter

return GraphitiBridgeGraphMemoryAdapter(api_key=api_key or Config.LLM_API_KEY)

raise ValueError(f"Unsupported GRAPH_MEMORY_BACKEND: {selected_backend}")
128 changes: 128 additions & 0 deletions backend/app/graph_memory/graphiti_bridge_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""HTTP adapter for the on-premise Graphiti bridge service."""

from __future__ import annotations

import json
from types import SimpleNamespace
from typing import Any
from urllib.error import HTTPError
from urllib.parse import quote, urlencode
from urllib.request import Request, urlopen

from .base import GraphMemoryAdapter
from ..config import Config


class GraphitiBridgeGraphMemoryAdapter(GraphMemoryAdapter):
"""Graph memory adapter backed by the local Graphiti bridge service."""

def __init__(self, api_key: str | None = None, base_url: str | None = None):
self.base_url = (base_url or Config.GRAPHITI_BRIDGE_URL).rstrip("/")
self._node_graph_index: dict[str, str] = {}

@property
def raw_client(self) -> None:
return None

def create_graph(self, graph_id: str, name: str, description: str) -> Any:
return self._to_namespace(self._request("POST", "/graphs", {"graph_id": graph_id, "name": name, "description": description}))

def set_ontology(self, graph_id: str, ontology: dict[str, Any]) -> Any:
return self._request("POST", f"/graphs/{quote(graph_id)}/ontology", ontology)

def add_text_batch(self, graph_id: str, chunks: list[str]) -> list[Any]:
data = self._request("POST", f"/graphs/{quote(graph_id)}/episodes", {"chunks": chunks})
return [self._episode(item) for item in data.get("episodes", [])]

def add_text(self, graph_id: str, text: str) -> Any:
data = self._request("POST", f"/graphs/{quote(graph_id)}/episodes", {"text": text})
episodes = data.get("episodes", [])
return self._episode(episodes[0]) if episodes else self._episode({"uuid": None, "processed": True})

def get_episode(self, episode_uuid: str) -> Any:
return self._episode({"uuid": episode_uuid, "processed": True})

def get_all_nodes(self, graph_id: str) -> list[Any]:
data = self._request("GET", f"/graphs/{quote(graph_id)}/nodes")
nodes = [self._node(item) for item in data.get("nodes", [])]
for node in nodes:
self._node_graph_index[node.uuid_] = graph_id
return nodes

def get_all_edges(self, graph_id: str) -> list[Any]:
data = self._request("GET", f"/graphs/{quote(graph_id)}/edges")
return [self._edge(item) for item in data.get("edges", [])]

def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "edges", **kwargs: Any) -> Any:
data = self._request("POST", f"/graphs/{quote(graph_id)}/search", {"query": query, "limit": limit, "scope": scope})
nodes = [self._node(item) for item in data.get("nodes", [])]
for node in nodes:
self._node_graph_index[node.uuid_] = graph_id
return SimpleNamespace(edges=[self._edge(item) for item in data.get("edges", [])], nodes=nodes)

def get_node(self, node_uuid: str) -> Any:
graph_id = self._node_graph_index.get(node_uuid)
if not graph_id:
return None
query = urlencode({"graph_id": graph_id})
data = self._request("GET", f"/nodes/{quote(node_uuid)}?{query}")
node = data.get("node")
return self._node(node) if node else None

def get_node_edges(self, node_uuid: str) -> list[Any]:
graph_id = self._node_graph_index.get(node_uuid)
if not graph_id:
return []
query = urlencode({"graph_id": graph_id})
data = self._request("GET", f"/nodes/{quote(node_uuid)}/edges?{query}")
return [self._edge(item) for item in data.get("edges", [])]

def delete_graph(self, graph_id: str) -> Any:
return self._request("DELETE", f"/graphs/{quote(graph_id)}")

def _request(self, method: str, path: str, payload: dict[str, Any] | None = None) -> dict[str, Any]:
body = None if payload is None else json.dumps(payload).encode("utf-8")
headers = {"Content-Type": "application/json"}
req = Request(f"{self.base_url}{path}", data=body, headers=headers, method=method)
try:
with urlopen(req, timeout=120) as response:
raw = response.read().decode("utf-8")
return json.loads(raw) if raw else {}
except HTTPError as exc:
error_body = exc.read().decode("utf-8", errors="replace")
raise RuntimeError(f"Graphiti bridge request failed: {exc.code} {error_body}") from exc

def _episode(self, data: dict[str, Any]) -> Any:
uuid = data.get("uuid") or data.get("uuid_")
return SimpleNamespace(uuid_=uuid, uuid=uuid, processed=data.get("processed", True))

def _node(self, data: dict[str, Any]) -> Any:
uuid = data.get("uuid") or data.get("uuid_") or ""
return SimpleNamespace(
uuid_=uuid,
uuid=uuid,
name=data.get("name") or "",
labels=data.get("labels") or [],
summary=data.get("summary") or "",
attributes=data.get("attributes") or {},
created_at=data.get("created_at"),
)

def _edge(self, data: dict[str, Any]) -> Any:
uuid = data.get("uuid") or data.get("uuid_") or ""
return SimpleNamespace(
uuid_=uuid,
uuid=uuid,
name=data.get("name") or "",
fact=data.get("fact") or "",
source_node_uuid=data.get("source_node_uuid") or "",
target_node_uuid=data.get("target_node_uuid") or "",
attributes=data.get("attributes") or {},
created_at=data.get("created_at"),
valid_at=data.get("valid_at"),
invalid_at=data.get("invalid_at"),
expired_at=data.get("expired_at"),
)

def _to_namespace(self, data: dict[str, Any]) -> Any:
return SimpleNamespace(**data)
121 changes: 121 additions & 0 deletions backend/app/graph_memory/zep_cloud_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Zep Cloud graph memory adapter."""

from __future__ import annotations

import warnings
from typing import Any, Optional

from pydantic import Field
from zep_cloud import EpisodeData, EntityEdgeSourceTarget
from zep_cloud.client import Zep
from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel

from .base import GraphMemoryAdapter
from ..utils.zep_paging import fetch_all_edges, fetch_all_nodes


class ZepCloudGraphMemoryAdapter(GraphMemoryAdapter):
"""Adapter preserving the existing Zep Cloud behavior."""

RESERVED_NAMES = {"uuid", "name", "group_id", "name_embedding", "summary", "created_at"}

def __init__(self, api_key: str):
if not api_key:
raise ValueError("ZEP_API_KEY 未配置")
self._client = Zep(api_key=api_key)

@property
def raw_client(self) -> Zep:
return self._client

def create_graph(self, graph_id: str, name: str, description: str) -> Any:
return self._client.graph.create(graph_id=graph_id, name=name, description=description)

def set_ontology(self, graph_id: str, ontology: dict[str, Any]) -> Any:
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")

entity_types: dict[str, type[EntityModel]] = {}
for entity_def in ontology.get("entity_types", []):
name = entity_def["name"]
description = entity_def.get("description", f"A {name} entity.")
attrs: dict[str, Any] = {"__doc__": description}
annotations: dict[str, Any] = {}

for attr_def in entity_def.get("attributes", []):
attr_name = self._safe_attr_name(attr_def["name"])
attr_desc = attr_def.get("description", attr_name)
attrs[attr_name] = Field(description=attr_desc, default=None)
annotations[attr_name] = Optional[EntityText]

attrs["__annotations__"] = annotations
entity_class = type(name, (EntityModel,), attrs)
entity_class.__doc__ = description
entity_types[name] = entity_class

edge_definitions: dict[str, tuple[type[EdgeModel], list[EntityEdgeSourceTarget]]] = {}
for edge_def in ontology.get("edge_types", []):
name = edge_def["name"]
description = edge_def.get("description", f"A {name} relationship.")
attrs = {"__doc__": description}
annotations = {}

for attr_def in edge_def.get("attributes", []):
attr_name = self._safe_attr_name(attr_def["name"])
attr_desc = attr_def.get("description", attr_name)
attrs[attr_name] = Field(description=attr_desc, default=None)
annotations[attr_name] = Optional[str]

attrs["__annotations__"] = annotations
class_name = "".join(word.capitalize() for word in name.split("_"))
edge_class = type(class_name, (EdgeModel,), attrs)
edge_class.__doc__ = description

source_targets = [
EntityEdgeSourceTarget(source=st.get("source", "Entity"), target=st.get("target", "Entity"))
for st in edge_def.get("source_targets", [])
]
if source_targets:
edge_definitions[name] = (edge_class, source_targets)

if not entity_types and not edge_definitions:
return None

return self._client.graph.set_ontology(
graph_ids=[graph_id],
entities=entity_types if entity_types else None,
edges=edge_definitions if edge_definitions else None,
)

def add_text_batch(self, graph_id: str, chunks: list[str]) -> Any:
episodes = [EpisodeData(data=chunk, type="text") for chunk in chunks]
return self._client.graph.add_batch(graph_id=graph_id, episodes=episodes)

def add_text(self, graph_id: str, text: str) -> Any:
return self._client.graph.add(graph_id=graph_id, type="text", data=text)

def get_episode(self, episode_uuid: str) -> Any:
return self._client.graph.episode.get(uuid_=episode_uuid)

def get_all_nodes(self, graph_id: str) -> list[Any]:
return fetch_all_nodes(self._client, graph_id)

def get_all_edges(self, graph_id: str) -> list[Any]:
return fetch_all_edges(self._client, graph_id)

def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "edges", **kwargs: Any) -> Any:
return self._client.graph.search(graph_id=graph_id, query=query, limit=limit, scope=scope, **kwargs)

def get_node(self, node_uuid: str) -> Any:
return self._client.graph.node.get(uuid_=node_uuid)

def get_node_edges(self, node_uuid: str) -> list[Any]:
return self._client.graph.node.get_entity_edges(node_uuid=node_uuid)

def delete_graph(self, graph_id: str) -> Any:
return self._client.graph.delete(graph_id=graph_id)

@classmethod
def _safe_attr_name(cls, attr_name: str) -> str:
if attr_name.lower() in cls.RESERVED_NAMES:
return f"entity_{attr_name}"
return attr_name
Loading