Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions integrations/mem0-plugin/scripts/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,31 @@
from __future__ import annotations

import json
import os
import urllib.request

SEARCH_URL = "https://api.mem0.ai/v3/memories/search/"
SEARCH_TIMEOUT = 5


def should_rerank() -> bool:
"""Whether auto-injection searches should request Platform reranking.

The REST search endpoint does not rerank when ``rerank`` is omitted, so
auto-injected context is ordered by raw vector similarity and the single
most relevant memory can fall outside the injected top_k window. We default
reranking ON for the hook-driven injection path (the extra ~150-200ms is
well within the hook's curl budget) and let users opt out via MEM0_RERANK.

MEM0_RERANK is read case-insensitively; ``0``, ``false``, ``no``, and
``off`` disable reranking. Anything else (including unset) enables it.
"""
raw = os.environ.get("MEM0_RERANK")
if raw is None:
return True
return raw.strip().lower() not in ("0", "false", "no", "off", "")


def _do_search(api_key: str, payload: dict) -> list[dict]:
body = json.dumps(payload).encode()
req = urllib.request.Request(
Expand Down
3 changes: 2 additions & 1 deletion integrations/mem0-plugin/scripts/file_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from _formatting import TYPE_ICONS, format_age
from _identity import resolve_api_key, resolve_user_id
from _project import resolve_project_id
from _search import search_memories
from _search import search_memories, should_rerank

FILE_READ_GATE_MIN_BYTES = 1500
MAX_RESULTS = 5
Expand Down Expand Up @@ -93,6 +93,7 @@ def search_file_context(
api_key, user_id, project_id, query,
top_k=MAX_RESULTS, threshold=0.3,
global_search=global_search,
rerank=should_rerank(),
)

results = results[:MAX_RESULTS]
Expand Down
7 changes: 4 additions & 3 deletions integrations/mem0-plugin/scripts/on_bash_output.sh
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,16 @@ RESULTS=$(PYTHONPATH="$SCRIPT_DIR" MEM0_SEARCH_QUERY="$ERROR_QUERY" MEM0_SEARCH_
python3 -c "
import os, sys
sys.path.insert(0, os.environ.get('PYTHONPATH', '.'))
from _search import search_memories, format_results_for_context
from _search import search_memories, format_results_for_context, should_rerank

api_key = os.environ.get('MEM0_API_KEY', '')
user_id = os.environ.get('MEM0_SEARCH_USER', 'default')
project_id = os.environ.get('MEM0_PROJECT_ID', 'unknown')
query = os.environ.get('MEM0_SEARCH_QUERY', '')
rerank = should_rerank()

r1 = search_memories(api_key, user_id, project_id, query, metadata_type='anti_pattern', top_k=3)
r2 = search_memories(api_key, user_id, project_id, query, metadata_type='bug_fix', top_k=3)
r1 = search_memories(api_key, user_id, project_id, query, metadata_type='anti_pattern', top_k=3, rerank=rerank)
r2 = search_memories(api_key, user_id, project_id, query, metadata_type='bug_fix', top_k=3, rerank=rerank)

seen = set()
combined = []
Expand Down
7 changes: 4 additions & 3 deletions integrations/mem0-plugin/scripts/on_user_prompt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,15 @@ if [ -n "$HAS_RESUME" ]; then
RESUME_RESULTS=$(PYTHONPATH="$SCRIPT_DIR" MEM0_SEARCH_USER="$USER_ID" python3 -c "
import os, sys
sys.path.insert(0, os.environ.get('PYTHONPATH', '.'))
from _search import search_memories, format_results_for_context
from _search import search_memories, format_results_for_context, should_rerank

api_key = os.environ.get('MEM0_API_KEY', '')
user_id = os.environ.get('MEM0_SEARCH_USER', 'default')
project_id = os.environ.get('MEM0_PROJECT_ID', 'unknown')
rerank = should_rerank()

state = search_memories(api_key, user_id, project_id, 'session state current task', metadata_type='session_state', top_k=3)
decisions = search_memories(api_key, user_id, project_id, 'recent decisions and learnings', metadata_type='decision', top_k=3)
state = search_memories(api_key, user_id, project_id, 'session state current task', metadata_type='session_state', top_k=3, rerank=rerank)
decisions = search_memories(api_key, user_id, project_id, 'recent decisions and learnings', metadata_type='decision', top_k=3, rerank=rerank)

all_r = state + decisions
seen = set()
Expand Down
61 changes: 61 additions & 0 deletions integrations/mem0-plugin/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,67 @@ def test_search_memories_no_api_key_returns_empty():
assert results == []


def test_search_memories_omits_rerank_by_default():
"""Regression for #5684: rerank must not be sent unless requested."""
from _search import search_memories

captured_body = {}

def mock_urlopen(req, timeout=None):
captured_body.update(json.loads(req.data.decode()))
resp = MagicMock()
resp.read.return_value = json.dumps({"results": []}).encode()
resp.__enter__ = lambda s: s
resp.__exit__ = MagicMock(return_value=False)
return resp

with patch("urllib.request.urlopen", side_effect=mock_urlopen):
search_memories("key", "user", "proj", "query")

assert "rerank" not in captured_body


def test_search_memories_forwards_rerank_true():
"""Regression for #5684: rerank=True must reach the request body so the
REST endpoint actually reranks (it does not rerank when omitted)."""
from _search import search_memories

captured_body = {}

def mock_urlopen(req, timeout=None):
captured_body.update(json.loads(req.data.decode()))
resp = MagicMock()
resp.read.return_value = json.dumps({"results": []}).encode()
resp.__enter__ = lambda s: s
resp.__exit__ = MagicMock(return_value=False)
return resp

with patch("urllib.request.urlopen", side_effect=mock_urlopen):
search_memories("key", "user", "proj", "query", rerank=True)

assert captured_body.get("rerank") is True


def test_should_rerank_defaults_true(monkeypatch):
"""Regression for #5684: auto-injection reranks by default."""
from _search import should_rerank

monkeypatch.delenv("MEM0_RERANK", raising=False)
assert should_rerank() is True


def test_should_rerank_opt_out_values(monkeypatch):
from _search import should_rerank

for falsey in ("0", "false", "False", "NO", "off", ""):
monkeypatch.setenv("MEM0_RERANK", falsey)
assert should_rerank() is False, falsey

for truthy in ("1", "true", "yes", "on"):
monkeypatch.setenv("MEM0_RERANK", truthy)
assert should_rerank() is True, truthy


def test_format_results_for_context():
from _search import format_results_for_context

Expand Down
8 changes: 7 additions & 1 deletion mem0-ts/src/oss/src/llms/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@ export class AnthropicLLM implements LLM {
if (!apiKey) {
throw new Error("Anthropic API key is required");
}
this.client = new Anthropic({ apiKey });
// Forward baseURL to the client when set so proxy/gateway users are
// honored (parity with the OpenAI provider and the Python fix in #5626).
const clientArgs: { apiKey: string; baseURL?: string } = { apiKey };
if (config.baseURL) {
clientArgs.baseURL = config.baseURL;
}
this.client = new Anthropic(clientArgs);
this.model = config.model || "claude-sonnet-4-6";
// Defaults mirror the Python provider's AnthropicConfig
// (max_tokens=2000, temperature=0.1, top_p omitted).
Expand Down
14 changes: 12 additions & 2 deletions mem0-ts/src/oss/src/memory/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ export class Memory {
if (!infer) {
const returnedMemories: MemoryItem[] = [];
for (const message of messages) {
if (message.content === "system") {
if (message.role === "system") {
continue;
}
const memoryId = await this.createMemory(
Expand Down Expand Up @@ -1189,7 +1189,13 @@ export class Memory {
}
}

const result = { ...memoryItem, ...filters };
const result = {
...memoryItem,
...filters,
...(memory.payload.attributedTo && {
attributedTo: memory.payload.attributedTo,
}),
};
await this._displayFirstRunNotice("get");
return result;
}
Expand Down Expand Up @@ -1453,6 +1459,7 @@ export class Memory {
...(payload.user_id && { user_id: payload.user_id }),
...(payload.agent_id && { agent_id: payload.agent_id }),
...(payload.run_id && { run_id: payload.run_id }),
...(payload.attributedTo && { attributedTo: payload.attributedTo }),
...(scored.scoreDetails && { score_details: scored.scoreDetails }),
};
});
Expand Down Expand Up @@ -1687,6 +1694,9 @@ export class Memory {
...(mem.payload.user_id && { user_id: mem.payload.user_id }),
...(mem.payload.agent_id && { agent_id: mem.payload.agent_id }),
...(mem.payload.run_id && { run_id: mem.payload.run_id }),
...(mem.payload.attributedTo && {
attributedTo: mem.payload.attributedTo,
}),
}));

const result = { results };
Expand Down
1 change: 1 addition & 0 deletions mem0-ts/src/oss/src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ export interface MemoryItem {
updatedAt?: string;
score?: number;
metadata?: Record<string, any>;
attributedTo?: string;
}

export interface SearchFilters {
Expand Down
37 changes: 33 additions & 4 deletions mem0-ts/src/oss/tests/anthropic-llm.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,46 @@
*/

const mockCreate = jest.fn();
const mockConstructor = jest.fn();

jest.mock("@anthropic-ai/sdk", () => {
return jest.fn().mockImplementation(() => ({
messages: { create: mockCreate },
}));
return jest.fn().mockImplementation((args) => {
mockConstructor(args);
return { messages: { create: mockCreate } };
});
});

import { AnthropicLLM } from "../src/llms/anthropic";

describe("AnthropicLLM (unit)", () => {
beforeEach(() => mockCreate.mockClear());
beforeEach(() => {
mockCreate.mockClear();
mockConstructor.mockClear();
});

// Regression #5665: a configured baseURL must reach the Anthropic client so
// proxy/gateway users are not silently bypassed (TS parity with #5626).
it("forwards baseURL to the Anthropic client when set", () => {
new AnthropicLLM({
apiKey: "test-key",
baseURL: "https://proxy.example/v1",
});

expect(mockConstructor).toHaveBeenCalledTimes(1);
const ctorArgs = mockConstructor.mock.calls[0][0];
expect(ctorArgs.apiKey).toBe("test-key");
expect(ctorArgs.baseURL).toBe("https://proxy.example/v1");
});

// When no baseURL is configured the client must not receive a baseURL key
// (so the SDK default endpoint is used).
it("does NOT set baseURL when none is configured", () => {
new AnthropicLLM({ apiKey: "test-key" });

expect(mockConstructor).toHaveBeenCalledTimes(1);
const ctorArgs = mockConstructor.mock.calls[0][0];
expect(ctorArgs.baseURL).toBeUndefined();
});

it("returns text when no tools are provided and model returns a text block", async () => {
mockCreate.mockResolvedValueOnce({
Expand Down
39 changes: 39 additions & 0 deletions mem0-ts/src/oss/tests/memory.crud.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,45 @@ describe("Memory - search()", () => {
});
});

// ─── attributedTo (#5666) ────────────────────────────────

describe("Memory - attributedTo round-trip (#5666)", () => {
let memory: Memory;
const userId = `attributed_test_${Date.now()}`;
let id: string;

beforeAll(async () => {
memory = createMemory();
// The mocked LLM tags every extracted fact with attributed_to: "user".
const addResult: SearchResult = await memory.add("I love AI", { userId });
id = addResult.results[0].id;
});

afterAll(async () => {
await memory.reset();
});

test("get() surfaces attributedTo", async () => {
const item: MemoryItem | null = await memory.get(id);
expect(item!.attributedTo).toBe("user");
});

test("getAll() surfaces attributedTo", async () => {
const result: SearchResult = await memory.getAll({
filters: { user_id: userId },
});
expect(result.results[0].attributedTo).toBe("user");
});

test("search() surfaces attributedTo", async () => {
const result: SearchResult = await memory.search("AI", {
filters: { user_id: userId },
});
expect(result.results.length).toBeGreaterThan(0);
expect(result.results[0].attributedTo).toBe("user");
});
});

// ─── history() ───────────────────────────────────────────

describe("Memory - history()", () => {
Expand Down
6 changes: 5 additions & 1 deletion mem0/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def __init__(self, config: Optional[Union[BaseLlmConfig, AnthropicConfig, Dict]]
self.config.model = "claude-sonnet-4-6"

api_key = self.config.api_key or os.getenv("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic(api_key=api_key)
base_url = self.config.anthropic_base_url or os.getenv("ANTHROPIC_BASE_URL")
client_kwargs = {"api_key": api_key}
if base_url:
client_kwargs["base_url"] = base_url
self.client = anthropic.Anthropic(**client_kwargs)

def _get_common_params(self, **kwargs) -> Dict:
"""Get common parameters, avoiding sending both temperature and top_p together.
Expand Down
6 changes: 6 additions & 0 deletions mem0/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,7 @@ def get(self, memory_id):
"run_id",
"actor_id",
"role",
"attributed_to",
]

core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", "text_lemmatized", "attributed_to", *promoted_payload_keys}
Expand Down Expand Up @@ -1215,6 +1216,7 @@ def _get_all_from_vector_store(self, filters, limit):
"run_id",
"actor_id",
"role",
"attributed_to",
]
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", "text_lemmatized", "attributed_to", *promoted_payload_keys}

Expand Down Expand Up @@ -1550,6 +1552,7 @@ def _search_vector_store(self, query, filters, limit, threshold=0.1, explain=Fal
"run_id",
"actor_id",
"role",
"attributed_to",
]
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", "text_lemmatized", "attributed_to", *promoted_payload_keys}

Expand Down Expand Up @@ -2638,6 +2641,7 @@ async def get(self, memory_id):
"run_id",
"actor_id",
"role",
"attributed_to",
]

core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", "text_lemmatized", "attributed_to", *promoted_payload_keys}
Expand Down Expand Up @@ -2751,6 +2755,7 @@ async def _get_all_from_vector_store(self, filters, limit):
"run_id",
"actor_id",
"role",
"attributed_to",
]
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", "text_lemmatized", "attributed_to", *promoted_payload_keys}

Expand Down Expand Up @@ -3092,6 +3097,7 @@ async def _search_vector_store(self, query, filters, limit, threshold=0.1, expla
"run_id",
"actor_id",
"role",
"attributed_to",
]
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", "text_lemmatized", "attributed_to", *promoted_payload_keys}

Expand Down
9 changes: 9 additions & 0 deletions mem0/utils/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import inspect
from typing import Dict, Optional, Union

from mem0.configs.embeddings.base import BaseEmbedderConfig
Expand Down Expand Up @@ -102,6 +103,14 @@ def create(cls, provider_name: str, config: Optional[Union[BaseLlmConfig, Dict]]
"vision_details": config.vision_details,
"http_client_proxies": config.http_client_proxies,
}
# Only forward reasoning fields to provider configs that accept them
# (explicitly or via **kwargs); others would raise on unexpected kwargs.
params = inspect.signature(config_class).parameters
accepts_kwargs = any(p.kind == p.VAR_KEYWORD for p in params.values())
if accepts_kwargs or "reasoning_effort" in params:
config_dict["reasoning_effort"] = config.reasoning_effort
if accepts_kwargs or "is_reasoning_model" in params:
config_dict["is_reasoning_model"] = config.is_reasoning_model
config_dict.update(kwargs)
config = config_class(**config_dict)
else:
Expand Down
Loading
Loading