diff --git a/backend/app/config.py b/backend/app/config.py index de63e2b4b..e6487e670 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -31,6 +31,15 @@ class Config: LLM_API_KEY = os.environ.get('LLM_API_KEY') LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1') LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini') + + # Graph memory backend configuration + GRAPH_MEMORY_BACKEND = os.environ.get('GRAPH_MEMORY_BACKEND', 'zep_cloud') + GRAPHITI_MODEL_NAME = os.environ.get('GRAPHITI_MODEL_NAME', LLM_MODEL_NAME) + GRAPHITI_EMBEDDING_MODEL_NAME = os.environ.get('GRAPHITI_EMBEDDING_MODEL_NAME', 'text-embedding-3-small') + GRAPHITI_BRIDGE_URL = os.environ.get('GRAPHITI_BRIDGE_URL', 'http://graphiti-bridge:8008') + FALKORDB_HOST = os.environ.get('FALKORDB_HOST', 'localhost') + FALKORDB_PORT = int(os.environ.get('FALKORDB_PORT', '6379')) + FALKORDB_DATABASE = os.environ.get('FALKORDB_DATABASE', 'mirofish') # Zep配置 ZEP_API_KEY = os.environ.get('ZEP_API_KEY') @@ -69,7 +78,8 @@ def validate(cls) -> list[str]: errors: list[str] = [] if not cls.LLM_API_KEY: errors.append("LLM_API_KEY 未配置") - if not cls.ZEP_API_KEY: + graph_backend = (cls.GRAPH_MEMORY_BACKEND or 'zep_cloud').lower() + if graph_backend in {'zep', 'zep_cloud', 'zep-cloud'} and not cls.ZEP_API_KEY: errors.append("ZEP_API_KEY 未配置") return errors diff --git a/backend/app/graph_memory/__init__.py b/backend/app/graph_memory/__init__.py new file mode 100644 index 000000000..b951eee36 --- /dev/null +++ b/backend/app/graph_memory/__init__.py @@ -0,0 +1,13 @@ +"""Graph memory backend adapters.""" + +from .base import GraphMemoryAdapter +from .factory import create_graph_memory_adapter +from .graphiti_bridge_adapter import GraphitiBridgeGraphMemoryAdapter +from .zep_cloud_adapter import ZepCloudGraphMemoryAdapter + +__all__ = [ + "GraphMemoryAdapter", + "GraphitiBridgeGraphMemoryAdapter", + "ZepCloudGraphMemoryAdapter", + "create_graph_memory_adapter", +] diff --git a/backend/app/graph_memory/base.py b/backend/app/graph_memory/base.py new file mode 100644 index 000000000..d0b45be96 --- /dev/null +++ b/backend/app/graph_memory/base.py @@ -0,0 +1,66 @@ +"""Graph memory adapter contracts. + +This module defines the narrow graph-memory surface Mirofish needs. Concrete +backends can implement it without leaking vendor SDK details into services. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Protocol + + +class GraphMemoryAdapter(ABC): + """Backend-neutral graph memory interface used by Mirofish services.""" + + @abstractmethod + def create_graph(self, graph_id: str, name: str, description: str) -> Any: + """Create a graph and return the backend response.""" + + @abstractmethod + def set_ontology(self, graph_id: str, ontology: dict[str, Any]) -> Any: + """Apply ontology definitions for a graph.""" + + @abstractmethod + def add_text_batch(self, graph_id: str, chunks: list[str]) -> Any: + """Add a batch of text episodes to a graph.""" + + @abstractmethod + def add_text(self, graph_id: str, text: str) -> Any: + """Add a single text episode to a graph.""" + + @abstractmethod + def get_episode(self, episode_uuid: str) -> Any: + """Return one episode by UUID.""" + + @abstractmethod + def get_all_nodes(self, graph_id: str) -> list[Any]: + """Return all nodes for a graph.""" + + @abstractmethod + def get_all_edges(self, graph_id: str) -> list[Any]: + """Return all edges for a graph.""" + + @abstractmethod + def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "edges", **kwargs: Any) -> Any: + """Search graph memory.""" + + @abstractmethod + def get_node(self, node_uuid: str) -> Any: + """Return one node by UUID.""" + + @abstractmethod + def get_node_edges(self, node_uuid: str) -> list[Any]: + """Return edges related to one node.""" + + @abstractmethod + def delete_graph(self, graph_id: str) -> Any: + """Delete a graph.""" + + +class SupportsRawClient(Protocol): + """Compatibility escape hatch for legacy code not yet adapter-native.""" + + @property + def raw_client(self) -> Any: + """Return the underlying SDK client.""" diff --git a/backend/app/graph_memory/factory.py b/backend/app/graph_memory/factory.py new file mode 100644 index 000000000..484122ddf --- /dev/null +++ b/backend/app/graph_memory/factory.py @@ -0,0 +1,23 @@ +"""Graph memory adapter factory.""" + +from __future__ import annotations + +from typing import Optional + +from ..config import Config +from .base import GraphMemoryAdapter +from .zep_cloud_adapter import ZepCloudGraphMemoryAdapter + + +def create_graph_memory_adapter(api_key: Optional[str] = None, backend: Optional[str] = None) -> GraphMemoryAdapter: + selected_backend = (backend or Config.GRAPH_MEMORY_BACKEND).strip().lower() + + if selected_backend in {"zep", "zep_cloud", "zep-cloud"}: + return ZepCloudGraphMemoryAdapter(api_key=api_key or Config.ZEP_API_KEY) + + if selected_backend in {"graphiti", "graphiti_core", "graphiti-core", "graphiti_bridge", "graphiti-bridge"}: + from .graphiti_bridge_adapter import GraphitiBridgeGraphMemoryAdapter + + return GraphitiBridgeGraphMemoryAdapter(api_key=api_key or Config.LLM_API_KEY) + + raise ValueError(f"Unsupported GRAPH_MEMORY_BACKEND: {selected_backend}") diff --git a/backend/app/graph_memory/graphiti_bridge_adapter.py b/backend/app/graph_memory/graphiti_bridge_adapter.py new file mode 100644 index 000000000..cd1d6b73c --- /dev/null +++ b/backend/app/graph_memory/graphiti_bridge_adapter.py @@ -0,0 +1,128 @@ +"""HTTP adapter for the on-premise Graphiti bridge service.""" + +from __future__ import annotations + +import json +from types import SimpleNamespace +from typing import Any +from urllib.error import HTTPError +from urllib.parse import quote, urlencode +from urllib.request import Request, urlopen + +from .base import GraphMemoryAdapter +from ..config import Config + + +class GraphitiBridgeGraphMemoryAdapter(GraphMemoryAdapter): + """Graph memory adapter backed by the local Graphiti bridge service.""" + + def __init__(self, api_key: str | None = None, base_url: str | None = None): + self.base_url = (base_url or Config.GRAPHITI_BRIDGE_URL).rstrip("/") + self._node_graph_index: dict[str, str] = {} + + @property + def raw_client(self) -> None: + return None + + def create_graph(self, graph_id: str, name: str, description: str) -> Any: + return self._to_namespace(self._request("POST", "/graphs", {"graph_id": graph_id, "name": name, "description": description})) + + def set_ontology(self, graph_id: str, ontology: dict[str, Any]) -> Any: + return self._request("POST", f"/graphs/{quote(graph_id)}/ontology", ontology) + + def add_text_batch(self, graph_id: str, chunks: list[str]) -> list[Any]: + data = self._request("POST", f"/graphs/{quote(graph_id)}/episodes", {"chunks": chunks}) + return [self._episode(item) for item in data.get("episodes", [])] + + def add_text(self, graph_id: str, text: str) -> Any: + data = self._request("POST", f"/graphs/{quote(graph_id)}/episodes", {"text": text}) + episodes = data.get("episodes", []) + return self._episode(episodes[0]) if episodes else self._episode({"uuid": None, "processed": True}) + + def get_episode(self, episode_uuid: str) -> Any: + return self._episode({"uuid": episode_uuid, "processed": True}) + + def get_all_nodes(self, graph_id: str) -> list[Any]: + data = self._request("GET", f"/graphs/{quote(graph_id)}/nodes") + nodes = [self._node(item) for item in data.get("nodes", [])] + for node in nodes: + self._node_graph_index[node.uuid_] = graph_id + return nodes + + def get_all_edges(self, graph_id: str) -> list[Any]: + data = self._request("GET", f"/graphs/{quote(graph_id)}/edges") + return [self._edge(item) for item in data.get("edges", [])] + + def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "edges", **kwargs: Any) -> Any: + data = self._request("POST", f"/graphs/{quote(graph_id)}/search", {"query": query, "limit": limit, "scope": scope}) + nodes = [self._node(item) for item in data.get("nodes", [])] + for node in nodes: + self._node_graph_index[node.uuid_] = graph_id + return SimpleNamespace(edges=[self._edge(item) for item in data.get("edges", [])], nodes=nodes) + + def get_node(self, node_uuid: str) -> Any: + graph_id = self._node_graph_index.get(node_uuid) + if not graph_id: + return None + query = urlencode({"graph_id": graph_id}) + data = self._request("GET", f"/nodes/{quote(node_uuid)}?{query}") + node = data.get("node") + return self._node(node) if node else None + + def get_node_edges(self, node_uuid: str) -> list[Any]: + graph_id = self._node_graph_index.get(node_uuid) + if not graph_id: + return [] + query = urlencode({"graph_id": graph_id}) + data = self._request("GET", f"/nodes/{quote(node_uuid)}/edges?{query}") + return [self._edge(item) for item in data.get("edges", [])] + + def delete_graph(self, graph_id: str) -> Any: + return self._request("DELETE", f"/graphs/{quote(graph_id)}") + + def _request(self, method: str, path: str, payload: dict[str, Any] | None = None) -> dict[str, Any]: + body = None if payload is None else json.dumps(payload).encode("utf-8") + headers = {"Content-Type": "application/json"} + req = Request(f"{self.base_url}{path}", data=body, headers=headers, method=method) + try: + with urlopen(req, timeout=120) as response: + raw = response.read().decode("utf-8") + return json.loads(raw) if raw else {} + except HTTPError as exc: + error_body = exc.read().decode("utf-8", errors="replace") + raise RuntimeError(f"Graphiti bridge request failed: {exc.code} {error_body}") from exc + + def _episode(self, data: dict[str, Any]) -> Any: + uuid = data.get("uuid") or data.get("uuid_") + return SimpleNamespace(uuid_=uuid, uuid=uuid, processed=data.get("processed", True)) + + def _node(self, data: dict[str, Any]) -> Any: + uuid = data.get("uuid") or data.get("uuid_") or "" + return SimpleNamespace( + uuid_=uuid, + uuid=uuid, + name=data.get("name") or "", + labels=data.get("labels") or [], + summary=data.get("summary") or "", + attributes=data.get("attributes") or {}, + created_at=data.get("created_at"), + ) + + def _edge(self, data: dict[str, Any]) -> Any: + uuid = data.get("uuid") or data.get("uuid_") or "" + return SimpleNamespace( + uuid_=uuid, + uuid=uuid, + name=data.get("name") or "", + fact=data.get("fact") or "", + source_node_uuid=data.get("source_node_uuid") or "", + target_node_uuid=data.get("target_node_uuid") or "", + attributes=data.get("attributes") or {}, + created_at=data.get("created_at"), + valid_at=data.get("valid_at"), + invalid_at=data.get("invalid_at"), + expired_at=data.get("expired_at"), + ) + + def _to_namespace(self, data: dict[str, Any]) -> Any: + return SimpleNamespace(**data) diff --git a/backend/app/graph_memory/zep_cloud_adapter.py b/backend/app/graph_memory/zep_cloud_adapter.py new file mode 100644 index 000000000..fae410743 --- /dev/null +++ b/backend/app/graph_memory/zep_cloud_adapter.py @@ -0,0 +1,121 @@ +"""Zep Cloud graph memory adapter.""" + +from __future__ import annotations + +import warnings +from typing import Any, Optional + +from pydantic import Field +from zep_cloud import EpisodeData, EntityEdgeSourceTarget +from zep_cloud.client import Zep +from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel + +from .base import GraphMemoryAdapter +from ..utils.zep_paging import fetch_all_edges, fetch_all_nodes + + +class ZepCloudGraphMemoryAdapter(GraphMemoryAdapter): + """Adapter preserving the existing Zep Cloud behavior.""" + + RESERVED_NAMES = {"uuid", "name", "group_id", "name_embedding", "summary", "created_at"} + + def __init__(self, api_key: str): + if not api_key: + raise ValueError("ZEP_API_KEY 未配置") + self._client = Zep(api_key=api_key) + + @property + def raw_client(self) -> Zep: + return self._client + + def create_graph(self, graph_id: str, name: str, description: str) -> Any: + return self._client.graph.create(graph_id=graph_id, name=name, description=description) + + def set_ontology(self, graph_id: str, ontology: dict[str, Any]) -> Any: + warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") + + entity_types: dict[str, type[EntityModel]] = {} + for entity_def in ontology.get("entity_types", []): + name = entity_def["name"] + description = entity_def.get("description", f"A {name} entity.") + attrs: dict[str, Any] = {"__doc__": description} + annotations: dict[str, Any] = {} + + for attr_def in entity_def.get("attributes", []): + attr_name = self._safe_attr_name(attr_def["name"]) + attr_desc = attr_def.get("description", attr_name) + attrs[attr_name] = Field(description=attr_desc, default=None) + annotations[attr_name] = Optional[EntityText] + + attrs["__annotations__"] = annotations + entity_class = type(name, (EntityModel,), attrs) + entity_class.__doc__ = description + entity_types[name] = entity_class + + edge_definitions: dict[str, tuple[type[EdgeModel], list[EntityEdgeSourceTarget]]] = {} + for edge_def in ontology.get("edge_types", []): + name = edge_def["name"] + description = edge_def.get("description", f"A {name} relationship.") + attrs = {"__doc__": description} + annotations = {} + + for attr_def in edge_def.get("attributes", []): + attr_name = self._safe_attr_name(attr_def["name"]) + attr_desc = attr_def.get("description", attr_name) + attrs[attr_name] = Field(description=attr_desc, default=None) + annotations[attr_name] = Optional[str] + + attrs["__annotations__"] = annotations + class_name = "".join(word.capitalize() for word in name.split("_")) + edge_class = type(class_name, (EdgeModel,), attrs) + edge_class.__doc__ = description + + source_targets = [ + EntityEdgeSourceTarget(source=st.get("source", "Entity"), target=st.get("target", "Entity")) + for st in edge_def.get("source_targets", []) + ] + if source_targets: + edge_definitions[name] = (edge_class, source_targets) + + if not entity_types and not edge_definitions: + return None + + return self._client.graph.set_ontology( + graph_ids=[graph_id], + entities=entity_types if entity_types else None, + edges=edge_definitions if edge_definitions else None, + ) + + def add_text_batch(self, graph_id: str, chunks: list[str]) -> Any: + episodes = [EpisodeData(data=chunk, type="text") for chunk in chunks] + return self._client.graph.add_batch(graph_id=graph_id, episodes=episodes) + + def add_text(self, graph_id: str, text: str) -> Any: + return self._client.graph.add(graph_id=graph_id, type="text", data=text) + + def get_episode(self, episode_uuid: str) -> Any: + return self._client.graph.episode.get(uuid_=episode_uuid) + + def get_all_nodes(self, graph_id: str) -> list[Any]: + return fetch_all_nodes(self._client, graph_id) + + def get_all_edges(self, graph_id: str) -> list[Any]: + return fetch_all_edges(self._client, graph_id) + + def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "edges", **kwargs: Any) -> Any: + return self._client.graph.search(graph_id=graph_id, query=query, limit=limit, scope=scope, **kwargs) + + def get_node(self, node_uuid: str) -> Any: + return self._client.graph.node.get(uuid_=node_uuid) + + def get_node_edges(self, node_uuid: str) -> list[Any]: + return self._client.graph.node.get_entity_edges(node_uuid=node_uuid) + + def delete_graph(self, graph_id: str) -> Any: + return self._client.graph.delete(graph_id=graph_id) + + @classmethod + def _safe_attr_name(cls, attr_name: str) -> str: + if attr_name.lower() in cls.RESERVED_NAMES: + return f"entity_{attr_name}" + return attr_name diff --git a/backend/app/services/graph_builder.py b/backend/app/services/graph_builder.py index 37c9969c7..01391403b 100644 --- a/backend/app/services/graph_builder.py +++ b/backend/app/services/graph_builder.py @@ -10,12 +10,9 @@ from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass -from zep_cloud.client import Zep -from zep_cloud import EpisodeData, EntityEdgeSourceTarget - from ..config import Config from ..models.task import TaskManager, TaskStatus -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from ..graph_memory import create_graph_memory_adapter from .text_processor import TextProcessor from ..utils.locale import t, get_locale, set_locale @@ -48,7 +45,8 @@ def __init__(self, api_key: Optional[str] = None): if not self.api_key: raise ValueError("ZEP_API_KEY 未配置") - self.client = Zep(api_key=self.api_key) + self.graph_memory = create_graph_memory_adapter(api_key=self.api_key) + self.client = getattr(self.graph_memory, 'raw_client', None) self.task_manager = TaskManager() def build_graph_async( @@ -194,7 +192,7 @@ def create_graph(self, name: str) -> str: """创建Zep图谱(公开方法)""" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" - self.client.graph.create( + self.graph_memory.create_graph( graph_id=graph_id, name=name, description="MiroFish Social Simulation Graph" @@ -204,93 +202,8 @@ def create_graph(self, name: str) -> str: def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): """设置图谱本体(公开方法)""" - import warnings - from typing import Optional - from pydantic import Field - from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel - - # 抑制 Pydantic v2 关于 Field(default=None) 的警告 - # 这是 Zep SDK 要求的用法,警告来自动态类创建,可以安全忽略 - warnings.filterwarnings('ignore', category=UserWarning, module='pydantic') - - # Zep 保留名称,不能作为属性名 - RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'} - - def safe_attr_name(attr_name: str) -> str: - """将保留名称转换为安全名称""" - if attr_name.lower() in RESERVED_NAMES: - return f"entity_{attr_name}" - return attr_name - - # 动态创建实体类型 - entity_types = {} - for entity_def in ontology.get("entity_types", []): - name = entity_def["name"] - description = entity_def.get("description", f"A {name} entity.") - - # 创建属性字典和类型注解(Pydantic v2 需要) - attrs = {"__doc__": description} - annotations = {} - - for attr_def in entity_def.get("attributes", []): - attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称 - attr_desc = attr_def.get("description", attr_name) - # Zep API 需要 Field 的 description,这是必需的 - attrs[attr_name] = Field(description=attr_desc, default=None) - annotations[attr_name] = Optional[EntityText] # 类型注解 - - attrs["__annotations__"] = annotations - - # 动态创建类 - entity_class = type(name, (EntityModel,), attrs) - entity_class.__doc__ = description - entity_types[name] = entity_class - - # 动态创建边类型 - edge_definitions = {} - for edge_def in ontology.get("edge_types", []): - name = edge_def["name"] - description = edge_def.get("description", f"A {name} relationship.") - - # 创建属性字典和类型注解 - attrs = {"__doc__": description} - annotations = {} - - for attr_def in edge_def.get("attributes", []): - attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称 - attr_desc = attr_def.get("description", attr_name) - # Zep API 需要 Field 的 description,这是必需的 - attrs[attr_name] = Field(description=attr_desc, default=None) - annotations[attr_name] = Optional[str] # 边属性用str类型 - - attrs["__annotations__"] = annotations - - # 动态创建类 - class_name = ''.join(word.capitalize() for word in name.split('_')) - edge_class = type(class_name, (EdgeModel,), attrs) - edge_class.__doc__ = description - - # 构建source_targets - source_targets = [] - for st in edge_def.get("source_targets", []): - source_targets.append( - EntityEdgeSourceTarget( - source=st.get("source", "Entity"), - target=st.get("target", "Entity") - ) - ) - - if source_targets: - edge_definitions[name] = (edge_class, source_targets) - - # 调用Zep API设置本体 - if entity_types or edge_definitions: - self.client.graph.set_ontology( - graph_ids=[graph_id], - entities=entity_types if entity_types else None, - edges=edge_definitions if edge_definitions else None, - ) - + self.graph_memory.set_ontology(graph_id, ontology) + def add_text_batches( self, graph_id: str, @@ -314,17 +227,11 @@ def add_text_batches( progress ) - # 构建episode数据 - episodes = [ - EpisodeData(data=chunk, type="text") - for chunk in batch_chunks - ] - - # 发送到Zep + # 发送到图谱记忆后端 try: - batch_result = self.client.graph.add_batch( + batch_result = self.graph_memory.add_text_batch( graph_id=graph_id, - episodes=episodes + chunks=batch_chunks ) # 收集返回的 episode uuid @@ -376,7 +283,7 @@ def _wait_for_episodes( # 检查每个 episode 的处理状态 for ep_uuid in list(pending_episodes): try: - episode = self.client.graph.episode.get(uuid_=ep_uuid) + episode = self.graph_memory.get_episode(ep_uuid) is_processed = getattr(episode, 'processed', False) if is_processed: @@ -403,10 +310,10 @@ def _wait_for_episodes( def _get_graph_info(self, graph_id: str) -> GraphInfo: """获取图谱信息""" # 获取节点(分页) - nodes = fetch_all_nodes(self.client, graph_id) + nodes = self.graph_memory.get_all_nodes(graph_id) # 获取边(分页) - edges = fetch_all_edges(self.client, graph_id) + edges = self.graph_memory.get_all_edges(graph_id) # 统计实体类型 entity_types = set() @@ -433,8 +340,8 @@ def get_graph_data(self, graph_id: str) -> Dict[str, Any]: Returns: 包含nodes和edges的字典,包括时间信息、属性等详细数据 """ - nodes = fetch_all_nodes(self.client, graph_id) - edges = fetch_all_edges(self.client, graph_id) + nodes = self.graph_memory.get_all_nodes(graph_id) + edges = self.graph_memory.get_all_edges(graph_id) # 创建节点映射用于获取节点名称 node_map = {} @@ -502,5 +409,5 @@ def get_graph_data(self, graph_id: str) -> Dict[str, Any]: def delete_graph(self, graph_id: str): """删除图谱""" - self.client.graph.delete(graph_id=graph_id) + self.graph_memory.delete_graph(graph_id) diff --git a/backend/app/services/oasis_profile_generator.py b/backend/app/services/oasis_profile_generator.py index 7704a627e..7a85b306b 100644 --- a/backend/app/services/oasis_profile_generator.py +++ b/backend/app/services/oasis_profile_generator.py @@ -16,12 +16,11 @@ from datetime import datetime from openai import OpenAI -from zep_cloud.client import Zep - from ..config import Config from ..utils.logger import get_logger from ..utils.locale import get_language_instruction, get_locale, set_locale, t from .zep_entity_reader import EntityNode, ZepEntityReader +from ..graph_memory import create_graph_memory_adapter logger = get_logger('mirofish.oasis_profile') @@ -205,7 +204,8 @@ def __init__( if self.zep_api_key: try: - self.zep_client = Zep(api_key=self.zep_api_key) + self.graph_memory = create_graph_memory_adapter(api_key=self.zep_api_key) + self.zep_client = getattr(self.graph_memory, 'raw_client', None) except Exception as e: logger.warning(f"Zep客户端初始化失败: {e}") @@ -324,7 +324,7 @@ def search_edges(): for attempt in range(max_retries): try: - return self.zep_client.graph.search( + return self.graph_memory.search( query=comprehensive_query, graph_id=self.graph_id, limit=30, @@ -349,7 +349,7 @@ def search_nodes(): for attempt in range(max_retries): try: - return self.zep_client.graph.search( + return self.graph_memory.search( query=comprehensive_query, graph_id=self.graph_id, limit=20, diff --git a/backend/app/services/zep_entity_reader.py b/backend/app/services/zep_entity_reader.py index 71661be49..2dcd2d0cb 100644 --- a/backend/app/services/zep_entity_reader.py +++ b/backend/app/services/zep_entity_reader.py @@ -7,11 +7,9 @@ from typing import Dict, Any, List, Optional, Set, Callable, TypeVar from dataclasses import dataclass, field -from zep_cloud.client import Zep - from ..config import Config from ..utils.logger import get_logger -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from ..graph_memory import create_graph_memory_adapter logger = get_logger('mirofish.zep_entity_reader') @@ -83,7 +81,8 @@ def __init__(self, api_key: Optional[str] = None): if not self.api_key: raise ValueError("ZEP_API_KEY 未配置") - self.client = Zep(api_key=self.api_key) + self.graph_memory = create_graph_memory_adapter(api_key=self.api_key) + self.client = getattr(self.graph_memory, 'raw_client', None) def _call_with_retry( self, @@ -136,7 +135,7 @@ def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: """ logger.info(f"获取图谱 {graph_id} 的所有节点...") - nodes = fetch_all_nodes(self.client, graph_id) + nodes = self.graph_memory.get_all_nodes(graph_id) nodes_data = [] for node in nodes: @@ -163,7 +162,7 @@ def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: """ logger.info(f"获取图谱 {graph_id} 的所有边...") - edges = fetch_all_edges(self.client, graph_id) + edges = self.graph_memory.get_all_edges(graph_id) edges_data = [] for edge in edges: @@ -192,7 +191,7 @@ def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: try: # 使用重试机制调用Zep API edges = self._call_with_retry( - func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid), + func=lambda: self.graph_memory.get_node_edges(node_uuid), operation_name=f"获取节点边(node={node_uuid[:8]}...)" ) @@ -348,7 +347,7 @@ def get_entity_with_context( try: # 使用重试机制获取节点 node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=entity_uuid), + func=lambda: self.graph_memory.get_node(entity_uuid), operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)" ) diff --git a/backend/app/services/zep_graph_memory_updater.py b/backend/app/services/zep_graph_memory_updater.py index e034fee2b..9262f4dbd 100644 --- a/backend/app/services/zep_graph_memory_updater.py +++ b/backend/app/services/zep_graph_memory_updater.py @@ -12,11 +12,10 @@ from datetime import datetime from queue import Queue, Empty -from zep_cloud.client import Zep - from ..config import Config from ..utils.logger import get_logger from ..utils.locale import get_locale, set_locale +from ..graph_memory import create_graph_memory_adapter logger = get_logger('mirofish.zep_graph_memory_updater') @@ -243,7 +242,8 @@ def __init__(self, graph_id: str, api_key: Optional[str] = None): if not self.api_key: raise ValueError("ZEP_API_KEY未配置") - self.client = Zep(api_key=self.api_key) + self.graph_memory = create_graph_memory_adapter(api_key=self.api_key) + self.client = getattr(self.graph_memory, 'raw_client', None) # 活动队列 self._activity_queue: Queue = Queue() @@ -411,10 +411,9 @@ def _send_batch_activities(self, activities: List[AgentActivity], platform: str) # 带重试的发送 for attempt in range(self.MAX_RETRIES): try: - self.client.graph.add( + self.graph_memory.add_text( graph_id=self.graph_id, - type="text", - data=combined_text + text=combined_text ) self._total_sent += 1 diff --git a/backend/app/services/zep_tools.py b/backend/app/services/zep_tools.py index 3bc8a57ab..e3048076d 100644 --- a/backend/app/services/zep_tools.py +++ b/backend/app/services/zep_tools.py @@ -13,13 +13,11 @@ from typing import Dict, Any, List, Optional from dataclasses import dataclass, field -from zep_cloud.client import Zep - from ..config import Config from ..utils.logger import get_logger from ..utils.llm_client import LLMClient from ..utils.locale import get_locale, t -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from ..graph_memory import create_graph_memory_adapter logger = get_logger('mirofish.zep_tools') @@ -427,7 +425,8 @@ def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient if not self.api_key: raise ValueError("ZEP_API_KEY 未配置") - self.client = Zep(api_key=self.api_key) + self.graph_memory = create_graph_memory_adapter(api_key=self.api_key) + self.client = getattr(self.graph_memory, 'raw_client', None) # LLM客户端用于InsightForge生成子问题 self._llm_client = llm_client logger.info(t("console.zepToolsInitialized")) @@ -488,7 +487,7 @@ def search_graph( # 尝试使用Zep Cloud Search API try: search_results = self._call_with_retry( - func=lambda: self.client.graph.search( + func=lambda: self.graph_memory.search( graph_id=graph_id, query=query, limit=limit, @@ -659,7 +658,7 @@ def get_all_nodes(self, graph_id: str) -> List[NodeInfo]: """ logger.info(t("console.fetchingAllNodes", graphId=graph_id)) - nodes = fetch_all_nodes(self.client, graph_id) + nodes = self.graph_memory.get_all_nodes(graph_id) result = [] for node in nodes: @@ -688,7 +687,7 @@ def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[Ed """ logger.info(t("console.fetchingAllEdges", graphId=graph_id)) - edges = fetch_all_edges(self.client, graph_id) + edges = self.graph_memory.get_all_edges(graph_id) result = [] for edge in edges: @@ -727,7 +726,7 @@ def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]: try: node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=node_uuid), + func=lambda: self.graph_memory.get_node(node_uuid), operation_name=t("console.fetchNodeDetailOp", uuid=node_uuid[:8]) ) diff --git a/docker-compose.yml b/docker-compose.yml index 637f1dfae..bf1a3b431 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,4 +11,37 @@ services: - "5001:5001" restart: unless-stopped volumes: - - ./backend/uploads:/app/backend/uploads \ No newline at end of file + - ./backend/uploads:/app/backend/uploads + + graphiti-falkordb: + image: falkordb/falkordb:latest + container_name: graphiti-falkordb + profiles: + - graphiti + volumes: + - graphiti_falkordb_data:/data + restart: unless-stopped + + graphiti-bridge: + build: + context: . + dockerfile: graphiti_bridge/Dockerfile + image: mirofish-graphiti-bridge:latest + container_name: graphiti-bridge + profiles: + - graphiti + environment: + OPENAI_API_KEY: ${LLM_API_KEY} + OPENAI_BASE_URL: ${LLM_BASE_URL:-https://api.openai.com/v1} + MODEL_NAME: ${GRAPHITI_MODEL_NAME:-${LLM_MODEL_NAME:-gpt-5.4-mini}} + EMBEDDING_MODEL_NAME: ${GRAPHITI_EMBEDDING_MODEL_NAME:-text-embedding-3-small} + FALKORDB_HOST: graphiti-falkordb + FALKORDB_PORT: 6379 + depends_on: + - graphiti-falkordb + ports: + - "127.0.0.1:8008:8008" + restart: unless-stopped + +volumes: + graphiti_falkordb_data: diff --git a/docs/on-prem-graph-memory.md b/docs/on-prem-graph-memory.md new file mode 100644 index 000000000..7dc9d5d83 --- /dev/null +++ b/docs/on-prem-graph-memory.md @@ -0,0 +1,43 @@ +# On-Premise Graph Memory + +Mirofish now uses a graph-memory adapter layer. The default backend remains Zep Cloud, so existing behavior does not change unless the backend is explicitly switched. + +## Default: Zep Cloud + +```env +GRAPH_MEMORY_BACKEND=zep_cloud +ZEP_API_KEY=... +``` + +## On-premise: Graphiti Bridge + FalkorDB + +Start the local graph-memory services: + +```bash +docker compose --profile graphiti up -d graphiti-falkordb graphiti-bridge +``` + +Switch Mirofish to the on-premise backend: + +```env +GRAPH_MEMORY_BACKEND=graphiti_bridge +GRAPHITI_BRIDGE_URL=http://graphiti-bridge:8008 +GRAPHITI_MODEL_NAME=gpt-5.4-mini +GRAPHITI_EMBEDDING_MODEL_NAME=text-embedding-3-small +``` + +Then rebuild/restart Mirofish: + +```bash +docker compose build mirofish +docker compose up -d mirofish +``` + +The Graphiti bridge runs in a separate container so its dependencies do not conflict with OASIS. FalkorDB data is stored in the `graphiti_falkordb_data` Docker volume. + +Health checks: + +```bash +curl http://127.0.0.1:8008/health +curl http://localhost:5001/health +``` diff --git a/graphiti_bridge/Dockerfile b/graphiti_bridge/Dockerfile new file mode 100644 index 000000000..bacc544f2 --- /dev/null +++ b/graphiti_bridge/Dockerfile @@ -0,0 +1,8 @@ +FROM python:3.12-slim + +WORKDIR /app +COPY graphiti_bridge/requirements.txt ./requirements.txt +RUN pip install --no-cache-dir -r requirements.txt +COPY graphiti_bridge/app.py ./app.py +EXPOSE 8008 +CMD ["python", "app.py"] diff --git a/graphiti_bridge/app.py b/graphiti_bridge/app.py new file mode 100644 index 000000000..3c0d13957 --- /dev/null +++ b/graphiti_bridge/app.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import asyncio +import os +from datetime import datetime, timezone +from typing import Any, Optional + +from flask import Flask, jsonify, request +from pydantic import BaseModel, Field + +from graphiti_core.driver.falkordb_driver import FalkorDriver +from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig +from graphiti_core.graphiti import Graphiti +from graphiti_core.llm_client.config import LLMConfig +from graphiti_core.llm_client.openai_client import OpenAIClient +from graphiti_core.nodes import EpisodeType + +app = Flask(__name__) + +FALKORDB_HOST = os.environ.get("FALKORDB_HOST", "graphiti-falkordb") +FALKORDB_PORT = int(os.environ.get("FALKORDB_PORT", "6379")) +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY") +OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL") or os.environ.get("LLM_BASE_URL", "https://api.openai.com/v1") +MODEL_NAME = os.environ.get("MODEL_NAME") or os.environ.get("GRAPHITI_MODEL_NAME", "gpt-5.4-mini") +EMBEDDING_MODEL_NAME = os.environ.get("EMBEDDING_MODEL_NAME") or os.environ.get("GRAPHITI_EMBEDDING_MODEL_NAME", "text-embedding-3-small") + +ONTOLOGIES: dict[str, dict[str, Any]] = {} +INDICES_INITIALIZED: set[str] = set() +RESERVED_NAMES = {"uuid", "name", "group_id", "name_embedding", "summary", "created_at"} + + +def run(coro): + return asyncio.run(coro) + + +def safe_attr_name(attr_name: str) -> str: + if attr_name.lower() in RESERVED_NAMES: + return f"entity_{attr_name}" + return attr_name + + +def build_ontology(ontology: dict[str, Any]) -> dict[str, Any]: + entity_types: dict[str, type[BaseModel]] = {} + entity_names: list[str] = [] + for entity_def in ontology.get("entity_types", []): + name = entity_def["name"] + description = entity_def.get("description", f"A {name} entity.") + attrs: dict[str, Any] = {"__doc__": description} + annotations: dict[str, Any] = {} + for attr_def in entity_def.get("attributes", []): + attr_name = safe_attr_name(attr_def["name"]) + attrs[attr_name] = Field(default=None, description=attr_def.get("description", attr_name)) + annotations[attr_name] = Optional[str] + attrs["__annotations__"] = annotations + entity_types[name] = type(name, (BaseModel,), attrs) + entity_names.append(name) + + edge_types: dict[str, type[BaseModel]] = {} + edge_type_map: dict[tuple[str, str], list[str]] = {} + for edge_def in ontology.get("edge_types", []): + name = edge_def["name"] + description = edge_def.get("description", f"A {name} relationship.") + edge_types[name] = type(name, (BaseModel,), {"__doc__": description, "__annotations__": {}}) + for st in edge_def.get("source_targets", []): + source = st.get("source", "Entity") + target = st.get("target", "Entity") + edge_type_map.setdefault((source, target), []).append(name) + + if edge_types and not edge_type_map: + labels = entity_names + ["Entity"] + edge_names = list(edge_types.keys()) + edge_type_map = {(source, target): edge_names for source in labels for target in labels} + + return {"entity_types": entity_types or None, "edge_types": edge_types or None, "edge_type_map": edge_type_map or None} + + +async def graphiti(graph_id: str) -> Graphiti: + if not OPENAI_API_KEY: + raise RuntimeError("OPENAI_API_KEY/LLM_API_KEY is not configured") + os.environ.setdefault("OPENAI_API_KEY", OPENAI_API_KEY) + os.environ.setdefault("OPENAI_BASE_URL", OPENAI_BASE_URL) + os.environ.setdefault("MODEL_NAME", MODEL_NAME) + os.environ.setdefault("EMBEDDING_MODEL_NAME", EMBEDDING_MODEL_NAME) + driver = FalkorDriver(host=FALKORDB_HOST, port=FALKORDB_PORT, database=graph_id) + return Graphiti(graph_driver=driver) + + +async def ensure_indices(graph_id: str) -> None: + if graph_id in INDICES_INITIALIZED: + return + client = await graphiti(graph_id) + await client.build_indices_and_constraints() + INDICES_INITIALIZED.add(graph_id) + + +async def query_graph(graph_id: str, query: str, **params: Any) -> list[dict[str, Any]]: + driver = FalkorDriver(host=FALKORDB_HOST, port=FALKORDB_PORT, database=graph_id) + rows, _, _ = await driver.execute_query(query, **params) + return rows + + +def edge_from_obj(edge: Any) -> dict[str, Any]: + return { + "uuid": getattr(edge, "uuid", "") or getattr(edge, "uuid_", ""), + "name": getattr(edge, "name", ""), + "fact": getattr(edge, "fact", ""), + "source_node_uuid": getattr(edge, "source_node_uuid", ""), + "target_node_uuid": getattr(edge, "target_node_uuid", ""), + "attributes": getattr(edge, "attributes", {}) or {}, + "created_at": str(getattr(edge, "created_at", "") or "") or None, + "valid_at": str(getattr(edge, "valid_at", "") or "") or None, + "invalid_at": str(getattr(edge, "invalid_at", "") or "") or None, + "expired_at": str(getattr(edge, "expired_at", "") or "") or None, + } + + +@app.get("/health") +def health(): + return jsonify({"status": "ok", "service": "graphiti-bridge"}) + + +@app.post("/graphs") +def create_graph(): + payload = request.get_json(force=True) + graph_id = payload["graph_id"] + run(ensure_indices(graph_id)) + return jsonify({"graph_id": graph_id}) + + +@app.post("/graphs//ontology") +def set_ontology(graph_id: str): + ONTOLOGIES[graph_id] = build_ontology(request.get_json(force=True) or {}) + return jsonify({"ok": True}) + + +@app.post("/graphs//episodes") +def add_episodes(graph_id: str): + payload = request.get_json(force=True) + chunks = payload.get("chunks") or [payload.get("text", "")] + + async def add_all(): + await ensure_indices(graph_id) + client = await graphiti(graph_id) + ontology = ONTOLOGIES.get(graph_id, {}) + out = [] + for index, chunk in enumerate(chunks, 1): + result = await client.add_episode( + name=f"mirofish-chunk-{index}", + episode_body=chunk, + source_description="Mirofish text episode", + reference_time=datetime.now(timezone.utc), + source=EpisodeType.text, + group_id=graph_id, + entity_types=ontology.get("entity_types"), + edge_types=ontology.get("edge_types"), + edge_type_map=ontology.get("edge_type_map"), + custom_extraction_instructions="Extract actors and relationships relevant to the simulation. Prefer provided ontology labels when supported by the text.", + ) + episode_uuid = getattr(result.episode, "uuid", None) or getattr(result.episode, "uuid_", None) + out.append({"uuid": episode_uuid, "processed": True}) + return out + + return jsonify({"episodes": run(add_all())}) + + +@app.get("/graphs//episodes/") +def get_episode(graph_id: str, episode_uuid: str): + return jsonify({"uuid": episode_uuid, "processed": True}) + + +@app.get("/graphs//nodes") +def get_nodes(graph_id: str): + rows = run(query_graph(graph_id, """ + MATCH (n:Entity) + RETURN n.uuid AS uuid, n.name AS name, labels(n) AS labels, n.summary AS summary, n.created_at AS created_at + LIMIT 2000 + """)) + return jsonify({"nodes": rows}) + + +@app.get("/graphs//edges") +def get_edges(graph_id: str): + rows = run(query_graph(graph_id, """ + MATCH (a:Entity)-[r]->(b:Entity) + RETURN r.uuid AS uuid, type(r) AS name, r.fact AS fact, + a.uuid AS source_node_uuid, b.uuid AS target_node_uuid, + r.created_at AS created_at, r.valid_at AS valid_at, r.invalid_at AS invalid_at, r.expired_at AS expired_at + LIMIT 5000 + """)) + return jsonify({"edges": rows}) + + +@app.get("/nodes/") +def get_node(node_uuid: str): + graph_id = request.args.get("graph_id") + if not graph_id: + return jsonify({"node": None}) + rows = run(query_graph(graph_id, """ + MATCH (n:Entity {uuid: $uuid}) + RETURN n.uuid AS uuid, n.name AS name, labels(n) AS labels, n.summary AS summary, n.created_at AS created_at + LIMIT 1 + """, uuid=node_uuid)) + return jsonify({"node": rows[0] if rows else None}) + + +@app.get("/nodes//edges") +def get_node_edges(node_uuid: str): + graph_id = request.args.get("graph_id") + if not graph_id: + return jsonify({"edges": []}) + rows = run(query_graph(graph_id, """ + MATCH (a:Entity)-[r]->(b:Entity) + WHERE a.uuid = $uuid OR b.uuid = $uuid + RETURN r.uuid AS uuid, type(r) AS name, r.fact AS fact, + a.uuid AS source_node_uuid, b.uuid AS target_node_uuid, + r.created_at AS created_at, r.valid_at AS valid_at, r.invalid_at AS invalid_at, r.expired_at AS expired_at + LIMIT 5000 + """, uuid=node_uuid)) + return jsonify({"edges": rows}) + + +@app.post("/graphs//search") +def search(graph_id: str): + payload = request.get_json(force=True) + query = payload.get("query", "") + limit = int(payload.get("limit", 10)) + scope = payload.get("scope", "edges") + + async def do_search(): + client = await graphiti(graph_id) + edges = [] if scope == "nodes" else [edge_from_obj(edge) for edge in await client.search(query=query, group_ids=[graph_id], num_results=limit)] + nodes = [] + if scope in {"nodes", "both"}: + terms = [term.lower() for term in query.split() if len(term) > 1] + rows = await query_graph(graph_id, """ + MATCH (n:Entity) + RETURN n.uuid AS uuid, n.name AS name, labels(n) AS labels, n.summary AS summary, n.created_at AS created_at + LIMIT 2000 + """) + scored = [] + for row in rows: + text = f"{row.get('name', '')} {row.get('summary', '')}".lower() + score = sum(1 for term in terms if term in text) + if query.lower() in text: + score += 10 + if score: + scored.append((score, row)) + scored.sort(key=lambda item: item[0], reverse=True) + nodes = [row for _, row in scored[:limit]] + return {"edges": edges, "nodes": nodes} + + return jsonify(run(do_search())) + + +@app.delete("/graphs/") +def delete_graph(graph_id: str): + run(query_graph(graph_id, "MATCH (n) DETACH DELETE n")) + ONTOLOGIES.pop(graph_id, None) + INDICES_INITIALIZED.discard(graph_id) + return jsonify({"ok": True}) + + +if __name__ == "__main__": + app.run(host="0.0.0.0", port=int(os.environ.get("PORT", "8008"))) diff --git a/graphiti_bridge/requirements.txt b/graphiti_bridge/requirements.txt new file mode 100644 index 000000000..8f937a6e0 --- /dev/null +++ b/graphiti_bridge/requirements.txt @@ -0,0 +1,4 @@ +flask>=3.0.0 +graphiti-core[falkordb]==0.29.2 +openai>=2.0.0 +pydantic>=2.0.0