diff --git a/.env.example b/.env.example index 78a3b72c07..0491a50112 100644 --- a/.env.example +++ b/.env.example @@ -5,12 +5,22 @@ LLM_API_KEY=your_api_key_here LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 LLM_MODEL_NAME=qwen-plus -# ===== ZEP记忆图谱配置 ===== -# 每月免费额度即可支撑简单使用:https://app.getzep.com/ +# ===== Graph memory backend ===== +# Default cloud mode: +GRAPH_MEMORY_BACKEND=zep_cloud + +# Zep Cloud configuration. Required only when GRAPH_MEMORY_BACKEND=zep_cloud. +# Free monthly quota is sufficient for simple usage: https://app.getzep.com/ ZEP_API_KEY=your_zep_api_key_here +# Local on-premise mode. Enable with: +# GRAPH_MEMORY_BACKEND=graphiti_bridge +# GRAPHITI_BRIDGE_URL=http://graphiti-bridge:8008 +# GRAPHITI_MODEL_NAME=gpt-5.4-mini +# GRAPHITI_EMBEDDING_MODEL_NAME=text-embedding-3-small + # ===== 加速 LLM 配置(可选)===== # 注意如果不使用加速配置,env文件中就不要出现下面的配置项 LLM_BOOST_API_KEY=your_api_key_here LLM_BOOST_BASE_URL=your_base_url_here -LLM_BOOST_MODEL_NAME=your_model_name_here \ No newline at end of file +LLM_BOOST_MODEL_NAME=your_model_name_here diff --git a/README.md b/README.md index de082935a7..8eef4db8d1 100644 --- a/README.md +++ b/README.md @@ -176,6 +176,10 @@ Reads `.env` from root directory by default, maps ports `3000 (frontend) / 5001 > Mirror address for faster pulling is provided as comments in `docker-compose.yml`, replace if needed. +### Optional: On-Premise Graph Memory + +The default graph-memory backend remains Zep Cloud. For local graph memory, MiroFish can run a Graphiti bridge and FalkorDB through Docker Compose. See [On-Premise Graph Memory](./docs/on-prem-graph-memory.md) for architecture, configuration, health checks, and installation-agent instructions. + ## 📬 Join the Conversation
@@ -200,4 +204,4 @@ MiroFish's simulation engine is powered by **[OASIS (Open Agent Social Interacti Star History Chart - \ No newline at end of file + diff --git a/backend/app/api/report.py b/backend/app/api/report.py index d7f2a4d03a..fce99270d7 100644 --- a/backend/app/api/report.py +++ b/backend/app/api/report.py @@ -20,6 +20,46 @@ logger = get_logger('mirofish.api.report') +def _get_status_request_data(): + """Read status parameters from JSON bodies or query strings.""" + if request.method == 'GET': + return request.args + return request.get_json(silent=True) or {} + + +def _find_report_task(task_manager: TaskManager, task_id: str = None, report_id: str = None, simulation_id: str = None): + if task_id: + task = task_manager.get_task(task_id) + return task.to_dict() if task else None + + for task in task_manager.list_tasks('report_generate'): + metadata = task.get('metadata') or {} + if report_id and metadata.get('report_id') == report_id: + return task + if simulation_id and metadata.get('simulation_id') == simulation_id: + return task + return None + + +def _report_status_payload(report, progress=None): + status = report.status.value if hasattr(report.status, 'value') else str(report.status) + payload = { + "simulation_id": report.simulation_id, + "report_id": report.report_id, + "status": status, + "progress": 100 if report.status == ReportStatus.COMPLETED else 0, + "message": t('api.reportGenerated') if report.status == ReportStatus.COMPLETED else status, + "already_completed": report.status == ReportStatus.COMPLETED + } + if progress: + payload.update(progress) + payload["simulation_id"] = report.simulation_id + payload["report_id"] = report.report_id + payload["status"] = progress.get("status", status) + payload["already_completed"] = report.status == ReportStatus.COMPLETED + return payload + + # ============== 报告生成接口 ============== @report_bp.route('/generate', methods=['POST']) @@ -200,58 +240,71 @@ def progress_callback(stage, progress, message): }), 500 -@report_bp.route('/generate/status', methods=['POST']) +@report_bp.route('/generate/status', methods=['GET', 'POST']) def get_generate_status(): """ 查询报告生成任务进度 - 请求(JSON): - { - "task_id": "task_xxxx", // 可选,generate返回的task_id - "simulation_id": "sim_xxxx" // 可选,模拟ID - } - - 返回: - { - "success": true, - "data": { - "task_id": "task_xxxx", - "status": "processing|completed|failed", - "progress": 45, - "message": "..." - } - } + 支持通过 JSON body 或 query string 传入 task_id、report_id 或 simulation_id。 """ try: - data = request.get_json() or {} + data = _get_status_request_data() task_id = data.get('task_id') + report_id = data.get('report_id') simulation_id = data.get('simulation_id') - - # 如果提供了simulation_id,先检查是否已有完成的报告 + + task_manager = TaskManager() + + if report_id: + report = ReportManager.get_report(report_id) + if report: + return jsonify({ + "success": True, + "data": _report_status_payload(report, ReportManager.get_progress(report_id)) + }) + + task = _find_report_task(task_manager, report_id=report_id) + if task: + return jsonify({"success": True, "data": task}) + + return jsonify({ + "success": False, + "error": t('api.reportNotFound', id=report_id) + }), 404 + if simulation_id: existing_report = ReportManager.get_report_by_simulation(simulation_id) - if existing_report and existing_report.status == ReportStatus.COMPLETED: + if existing_report: return jsonify({ "success": True, - "data": { - "simulation_id": simulation_id, - "report_id": existing_report.report_id, - "status": "completed", - "progress": 100, - "message": t('api.reportGenerated'), - "already_completed": True - } + "data": _report_status_payload( + existing_report, + ReportManager.get_progress(existing_report.report_id) + ) }) - + + task = _find_report_task(task_manager, simulation_id=simulation_id) + if task: + return jsonify({"success": True, "data": task}) + + return jsonify({ + "success": True, + "data": { + "simulation_id": simulation_id, + "status": "not_started", + "progress": 0, + "message": t('api.requireTaskOrSimId') + } + }) + if not task_id: return jsonify({ "success": False, "error": t('api.requireTaskOrSimId') }), 400 - - task_manager = TaskManager() - task = task_manager.get_task(task_id) + + task = _find_report_task(task_manager, task_id=task_id) if not task: return jsonify({ @@ -261,7 +314,7 @@ def get_generate_status(): return jsonify({ "success": True, - "data": task.to_dict() + "data": task }) except Exception as e: diff --git a/backend/app/config.py b/backend/app/config.py index de63e2b4b0..e6487e6705 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -31,6 +31,15 @@ class Config: LLM_API_KEY = os.environ.get('LLM_API_KEY') LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1') LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini') + + # Graph memory backend configuration + GRAPH_MEMORY_BACKEND = os.environ.get('GRAPH_MEMORY_BACKEND', 'zep_cloud') + GRAPHITI_MODEL_NAME = os.environ.get('GRAPHITI_MODEL_NAME', LLM_MODEL_NAME) + GRAPHITI_EMBEDDING_MODEL_NAME = os.environ.get('GRAPHITI_EMBEDDING_MODEL_NAME', 'text-embedding-3-small') + GRAPHITI_BRIDGE_URL = os.environ.get('GRAPHITI_BRIDGE_URL', 'http://graphiti-bridge:8008') + FALKORDB_HOST = os.environ.get('FALKORDB_HOST', 'localhost') + FALKORDB_PORT = int(os.environ.get('FALKORDB_PORT', '6379')) + FALKORDB_DATABASE = os.environ.get('FALKORDB_DATABASE', 'mirofish') # Zep配置 ZEP_API_KEY = os.environ.get('ZEP_API_KEY') @@ -69,7 +78,8 @@ def validate(cls) -> list[str]: errors: list[str] = [] if not cls.LLM_API_KEY: errors.append("LLM_API_KEY 未配置") - if not cls.ZEP_API_KEY: + graph_backend = (cls.GRAPH_MEMORY_BACKEND or 'zep_cloud').lower() + if graph_backend in {'zep', 'zep_cloud', 'zep-cloud'} and not cls.ZEP_API_KEY: errors.append("ZEP_API_KEY 未配置") return errors diff --git a/backend/app/graph_memory/__init__.py b/backend/app/graph_memory/__init__.py new file mode 100644 index 0000000000..b951eee360 --- /dev/null +++ b/backend/app/graph_memory/__init__.py @@ -0,0 +1,13 @@ +"""Graph memory backend adapters.""" + +from .base import GraphMemoryAdapter +from .factory import create_graph_memory_adapter +from .graphiti_bridge_adapter import GraphitiBridgeGraphMemoryAdapter +from .zep_cloud_adapter import ZepCloudGraphMemoryAdapter + +__all__ = [ + "GraphMemoryAdapter", + "GraphitiBridgeGraphMemoryAdapter", + "ZepCloudGraphMemoryAdapter", + "create_graph_memory_adapter", +] diff --git a/backend/app/graph_memory/base.py b/backend/app/graph_memory/base.py new file mode 100644 index 0000000000..d0b45be96d --- /dev/null +++ b/backend/app/graph_memory/base.py @@ -0,0 +1,66 @@ +"""Graph memory adapter contracts. + +This module defines the narrow graph-memory surface Mirofish needs. Concrete +backends can implement it without leaking vendor SDK details into services. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Protocol + + +class GraphMemoryAdapter(ABC): + """Backend-neutral graph memory interface used by Mirofish services.""" + + @abstractmethod + def create_graph(self, graph_id: str, name: str, description: str) -> Any: + """Create a graph and return the backend response.""" + + @abstractmethod + def set_ontology(self, graph_id: str, ontology: dict[str, Any]) -> Any: + """Apply ontology definitions for a graph.""" + + @abstractmethod + def add_text_batch(self, graph_id: str, chunks: list[str]) -> Any: + """Add a batch of text episodes to a graph.""" + + @abstractmethod + def add_text(self, graph_id: str, text: str) -> Any: + """Add a single text episode to a graph.""" + + @abstractmethod + def get_episode(self, episode_uuid: str) -> Any: + """Return one episode by UUID.""" + + @abstractmethod + def get_all_nodes(self, graph_id: str) -> list[Any]: + """Return all nodes for a graph.""" + + @abstractmethod + def get_all_edges(self, graph_id: str) -> list[Any]: + """Return all edges for a graph.""" + + @abstractmethod + def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "edges", **kwargs: Any) -> Any: + """Search graph memory.""" + + @abstractmethod + def get_node(self, node_uuid: str) -> Any: + """Return one node by UUID.""" + + @abstractmethod + def get_node_edges(self, node_uuid: str) -> list[Any]: + """Return edges related to one node.""" + + @abstractmethod + def delete_graph(self, graph_id: str) -> Any: + """Delete a graph.""" + + +class SupportsRawClient(Protocol): + """Compatibility escape hatch for legacy code not yet adapter-native.""" + + @property + def raw_client(self) -> Any: + """Return the underlying SDK client.""" diff --git a/backend/app/graph_memory/factory.py b/backend/app/graph_memory/factory.py new file mode 100644 index 0000000000..484122ddf5 --- /dev/null +++ b/backend/app/graph_memory/factory.py @@ -0,0 +1,23 @@ +"""Graph memory adapter factory.""" + +from __future__ import annotations + +from typing import Optional + +from ..config import Config +from .base import GraphMemoryAdapter +from .zep_cloud_adapter import ZepCloudGraphMemoryAdapter + + +def create_graph_memory_adapter(api_key: Optional[str] = None, backend: Optional[str] = None) -> GraphMemoryAdapter: + selected_backend = (backend or Config.GRAPH_MEMORY_BACKEND).strip().lower() + + if selected_backend in {"zep", "zep_cloud", "zep-cloud"}: + return ZepCloudGraphMemoryAdapter(api_key=api_key or Config.ZEP_API_KEY) + + if selected_backend in {"graphiti", "graphiti_core", "graphiti-core", "graphiti_bridge", "graphiti-bridge"}: + from .graphiti_bridge_adapter import GraphitiBridgeGraphMemoryAdapter + + return GraphitiBridgeGraphMemoryAdapter(api_key=api_key or Config.LLM_API_KEY) + + raise ValueError(f"Unsupported GRAPH_MEMORY_BACKEND: {selected_backend}") diff --git a/backend/app/graph_memory/graphiti_bridge_adapter.py b/backend/app/graph_memory/graphiti_bridge_adapter.py new file mode 100644 index 0000000000..cd1d6b73c0 --- /dev/null +++ b/backend/app/graph_memory/graphiti_bridge_adapter.py @@ -0,0 +1,128 @@ +"""HTTP adapter for the on-premise Graphiti bridge service.""" + +from __future__ import annotations + +import json +from types import SimpleNamespace +from typing import Any +from urllib.error import HTTPError +from urllib.parse import quote, urlencode +from urllib.request import Request, urlopen + +from .base import GraphMemoryAdapter +from ..config import Config + + +class GraphitiBridgeGraphMemoryAdapter(GraphMemoryAdapter): + """Graph memory adapter backed by the local Graphiti bridge service.""" + + def __init__(self, api_key: str | None = None, base_url: str | None = None): + self.base_url = (base_url or Config.GRAPHITI_BRIDGE_URL).rstrip("/") + self._node_graph_index: dict[str, str] = {} + + @property + def raw_client(self) -> None: + return None + + def create_graph(self, graph_id: str, name: str, description: str) -> Any: + return self._to_namespace(self._request("POST", "/graphs", {"graph_id": graph_id, "name": name, "description": description})) + + def set_ontology(self, graph_id: str, ontology: dict[str, Any]) -> Any: + return self._request("POST", f"/graphs/{quote(graph_id)}/ontology", ontology) + + def add_text_batch(self, graph_id: str, chunks: list[str]) -> list[Any]: + data = self._request("POST", f"/graphs/{quote(graph_id)}/episodes", {"chunks": chunks}) + return [self._episode(item) for item in data.get("episodes", [])] + + def add_text(self, graph_id: str, text: str) -> Any: + data = self._request("POST", f"/graphs/{quote(graph_id)}/episodes", {"text": text}) + episodes = data.get("episodes", []) + return self._episode(episodes[0]) if episodes else self._episode({"uuid": None, "processed": True}) + + def get_episode(self, episode_uuid: str) -> Any: + return self._episode({"uuid": episode_uuid, "processed": True}) + + def get_all_nodes(self, graph_id: str) -> list[Any]: + data = self._request("GET", f"/graphs/{quote(graph_id)}/nodes") + nodes = [self._node(item) for item in data.get("nodes", [])] + for node in nodes: + self._node_graph_index[node.uuid_] = graph_id + return nodes + + def get_all_edges(self, graph_id: str) -> list[Any]: + data = self._request("GET", f"/graphs/{quote(graph_id)}/edges") + return [self._edge(item) for item in data.get("edges", [])] + + def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "edges", **kwargs: Any) -> Any: + data = self._request("POST", f"/graphs/{quote(graph_id)}/search", {"query": query, "limit": limit, "scope": scope}) + nodes = [self._node(item) for item in data.get("nodes", [])] + for node in nodes: + self._node_graph_index[node.uuid_] = graph_id + return SimpleNamespace(edges=[self._edge(item) for item in data.get("edges", [])], nodes=nodes) + + def get_node(self, node_uuid: str) -> Any: + graph_id = self._node_graph_index.get(node_uuid) + if not graph_id: + return None + query = urlencode({"graph_id": graph_id}) + data = self._request("GET", f"/nodes/{quote(node_uuid)}?{query}") + node = data.get("node") + return self._node(node) if node else None + + def get_node_edges(self, node_uuid: str) -> list[Any]: + graph_id = self._node_graph_index.get(node_uuid) + if not graph_id: + return [] + query = urlencode({"graph_id": graph_id}) + data = self._request("GET", f"/nodes/{quote(node_uuid)}/edges?{query}") + return [self._edge(item) for item in data.get("edges", [])] + + def delete_graph(self, graph_id: str) -> Any: + return self._request("DELETE", f"/graphs/{quote(graph_id)}") + + def _request(self, method: str, path: str, payload: dict[str, Any] | None = None) -> dict[str, Any]: + body = None if payload is None else json.dumps(payload).encode("utf-8") + headers = {"Content-Type": "application/json"} + req = Request(f"{self.base_url}{path}", data=body, headers=headers, method=method) + try: + with urlopen(req, timeout=120) as response: + raw = response.read().decode("utf-8") + return json.loads(raw) if raw else {} + except HTTPError as exc: + error_body = exc.read().decode("utf-8", errors="replace") + raise RuntimeError(f"Graphiti bridge request failed: {exc.code} {error_body}") from exc + + def _episode(self, data: dict[str, Any]) -> Any: + uuid = data.get("uuid") or data.get("uuid_") + return SimpleNamespace(uuid_=uuid, uuid=uuid, processed=data.get("processed", True)) + + def _node(self, data: dict[str, Any]) -> Any: + uuid = data.get("uuid") or data.get("uuid_") or "" + return SimpleNamespace( + uuid_=uuid, + uuid=uuid, + name=data.get("name") or "", + labels=data.get("labels") or [], + summary=data.get("summary") or "", + attributes=data.get("attributes") or {}, + created_at=data.get("created_at"), + ) + + def _edge(self, data: dict[str, Any]) -> Any: + uuid = data.get("uuid") or data.get("uuid_") or "" + return SimpleNamespace( + uuid_=uuid, + uuid=uuid, + name=data.get("name") or "", + fact=data.get("fact") or "", + source_node_uuid=data.get("source_node_uuid") or "", + target_node_uuid=data.get("target_node_uuid") or "", + attributes=data.get("attributes") or {}, + created_at=data.get("created_at"), + valid_at=data.get("valid_at"), + invalid_at=data.get("invalid_at"), + expired_at=data.get("expired_at"), + ) + + def _to_namespace(self, data: dict[str, Any]) -> Any: + return SimpleNamespace(**data) diff --git a/backend/app/graph_memory/zep_cloud_adapter.py b/backend/app/graph_memory/zep_cloud_adapter.py new file mode 100644 index 0000000000..fae4107437 --- /dev/null +++ b/backend/app/graph_memory/zep_cloud_adapter.py @@ -0,0 +1,121 @@ +"""Zep Cloud graph memory adapter.""" + +from __future__ import annotations + +import warnings +from typing import Any, Optional + +from pydantic import Field +from zep_cloud import EpisodeData, EntityEdgeSourceTarget +from zep_cloud.client import Zep +from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel + +from .base import GraphMemoryAdapter +from ..utils.zep_paging import fetch_all_edges, fetch_all_nodes + + +class ZepCloudGraphMemoryAdapter(GraphMemoryAdapter): + """Adapter preserving the existing Zep Cloud behavior.""" + + RESERVED_NAMES = {"uuid", "name", "group_id", "name_embedding", "summary", "created_at"} + + def __init__(self, api_key: str): + if not api_key: + raise ValueError("ZEP_API_KEY 未配置") + self._client = Zep(api_key=api_key) + + @property + def raw_client(self) -> Zep: + return self._client + + def create_graph(self, graph_id: str, name: str, description: str) -> Any: + return self._client.graph.create(graph_id=graph_id, name=name, description=description) + + def set_ontology(self, graph_id: str, ontology: dict[str, Any]) -> Any: + warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") + + entity_types: dict[str, type[EntityModel]] = {} + for entity_def in ontology.get("entity_types", []): + name = entity_def["name"] + description = entity_def.get("description", f"A {name} entity.") + attrs: dict[str, Any] = {"__doc__": description} + annotations: dict[str, Any] = {} + + for attr_def in entity_def.get("attributes", []): + attr_name = self._safe_attr_name(attr_def["name"]) + attr_desc = attr_def.get("description", attr_name) + attrs[attr_name] = Field(description=attr_desc, default=None) + annotations[attr_name] = Optional[EntityText] + + attrs["__annotations__"] = annotations + entity_class = type(name, (EntityModel,), attrs) + entity_class.__doc__ = description + entity_types[name] = entity_class + + edge_definitions: dict[str, tuple[type[EdgeModel], list[EntityEdgeSourceTarget]]] = {} + for edge_def in ontology.get("edge_types", []): + name = edge_def["name"] + description = edge_def.get("description", f"A {name} relationship.") + attrs = {"__doc__": description} + annotations = {} + + for attr_def in edge_def.get("attributes", []): + attr_name = self._safe_attr_name(attr_def["name"]) + attr_desc = attr_def.get("description", attr_name) + attrs[attr_name] = Field(description=attr_desc, default=None) + annotations[attr_name] = Optional[str] + + attrs["__annotations__"] = annotations + class_name = "".join(word.capitalize() for word in name.split("_")) + edge_class = type(class_name, (EdgeModel,), attrs) + edge_class.__doc__ = description + + source_targets = [ + EntityEdgeSourceTarget(source=st.get("source", "Entity"), target=st.get("target", "Entity")) + for st in edge_def.get("source_targets", []) + ] + if source_targets: + edge_definitions[name] = (edge_class, source_targets) + + if not entity_types and not edge_definitions: + return None + + return self._client.graph.set_ontology( + graph_ids=[graph_id], + entities=entity_types if entity_types else None, + edges=edge_definitions if edge_definitions else None, + ) + + def add_text_batch(self, graph_id: str, chunks: list[str]) -> Any: + episodes = [EpisodeData(data=chunk, type="text") for chunk in chunks] + return self._client.graph.add_batch(graph_id=graph_id, episodes=episodes) + + def add_text(self, graph_id: str, text: str) -> Any: + return self._client.graph.add(graph_id=graph_id, type="text", data=text) + + def get_episode(self, episode_uuid: str) -> Any: + return self._client.graph.episode.get(uuid_=episode_uuid) + + def get_all_nodes(self, graph_id: str) -> list[Any]: + return fetch_all_nodes(self._client, graph_id) + + def get_all_edges(self, graph_id: str) -> list[Any]: + return fetch_all_edges(self._client, graph_id) + + def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "edges", **kwargs: Any) -> Any: + return self._client.graph.search(graph_id=graph_id, query=query, limit=limit, scope=scope, **kwargs) + + def get_node(self, node_uuid: str) -> Any: + return self._client.graph.node.get(uuid_=node_uuid) + + def get_node_edges(self, node_uuid: str) -> list[Any]: + return self._client.graph.node.get_entity_edges(node_uuid=node_uuid) + + def delete_graph(self, graph_id: str) -> Any: + return self._client.graph.delete(graph_id=graph_id) + + @classmethod + def _safe_attr_name(cls, attr_name: str) -> str: + if attr_name.lower() in cls.RESERVED_NAMES: + return f"entity_{attr_name}" + return attr_name diff --git a/backend/app/services/graph_builder.py b/backend/app/services/graph_builder.py index 37c9969c79..01391403b9 100644 --- a/backend/app/services/graph_builder.py +++ b/backend/app/services/graph_builder.py @@ -10,12 +10,9 @@ from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass -from zep_cloud.client import Zep -from zep_cloud import EpisodeData, EntityEdgeSourceTarget - from ..config import Config from ..models.task import TaskManager, TaskStatus -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from ..graph_memory import create_graph_memory_adapter from .text_processor import TextProcessor from ..utils.locale import t, get_locale, set_locale @@ -48,7 +45,8 @@ def __init__(self, api_key: Optional[str] = None): if not self.api_key: raise ValueError("ZEP_API_KEY 未配置") - self.client = Zep(api_key=self.api_key) + self.graph_memory = create_graph_memory_adapter(api_key=self.api_key) + self.client = getattr(self.graph_memory, 'raw_client', None) self.task_manager = TaskManager() def build_graph_async( @@ -194,7 +192,7 @@ def create_graph(self, name: str) -> str: """创建Zep图谱(公开方法)""" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" - self.client.graph.create( + self.graph_memory.create_graph( graph_id=graph_id, name=name, description="MiroFish Social Simulation Graph" @@ -204,93 +202,8 @@ def create_graph(self, name: str) -> str: def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): """设置图谱本体(公开方法)""" - import warnings - from typing import Optional - from pydantic import Field - from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel - - # 抑制 Pydantic v2 关于 Field(default=None) 的警告 - # 这是 Zep SDK 要求的用法,警告来自动态类创建,可以安全忽略 - warnings.filterwarnings('ignore', category=UserWarning, module='pydantic') - - # Zep 保留名称,不能作为属性名 - RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'} - - def safe_attr_name(attr_name: str) -> str: - """将保留名称转换为安全名称""" - if attr_name.lower() in RESERVED_NAMES: - return f"entity_{attr_name}" - return attr_name - - # 动态创建实体类型 - entity_types = {} - for entity_def in ontology.get("entity_types", []): - name = entity_def["name"] - description = entity_def.get("description", f"A {name} entity.") - - # 创建属性字典和类型注解(Pydantic v2 需要) - attrs = {"__doc__": description} - annotations = {} - - for attr_def in entity_def.get("attributes", []): - attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称 - attr_desc = attr_def.get("description", attr_name) - # Zep API 需要 Field 的 description,这是必需的 - attrs[attr_name] = Field(description=attr_desc, default=None) - annotations[attr_name] = Optional[EntityText] # 类型注解 - - attrs["__annotations__"] = annotations - - # 动态创建类 - entity_class = type(name, (EntityModel,), attrs) - entity_class.__doc__ = description - entity_types[name] = entity_class - - # 动态创建边类型 - edge_definitions = {} - for edge_def in ontology.get("edge_types", []): - name = edge_def["name"] - description = edge_def.get("description", f"A {name} relationship.") - - # 创建属性字典和类型注解 - attrs = {"__doc__": description} - annotations = {} - - for attr_def in edge_def.get("attributes", []): - attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称 - attr_desc = attr_def.get("description", attr_name) - # Zep API 需要 Field 的 description,这是必需的 - attrs[attr_name] = Field(description=attr_desc, default=None) - annotations[attr_name] = Optional[str] # 边属性用str类型 - - attrs["__annotations__"] = annotations - - # 动态创建类 - class_name = ''.join(word.capitalize() for word in name.split('_')) - edge_class = type(class_name, (EdgeModel,), attrs) - edge_class.__doc__ = description - - # 构建source_targets - source_targets = [] - for st in edge_def.get("source_targets", []): - source_targets.append( - EntityEdgeSourceTarget( - source=st.get("source", "Entity"), - target=st.get("target", "Entity") - ) - ) - - if source_targets: - edge_definitions[name] = (edge_class, source_targets) - - # 调用Zep API设置本体 - if entity_types or edge_definitions: - self.client.graph.set_ontology( - graph_ids=[graph_id], - entities=entity_types if entity_types else None, - edges=edge_definitions if edge_definitions else None, - ) - + self.graph_memory.set_ontology(graph_id, ontology) + def add_text_batches( self, graph_id: str, @@ -314,17 +227,11 @@ def add_text_batches( progress ) - # 构建episode数据 - episodes = [ - EpisodeData(data=chunk, type="text") - for chunk in batch_chunks - ] - - # 发送到Zep + # 发送到图谱记忆后端 try: - batch_result = self.client.graph.add_batch( + batch_result = self.graph_memory.add_text_batch( graph_id=graph_id, - episodes=episodes + chunks=batch_chunks ) # 收集返回的 episode uuid @@ -376,7 +283,7 @@ def _wait_for_episodes( # 检查每个 episode 的处理状态 for ep_uuid in list(pending_episodes): try: - episode = self.client.graph.episode.get(uuid_=ep_uuid) + episode = self.graph_memory.get_episode(ep_uuid) is_processed = getattr(episode, 'processed', False) if is_processed: @@ -403,10 +310,10 @@ def _wait_for_episodes( def _get_graph_info(self, graph_id: str) -> GraphInfo: """获取图谱信息""" # 获取节点(分页) - nodes = fetch_all_nodes(self.client, graph_id) + nodes = self.graph_memory.get_all_nodes(graph_id) # 获取边(分页) - edges = fetch_all_edges(self.client, graph_id) + edges = self.graph_memory.get_all_edges(graph_id) # 统计实体类型 entity_types = set() @@ -433,8 +340,8 @@ def get_graph_data(self, graph_id: str) -> Dict[str, Any]: Returns: 包含nodes和edges的字典,包括时间信息、属性等详细数据 """ - nodes = fetch_all_nodes(self.client, graph_id) - edges = fetch_all_edges(self.client, graph_id) + nodes = self.graph_memory.get_all_nodes(graph_id) + edges = self.graph_memory.get_all_edges(graph_id) # 创建节点映射用于获取节点名称 node_map = {} @@ -502,5 +409,5 @@ def get_graph_data(self, graph_id: str) -> Dict[str, Any]: def delete_graph(self, graph_id: str): """删除图谱""" - self.client.graph.delete(graph_id=graph_id) + self.graph_memory.delete_graph(graph_id) diff --git a/backend/app/services/oasis_profile_generator.py b/backend/app/services/oasis_profile_generator.py index 7704a627eb..7a85b306b6 100644 --- a/backend/app/services/oasis_profile_generator.py +++ b/backend/app/services/oasis_profile_generator.py @@ -16,12 +16,11 @@ from datetime import datetime from openai import OpenAI -from zep_cloud.client import Zep - from ..config import Config from ..utils.logger import get_logger from ..utils.locale import get_language_instruction, get_locale, set_locale, t from .zep_entity_reader import EntityNode, ZepEntityReader +from ..graph_memory import create_graph_memory_adapter logger = get_logger('mirofish.oasis_profile') @@ -205,7 +204,8 @@ def __init__( if self.zep_api_key: try: - self.zep_client = Zep(api_key=self.zep_api_key) + self.graph_memory = create_graph_memory_adapter(api_key=self.zep_api_key) + self.zep_client = getattr(self.graph_memory, 'raw_client', None) except Exception as e: logger.warning(f"Zep客户端初始化失败: {e}") @@ -324,7 +324,7 @@ def search_edges(): for attempt in range(max_retries): try: - return self.zep_client.graph.search( + return self.graph_memory.search( query=comprehensive_query, graph_id=self.graph_id, limit=30, @@ -349,7 +349,7 @@ def search_nodes(): for attempt in range(max_retries): try: - return self.zep_client.graph.search( + return self.graph_memory.search( query=comprehensive_query, graph_id=self.graph_id, limit=20, diff --git a/backend/app/services/simulation_runner.py b/backend/app/services/simulation_runner.py index e86021f808..e4c858d037 100644 --- a/backend/app/services/simulation_runner.py +++ b/backend/app/services/simulation_runner.py @@ -1702,7 +1702,7 @@ def _get_interview_history_from_db( "agent_id": user_id, "response": info.get("response", info), "prompt": info.get("prompt", ""), - "timestamp": created_at, + "timestamp": str(created_at) if created_at is not None else "", "platform": platform_name }) @@ -1713,6 +1713,28 @@ def _get_interview_history_from_db( return results + @staticmethod + def _timestamp_sort_value(value: Any) -> float: + """Return a stable numeric sort value for mixed SQLite timestamp formats.""" + if value is None: + return 0.0 + if isinstance(value, (int, float)): + return float(value) + + text_value = str(value).strip() + if not text_value: + return 0.0 + + try: + return float(text_value) + except ValueError: + pass + + try: + return datetime.fromisoformat(text_value.replace("Z", "+00:00")).timestamp() + except ValueError: + return 0.0 + @classmethod def get_interview_history( cls, @@ -1757,12 +1779,11 @@ def get_interview_history( ) results.extend(platform_results) - # 按时间降序排序 - results.sort(key=lambda x: x.get("timestamp", ""), reverse=True) + # 按时间降序排序,兼容不同平台写入的字符串/数字时间戳 + results.sort(key=lambda x: cls._timestamp_sort_value(x.get("timestamp")), reverse=True) # 如果查询了多个平台,限制总数 if len(platforms) > 1 and len(results) > limit: results = results[:limit] return results - diff --git a/backend/app/services/zep_entity_reader.py b/backend/app/services/zep_entity_reader.py index 71661be499..2dcd2d0cbc 100644 --- a/backend/app/services/zep_entity_reader.py +++ b/backend/app/services/zep_entity_reader.py @@ -7,11 +7,9 @@ from typing import Dict, Any, List, Optional, Set, Callable, TypeVar from dataclasses import dataclass, field -from zep_cloud.client import Zep - from ..config import Config from ..utils.logger import get_logger -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from ..graph_memory import create_graph_memory_adapter logger = get_logger('mirofish.zep_entity_reader') @@ -83,7 +81,8 @@ def __init__(self, api_key: Optional[str] = None): if not self.api_key: raise ValueError("ZEP_API_KEY 未配置") - self.client = Zep(api_key=self.api_key) + self.graph_memory = create_graph_memory_adapter(api_key=self.api_key) + self.client = getattr(self.graph_memory, 'raw_client', None) def _call_with_retry( self, @@ -136,7 +135,7 @@ def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: """ logger.info(f"获取图谱 {graph_id} 的所有节点...") - nodes = fetch_all_nodes(self.client, graph_id) + nodes = self.graph_memory.get_all_nodes(graph_id) nodes_data = [] for node in nodes: @@ -163,7 +162,7 @@ def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: """ logger.info(f"获取图谱 {graph_id} 的所有边...") - edges = fetch_all_edges(self.client, graph_id) + edges = self.graph_memory.get_all_edges(graph_id) edges_data = [] for edge in edges: @@ -192,7 +191,7 @@ def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: try: # 使用重试机制调用Zep API edges = self._call_with_retry( - func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid), + func=lambda: self.graph_memory.get_node_edges(node_uuid), operation_name=f"获取节点边(node={node_uuid[:8]}...)" ) @@ -348,7 +347,7 @@ def get_entity_with_context( try: # 使用重试机制获取节点 node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=entity_uuid), + func=lambda: self.graph_memory.get_node(entity_uuid), operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)" ) diff --git a/backend/app/services/zep_graph_memory_updater.py b/backend/app/services/zep_graph_memory_updater.py index e034fee2b2..9262f4dbdb 100644 --- a/backend/app/services/zep_graph_memory_updater.py +++ b/backend/app/services/zep_graph_memory_updater.py @@ -12,11 +12,10 @@ from datetime import datetime from queue import Queue, Empty -from zep_cloud.client import Zep - from ..config import Config from ..utils.logger import get_logger from ..utils.locale import get_locale, set_locale +from ..graph_memory import create_graph_memory_adapter logger = get_logger('mirofish.zep_graph_memory_updater') @@ -243,7 +242,8 @@ def __init__(self, graph_id: str, api_key: Optional[str] = None): if not self.api_key: raise ValueError("ZEP_API_KEY未配置") - self.client = Zep(api_key=self.api_key) + self.graph_memory = create_graph_memory_adapter(api_key=self.api_key) + self.client = getattr(self.graph_memory, 'raw_client', None) # 活动队列 self._activity_queue: Queue = Queue() @@ -411,10 +411,9 @@ def _send_batch_activities(self, activities: List[AgentActivity], platform: str) # 带重试的发送 for attempt in range(self.MAX_RETRIES): try: - self.client.graph.add( + self.graph_memory.add_text( graph_id=self.graph_id, - type="text", - data=combined_text + text=combined_text ) self._total_sent += 1 diff --git a/backend/app/services/zep_tools.py b/backend/app/services/zep_tools.py index 3bc8a57abb..e3048076df 100644 --- a/backend/app/services/zep_tools.py +++ b/backend/app/services/zep_tools.py @@ -13,13 +13,11 @@ from typing import Dict, Any, List, Optional from dataclasses import dataclass, field -from zep_cloud.client import Zep - from ..config import Config from ..utils.logger import get_logger from ..utils.llm_client import LLMClient from ..utils.locale import get_locale, t -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from ..graph_memory import create_graph_memory_adapter logger = get_logger('mirofish.zep_tools') @@ -427,7 +425,8 @@ def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient if not self.api_key: raise ValueError("ZEP_API_KEY 未配置") - self.client = Zep(api_key=self.api_key) + self.graph_memory = create_graph_memory_adapter(api_key=self.api_key) + self.client = getattr(self.graph_memory, 'raw_client', None) # LLM客户端用于InsightForge生成子问题 self._llm_client = llm_client logger.info(t("console.zepToolsInitialized")) @@ -488,7 +487,7 @@ def search_graph( # 尝试使用Zep Cloud Search API try: search_results = self._call_with_retry( - func=lambda: self.client.graph.search( + func=lambda: self.graph_memory.search( graph_id=graph_id, query=query, limit=limit, @@ -659,7 +658,7 @@ def get_all_nodes(self, graph_id: str) -> List[NodeInfo]: """ logger.info(t("console.fetchingAllNodes", graphId=graph_id)) - nodes = fetch_all_nodes(self.client, graph_id) + nodes = self.graph_memory.get_all_nodes(graph_id) result = [] for node in nodes: @@ -688,7 +687,7 @@ def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[Ed """ logger.info(t("console.fetchingAllEdges", graphId=graph_id)) - edges = fetch_all_edges(self.client, graph_id) + edges = self.graph_memory.get_all_edges(graph_id) result = [] for edge in edges: @@ -727,7 +726,7 @@ def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]: try: node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=node_uuid), + func=lambda: self.graph_memory.get_node(node_uuid), operation_name=t("console.fetchNodeDetailOp", uuid=node_uuid[:8]) ) diff --git a/backend/scripts/run_reddit_simulation.py b/backend/scripts/run_reddit_simulation.py index 14907cbda5..1673a719a4 100644 --- a/backend/scripts/run_reddit_simulation.py +++ b/backend/scripts/run_reddit_simulation.py @@ -52,24 +52,24 @@ class UnicodeFormatter(logging.Formatter): """自定义格式化器,将 Unicode 转义序列转换为可读字符""" - + UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})') - + def format(self, record): result = super().format(record) - + def replace_unicode(match): try: return chr(int(match.group(1), 16)) except (ValueError, OverflowError): return match.group(0) - + return self.UNICODE_ESCAPE_PATTERN.sub(replace_unicode, result) class MaxTokensWarningFilter(logging.Filter): """过滤掉 camel-ai 关于 max_tokens 的警告(我们故意不设置 max_tokens,让模型自行决定)""" - + def filter(self, record): # 过滤掉包含 max_tokens 警告的日志 if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage(): @@ -84,7 +84,7 @@ def filter(self, record): def setup_oasis_logging(log_dir: str): """配置 OASIS 的日志,使用固定名称的日志文件""" os.makedirs(log_dir, exist_ok=True) - + # 清理旧的日志文件 for f in os.listdir(log_dir): old_log = os.path.join(log_dir, f) @@ -93,9 +93,9 @@ def setup_oasis_logging(log_dir: str): os.remove(old_log) except OSError: pass - + formatter = UnicodeFormatter("%(levelname)s - %(asctime)s - %(name)s - %(message)s") - + loggers_config = { "social.agent": os.path.join(log_dir, "social.agent.log"), "social.twitter": os.path.join(log_dir, "social.twitter.log"), @@ -103,7 +103,7 @@ def setup_oasis_logging(log_dir: str): "oasis.env": os.path.join(log_dir, "oasis.env.log"), "table": os.path.join(log_dir, "table.log"), } - + for logger_name, log_file in loggers_config.items(): logger = logging.getLogger(logger_name) logger.setLevel(logging.DEBUG) @@ -130,6 +130,8 @@ def setup_oasis_logging(log_dir: str): print("请先安装: pip install oasis-ai camel-ai") sys.exit(1) +from action_logger import PlatformActionLogger + # IPC相关常量 IPC_COMMANDS_DIR = "ipc_commands" @@ -145,7 +147,7 @@ class CommandType: class IPCHandler: """IPC命令处理器""" - + def __init__(self, simulation_dir: str, env, agent_graph): self.simulation_dir = simulation_dir self.env = env @@ -154,11 +156,11 @@ def __init__(self, simulation_dir: str, env, agent_graph): self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR) self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) self._running = True - + # 确保目录存在 os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True) - + def update_status(self, status: str): """更新环境状态""" with open(self.status_file, 'w', encoding='utf-8') as f: @@ -166,30 +168,30 @@ def update_status(self, status: str): "status": status, "timestamp": datetime.now().isoformat() }, f, ensure_ascii=False, indent=2) - + def poll_command(self) -> Optional[Dict[str, Any]]: """轮询获取待处理命令""" if not os.path.exists(self.commands_dir): return None - + # 获取命令文件(按时间排序) command_files = [] for filename in os.listdir(self.commands_dir): if filename.endswith('.json'): filepath = os.path.join(self.commands_dir, filename) command_files.append((filepath, os.path.getmtime(filepath))) - + command_files.sort(key=lambda x: x[1]) - + for filepath, _ in command_files: try: with open(filepath, 'r', encoding='utf-8') as f: return json.load(f) except (json.JSONDecodeError, OSError): continue - + return None - + def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): """发送响应""" response = { @@ -199,56 +201,56 @@ def send_response(self, command_id: str, status: str, result: Dict = None, error "error": error, "timestamp": datetime.now().isoformat() } - + response_file = os.path.join(self.responses_dir, f"{command_id}.json") with open(response_file, 'w', encoding='utf-8') as f: json.dump(response, f, ensure_ascii=False, indent=2) - + # 删除命令文件 command_file = os.path.join(self.commands_dir, f"{command_id}.json") try: os.remove(command_file) except OSError: pass - + async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool: """ 处理单个Agent采访命令 - + Returns: True 表示成功,False 表示失败 """ try: # 获取Agent agent = self.agent_graph.get_agent(agent_id) - + # 创建Interview动作 interview_action = ManualAction( action_type=ActionType.INTERVIEW, action_args={"prompt": prompt} ) - + # 执行Interview actions = {agent: interview_action} await self.env.step(actions) - + # 从数据库获取结果 result = self._get_interview_result(agent_id) - + self.send_response(command_id, "completed", result=result) print(f" Interview完成: agent_id={agent_id}") return True - + except Exception as e: error_msg = str(e) print(f" Interview失败: agent_id={agent_id}, error={error_msg}") self.send_response(command_id, "failed", error=error_msg) return False - + async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool: """ 处理批量采访命令 - + Args: interviews: [{"agent_id": int, "prompt": str}, ...] """ @@ -256,11 +258,11 @@ async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) # 构建动作字典 actions = {} agent_prompts = {} # 记录每个agent的prompt - + for interview in interviews: agent_id = interview.get("agent_id") prompt = interview.get("prompt", "") - + try: agent = self.agent_graph.get_agent(agent_id) actions[agent] = ManualAction( @@ -270,50 +272,50 @@ async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) agent_prompts[agent_id] = prompt except Exception as e: print(f" 警告: 无法获取Agent {agent_id}: {e}") - + if not actions: self.send_response(command_id, "failed", error="没有有效的Agent") return False - + # 执行批量Interview await self.env.step(actions) - + # 获取所有结果 results = {} for agent_id in agent_prompts.keys(): result = self._get_interview_result(agent_id) results[agent_id] = result - + self.send_response(command_id, "completed", result={ "interviews_count": len(results), "results": results }) print(f" 批量Interview完成: {len(results)} 个Agent") return True - + except Exception as e: error_msg = str(e) print(f" 批量Interview失败: {error_msg}") self.send_response(command_id, "failed", error=error_msg) return False - + def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: """从数据库获取最新的Interview结果""" db_path = os.path.join(self.simulation_dir, "reddit_simulation.db") - + result = { "agent_id": agent_id, "response": None, "timestamp": None } - + if not os.path.exists(db_path): return result - + try: conn = sqlite3.connect(db_path) cursor = conn.cursor() - + # 查询最新的Interview记录 cursor.execute(""" SELECT user_id, info, created_at @@ -322,7 +324,7 @@ def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: ORDER BY created_at DESC LIMIT 1 """, (ActionType.INTERVIEW.value, agent_id)) - + row = cursor.fetchone() if row: user_id, info_json, created_at = row @@ -332,31 +334,31 @@ def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: result["timestamp"] = created_at except json.JSONDecodeError: result["response"] = info_json - + conn.close() - + except Exception as e: print(f" 读取Interview结果失败: {e}") - + return result - + async def process_commands(self) -> bool: """ 处理所有待处理命令 - + Returns: True 表示继续运行,False 表示应该退出 """ command = self.poll_command() if not command: return True - + command_id = command.get("command_id") command_type = command.get("command_type") args = command.get("args", {}) - + print(f"\n收到IPC命令: {command_type}, id={command_id}") - + if command_type == CommandType.INTERVIEW: await self.handle_interview( command_id, @@ -364,19 +366,19 @@ async def process_commands(self) -> bool: args.get("prompt", "") ) return True - + elif command_type == CommandType.BATCH_INTERVIEW: await self.handle_batch_interview( command_id, args.get("interviews", []) ) return True - + elif command_type == CommandType.CLOSE_ENV: print("收到关闭环境命令") self.send_response(command_id, "completed", result={"message": "环境即将关闭"}) return False - + else: self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}") return True @@ -384,7 +386,7 @@ async def process_commands(self) -> bool: class RedditSimulationRunner: """Reddit模拟运行器""" - + # Reddit可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) AVAILABLE_ACTIONS = [ ActionType.LIKE_POST, @@ -401,11 +403,11 @@ class RedditSimulationRunner: ActionType.FOLLOW, ActionType.MUTE, ] - + def __init__(self, config_path: str, wait_for_commands: bool = True): """ 初始化模拟运行器 - + Args: config_path: 配置文件路径 (simulation_config.json) wait_for_commands: 模拟完成后是否等待命令(默认True) @@ -417,24 +419,27 @@ def __init__(self, config_path: str, wait_for_commands: bool = True): self.env = None self.agent_graph = None self.ipc_handler = None - + self.action_logger = PlatformActionLogger("reddit", self.simulation_dir) + self._last_synced_trace_rowid = 0 + self._synced_actions_count = 0 + def _load_config(self) -> Dict[str, Any]: """加载配置文件""" with open(self.config_path, 'r', encoding='utf-8') as f: return json.load(f) - + def _get_profile_path(self) -> str: """获取Profile文件路径""" return os.path.join(self.simulation_dir, "reddit_profiles.json") - + def _get_db_path(self) -> str: """获取数据库路径""" return os.path.join(self.simulation_dir, "reddit_simulation.db") - + def _create_model(self): """ 创建LLM模型 - + 统一使用项目根目录 .env 文件中的配置(优先级最高): - LLM_API_KEY: API密钥 - LLM_BASE_URL: API基础URL @@ -444,31 +449,31 @@ def _create_model(self): llm_api_key = os.environ.get("LLM_API_KEY", "") llm_base_url = os.environ.get("LLM_BASE_URL", "") llm_model = os.environ.get("LLM_MODEL_NAME", "") - + # 如果 .env 中没有,则使用 config 作为备用 if not llm_model: llm_model = self.config.get("llm_model", "gpt-4o-mini") - + # 设置 camel-ai 所需的环境变量 if llm_api_key: os.environ["OPENAI_API_KEY"] = llm_api_key - + if not os.environ.get("OPENAI_API_KEY"): raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY") - + if llm_base_url: os.environ["OPENAI_API_BASE_URL"] = llm_base_url - + print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...") - + return ModelFactory.create( model_platform=ModelPlatformType.OPENAI, model_type=llm_model, ) - + def _get_active_agents_for_round( - self, - env, + self, + env, current_hour: int, round_num: int ) -> List: @@ -477,39 +482,39 @@ def _get_active_agents_for_round( """ time_config = self.config.get("time_config", {}) agent_configs = self.config.get("agent_configs", []) - + base_min = time_config.get("agents_per_hour_min", 5) base_max = time_config.get("agents_per_hour_max", 20) - + peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]) off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5]) - + if current_hour in peak_hours: multiplier = time_config.get("peak_activity_multiplier", 1.5) elif current_hour in off_peak_hours: multiplier = time_config.get("off_peak_activity_multiplier", 0.3) else: multiplier = 1.0 - + target_count = int(random.uniform(base_min, base_max) * multiplier) - + candidates = [] for cfg in agent_configs: agent_id = cfg.get("agent_id", 0) active_hours = cfg.get("active_hours", list(range(8, 23))) activity_level = cfg.get("activity_level", 0.5) - + if current_hour not in active_hours: continue - + if random.random() < activity_level: candidates.append(agent_id) - + selected_ids = random.sample( - candidates, + candidates, min(target_count, len(candidates)) ) if candidates else [] - + active_agents = [] for agent_id in selected_ids: try: @@ -517,12 +522,74 @@ def _get_active_agents_for_round( active_agents.append((agent_id, agent)) except Exception: pass - + return active_agents - + + def _get_agent_name(self, agent_id: int) -> str: + """Return the configured entity name for an OASIS agent.""" + for agent_config in self.config.get("agent_configs", []): + if agent_config.get("agent_id") == agent_id: + return agent_config.get("entity_name") or f"Agent {agent_id}" + return f"Agent {agent_id}" + + def _sync_trace_actions(self, round_num: int) -> int: + """Mirror new OASIS trace rows into Mirofish action logs.""" + db_path = self._get_db_path() + if not os.path.exists(db_path): + return 0 + + try: + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.execute( + """ + SELECT rowid, user_id, created_at, action, info + FROM trace + WHERE rowid > ? + ORDER BY rowid ASC + """, + (self._last_synced_trace_rowid,), + ) + rows = cursor.fetchall() + conn.close() + except Exception as e: + print(f" 警告: 无法同步 trace 到 actions 日志: {e}") + return 0 + + synced_actions = 0 + ignored_actions = {"sign_up"} + + for rowid, user_id, created_at, action, info_json in rows: + self._last_synced_trace_rowid = max(self._last_synced_trace_rowid, rowid) + + if action in ignored_actions: + continue + + try: + action_args = json.loads(info_json) if info_json else {} + except json.JSONDecodeError: + action_args = {"raw_info": info_json} + + agent_id = int(user_id) if user_id is not None else 0 + action_type = str(action or "").upper() + + self.action_logger.log_action( + round_num=round_num, + agent_id=agent_id, + agent_name=self._get_agent_name(agent_id), + action_type=action_type, + action_args=action_args, + result=json.dumps(action_args, ensure_ascii=False) if action_args else None, + success=True, + ) + synced_actions += 1 + self._synced_actions_count += 1 + + return synced_actions + async def run(self, max_rounds: int = None): """运行Reddit模拟 - + Args: max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) """ @@ -532,19 +599,19 @@ async def run(self, max_rounds: int = None): print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}") print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}") print("=" * 60) - + time_config = self.config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) total_rounds = (total_hours * 60) // minutes_per_round - + # 如果指定了最大轮数,则截断 if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) if total_rounds < original_rounds: print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") - + print(f"\n模拟参数:") print(f" - 总模拟时长: {total_hours}小时") print(f" - 每轮时间: {minutes_per_round}分钟") @@ -552,27 +619,27 @@ async def run(self, max_rounds: int = None): if max_rounds: print(f" - 最大轮数限制: {max_rounds}") print(f" - Agent数量: {len(self.config.get('agent_configs', []))}") - + print("\n初始化LLM模型...") model = self._create_model() - + print("加载Agent Profile...") profile_path = self._get_profile_path() if not os.path.exists(profile_path): print(f"错误: Profile文件不存在: {profile_path}") return - + self.agent_graph = await generate_reddit_agent_graph( profile_path=profile_path, model=model, available_actions=self.AVAILABLE_ACTIONS, ) - + db_path = self._get_db_path() if os.path.exists(db_path): os.remove(db_path) print(f"已删除旧数据库: {db_path}") - + print("创建OASIS环境...") self.env = oasis.make( agent_graph=self.agent_graph, @@ -580,18 +647,19 @@ async def run(self, max_rounds: int = None): database_path=db_path, semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 ) - + await self.env.reset() print("环境初始化完成\n") - + # 初始化IPC处理器 self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph) self.ipc_handler.update_status("running") - + self.action_logger.log_simulation_start(self.config) + # 执行初始事件 event_config = self.config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) - + if initial_posts: print(f"执行初始事件 ({len(initial_posts)}条初始帖子)...") initial_actions = {} @@ -614,34 +682,40 @@ async def run(self, max_rounds: int = None): ) except Exception as e: print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}") - + if initial_actions: await self.env.step(initial_actions) - print(f" 已发布 {len(initial_actions)} 条初始帖子") - + synced = self._sync_trace_actions(round_num=0) + print(f" 已发布 {len(initial_actions)} 条初始帖子,同步 {synced} 条动作") + # 主模拟循环 print("\n开始模拟循环...") start_time = datetime.now() - + for round_num in range(total_rounds): simulated_minutes = round_num * minutes_per_round simulated_hour = (simulated_minutes // 60) % 24 simulated_day = simulated_minutes // (60 * 24) + 1 - + active_agents = self._get_active_agents_for_round( self.env, simulated_hour, round_num ) - + + self.action_logger.log_round_start(round_num + 1, simulated_hour) if not active_agents: + self.action_logger.log_round_end(round_num + 1, 0) continue - + actions = { agent: LLMAction() for _, agent in active_agents } - + await self.env.step(actions) - + + synced = self._sync_trace_actions(round_num=round_num + 1) + self.action_logger.log_round_end(round_num + 1, synced) + if (round_num + 1) % 10 == 0 or round_num == 0: elapsed = (datetime.now() - start_time).total_seconds() progress = (round_num + 1) / total_rounds * 100 @@ -649,21 +723,23 @@ async def run(self, max_rounds: int = None): f"Round {round_num + 1}/{total_rounds} ({progress:.1f}%) " f"- {len(active_agents)} agents active " f"- elapsed: {elapsed:.1f}s") - + + self.action_logger.log_simulation_end(total_rounds, self._synced_actions_count) + total_elapsed = (datetime.now() - start_time).total_seconds() print(f"\n模拟循环完成!") print(f" - 总耗时: {total_elapsed:.1f}秒") print(f" - 数据库: {db_path}") - + # 是否进入等待命令模式 if self.wait_for_commands: print("\n" + "=" * 60) print("进入等待命令模式 - 环境保持运行") print("支持的命令: interview, batch_interview, close_env") print("=" * 60) - + self.ipc_handler.update_status("alive") - + # 等待命令循环(使用全局 _shutdown_event) try: while not _shutdown_event.is_set(): @@ -681,13 +757,13 @@ async def run(self, max_rounds: int = None): print("\n任务被取消") except Exception as e: print(f"\n命令处理出错: {e}") - + print("\n关闭环境...") - + # 关闭环境 self.ipc_handler.update_status("stopped") await self.env.close() - + print("环境已关闭") print("=" * 60) @@ -695,8 +771,8 @@ async def run(self, max_rounds: int = None): async def main(): parser = argparse.ArgumentParser(description='OASIS Reddit模拟') parser.add_argument( - '--config', - type=str, + '--config', + type=str, required=True, help='配置文件路径 (simulation_config.json)' ) @@ -712,21 +788,21 @@ async def main(): default=False, help='模拟完成后立即关闭环境,不进入等待命令模式' ) - + args = parser.parse_args() - + # 在 main 函数开始时创建 shutdown 事件 global _shutdown_event _shutdown_event = asyncio.Event() - + if not os.path.exists(args.config): print(f"错误: 配置文件不存在: {args.config}") sys.exit(1) - + # 初始化日志配置(使用固定文件名,清理旧日志) simulation_dir = os.path.dirname(args.config) or "." setup_oasis_logging(os.path.join(simulation_dir, "log")) - + runner = RedditSimulationRunner( config_path=args.config, wait_for_commands=not args.no_wait @@ -751,7 +827,7 @@ def signal_handler(signum, frame): # 重复收到信号才强制退出 print("强制退出...") sys.exit(1) - + signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) diff --git a/docker-compose.yml b/docker-compose.yml index 637f1dfaee..bf1a3b431c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,4 +11,37 @@ services: - "5001:5001" restart: unless-stopped volumes: - - ./backend/uploads:/app/backend/uploads \ No newline at end of file + - ./backend/uploads:/app/backend/uploads + + graphiti-falkordb: + image: falkordb/falkordb:latest + container_name: graphiti-falkordb + profiles: + - graphiti + volumes: + - graphiti_falkordb_data:/data + restart: unless-stopped + + graphiti-bridge: + build: + context: . + dockerfile: graphiti_bridge/Dockerfile + image: mirofish-graphiti-bridge:latest + container_name: graphiti-bridge + profiles: + - graphiti + environment: + OPENAI_API_KEY: ${LLM_API_KEY} + OPENAI_BASE_URL: ${LLM_BASE_URL:-https://api.openai.com/v1} + MODEL_NAME: ${GRAPHITI_MODEL_NAME:-${LLM_MODEL_NAME:-gpt-5.4-mini}} + EMBEDDING_MODEL_NAME: ${GRAPHITI_EMBEDDING_MODEL_NAME:-text-embedding-3-small} + FALKORDB_HOST: graphiti-falkordb + FALKORDB_PORT: 6379 + depends_on: + - graphiti-falkordb + ports: + - "127.0.0.1:8008:8008" + restart: unless-stopped + +volumes: + graphiti_falkordb_data: diff --git a/docs/on-prem-graph-memory.md b/docs/on-prem-graph-memory.md new file mode 100644 index 0000000000..35fb3be171 --- /dev/null +++ b/docs/on-prem-graph-memory.md @@ -0,0 +1,211 @@ +# On-Premise Graph Memory + +This document describes the intended on-premise graph-memory setup for MiroFish. The goal is to keep the default Zep Cloud setup unchanged while allowing operators to run graph memory locally with Graphiti in a separate Docker service. + +## Target architecture + +MiroFish should not vendor or fork Graphiti. Graphiti stays an external service and is consumed through the graph-memory adapter layer. + +```text +MiroFish application + -> graph-memory adapter + -> Zep Cloud adapter, default + -> Graphiti bridge adapter, optional on-premise mode + -> Graphiti bridge container + -> Graphiti Core from the public Graphiti package + -> FalkorDB container for graph storage +``` + +This separation keeps the MiroFish codebase small, avoids dependency conflicts with the existing backend, and makes the deployment model explicit. The Graphiti bridge container owns the Graphiti Python dependencies. MiroFish only calls the bridge through HTTP. + +## Repository responsibilities + +The MiroFish repository contains: + +- the graph-memory provider interface; +- the default Zep Cloud implementation; +- the optional Graphiti bridge implementation; +- the `graphiti_bridge` service wrapper; +- Docker Compose wiring for the bridge and FalkorDB; +- environment variables selecting the active backend. + +The MiroFish repository does not contain: + +- a copied Graphiti source tree; +- local modifications to the Graphiti upstream project; +- a hard dependency on on-premise graph memory for default users. + +## Backend selection + +Zep Cloud remains the default backend. Existing installations continue to work with the current configuration. + +```env +GRAPH_MEMORY_BACKEND=zep_cloud +ZEP_API_KEY=your_zep_api_key_here +``` + +On-premise mode is enabled explicitly: + +```env +GRAPH_MEMORY_BACKEND=graphiti_bridge +GRAPHITI_BRIDGE_URL=http://graphiti-bridge:8008 +GRAPHITI_MODEL_NAME=gpt-5.4-mini +GRAPHITI_EMBEDDING_MODEL_NAME=text-embedding-3-small +``` + +The bridge receives the LLM credentials through the existing OpenAI-compatible variables: + +```env +LLM_API_KEY=your_api_key_here +LLM_BASE_URL=https://api.openai.com/v1 +LLM_MODEL_NAME=gpt-5.4-mini +``` + +## Docker services + +The on-premise graph-memory stack uses two additional services: + +| Service | Purpose | +| --- | --- | +| `graphiti-bridge` | HTTP compatibility layer used by MiroFish. It imports `graphiti-core` and translates MiroFish graph-memory calls to Graphiti operations. | +| `graphiti-falkordb` | Local graph database used by Graphiti. Data is persisted in the `graphiti_falkordb_data` Docker volume. | + +Start the graph-memory services with the `graphiti` profile: + +```bash +docker compose --profile graphiti up -d graphiti-falkordb graphiti-bridge +``` + +Then start or restart MiroFish: + +```bash +docker compose build mirofish +docker compose up -d mirofish +``` + +For a full local stack, run: + +```bash +docker compose --profile graphiti up -d +``` + +## Expected runtime URLs + +Inside Docker Compose: + +```text +MiroFish backend -> http://graphiti-bridge:8008 +Graphiti bridge -> graphiti-falkordb:6379 +``` + +From the host machine: + +```text +MiroFish frontend: http://localhost:3000 +MiroFish backend: http://localhost:5001 +Graphiti bridge: http://127.0.0.1:8008 +``` + +The bridge is bound to `127.0.0.1` on the host so it is not exposed on the network by default. + +## Health checks + +Check the bridge: + +```bash +curl http://127.0.0.1:8008/health +``` + +Check the MiroFish backend: + +```bash +curl http://localhost:5001/health +``` + +Expected bridge response: + +```json +{ + "service": "graphiti-bridge", + "status": "ok" +} +``` + +## Smoke-test flow + +After startup, verify the complete path through the application rather than only checking container health. + +1. Open `http://localhost:3000`. +2. Create a new project. +3. Add a small but meaningful source document with several named actors and relationships. +4. Generate the ontology. +5. Build the graph. +6. Confirm that the graph contains more than placeholder nodes and edges. +7. Create a simulation. +8. Run at least one round. +9. Confirm that the simulation can read from graph memory and complete without a Zep Cloud key. + +For API-level validation, the same flow should cover ontology generation, graph build, simulation creation, simulation start, simulation status, graph readback, and report generation. + +## Operational notes + +- Use Zep Cloud for the default upstream-compatible mode. +- Use `graphiti_bridge` for local, no-cloud graph memory. +- Keep Graphiti isolated in its own container to avoid dependency conflicts. +- Keep the bridge API narrow and aligned with the MiroFish graph-memory adapter interface. +- Store graph data in the Docker volume `graphiti_falkordb_data`. +- Remove the volume only when local graph-memory data should be discarded. +- Do not expose the bridge port publicly unless an operator has added authentication and network controls. + +## Fresh MacBook installation notes + +For a clean on-premise installation on a new MacBook, the operator should install Docker Desktop, clone the MiroFish fork or the target branch, create `.env`, select `GRAPH_MEMORY_BACKEND=graphiti_bridge`, and start the Compose stack with the `graphiti` profile. + +The installation should not require a Zep Cloud account in on-premise mode. It still requires an OpenAI-compatible LLM endpoint because Graphiti extraction and embeddings use model calls. + +Recommended `.env` shape for on-premise testing: + +```env +LLM_API_KEY=your_api_key_here +LLM_BASE_URL=https://api.openai.com/v1 +LLM_MODEL_NAME=gpt-5.4-mini + +GRAPH_MEMORY_BACKEND=graphiti_bridge +GRAPHITI_BRIDGE_URL=http://graphiti-bridge:8008 +GRAPHITI_MODEL_NAME=gpt-5.4-mini +GRAPHITI_EMBEDDING_MODEL_NAME=text-embedding-3-small +``` + +## Instructions for an AI installation agent + +Use this section as the direct installation brief for an automated coding or operations agent. + +Objective: install and verify MiroFish with local on-premise graph memory. Do not use Zep Cloud. Do not copy Graphiti source code into MiroFish. Run Graphiti through the existing `graphiti_bridge` Docker service and persist graph data in FalkorDB. + +Steps: + +1. Confirm Docker Desktop or Docker Engine is installed and running. +2. Clone the target MiroFish repository and checkout the on-premise graph-memory branch. +3. Create `.env` from `.env.example`. +4. Set `GRAPH_MEMORY_BACKEND=graphiti_bridge`. +5. Set `GRAPHITI_BRIDGE_URL=http://graphiti-bridge:8008`. +6. Set `LLM_API_KEY`, `LLM_BASE_URL`, and `LLM_MODEL_NAME` to an OpenAI-compatible model endpoint. +7. Set `GRAPHITI_MODEL_NAME` and `GRAPHITI_EMBEDDING_MODEL_NAME` if different from the defaults. +8. Start the stack with `docker compose --profile graphiti up -d --build`. +9. Verify `http://127.0.0.1:8008/health` returns status `ok`. +10. Verify `http://localhost:5001/health` returns successfully. +11. Open `http://localhost:3000` and run a complete graph-build and simulation smoke test. +12. Inspect the graph result and reject the installation if only placeholder or generic nodes are created from a rich source document. +13. Record the tested commit, `.env` backend mode, container status, and smoke-test result. + +Acceptance criteria: + +- The application starts from Docker Compose. +- No Zep Cloud API key is required in `graphiti_bridge` mode. +- The bridge health check succeeds. +- A project can build a graph from source documents. +- Graph nodes and edges reflect the document content. +- A simulation can be created and started against the local graph memory. +- The setup remains compatible with the default Zep Cloud mode when `GRAPH_MEMORY_BACKEND=zep_cloud` is selected. + +Do not mark the installation complete if the graph build produces only a few generic nodes from a rich source document, if the simulation fails during startup, or if the application still requires `ZEP_API_KEY` while `GRAPH_MEMORY_BACKEND=graphiti_bridge` is selected. diff --git a/graphiti_bridge/Dockerfile b/graphiti_bridge/Dockerfile new file mode 100644 index 0000000000..bacc544f27 --- /dev/null +++ b/graphiti_bridge/Dockerfile @@ -0,0 +1,8 @@ +FROM python:3.12-slim + +WORKDIR /app +COPY graphiti_bridge/requirements.txt ./requirements.txt +RUN pip install --no-cache-dir -r requirements.txt +COPY graphiti_bridge/app.py ./app.py +EXPOSE 8008 +CMD ["python", "app.py"] diff --git a/graphiti_bridge/app.py b/graphiti_bridge/app.py new file mode 100644 index 0000000000..3c0d139575 --- /dev/null +++ b/graphiti_bridge/app.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import asyncio +import os +from datetime import datetime, timezone +from typing import Any, Optional + +from flask import Flask, jsonify, request +from pydantic import BaseModel, Field + +from graphiti_core.driver.falkordb_driver import FalkorDriver +from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig +from graphiti_core.graphiti import Graphiti +from graphiti_core.llm_client.config import LLMConfig +from graphiti_core.llm_client.openai_client import OpenAIClient +from graphiti_core.nodes import EpisodeType + +app = Flask(__name__) + +FALKORDB_HOST = os.environ.get("FALKORDB_HOST", "graphiti-falkordb") +FALKORDB_PORT = int(os.environ.get("FALKORDB_PORT", "6379")) +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY") +OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL") or os.environ.get("LLM_BASE_URL", "https://api.openai.com/v1") +MODEL_NAME = os.environ.get("MODEL_NAME") or os.environ.get("GRAPHITI_MODEL_NAME", "gpt-5.4-mini") +EMBEDDING_MODEL_NAME = os.environ.get("EMBEDDING_MODEL_NAME") or os.environ.get("GRAPHITI_EMBEDDING_MODEL_NAME", "text-embedding-3-small") + +ONTOLOGIES: dict[str, dict[str, Any]] = {} +INDICES_INITIALIZED: set[str] = set() +RESERVED_NAMES = {"uuid", "name", "group_id", "name_embedding", "summary", "created_at"} + + +def run(coro): + return asyncio.run(coro) + + +def safe_attr_name(attr_name: str) -> str: + if attr_name.lower() in RESERVED_NAMES: + return f"entity_{attr_name}" + return attr_name + + +def build_ontology(ontology: dict[str, Any]) -> dict[str, Any]: + entity_types: dict[str, type[BaseModel]] = {} + entity_names: list[str] = [] + for entity_def in ontology.get("entity_types", []): + name = entity_def["name"] + description = entity_def.get("description", f"A {name} entity.") + attrs: dict[str, Any] = {"__doc__": description} + annotations: dict[str, Any] = {} + for attr_def in entity_def.get("attributes", []): + attr_name = safe_attr_name(attr_def["name"]) + attrs[attr_name] = Field(default=None, description=attr_def.get("description", attr_name)) + annotations[attr_name] = Optional[str] + attrs["__annotations__"] = annotations + entity_types[name] = type(name, (BaseModel,), attrs) + entity_names.append(name) + + edge_types: dict[str, type[BaseModel]] = {} + edge_type_map: dict[tuple[str, str], list[str]] = {} + for edge_def in ontology.get("edge_types", []): + name = edge_def["name"] + description = edge_def.get("description", f"A {name} relationship.") + edge_types[name] = type(name, (BaseModel,), {"__doc__": description, "__annotations__": {}}) + for st in edge_def.get("source_targets", []): + source = st.get("source", "Entity") + target = st.get("target", "Entity") + edge_type_map.setdefault((source, target), []).append(name) + + if edge_types and not edge_type_map: + labels = entity_names + ["Entity"] + edge_names = list(edge_types.keys()) + edge_type_map = {(source, target): edge_names for source in labels for target in labels} + + return {"entity_types": entity_types or None, "edge_types": edge_types or None, "edge_type_map": edge_type_map or None} + + +async def graphiti(graph_id: str) -> Graphiti: + if not OPENAI_API_KEY: + raise RuntimeError("OPENAI_API_KEY/LLM_API_KEY is not configured") + os.environ.setdefault("OPENAI_API_KEY", OPENAI_API_KEY) + os.environ.setdefault("OPENAI_BASE_URL", OPENAI_BASE_URL) + os.environ.setdefault("MODEL_NAME", MODEL_NAME) + os.environ.setdefault("EMBEDDING_MODEL_NAME", EMBEDDING_MODEL_NAME) + driver = FalkorDriver(host=FALKORDB_HOST, port=FALKORDB_PORT, database=graph_id) + return Graphiti(graph_driver=driver) + + +async def ensure_indices(graph_id: str) -> None: + if graph_id in INDICES_INITIALIZED: + return + client = await graphiti(graph_id) + await client.build_indices_and_constraints() + INDICES_INITIALIZED.add(graph_id) + + +async def query_graph(graph_id: str, query: str, **params: Any) -> list[dict[str, Any]]: + driver = FalkorDriver(host=FALKORDB_HOST, port=FALKORDB_PORT, database=graph_id) + rows, _, _ = await driver.execute_query(query, **params) + return rows + + +def edge_from_obj(edge: Any) -> dict[str, Any]: + return { + "uuid": getattr(edge, "uuid", "") or getattr(edge, "uuid_", ""), + "name": getattr(edge, "name", ""), + "fact": getattr(edge, "fact", ""), + "source_node_uuid": getattr(edge, "source_node_uuid", ""), + "target_node_uuid": getattr(edge, "target_node_uuid", ""), + "attributes": getattr(edge, "attributes", {}) or {}, + "created_at": str(getattr(edge, "created_at", "") or "") or None, + "valid_at": str(getattr(edge, "valid_at", "") or "") or None, + "invalid_at": str(getattr(edge, "invalid_at", "") or "") or None, + "expired_at": str(getattr(edge, "expired_at", "") or "") or None, + } + + +@app.get("/health") +def health(): + return jsonify({"status": "ok", "service": "graphiti-bridge"}) + + +@app.post("/graphs") +def create_graph(): + payload = request.get_json(force=True) + graph_id = payload["graph_id"] + run(ensure_indices(graph_id)) + return jsonify({"graph_id": graph_id}) + + +@app.post("/graphs//ontology") +def set_ontology(graph_id: str): + ONTOLOGIES[graph_id] = build_ontology(request.get_json(force=True) or {}) + return jsonify({"ok": True}) + + +@app.post("/graphs//episodes") +def add_episodes(graph_id: str): + payload = request.get_json(force=True) + chunks = payload.get("chunks") or [payload.get("text", "")] + + async def add_all(): + await ensure_indices(graph_id) + client = await graphiti(graph_id) + ontology = ONTOLOGIES.get(graph_id, {}) + out = [] + for index, chunk in enumerate(chunks, 1): + result = await client.add_episode( + name=f"mirofish-chunk-{index}", + episode_body=chunk, + source_description="Mirofish text episode", + reference_time=datetime.now(timezone.utc), + source=EpisodeType.text, + group_id=graph_id, + entity_types=ontology.get("entity_types"), + edge_types=ontology.get("edge_types"), + edge_type_map=ontology.get("edge_type_map"), + custom_extraction_instructions="Extract actors and relationships relevant to the simulation. Prefer provided ontology labels when supported by the text.", + ) + episode_uuid = getattr(result.episode, "uuid", None) or getattr(result.episode, "uuid_", None) + out.append({"uuid": episode_uuid, "processed": True}) + return out + + return jsonify({"episodes": run(add_all())}) + + +@app.get("/graphs//episodes/") +def get_episode(graph_id: str, episode_uuid: str): + return jsonify({"uuid": episode_uuid, "processed": True}) + + +@app.get("/graphs//nodes") +def get_nodes(graph_id: str): + rows = run(query_graph(graph_id, """ + MATCH (n:Entity) + RETURN n.uuid AS uuid, n.name AS name, labels(n) AS labels, n.summary AS summary, n.created_at AS created_at + LIMIT 2000 + """)) + return jsonify({"nodes": rows}) + + +@app.get("/graphs//edges") +def get_edges(graph_id: str): + rows = run(query_graph(graph_id, """ + MATCH (a:Entity)-[r]->(b:Entity) + RETURN r.uuid AS uuid, type(r) AS name, r.fact AS fact, + a.uuid AS source_node_uuid, b.uuid AS target_node_uuid, + r.created_at AS created_at, r.valid_at AS valid_at, r.invalid_at AS invalid_at, r.expired_at AS expired_at + LIMIT 5000 + """)) + return jsonify({"edges": rows}) + + +@app.get("/nodes/") +def get_node(node_uuid: str): + graph_id = request.args.get("graph_id") + if not graph_id: + return jsonify({"node": None}) + rows = run(query_graph(graph_id, """ + MATCH (n:Entity {uuid: $uuid}) + RETURN n.uuid AS uuid, n.name AS name, labels(n) AS labels, n.summary AS summary, n.created_at AS created_at + LIMIT 1 + """, uuid=node_uuid)) + return jsonify({"node": rows[0] if rows else None}) + + +@app.get("/nodes//edges") +def get_node_edges(node_uuid: str): + graph_id = request.args.get("graph_id") + if not graph_id: + return jsonify({"edges": []}) + rows = run(query_graph(graph_id, """ + MATCH (a:Entity)-[r]->(b:Entity) + WHERE a.uuid = $uuid OR b.uuid = $uuid + RETURN r.uuid AS uuid, type(r) AS name, r.fact AS fact, + a.uuid AS source_node_uuid, b.uuid AS target_node_uuid, + r.created_at AS created_at, r.valid_at AS valid_at, r.invalid_at AS invalid_at, r.expired_at AS expired_at + LIMIT 5000 + """, uuid=node_uuid)) + return jsonify({"edges": rows}) + + +@app.post("/graphs//search") +def search(graph_id: str): + payload = request.get_json(force=True) + query = payload.get("query", "") + limit = int(payload.get("limit", 10)) + scope = payload.get("scope", "edges") + + async def do_search(): + client = await graphiti(graph_id) + edges = [] if scope == "nodes" else [edge_from_obj(edge) for edge in await client.search(query=query, group_ids=[graph_id], num_results=limit)] + nodes = [] + if scope in {"nodes", "both"}: + terms = [term.lower() for term in query.split() if len(term) > 1] + rows = await query_graph(graph_id, """ + MATCH (n:Entity) + RETURN n.uuid AS uuid, n.name AS name, labels(n) AS labels, n.summary AS summary, n.created_at AS created_at + LIMIT 2000 + """) + scored = [] + for row in rows: + text = f"{row.get('name', '')} {row.get('summary', '')}".lower() + score = sum(1 for term in terms if term in text) + if query.lower() in text: + score += 10 + if score: + scored.append((score, row)) + scored.sort(key=lambda item: item[0], reverse=True) + nodes = [row for _, row in scored[:limit]] + return {"edges": edges, "nodes": nodes} + + return jsonify(run(do_search())) + + +@app.delete("/graphs/") +def delete_graph(graph_id: str): + run(query_graph(graph_id, "MATCH (n) DETACH DELETE n")) + ONTOLOGIES.pop(graph_id, None) + INDICES_INITIALIZED.discard(graph_id) + return jsonify({"ok": True}) + + +if __name__ == "__main__": + app.run(host="0.0.0.0", port=int(os.environ.get("PORT", "8008"))) diff --git a/graphiti_bridge/requirements.txt b/graphiti_bridge/requirements.txt new file mode 100644 index 0000000000..8f937a6e05 --- /dev/null +++ b/graphiti_bridge/requirements.txt @@ -0,0 +1,4 @@ +flask>=3.0.0 +graphiti-core[falkordb]==0.29.2 +openai>=2.0.0 +pydantic>=2.0.0