diff --git a/docs/changelog/sdk.mdx b/docs/changelog/sdk.mdx index 4c17e61e1f..688913910c 100644 --- a/docs/changelog/sdk.mdx +++ b/docs/changelog/sdk.mdx @@ -7,6 +7,13 @@ mode: "wide" + + +**Bug Fixes:** +- **Memory (OSS):** Improve entity extraction precision by avoiding sentence-start common noun noise, preserving useful topic phrases, and exact-deduplicating entity links before semantic matching ([#5829](https://github.com/mem0ai/mem0/pull/5829)) + + + **New Features:** @@ -1063,6 +1070,13 @@ See the [OSS v1 to v2 migration guide](https://docs.mem0.ai/migration/oss-v1-to- + + +**Bug Fixes:** +- **Memory (OSS):** Align entity extraction with Python by reducing generic entity noise, preserving useful topic phrases, and exact-deduplicating entity links before semantic matching ([#5829](https://github.com/mem0ai/mem0/pull/5829)) + + + **Bug Fixes:** diff --git a/mem0-ts/package.json b/mem0-ts/package.json index 460b792527..3b83eaf5e3 100644 --- a/mem0-ts/package.json +++ b/mem0-ts/package.json @@ -1,6 +1,6 @@ { "name": "mem0ai", - "version": "3.0.10", + "version": "3.0.11", "description": "The Memory Layer For Your AI Apps", "main": "./dist/index.js", "module": "./dist/index.mjs", diff --git a/mem0-ts/src/oss/src/memory/index.ts b/mem0-ts/src/oss/src/memory/index.ts index 947e8c77fb..363e31f01c 100644 --- a/mem0-ts/src/oss/src/memory/index.ts +++ b/mem0-ts/src/oss/src/memory/index.ts @@ -296,6 +296,44 @@ export class Memory { return filters; } + private _normalizeEntityText(value: string): string { + return value.trim().toLowerCase().replace(/\s+/g, " "); + } + + private async _existingEntitiesByText( + entityStore: VectorStore, + filters: Record, + ): Promise }>> { + const rowsByText = new Map< + string, + { id: string; payload: Record } + >(); + let rows: Array<{ id: string; payload: Record }> = []; + try { + const listed = await entityStore.list(filters, 10000); + rows = ( + Array.isArray(listed) && Array.isArray(listed[0]) + ? listed[0] + : (listed as any) + ) as Array<{ id: string; payload: Record }>; + } catch (e) { + console.debug( + `Exact entity lookup failed, falling back to semantic dedup: ${e}`, + ); + return rowsByText; + } + + for (const row of rows) { + const text = row.payload?.data; + if (typeof text !== "string") continue; + const key = this._normalizeEntityText(text); + if (key && !rowsByText.has(key)) { + rowsByText.set(key, row); + } + } + return rowsByText; + } + /** * Remove `memoryId` from every entity record scoped to `filters`. * If an entity's `linkedMemoryIds` becomes empty after removal, the @@ -393,6 +431,10 @@ export class Memory { if (entities.length === 0) return; const entityStore = await this.getEntityStore(); + const exactMatches = await this._existingEntitiesByText( + entityStore, + filters, + ); for (const entity of entities) { try { @@ -409,12 +451,21 @@ export class Memory { score?: number; payload: Record; }> = []; - try { - matches = await entityStore.search(entityVec, 1, filters); - } catch {} + const exactMatch = exactMatches.get( + this._normalizeEntityText(entity.text), + ); + if (!exactMatch) { + try { + matches = await entityStore.search(entityVec, 1, filters); + } catch {} + } - if (matches.length > 0 && (matches[0].score ?? 0) >= 0.95) { - const match = matches[0]; + const semanticMatch = + matches.length > 0 && (matches[0].score ?? 0) >= 0.95 + ? matches[0] + : undefined; + const match = exactMatch ?? semanticMatch; + if (match) { const payload = match.payload || {}; const linked = new Set( Array.isArray(payload.linkedMemoryIds) @@ -1062,6 +1113,10 @@ export class Memory { if (valid.length > 0) { const entityStore = await this.getEntityStore(); + const exactMatches = await this._existingEntitiesByText( + entityStore, + filters, + ); // 7c: Search for existing entities one by one (no batch search) const toInsertVectors: number[][] = []; @@ -1077,13 +1132,20 @@ export class Memory { score?: number; payload: Record; }> = []; - try { - matches = await entityStore.search(entityVec, 1, filters); - } catch {} + const exactMatch = exactMatches.get(key); + if (!exactMatch) { + try { + matches = await entityStore.search(entityVec, 1, filters); + } catch {} + } - if (matches.length > 0 && (matches[0].score ?? 0) >= 0.95) { + const semanticMatch = + matches.length > 0 && (matches[0].score ?? 0) >= 0.95 + ? matches[0] + : undefined; + const match = exactMatch ?? semanticMatch; + if (match) { // Update existing entity - const match = matches[0]; const payload = match.payload || {}; const linked = new Set(payload.linkedMemoryIds ?? []); for (const mid of memoryIds) linked.add(mid); diff --git a/mem0-ts/src/oss/src/utils/entity_extraction.ts b/mem0-ts/src/oss/src/utils/entity_extraction.ts index d6f6d27002..f58cdd4275 100644 --- a/mem0-ts/src/oss/src/utils/entity_extraction.ts +++ b/mem0-ts/src/oss/src/utils/entity_extraction.ts @@ -4,8 +4,8 @@ * Extracts four types of entities from text: * - PROPER: Capitalized multi-word sequences (person names, places, brands) * - QUOTED: Text in single or double quotes (titles, specific terms) - * - COMPOUND: Multi-word noun phrases with specific modifiers (e.g., "machine learning") - * - NOUN: Single nouns from circumstantial compound patterns + * - TOPIC: Multi-word noun/topic phrases with specific modifiers + * - IDENTIFIER: Dotted technical identifiers such as person.properties.email * * Uses the `compromise` npm package for NLP-based extraction when available. * Falls back to regex-only extraction if `compromise` is not installed. @@ -196,6 +196,25 @@ const NON_SPECIFIC_ADJ: Set = new Set([ "final", "initial", "side", + "top", +]); + +/** Leading words that frame a topic but are not part of the topic itself. */ +const TOPIC_PREFIX_WORDS: Set = new Set([ + "a", + "an", + "the", + "my", + "your", + "our", + "their", + "his", + "her", + "its", + "this", + "that", + "these", + "those", ]); /** Generic tail words to strip from compound entities. */ @@ -267,6 +286,25 @@ const GENERIC_CAPS: Set = new Set([ "disadvantages", ]); +/** Generic role/title words that should not become single-token entities. */ +const GENERIC_SINGLE_ENTITY_TERMS: Set = new Set([ + "user", + "assistant", + "agent", + "customer", + "client", + "person", + "people", + "human", + "memory", + "message", + "conversation", + "chat", + "session", + "system", + "top", +]); + /** Markdown/formatting markers to skip during extraction. */ const FORMATTING_MARKERS: Set = new Set([ "*", @@ -287,7 +325,7 @@ const FORMATTING_MARKERS: Set = new Set([ // --------------------------------------------------------------------------- export interface ExtractedEntity { - type: "PROPER" | "QUOTED" | "COMPOUND" | "NOUN"; + type: "PROPER" | "QUOTED" | "TOPIC" | "IDENTIFIER"; text: string; } @@ -338,32 +376,96 @@ function stripGenericEnding(words: string[]): string[] { return words; } -/** - * Determine if a token position is at the start of a sentence. - * Simple heuristic: index 0, or preceded by sentence-ending punctuation - * or formatting markers. - */ -function isSentenceStart( - tokens: string[], - idx: number, - rawText: string, -): boolean { - if (idx === 0) { - return true; - } - const prev = tokens[idx - 1]; - if (/[.!?:]$/.test(prev)) { - return true; +function stripTopicPrefix(words: string[]): string[] { + let start = 0; + while ( + start < words.length && + TOPIC_PREFIX_WORDS.has(words[start].toLowerCase()) + ) { + start++; } - if (FORMATTING_MARKERS.has(prev)) { - return true; - } - // Check for newline before this token in the raw text - const tokenStart = rawText.indexOf(tokens[idx]); - if (tokenStart > 0 && rawText.charAt(tokenStart - 1) === "\n") { + return words.slice(start); +} + +function cleanToken(token: string): string { + return token.replace(/^[^\w.]+|[^\w.]+$/g, ""); +} + +function tokenize(text: string): string[] { + return ( + text.match( + /[A-Za-z_][\w-]*(?:\.[A-Za-z_][\w-]*)*|\d[\d,]*(?:\.\d+)?|[,:;.!?&]/g, + ) ?? [] + ); +} + +function isCapitalized(token: string): boolean { + return /^[A-Z]/.test(token) && /[A-Za-z]/.test(token); +} + +function hasInternalCapOrDigit(token: string): boolean { + return ( + /\d/.test(token) || + /[A-Z]/.test(token.slice(1)) || + /^[A-Z]{2,}$/.test(token) + ); +} + +function isBadSingleNameToken(token: string): boolean { + const lower = token.toLowerCase(); + return GENERIC_SINGLE_ENTITY_TERMS.has(lower) || GENERIC_CAPS.has(lower); +} + +function looksLikeMetricCount(token: string): boolean { + return /^\d[\d,]*(?:\.\d+)?$/.test(token); +} + +function isMetricListContext(tokens: string[], idx: number): boolean { + const prev = idx > 0 ? tokens[idx - 1] : ""; + const next = idx + 1 < tokens.length ? tokens[idx + 1] : ""; + return [":", ",", ";"].includes(prev) || [",", ";"].includes(next); +} + +function isSentenceStart(tokens: string[], idx: number): boolean { + if (idx === 0) return true; + return ( + [".", "!", "?", ":"].includes(tokens[idx - 1]) || + FORMATTING_MARKERS.has(tokens[idx - 1]) + ); +} + +function isListItemNameToken(tokens: string[], idx: number): boolean { + const token = cleanToken(tokens[idx]); + if (!isCapitalized(token) || isBadSingleNameToken(token)) return false; + const next = idx + 1 < tokens.length ? cleanToken(tokens[idx + 1]) : ""; + if (!looksLikeMetricCount(next)) return false; + return ( + isMetricListContext(tokens, idx) || isMetricListContext(tokens, idx + 1) + ); +} + +function isNameToken(tokens: string[], idx: number): boolean { + const token = cleanToken(tokens[idx]); + if (!token || !isCapitalized(token) || isBadSingleNameToken(token)) + return false; + if (hasInternalCapOrDigit(token) || isListItemNameToken(tokens, idx)) return true; - } - return false; + return !isSentenceStart(tokens, idx); +} + +function cleanEntityText(text: string): string { + return text + .replace(/^\*+\s*|\s*\*+$/g, "") + .replace(/\s*:+$/g, "") + .replace(/^\d+\s*\.\s*/, "") + .replace(/\s+\d[\d,]*(?:\.\d+)?$/g, "") + .replace(/[.,;!?]+$/, "") + .trim() + .replace(/\s+/g, " "); +} + +function isCoordinatedNameTopic(text: string): boolean { + return /\b[A-Z][\w-]+\s+and\s+[A-Z][\w-]+\b/.test(text); } // --------------------------------------------------------------------------- @@ -397,86 +499,79 @@ function extractQuoted(text: string): ExtractedEntity[] { } /** - * Extract proper noun sequences using capitalization heuristics. - * Finds sequences of capitalized words that are not at sentence starts. + * Extract dotted technical identifiers such as person.properties.email. + */ +function extractIdentifiers(text: string): ExtractedEntity[] { + const entities: ExtractedEntity[] = []; + const identifierRe = /\b[A-Za-z_][\w-]*(?:\.[A-Za-z_][\w-]*)+\b/g; + let match: RegExpExecArray | null; + while ((match = identifierRe.exec(text)) !== null) { + entities.push({ type: "IDENTIFIER", text: match[0] }); + } + return entities; +} + +/** + * Extract proper names using capitalization and list-context heuristics. */ function extractProper(text: string): ExtractedEntity[] { const entities: ExtractedEntity[] = []; - // Tokenize on whitespace, preserving order - const tokens = text.split(/\s+/).filter(Boolean); - const functionWords = new Set([ - "'s", - "of", - "the", - "in", - "and", - "for", - "at", - "is", - ]); + const tokens = tokenize(text); + const innerConnectors = new Set(["of", "the", "in", "for", "at"]); let i = 0; while (i < tokens.length) { - const tok = tokens[i]; - // Skip formatting markers - if (FORMATTING_MARKERS.has(tok)) { + const token = cleanToken(tokens[i]); + const next = i + 1 < tokens.length ? tokens[i + 1] : ""; + const afterNext = i + 2 < tokens.length ? cleanToken(tokens[i + 2]) : ""; + if ( + token && + next === "&" && + afterNext && + isCapitalized(token) && + isCapitalized(afterNext) && + !isBadSingleNameToken(token) && + !isBadSingleNameToken(afterNext) + ) { + entities.push({ + type: "PROPER", + text: cleanEntityText(`${token} & ${afterNext}`), + }); + i += 3; + continue; + } + + if (!isNameToken(tokens, i)) { i++; continue; } - const isLabel = i + 1 < tokens.length && tokens[i + 1] === ":"; - const isCap = - tok.length > 0 && - tok.charAt(0) === tok.charAt(0).toUpperCase() && - /[A-Z]/.test(tok.charAt(0)); - - if (isCap && !isLabel) { - const seq: Array<{ token: string; idx: number }> = [ - { token: tok, idx: i }, - ]; - let j = i + 1; - while (j < tokens.length) { - const t = tokens[j]; - const tIsCap = - t.length > 0 && - t.charAt(0) === t.charAt(0).toUpperCase() && - /[A-Z]/.test(t.charAt(0)); - if (tIsCap || functionWords.has(t.toLowerCase())) { - seq.push({ token: t, idx: j }); - j++; - } else { - break; - } + const span = [cleanToken(tokens[i])]; + let j = i + 1; + while (j < tokens.length) { + const current = cleanToken(tokens[j]); + if (isNameToken(tokens, j)) { + span.push(current); + j++; + continue; } - - // Strip trailing function words - while ( - seq.length > 0 && - functionWords.has(seq[seq.length - 1].token.toLowerCase()) + if ( + innerConnectors.has(current.toLowerCase()) && + j + 1 < tokens.length && + isNameToken(tokens, j + 1) ) { - seq.pop(); + span.push(current, cleanToken(tokens[j + 1])); + j += 2; + continue; } + break; + } - if (seq.length > 0) { - // Check for at least one mid-sentence capitalized word - const hasMidCap = seq.some(({ token, idx: tokenIdx }) => { - const isCapWord = - /[A-Z]/.test(token.charAt(0)) && - !functionWords.has(token.toLowerCase()); - return isCapWord && !isSentenceStart(tokens, tokenIdx, text); - }); - - if (hasMidCap) { - const phrase = seq.map((s) => s.token).join(" "); - if (phrase.length > 2) { - entities.push({ type: "PROPER", text: phrase }); - } - } - } - i = j; - } else { - i++; + const phrase = cleanEntityText(span.join(" ")); + if (phrase.length > 2) { + entities.push({ type: "PROPER", text: phrase }); } + i = Math.max(j, i + 1); } return entities; @@ -484,7 +579,7 @@ function extractProper(text: string): ExtractedEntity[] { /** * Extract compound noun phrases using the `compromise` NLP library. - * Returns COMPOUND and NOUN entities derived from noun chunks. + * Returns TOPIC entities derived from noun chunks. */ function extractCompoundsWithNlp(text: string): ExtractedEntity[] { if (!nlp) { @@ -524,12 +619,12 @@ function extractCompoundsWithNlp(text: string): ExtractedEntity[] { const filtered = words.filter( (w) => !NON_SPECIFIC_ADJ.has(w.toLowerCase()), ); - const cleaned = stripGenericEnding(filtered); + const cleaned = stripGenericEnding(stripTopicPrefix(filtered)); if (cleaned.length >= 2) { - const phrase = cleaned.join(" "); + const phrase = cleanEntityText(cleaned.join(" ")); if (phrase.length > 3) { - entities.push({ type: "COMPOUND", text: phrase }); + entities.push({ type: "TOPIC", text: phrase }); } } } @@ -547,7 +642,7 @@ function extractCompoundsRegex(text: string): ExtractedEntity[] { // Multi-word sequences with at least one non-trivial word // Match sequences like "machine learning", "New York", "data science" const compoundRe = - /\b([A-Z][a-z]+(?:\s+(?:of|and|the|for|in)\s+)?[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\b/g; + /\b([A-Z][a-z]+(?:\s+(?:of|the|for|in)\s+)?[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\b/g; let match: RegExpExecArray | null; while ((match = compoundRe.exec(text)) !== null) { const phrase = match[1].trim(); @@ -558,9 +653,12 @@ function extractCompoundsRegex(text: string): ExtractedEntity[] { const filtered = words.filter( (w) => !NON_SPECIFIC_ADJ.has(w.toLowerCase()), ); - const cleaned = stripGenericEnding(filtered); + const cleaned = stripGenericEnding(stripTopicPrefix(filtered)); if (cleaned.length >= 2) { - entities.push({ type: "COMPOUND", text: cleaned.join(" ") }); + entities.push({ + type: "TOPIC", + text: cleanEntityText(cleaned.join(" ")), + }); } } } @@ -590,9 +688,12 @@ function extractCompoundsRegex(text: string): ExtractedEntity[] { const filtered = words.filter( (w) => !NON_SPECIFIC_ADJ.has(w.toLowerCase()), ); - const cleaned = stripGenericEnding(filtered); + const cleaned = stripGenericEnding(stripTopicPrefix(filtered)); if (cleaned.length >= 2) { - entities.push({ type: "COMPOUND", text: cleaned.join(" ") }); + entities.push({ + type: "TOPIC", + text: cleanEntityText(cleaned.join(" ")), + }); } } } @@ -614,9 +715,9 @@ function extractCompoundsRegex(text: string): ExtractedEntity[] { * * Entity types (in priority order for deduplication): * PROPER - Capitalized multi-word sequences not at sentence start - * COMPOUND - Multi-word noun phrases with specific modifiers + * IDENTIFIER - Dotted technical identifiers * QUOTED - Text in single or double quotes (min 3 chars) - * NOUN - Single nouns from circumstantial patterns + * TOPIC - Multi-word noun/topic phrases with specific modifiers * * @param text - Input text to extract entities from. * @returns Deduplicated list of extracted entities. @@ -630,7 +731,10 @@ export function extractEntities(text: string): ExtractedEntity[] { // 2. PROPER entities (capitalization heuristics) raw.push(...extractProper(text)); - // 3. COMPOUND entities (NLP or regex fallback) + // 3. IDENTIFIER entities + raw.push(...extractIdentifiers(text)); + + // 4. TOPIC entities (NLP or regex fallback) if (nlp) { raw.push(...extractCompoundsWithNlp(text)); } else { @@ -654,19 +758,17 @@ export function extractEntities(text: string): ExtractedEntity[] { const cleaned: ExtractedEntity[] = []; for (const entity of deduped) { let txt = entity.text.trim(); - // Strip leading/trailing asterisks - txt = txt.replace(/^\*+\s*|\s*\*+$/g, ""); - // Strip trailing colons - txt = txt.replace(/\s*:+$/, ""); - // Strip leading numbered list markers - txt = txt.replace(/^\d+\s*\.\s*/, ""); - // Strip trailing sentence punctuation (".", ",", ";", "!", "?") — otherwise - // "Paris." and "Paris" produce different embeddings and break entity dedup. - txt = txt.replace(/[.,;!?]+$/, "").trim(); + txt = cleanEntityText(txt); if (!txt || txt.length <= 2 || hasArtifacts(txt)) { continue; } + if ( + entity.type === "TOPIC" && + (/^\d/.test(txt) || isCoordinatedNameTopic(txt)) + ) { + continue; + } // Filter generic single-word PROPER nouns if ( @@ -680,12 +782,12 @@ export function extractEntities(text: string): ExtractedEntity[] { cleaned.push({ type: entity.type, text: txt }); } - // Keep best type per entity (PROPER > COMPOUND > QUOTED > NOUN) + // Keep best type per entity (PROPER > IDENTIFIER > QUOTED > TOPIC) const typePriority: Record = { PROPER: 0, - COMPOUND: 1, + IDENTIFIER: 1, QUOTED: 2, - NOUN: 3, + TOPIC: 3, }; const best = new Map(); for (const entity of cleaned) { @@ -700,14 +802,17 @@ export function extractEntities(text: string): ExtractedEntity[] { } const bestEntities = Array.from(best.values()); - // Remove entities that are substrings of longer entities - const allLower = bestEntities.map((e) => e.text.toLowerCase()); + // Remove entities that are token substrings of longer entities. return bestEntities.filter( (entity) => - !allLower.some( + !bestEntities.some( (other) => - entity.text.toLowerCase() !== other && - other.includes(entity.text.toLowerCase()), + entity.text.toLowerCase() !== other.text.toLowerCase() && + (typePriority[entity.type] ?? 99) >= + (typePriority[other.type] ?? 99) && + new RegExp( + `(^|\\s)${entity.text.toLowerCase().replace(/[.*+?^${}()|[\]\\]/g, "\\$&")}(\\s|$)`, + ).test(other.text.toLowerCase()), ), ); } diff --git a/mem0-ts/src/oss/tests/entity-extraction.test.ts b/mem0-ts/src/oss/tests/entity-extraction.test.ts new file mode 100644 index 0000000000..3ff16853ea --- /dev/null +++ b/mem0-ts/src/oss/tests/entity-extraction.test.ts @@ -0,0 +1,45 @@ +import { extractEntities } from "../src/utils/entity_extraction"; + +describe("extractEntities", () => { + it("handles product lists, coordinated names, and identifiers", () => { + const text = + "User reported top inbound integration pages: OpenClaw 25,443, " + + "Claude Code 8,916, Codex 2,573, Dify 656. " + + "User compared Cartesia and Deepgram. " + + "The email field for Mem0 lives at person.properties.email. " + + "The qwen endpoint uses person.properties.email. " + + "Johnson & Johnson was mentioned. " + + "Glasses around my window. " + + "On 2026-05-27 there were 90 days of stats."; + + const entityTexts = new Set( + extractEntities(text).map((entity) => entity.text), + ); + const normalized = new Set( + [...entityTexts].map((entityText) => entityText.toLowerCase()), + ); + + for (const expected of [ + "OpenClaw", + "Claude Code", + "Codex", + "Dify", + "Cartesia", + "Deepgram", + "Mem0", + ]) { + expect(entityTexts.has(expected)).toBe(true); + } + expect(entityTexts.has("person.properties.email")).toBe(true); + expect(entityTexts.has("qwen endpoint")).toBe(true); + expect(entityTexts.has("Johnson & Johnson")).toBe(true); + expect(entityTexts.has("Johnson")).toBe(false); + expect(normalized.has("top")).toBe(false); + expect(normalized.has("glasses")).toBe(false); + expect(entityTexts.has("Cartesia and Deepgram")).toBe(false); + expect(entityTexts.has("Claude Code 8,916")).toBe(false); + for (const rejected of ["8,916", "2,573", "656", "2026-05-27", "90"]) { + expect(entityTexts.has(rejected)).toBe(false); + } + }); +}); diff --git a/mem0/memory/main.py b/mem0/memory/main.py index b0335e501a..8aa2b59b77 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -81,6 +81,14 @@ logger = logging.getLogger(__name__) +def _vector_store_list_rows(listed): + if isinstance(listed, (list, tuple)) and listed and isinstance(listed[0], list): + return listed[0] + if isinstance(listed, (list, tuple)): + return listed + return [] + + # Fields that hold runtime auth/connection objects and must be preserved. # These are non-serializable objects (e.g. AWSV4SignerAuth, RequestsHttpConnection) # needed by clients like OpenSearch — not sensitive strings to redact. @@ -499,22 +507,49 @@ def entity_store(self): ) return self._entity_store + @staticmethod + def _normalize_entity_text(value: str) -> str: + return " ".join(value.strip().lower().split()) + + def _existing_entities_by_text(self, filters): + """Return existing entity rows keyed by normalized payload data.""" + try: + listed = self.entity_store.list(filters=filters, top_k=10000) + except Exception as e: + logger.debug(f"Exact entity lookup failed, falling back to semantic dedup: {e}") + return {} + + rows_by_text = {} + for row in _vector_store_list_rows(listed): + payload = getattr(row, "payload", None) or {} + text = payload.get("data") + if not isinstance(text, str): + continue + normalized = self._normalize_entity_text(text) + if normalized and normalized not in rows_by_text: + rows_by_text[normalized] = row + return rows_by_text + def _upsert_entity(self, entity_text, entity_type, memory_id, filters): """Upsert an entity into the entity store, linking it to a memory.""" try: entity_embedding = self.embedding_model.embed(entity_text, "add") search_filters = {k: v for k, v in filters.items() if k in ("user_id", "agent_id", "run_id") and v} + exact_match = self._existing_entities_by_text(search_filters).get(self._normalize_entity_text(entity_text)) + + existing = [] + if exact_match is None: + existing = self.entity_store.search( + query=entity_text, + vectors=entity_embedding, + top_k=1, + filters=search_filters, + ) - existing = self.entity_store.search( - query=entity_text, - vectors=entity_embedding, - top_k=1, - filters=search_filters, - ) - - if existing and existing[0].score >= 0.95: + semantic_match = existing[0] if existing and existing[0].score >= 0.95 else None + match = exact_match or semantic_match + if match: # Update existing entity's linked_memory_ids - match = existing[0] payload = match.payload or {} linked_ids = payload.get("linked_memory_ids", []) if memory_id not in linked_ids: @@ -609,7 +644,7 @@ def _link_entities_for_memory(self, memory_id, text, filters): return seen = set() for entity_type, entity_text in entities: - key = entity_text.strip().lower() + key = self._normalize_entity_text(entity_text) if not key or key in seen: continue seen.add(key) @@ -971,7 +1006,7 @@ def _add_to_vector_store(self, messages, metadata, filters, infer, prompt=None): for idx, (memory_id, text, embedding, payload) in enumerate(records): entities = all_entities[idx] if idx < len(all_entities) else [] for entity_type, entity_text in entities: - key = entity_text.strip().lower() + key = self._normalize_entity_text(entity_text) if key in global_entities: global_entities[key][2].add(memory_id) else: @@ -1009,6 +1044,7 @@ def _add_to_vector_store(self, messages, metadata, filters, infer, prompt=None): if valid: valid_indices, valid_keys = zip(*valid) valid_vectors = [entity_embeddings[i] for i in valid_indices] + exact_matches = self._existing_entities_by_text(search_filters) # 7c: Batch search for existing entities valid_texts = [global_entities[k][1] for k in valid_keys] @@ -1024,10 +1060,12 @@ def _add_to_vector_store(self, messages, metadata, filters, infer, prompt=None): for j, key in enumerate(valid_keys): entity_type, entity_text, memory_ids = global_entities[key] matches = existing_matches[j] if j < len(existing_matches) else [] + exact_match = exact_matches.get(key) - if matches and matches[0].score >= 0.95: + semantic_match = matches[0] if matches and matches[0].score >= 0.95 else None + match = exact_match or semantic_match + if match: # Update existing entity - match = matches[0] payload = match.payload or {} linked = set(payload.get("linked_memory_ids", [])) linked |= memory_ids @@ -1603,7 +1641,7 @@ def _compute_entity_boosts(self, query_entities, filters): seen = set() deduped = [] for entity_type, entity_text in query_entities[:8]: - key = entity_text.strip().lower() + key = self._normalize_entity_text(entity_text) if key and key not in seen: seen.add(key) deduped.append((entity_type, entity_text)) @@ -2045,22 +2083,51 @@ def entity_store(self): ) return self._entity_store + @staticmethod + def _normalize_entity_text(value: str) -> str: + return " ".join(value.strip().lower().split()) + + def _existing_entities_by_text(self, filters): + """Return existing entity rows keyed by normalized payload data.""" + try: + listed = self.entity_store.list(filters=filters, top_k=10000) + except Exception as e: + logger.debug(f"Exact entity lookup failed, falling back to semantic dedup: {e}") + return {} + + rows_by_text = {} + for row in _vector_store_list_rows(listed): + payload = getattr(row, "payload", None) or {} + text = payload.get("data") + if not isinstance(text, str): + continue + normalized = self._normalize_entity_text(text) + if normalized and normalized not in rows_by_text: + rows_by_text[normalized] = row + return rows_by_text + async def _upsert_entity_async(self, entity_text, entity_type, memory_id, filters): """Async variant of `_upsert_entity` — per-entity search-then-update-or-insert.""" try: entity_embedding = await asyncio.to_thread(self.embedding_model.embed, entity_text, "add") search_filters = {k: v for k, v in filters.items() if k in ("user_id", "agent_id", "run_id") and v} + exact_match = ( + await asyncio.to_thread(self._existing_entities_by_text, search_filters) + ).get(self._normalize_entity_text(entity_text)) + + existing = [] + if exact_match is None: + existing = await asyncio.to_thread( + self.entity_store.search, + query=entity_text, + vectors=entity_embedding, + top_k=1, + filters=search_filters, + ) - existing = await asyncio.to_thread( - self.entity_store.search, - query=entity_text, - vectors=entity_embedding, - top_k=1, - filters=search_filters, - ) - - if existing and existing[0].score >= 0.95: - match = existing[0] + semantic_match = existing[0] if existing and existing[0].score >= 0.95 else None + match = exact_match or semantic_match + if match: payload = match.payload or {} linked_ids = payload.get("linked_memory_ids", []) if memory_id not in linked_ids: @@ -2163,7 +2230,7 @@ async def _link_entities_for_memory(self, memory_id, text, filters): return seen = set() for entity_type, entity_text in entities: - key = entity_text.strip().lower() + key = self._normalize_entity_text(entity_text) if not key or key in seen: continue seen.add(key) @@ -2512,7 +2579,7 @@ async def _add_to_vector_store( for idx, (memory_id, text, embedding, payload) in enumerate(records): entities = all_entities[idx] if idx < len(all_entities) else [] for entity_type, entity_text in entities: - key = entity_text.strip().lower() + key = self._normalize_entity_text(entity_text) if key in global_entities: global_entities[key][2].add(memory_id) else: @@ -2547,6 +2614,7 @@ async def _add_to_vector_store( if valid: valid_indices, valid_keys = zip(*valid) valid_vectors = [entity_embeddings[i] for i in valid_indices] + exact_matches = await asyncio.to_thread(self._existing_entities_by_text, search_filters) # 7c: Batch search for existing entities valid_texts = [global_entities[k][1] for k in valid_keys] @@ -2563,9 +2631,11 @@ async def _add_to_vector_store( for j, key in enumerate(valid_keys): entity_type, entity_text, memory_ids = global_entities[key] matches = existing_matches[j] if j < len(existing_matches) else [] + exact_match = exact_matches.get(key) - if matches and matches[0].score >= 0.95: - match = matches[0] + semantic_match = matches[0] if matches and matches[0].score >= 0.95 else None + match = exact_match or semantic_match + if match: payload = match.payload or {} linked = set(payload.get("linked_memory_ids", [])) linked |= memory_ids @@ -3137,7 +3207,7 @@ async def _compute_entity_boosts_async(self, query_entities, filters): seen = set() deduped = [] for entity_type, entity_text in query_entities[:8]: - key = entity_text.strip().lower() + key = self._normalize_entity_text(entity_text) if key and key not in seen: seen.add(key) deduped.append((entity_type, entity_text)) diff --git a/mem0/utils/entity_extraction.py b/mem0/utils/entity_extraction.py index 1c8f7f83c6..23856f49a2 100644 --- a/mem0/utils/entity_extraction.py +++ b/mem0/utils/entity_extraction.py @@ -1,82 +1,352 @@ """ Entity extraction from text using spaCy NLP. -Extracts four types of entities from text: +Extracts three types of entities from a spaCy-processed document: - **Proper nouns**: Capitalized multi-word sequences (person names, places, brands) - **Quoted text**: Text in single or double quotes (titles, specific terms) - **Noun compounds**: Multi-word noun phrases with specific modifiers (e.g., "machine learning") -- **Noun fallback**: Single nouns from circumstantial compound patterns Public API: - extract_entities(text: str) -> List[Tuple[str, str]] + ``extract_entities(text)`` accepts a string and owns spaCy model loading. + ``extract_entities_batch(texts)`` uses ``nlp.pipe`` for batched extraction. -Internal: - _extract_entities_from_doc(doc) -> List[Tuple[str, str]] +Returns: + List of ``(entity_type, entity_text)`` tuples where entity_type is one of + PROPER, QUOTED, TOPIC, or IDENTIFIER. Returns ``[]`` if spaCy is unavailable. """ from __future__ import annotations -import logging +from dataclasses import dataclass import re -from typing import List, Tuple -logger = logging.getLogger(__name__) + +@dataclass(frozen=True) +class _EntityCandidate: + entity_type: str + text: str + source: str + start: int + end: int + confidence: float + priority: int + # Words that are too generic to be useful as entity heads _GENERIC_HEADS = { - "thing", "stuff", "way", "time", "experience", "situation", "case", - "fact", "matter", "issue", "idea", "thought", "feeling", "place", - "area", "part", "kind", "type", "sort", "lot", "bit", "day", "year", - "week", "month", "moment", "instance", "example", "technique", - "method", "approach", "process", "step", "tool", "result", "outcome", - "goal", "task", "item", "topic", "scale", "size", "level", "degree", - "amount", "number", "style", "look", "color", "colour", "shape", - "form", "piece", "section", "side", "end", "edge", "surface", "point", + "thing", + "stuff", + "way", + "time", + "experience", + "situation", + "case", + "fact", + "matter", + "issue", + "idea", + "thought", + "feeling", + "place", + "area", + "part", + "kind", + "type", + "sort", + "lot", + "bit", + "day", + "year", + "week", + "month", + "moment", + "instance", + "example", + "technique", + "method", + "approach", + "process", + "step", + "tool", + "result", + "outcome", + "goal", + "task", + "item", + "topic", + "scale", + "size", + "level", + "degree", + "amount", + "number", + "style", + "look", + "color", + "colour", + "shape", + "form", + "piece", + "section", + "side", + "end", + "edge", + "surface", + "point", +} + +# Entity labels emitted by spaCy that are usually safe to treat as named +# entities. Numeric and temporal labels are intentionally excluded. +_ACCEPTED_NER_LABELS = { + "PERSON", + "ORG", + "GPE", + "LOC", + "FAC", + "PRODUCT", + "WORK_OF_ART", + "EVENT", + "NORP", + "LAW", + "LANGUAGE", +} + +_REJECTED_NER_LABELS = { + "DATE", + "TIME", + "CARDINAL", + "ORDINAL", + "QUANTITY", + "MONEY", + "PERCENT", +} + +# Generic role words and title-cased English words that should not become +# single-token named entities just because spaCy tagged them as PROPN. +_GENERIC_SINGLE_ENTITY_TERMS = { + "user", + "assistant", + "agent", + "customer", + "client", + "person", + "people", + "human", + "memory", + "message", + "conversation", + "chat", + "session", + "system", + "top", } # Modifiers that describe circumstance, not content _CIRCUMSTANTIAL_MODS = { - "solo", "individual", "team", "group", "joint", "collaborative", - "first", "last", "next", "previous", "final", "initial", "main", "side", + "solo", + "individual", + "team", + "group", + "joint", + "collaborative", + "first", + "last", + "next", + "previous", + "final", + "initial", + "main", + "side", + "top", } # Adjectives too vague to make a compound entity specific _NON_SPECIFIC_ADJ = { - "many", "few", "several", "some", "any", "all", "most", "more", - "less", "much", "little", "enough", "various", "numerous", "multiple", - "countless", "great", "good", "bad", "nice", "terrible", "awful", - "awesome", "amazing", "wonderful", "horrible", "excellent", "poor", - "best", "worst", "fine", "okay", "new", "old", "recent", "past", - "future", "current", "previous", "next", "last", "first", "latest", - "early", "late", "former", "modern", "ancient", "big", "small", - "large", "tiny", "huge", "enormous", "long", "short", "tall", "high", - "low", "wide", "narrow", "thick", "thin", "deep", "shallow", - "similar", "different", "same", "other", "another", "such", "certain", - "important", "main", "major", "minor", "key", "primary", "real", - "actual", "true", "whole", "entire", "full", "complete", "total", - "basic", "simple", "interesting", "boring", "exciting", "special", - "particular", "general", "common", "unique", "rare", "typical", - "usual", "normal", "regular", "possible", "likely", "potential", - "available", "necessary", "only", "solo", "individual", "team", - "group", "joint", "collaborative", "final", "initial", "side", + "many", + "few", + "several", + "some", + "any", + "all", + "most", + "more", + "less", + "much", + "little", + "enough", + "various", + "numerous", + "multiple", + "countless", + "great", + "good", + "bad", + "nice", + "terrible", + "awful", + "awesome", + "amazing", + "wonderful", + "horrible", + "excellent", + "poor", + "best", + "worst", + "fine", + "okay", + "new", + "old", + "recent", + "past", + "future", + "current", + "previous", + "next", + "last", + "first", + "latest", + "early", + "late", + "former", + "modern", + "ancient", + "big", + "small", + "large", + "tiny", + "huge", + "enormous", + "long", + "short", + "tall", + "high", + "low", + "wide", + "narrow", + "thick", + "thin", + "deep", + "shallow", + "similar", + "different", + "same", + "other", + "another", + "such", + "certain", + "important", + "main", + "major", + "minor", + "key", + "primary", + "real", + "actual", + "true", + "whole", + "entire", + "full", + "complete", + "total", + "basic", + "simple", + "interesting", + "boring", + "exciting", + "special", + "particular", + "general", + "common", + "unique", + "rare", + "typical", + "usual", + "normal", + "regular", + "possible", + "likely", + "potential", + "available", + "necessary", + "only", + "solo", + "individual", + "team", + "group", + "joint", + "collaborative", + "final", + "initial", + "side", } # Generic tail words to strip from compound entities _GENERIC_ENDINGS = { - "work", "works", "job", "jobs", "task", "tasks", "stuff", "things", - "thing", "info", "information", "details", "data", "content", - "material", "materials", "activities", "activity", "efforts", "effort", - "options", "option", "choices", "choice", "results", "result", - "output", "outputs", "products", "product", "items", "item", + "work", + "works", + "job", + "jobs", + "task", + "tasks", + "stuff", + "things", + "thing", + "info", + "information", + "details", + "data", + "content", + "material", + "materials", + "activities", + "activity", + "efforts", + "effort", + "options", + "option", + "choices", + "choice", + "results", + "result", + "output", + "outputs", + "products", + "product", + "items", + "item", } # Capitalized single words that are too generic to be proper nouns _GENERIC_CAPS = { - "works", "items", "things", "stuff", "resources", "options", "tips", - "ideas", "steps", "ways", "methods", "tools", "features", "benefits", - "examples", "details", "notes", "instructions", "guidelines", - "recommendations", "suggestions", "overview", "summary", "conclusion", - "introduction", "pros", "cons", "advantages", "disadvantages", + "works", + "items", + "things", + "stuff", + "resources", + "options", + "tips", + "ideas", + "steps", + "ways", + "methods", + "tools", + "features", + "benefits", + "examples", + "details", + "notes", + "instructions", + "guidelines", + "recommendations", + "suggestions", + "overview", + "summary", + "conclusion", + "introduction", + "pros", + "cons", + "advantages", + "disadvantages", } # Markdown/formatting markers to skip during extraction @@ -120,121 +390,208 @@ def _has_artifacts(txt: str) -> bool: ) -def extract_entities(text: str) -> List[Tuple[str, str]]: - """Extract named entities, quoted text, and noun compounds from text. +def _clean_text(txt: str) -> str: + txt = re.sub(r"^\*+\s*|\s*\*+$", "", txt.strip()) + txt = re.sub(r"\s*:+$", "", txt) + txt = re.sub(r"^\d+\s*\.\s*", "", txt) + return " ".join(txt.split()) - This is the public API that accepts a string. It loads the spaCy model - internally and delegates to _extract_entities_from_doc(). - Args: - text: Input text to extract entities from. +def _norm_text(txt: str) -> str: + return " ".join(txt.lower().split()) - Returns: - Deduplicated list of (entity_type, entity_text) tuples. - Entity types: PROPER, QUOTED, COMPOUND, NOUN. - Returns empty list if spaCy is unavailable. - """ - from mem0.utils.spacy_models import get_nlp_full - nlp = get_nlp_full() - if nlp is None: - return [] +def _looks_like_technical_identifier(text: str) -> bool: + return bool(re.fullmatch(r"[A-Za-z_][\w-]*(?:\.[A-Za-z_][\w-]*)+", text)) - doc = nlp(text) - return _extract_entities_from_doc(doc) +def _has_internal_cap_or_digit(text: str) -> bool: + return any(ch.isdigit() for ch in text) or any(ch.isupper() for ch in text[1:]) -def extract_entities_batch(texts: List[str], batch_size: int = 32) -> List[List[Tuple[str, str]]]: - """Extract entities from multiple texts using spaCy's nlp.pipe() for batched NER. - Uses spaCy's efficient batch processing pipeline instead of calling - nlp() individually per text. Significantly faster for multiple texts. +def _looks_like_metric_count_token(tok) -> bool: + return tok.pos_ == "NUM" and bool(re.fullmatch(r"\d[\d,]*(?:\.\d+)?", tok.text)) - Args: - texts: List of input texts to extract entities from. - batch_size: Number of texts to process in each spaCy batch. - Returns: - List of entity lists, one per input text. Each entity list contains - (entity_type, entity_text) tuples. Returns list of empty lists if - spaCy is unavailable. - """ - if not texts: - return [] +def _is_metric_list_context(tokens: list, idx: int) -> bool: + prev_text = tokens[idx - 1].text if idx > 0 else "" + next_text = tokens[idx + 1].text if idx + 1 < len(tokens) else "" + return prev_text in {":", ",", ";"} or next_text in {",", ";"} - from mem0.utils.spacy_models import get_nlp_full - nlp = get_nlp_full() - if nlp is None: - return [[] for _ in texts] +def _strip_trailing_metric_counts(span_tokens: list, all_tokens: list) -> list: + while len(span_tokens) > 1 and _looks_like_metric_count_token(span_tokens[-1]): + tok = span_tokens[-1] + if "," not in tok.text and not _is_metric_list_context(all_tokens, tok.i): + break + span_tokens = span_tokens[:-1] + return span_tokens - results = [] - for doc in nlp.pipe(texts, batch_size=batch_size): - results.append(_extract_entities_from_doc(doc)) - return results +def _is_list_item_name_token(tokens: list, idx: int) -> bool: + tok = tokens[idx] + if not tok.text or tok.text in _FORMATTING_MARKERS or not tok.text[0].isupper(): + return False + if not any(ch.isalpha() for ch in tok.text) or _is_bad_single_name_token(tok): + return False + next_tok = tokens[idx + 1] if idx + 1 < len(tokens) else None + if not next_tok or not _looks_like_metric_count_token(next_tok): + return False + return _is_metric_list_context(tokens, idx) or _is_metric_list_context(tokens, idx + 1) + + +def _is_name_like_token(tok, tokens: list | None = None, idx: int | None = None) -> bool: + if not tok.text or tok.text in _FORMATTING_MARKERS: + return False + if not tok.text[0].isupper(): + return False + if not any(ch.isalpha() for ch in tok.text): + return False + if _is_bad_single_name_token(tok): + return False + if tok.pos_ == "PROPN" or tok.tag_ in {"NNP", "NNPS"}: + return True + if tokens is not None and idx is not None and _is_list_item_name_token(tokens, idx): + return True + if _has_internal_cap_or_digit(tok.text): + return True + return ( + tokens is not None + and idx is not None + and tok.pos_ == "NOUN" + and tok.dep_ not in {"compound", "amod"} + and not _is_sentence_start(tokens, idx) + ) -def _extract_entities_from_doc(doc) -> List[Tuple[str, str]]: - """Extract entities from a spaCy Doc object. - Ported from platform's shared.core.utils.entity_extraction.extract_entities(). - """ - entities: List[Tuple[str, str]] = [] - text = doc.text - tokens = list(doc) +def _is_bad_single_name_token(tok) -> bool: + lower = tok.text.lower() + return lower in _GENERIC_SINGLE_ENTITY_TERMS or lower in _GENERIC_CAPS or tok.is_stop + + +def _add_candidate( + candidates: list[_EntityCandidate], + entity_type: str, + text: str, + source: str, + start: int, + end: int, + confidence: float, + priority: int, +) -> None: + cleaned = _clean_text(text) + if not cleaned or len(cleaned) <= 2 or _has_artifacts(cleaned): + return + candidates.append( + _EntityCandidate( + entity_type=entity_type, + text=cleaned, + source=source, + start=start, + end=end, + confidence=confidence, + priority=priority, + ) + ) + - # === PROPER NOUN SEQUENCES === +def _add_ner_candidates(doc, candidates: list[_EntityCandidate]) -> None: + tokens = list(doc) + for ent in doc.ents: + if ent.label_ in _REJECTED_NER_LABELS or ent.label_ not in _ACCEPTED_NER_LABELS: + continue + ent_tokens = _strip_trailing_metric_counts(list(ent), tokens) + if not ent_tokens: + continue + if any(tok.pos_ == "CCONJ" and tok.text.lower() == "and" for tok in ent_tokens): + continue + if len(ent_tokens) == 1 and _is_bad_single_name_token(ent_tokens[0]): + continue + if ( + len(ent_tokens) == 1 + and ent_tokens[0].dep_ in {"compound", "amod"} + and ent_tokens[0].head.pos_ in {"NOUN", "PROPN"} + ): + continue + _add_candidate( + candidates, + "PROPER", + "".join(tok.text_with_ws for tok in ent_tokens).strip(), + "spacy_ner", + ent_tokens[0].i, + ent_tokens[-1].i + 1, + 0.95, + 0, + ) + + +def _add_technical_identifier_candidates(tokens: list, candidates: list[_EntityCandidate]) -> None: + for tok in tokens: + if _looks_like_technical_identifier(tok.text): + _add_candidate( + candidates, + "IDENTIFIER", + tok.text, + "technical_identifier", + tok.i, + tok.i + 1, + 0.9, + 1, + ) + + +def _add_proper_name_candidates(tokens: list, candidates: list[_EntityCandidate]) -> None: + allowed_inner_connectors = {"of", "the", "for", "at", "in"} i = 0 while i < len(tokens): tok = tokens[i] - if tok.text in _FORMATTING_MARKERS: + if not _is_name_like_token(tok, tokens, i): i += 1 continue - is_cap = tok.text and tok.text[0].isupper() - is_label = i + 1 < len(tokens) and tokens[i + 1].text == ":" - - if is_cap and not is_label and tok.pos_ in {"PROPN", "NOUN", "ADJ"}: - seq = [(tok, i)] - j = i + 1 - while j < len(tokens): - t = tokens[j] - if (t.text and t.text[0].isupper()) or t.text.lower() in { - "'s", "of", "the", "in", "and", "for", "at", "is", - }: - seq.append((t, j)) - j += 1 - else: - break - # Strip trailing function words - while seq and seq[-1][0].text.lower() in {"of", "the", "in", "and", "for", "at", "is", "'s"}: - seq.pop() - if seq: - has_mid_cap = any( - not _is_sentence_start(tokens, idx) - for (t, idx) in seq - if t.text[0].isupper() and t.text.lower() not in {"'s", "of", "the", "in", "and", "for", "at", "is"} - ) - if has_mid_cap: - phrase = "".join(t.text_with_ws for (t, idx) in seq).strip() - if len(phrase) > 2: - entities.append(("PROPER", phrase)) - i = j - else: - i += 1 - # === QUOTED TEXT === + span_tokens = [tok] + j = i + 1 + while j < len(tokens): + current = tokens[j] + if _is_name_like_token(current, tokens, j): + span_tokens.append(current) + j += 1 + continue + if ( + current.text.lower() in allowed_inner_connectors + and j + 1 < len(tokens) + and _is_name_like_token(tokens[j + 1], tokens, j + 1) + ): + span_tokens.extend([current, tokens[j + 1]]) + j += 2 + continue + break + + name_tokens = [ + t + for t in span_tokens + if _is_name_like_token(t, tokens, t.i) or (0 <= t.i < len(tokens) and _is_list_item_name_token(tokens, t.i)) + ] + if len(name_tokens) > 1 or not _is_bad_single_name_token(name_tokens[0]): + text = "".join(t.text_with_ws for t in span_tokens).strip() + _add_candidate(candidates, "PROPER", text, "proper_name_span", i, j, 0.8, 2) + i = max(j, i + 1) + + +def _add_quoted_candidates(text: str, candidates: list[_EntityCandidate]) -> None: for m in re.finditer(r'"([^"]+)"', text): if len(m.group(1).strip()) > 2: - entities.append(("QUOTED", m.group(1).strip())) + _add_candidate(candidates, "QUOTED", m.group(1).strip(), "quoted", -1, -1, 0.75, 3) for m in re.finditer(r"(?:^|[\s\(\[{,;])'([^']+)'(?=[\s\.,;:!?\)\]]|$)", text): if len(m.group(1).strip()) > 2: - entities.append(("QUOTED", m.group(1).strip())) + _add_candidate(candidates, "QUOTED", m.group(1).strip(), "quoted", -1, -1, 0.75, 3) + - # === NOUN-NOUN COMPOUNDS === +def _add_topic_phrase_candidates(doc, candidates: list[_EntityCandidate]) -> None: for chunk in doc.noun_chunks: chunk_tokens = list(chunk) - split_indices: list = [] - poss_splits: list = [] + split_indices: list[int] = [] + poss_splits: list[int] = [] for idx, tok in enumerate(chunk_tokens): if tok.dep_ == "case" and tok.text in {"'s", "\u2019s", "'"}: split_indices.append(idx) @@ -243,14 +600,14 @@ def _extract_entities_from_doc(doc) -> List[Tuple[str, str]]: split_indices.append(idx) if split_indices: - groups: list = [] + groups: list[list] = [] prev = 0 for split_idx in split_indices: if split_idx > prev: groups.append(chunk_tokens[prev:split_idx]) if split_idx in poss_splits: next_split = next((s for s in split_indices if s > split_idx), None) - owned = chunk_tokens[split_idx + 1: next_split if next_split else len(chunk_tokens)] + owned = chunk_tokens[split_idx + 1 : next_split if next_split else len(chunk_tokens)] if owned: first_content = next((t for t in owned if t.pos_ not in {"PUNCT", "PART"}), None) if not (first_content and first_content.text and first_content.text[0].isupper()): @@ -272,7 +629,8 @@ def _extract_entities_from_doc(doc) -> List[Tuple[str, str]]: content = [ t for t in group - if t.pos_ not in {"DET", "PRON", "PUNCT", "PART", "ADP", "SCONJ", "NUM"} and (t.pos_ == "ADJ" or not t.is_stop) + if t.pos_ not in {"DET", "PRON", "PUNCT", "PART", "ADP", "SCONJ", "NUM"} + and (t.pos_ == "ADJ" or not t.is_stop) ] if not content: continue @@ -286,78 +644,129 @@ def _extract_entities_from_doc(doc) -> List[Tuple[str, str]]: if compound_toks: is_circ = any(t.lemma_.lower() in _CIRCUMSTANTIAL_MODS for t in compound_toks) if is_circ: - val = head.lemma_ if head.pos_ == "NOUN" else head.text + val = head.text if len(val) > 2: - entities.append(("NOUN", val)) + _add_candidate( + candidates, + "TOPIC", + val, + "topic_phrase", + head.i, + head.i + 1, + 0.45, + 4, + ) else: filtered = _strip_generic_ending( [t for t in content if not (t.pos_ == "ADJ" and t.lemma_.lower() in _NON_SPECIFIC_ADJ)] ) if filtered: - phrase = _lemmatize_compound(filtered) + phrase = " ".join(t.text for t in filtered) if len(phrase) > 3 and " " in phrase: - entities.append(("COMPOUND", phrase)) + _add_candidate( + candidates, + "TOPIC", + phrase, + "topic_phrase", + filtered[0].i, + filtered[-1].i + 1, + 0.45, + 4, + ) elif len(content) > 1 and has_spec_adj: filtered = _strip_generic_ending( - [t for t in content if not ((t.pos_ == "ADJ" or t.dep_ == "amod") and t.lemma_.lower() in _NON_SPECIFIC_ADJ)] + [ + t + for t in content + if not ((t.pos_ == "ADJ" or t.dep_ == "amod") and t.lemma_.lower() in _NON_SPECIFIC_ADJ) + ] ) if filtered: - phrase = _lemmatize_compound(filtered) + phrase = " ".join(t.text for t in filtered) if len(phrase) > 3 and " " in phrase: - entities.append(("COMPOUND", phrase)) - - # === FALLBACK: Mis-tagged VERB heads === - processed = {e[1].lower() for e in entities if e[0] == "COMPOUND"} - generic_verb_heads = _GENERIC_HEADS | {"find", "buy", "purchase", "sale", "deal", "trip", "visit"} - - def collect_compounds(head): - return [t for t in doc if t.head == head and t.dep_ == "compound"] - - for tok in doc: - if tok.pos_ == "VERB" and tok.dep_ in {"pobj", "dobj", "nsubj"}: - comps = sorted(collect_compounds(tok), key=lambda t: t.i) - if comps: - phrase_toks = comps if tok.lemma_.lower() in generic_verb_heads else comps + [tok] - phrase = " ".join(t.text for t in phrase_toks) - if phrase.lower() not in processed and len(phrase) > 3 and " " in phrase: - entities.append(("COMPOUND", phrase)) - processed.add(phrase.lower()) - - # === DEDUPLICATION & CLEANUP === - seen: set = set() - deduped = [] - for t, e in entities: - k = e.lower().strip() - if k not in seen and len(k) > 2: - seen.add(k) - deduped.append((t, e)) - - cleaned: List[Tuple[str, str]] = [] - for etype, etext in deduped: - txt = re.sub(r"^\*+\s*|\s*\*+$", "", etext.strip()) - txt = re.sub(r"\s*:+$", "", txt) - txt = re.sub(r"^\d+\s*\.\s*", "", txt) - if not txt or len(txt) <= 2 or _has_artifacts(txt): - continue - if etype == "PROPER" and " " not in txt and txt.lower() in _GENERIC_CAPS: + _add_candidate( + candidates, + "TOPIC", + phrase, + "topic_phrase", + filtered[0].i, + filtered[-1].i + 1, + 0.45, + 4, + ) + + +def _spans_overlap(a: _EntityCandidate, b: _EntityCandidate) -> bool: + if a.start < 0 or b.start < 0: + return False + return a.start < b.end and b.start < a.end + + +def _resolve_candidates(candidates: list[_EntityCandidate]) -> list[tuple[str, str]]: + deduped_by_text: dict[str, _EntityCandidate] = {} + for candidate in candidates: + key = _norm_text(candidate.text) + current = deduped_by_text.get(key) + if current is None or (candidate.priority, -candidate.confidence) < (current.priority, -current.confidence): + deduped_by_text[key] = candidate + + ordered = sorted( + deduped_by_text.values(), + key=lambda c: (c.priority, -c.confidence, -(c.end - c.start), c.start), + ) + accepted: list[_EntityCandidate] = [] + for candidate in ordered: + if any( + _spans_overlap(candidate, existing) + and not (candidate.entity_type == "TOPIC" and " " in candidate.text and existing.entity_type == "PROPER") + for existing in accepted + ): continue - cleaned.append((etype, txt)) - - # Keep best type per entity (PROPER > COMPOUND > QUOTED > NOUN) - type_pri = {"PROPER": 0, "COMPOUND": 1, "QUOTED": 2, "NOUN": 3, "VERB": 4} - best: dict = {} - for t, e in cleaned: - k = e.lower() - if k not in best or type_pri.get(t, 99) < type_pri.get(best[k][0], 99): - best[k] = (t, e) - deduped = list(best.values()) - - # Remove entities that are whole-word substrings of longer entities. - # Word-boundary anchoring avoids dropping distinct entities that only share a - # leading substring (e.g. "Sam" must survive alongside "Samsung"). - all_lower = [e[1].lower() for e in deduped] - return [ - (t, e) - for t, e in deduped - if not any(e.lower() != o and re.search(rf"\b{re.escape(e.lower())}\b", o) for o in all_lower) - ] + accepted.append(candidate) + + accepted.sort(key=lambda c: (c.start if c.start >= 0 else 10**9, c.end, c.priority)) + return [(candidate.entity_type, candidate.text) for candidate in accepted] + + +def _extract_entities_from_doc(doc) -> list[tuple[str, str]]: + """Extract typed entity candidates from a spaCy Doc. + + Args: + doc: A spaCy ``Doc`` object (from ``nlp(text)``). + + Returns: + Deduplicated list of ``(entity_type, entity_text)`` tuples. + Entity types include PROPER, QUOTED, TOPIC, and IDENTIFIER. + """ + tokens = list(doc) + candidates: list[_EntityCandidate] = [] + _add_ner_candidates(doc, candidates) + _add_technical_identifier_candidates(tokens, candidates) + _add_proper_name_candidates(tokens, candidates) + _add_quoted_candidates(doc.text, candidates) + _add_topic_phrase_candidates(doc, candidates) + return _resolve_candidates(candidates) + + +def extract_entities(text: str) -> list[tuple[str, str]]: + """Extract typed entity candidates from text.""" + from mem0.utils.spacy_models import get_nlp_full + + nlp = get_nlp_full() + if nlp is None: + return [] + return _extract_entities_from_doc(nlp(text)) + + +def extract_entities_batch(texts: list[str], batch_size: int = 32) -> list[list[tuple[str, str]]]: + """Extract typed entity candidates from multiple texts.""" + if not texts: + return [] + + from mem0.utils.spacy_models import get_nlp_full + + nlp = get_nlp_full() + if nlp is None: + return [[] for _ in texts] + + return [_extract_entities_from_doc(doc) for doc in nlp.pipe(texts, batch_size=batch_size)] diff --git a/pyproject.toml b/pyproject.toml index c17764a1bb..a752be1dbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "mem0ai" -version = "2.0.8" +version = "2.0.9" description = "Long-term memory for AI Agents" authors = [ { name = "Mem0", email = "support@mem0.ai" } diff --git a/tests/utils/test_entity_extraction.py b/tests/utils/test_entity_extraction.py index 499fce95b4..ab26cd172c 100644 --- a/tests/utils/test_entity_extraction.py +++ b/tests/utils/test_entity_extraction.py @@ -6,6 +6,7 @@ def _ensure_spacy(): """Skip tests if spaCy model is not available.""" try: import spacy + spacy.load("en_core_web_sm") except Exception: pytest.skip("spaCy en_core_web_sm model not available") @@ -33,8 +34,9 @@ def test_compound_nouns(self): entities = extract_entities("The machine learning engineer built a neural network") entity_texts = [e[1].lower() for e in entities] - has_compound = any("machine" in t and "learning" in t for t in entity_texts) or \ - any("neural" in t and "network" in t for t in entity_texts) + has_compound = any("machine" in t and "learning" in t for t in entity_texts) or any( + "neural" in t and "network" in t for t in entity_texts + ) assert has_compound, f"Expected compound nouns, got {entities}" def test_empty_string(self): @@ -76,9 +78,37 @@ def test_returns_tuples(self): for entity in entities: assert isinstance(entity, tuple) assert len(entity) == 2 - assert entity[0] in ("PROPER", "QUOTED", "COMPOUND", "NOUN") + assert entity[0] in ("PROPER", "QUOTED", "TOPIC", "IDENTIFIER") assert isinstance(entity[1], str) + def test_handles_names_lists_and_identifiers(self): + from mem0.utils.entity_extraction import extract_entities + + text = ( + "User reported top inbound integration pages: OpenClaw 25,443, " + "Claude Code 8,916, Codex 2,573, Dify 656. " + "User compared Cartesia and Deepgram. " + "The email field for Mem0 lives at person.properties.email. " + "The qwen endpoint uses person.properties.email. " + "Johnson & Johnson was mentioned. " + "Glasses around my window. " + "On 2026-05-27 there were 90 days of stats." + ) + + entities = extract_entities(text) + entity_texts = {entity_text for _, entity_text in entities} + normalized = {entity_text.lower() for entity_text in entity_texts} + + assert {"OpenClaw", "Claude Code", "Codex", "Dify", "Cartesia", "Deepgram", "Mem0"}.issubset(entity_texts) + assert "person.properties.email" in entity_texts + assert "qwen endpoint" in entity_texts + assert "Johnson & Johnson" in entity_texts + assert "top" not in normalized + assert "glasses" not in normalized + assert "Cartesia and Deepgram" not in entity_texts + assert "Claude Code 8,916" not in entity_texts + assert not {"8,916", "2,573", "656", "2026-05-27", "90"}.intersection(entity_texts) + class TestExtractEntitiesBatch: def test_batch_processing(self):