diff --git a/docling_agent/agent/enricher.py b/docling_agent/agent/enricher.py index 13aa482..c118430 100644 --- a/docling_agent/agent/enricher.py +++ b/docling_agent/agent/enricher.py @@ -860,8 +860,9 @@ def _detect_key_entities( ) with self._timed_stage("entities: leaf entities"): + m_leaf = self._create_extraction_session() self._extract_entities_from_leaf_items( - m=m, + m=m_leaf, document=hier_doc, loop_budget=loop_budget, entity_targets=entity_targets, @@ -880,11 +881,20 @@ def _infer_entity_targets( if not task: return None - def _validate_entity_target_spec(content: str) -> bool: + def _parse_spec_dict(content: str) -> dict | None: matches = find_json_dicts(text=content) - if len(matches) != 1: + if len(matches) == 1 and isinstance(matches[0], dict): + return matches[0] + try: + parsed = json.loads(content.strip()) + except (json.JSONDecodeError, ValueError): + return None + return parsed if isinstance(parsed, dict) else None + + def _validate_entity_target_spec(content: str) -> bool: + spec = _parse_spec_dict(content) + if spec is None: return False - spec = matches[0] labels = spec.get("labels", []) focus_terms = spec.get("focus_terms", []) generic = spec.get("generic", False) @@ -923,8 +933,7 @@ def _validate_entity_target_spec(content: str) -> bool: if not answer: return None - specs = find_json_dicts(text=answer) - return specs[0] if specs else None + return _parse_spec_dict(answer) def _walk_and_extract_entities( self, @@ -1027,6 +1036,8 @@ def _validate_entities(content: str) -> bool: return False try: val = json.loads(match.group(1)) + if isinstance(val, dict) and isinstance(val.get("entities"), list): + val = val["entities"] if not isinstance(val, list): return False return all(isinstance(item, dict) and "text" in item for item in val) @@ -1035,17 +1046,28 @@ def _validate_entities(content: str) -> bool: target_clause = "" rewritten_task = task or "" + allowed_labels: list[str] = [] if entity_targets: - labels = entity_targets.get("labels", []) + allowed_labels = [str(lbl).strip() for lbl in entity_targets.get("labels", []) if str(lbl).strip()] focus_terms = entity_targets.get("focus_terms", []) rewritten_task = entity_targets.get("rewritten_task", rewritten_task) generic = entity_targets.get("generic", False) - if not generic or labels or focus_terms or rewritten_task: + if allowed_labels: + target_clause = ( + "\nHARD CONSTRAINT — you MUST obey this:\n" + f"Use ONLY these label values: {allowed_labels}.\n" + "Reject any other entity type. If a candidate mention does not fit one of those labels, " + "omit it entirely — do NOT output it under a different label, do NOT invent new labels.\n" + "Apply the same definitions to every paragraph and every section, including reference lists, " + "captions, and tables.\n" + f"\nExtraction brief (for context only, does not expand the label set):\n{rewritten_task}\n" + f"- focus_terms (hints, not exhaustive): {focus_terms}\n" + ) + elif not generic or focus_terms or rewritten_task: target_clause = ( "\nUse this rewritten extraction brief:\n" f"{rewritten_task}\n" "Focus on entities that match the brief, including obvious instances even if the wording differs.\n" - f"- labels: {labels}\n" f"- focus_terms: {focus_terms}\n" "If an entity is not relevant to that brief, omit it." ) @@ -1073,14 +1095,28 @@ def _validate_entities(content: str) -> bool: if result: match = re.search(r"```json\s*(.*?)\s*```", result, re.DOTALL) - if match: + json_text = match.group(1) if match else result.strip() + if json_text: try: - payload = json.loads(match.group(1)) + payload = json.loads(json_text) + if isinstance(payload, dict) and isinstance(payload.get("entities"), list): + payload = payload["entities"] + if not isinstance(payload, list): + payload = [] + allowed_lookup = {lbl.casefold() for lbl in allowed_labels} mentions: list[EntityMention] = [] + dropped_by_label: dict[str, int] = {} search_start = 0 for item in payload: if not isinstance(item, dict) or not str(item.get("text", "")).strip(): continue + if allowed_lookup: + raw_label = str(item.get("label", "")).strip() + if raw_label.casefold() not in allowed_lookup: + dropped_by_label[raw_label or ""] = ( + dropped_by_label.get(raw_label or "", 0) + 1 + ) + continue mention = self._make_entity_mention( item=item, source_text=text, @@ -1090,6 +1126,12 @@ def _validate_entities(content: str) -> bool: if mention.charspan is not None: search_start = mention.charspan[1] mentions.append(mention) + if dropped_by_label: + log_info( + "Dropped mentions with disallowed labels", + count=sum(dropped_by_label.values()), + by_label=dropped_by_label, + ) log_debug("Parsed entities", count=len(mentions)) if mentions: return EntitiesMetaField(mentions=mentions) diff --git a/docling_agent/agent/orchestrator.py b/docling_agent/agent/orchestrator.py index 145d93a..2083c8f 100644 --- a/docling_agent/agent/orchestrator.py +++ b/docling_agent/agent/orchestrator.py @@ -205,6 +205,7 @@ def _ensure_enriched( source_pairs: list[_SourcePair], library: DoclingLibrary, operations: list[str], + task: str = "", ) -> list[_SourcePair]: """Run enrichment on documents that are missing the requested enrichments. @@ -229,7 +230,7 @@ def _ensure_enriched( if needed: log_info(f"Enriching {doc.name!r} with operations={needed}") - enriched_doc = enricher.run(task="", document=doc, operations=needed) + enriched_doc = enricher.run(task=task, document=doc, operations=needed) # Persist enriched document back to library library.store(enriched_doc, entry.source_path if entry else "in-memory") # Update status flags @@ -278,7 +279,7 @@ def _run_rag( ) -> DoclingDocument: log_info(f"_run_rag: query={task.query!r}, docs={len(source_pairs)}") if task.enrich_before_rag: - source_pairs = self._ensure_enriched(source_pairs, library, operations=["summarize"]) + source_pairs = self._ensure_enriched(source_pairs, library, operations=["summarize"], task=task.query) docs: list[DoclingDocument | Path] = [doc for doc, _ in source_pairs] rag_agent = DoclingRAGAgent( @@ -364,7 +365,7 @@ def _run_enrich( enriched_pairs.append((enriched_doc, doc_id)) else: ops: list[str] = list(task.operations) - enriched_pairs = self._ensure_enriched(source_pairs, library, operations=ops) + enriched_pairs = self._ensure_enriched(source_pairs, library, operations=ops, task=task.query) # Return: single doc → return it directly; multiple → a composite summary doc if len(enriched_pairs) == 1: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1667149 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,28 @@ +"""Pytest configuration and shared fixtures for docling-agent tests. + +This file is automatically discovered by pytest and makes fixtures +available to all test modules without explicit imports. +""" + +import json +from pathlib import Path + +import pytest +from docling_core.types.doc.document import DoclingDocument + +from .test_utils import MockBackend + + +@pytest.fixture(scope="module") +def mock_backend() -> MockBackend: + """Fixture providing a mocked backend instance.""" + return MockBackend() + + +@pytest.fixture(scope="module") +def test_document() -> DoclingDocument: + """Fixture providing the test document loaded from JSON.""" + json_path = Path("tests/data/2408.09869v5.json") + with open(json_path) as f: + doc_dict = json.load(f) + return DoclingDocument.model_validate(doc_dict) diff --git a/tests/test_enricher.py b/tests/test_enricher.py index ffb5671..21118c5 100644 --- a/tests/test_enricher.py +++ b/tests/test_enricher.py @@ -335,6 +335,295 @@ def _fake_generate_summary(self, *, m, text, loop_budget=5, style="sentences", s assert pages_with_summaries > 0 +def test_parse_spec_dict_with_find_json_dicts(enricher: DoclingEnrichingAgent) -> None: + """Test _parse_spec_dict helper function with various JSON formats.""" + # Simulate the _parse_spec_dict function behavior + import json + + from docling_agent.agent.base_functions import find_json_dicts + + def _parse_spec_dict(content: str) -> dict | None: + matches = find_json_dicts(text=content) + if len(matches) == 1 and isinstance(matches[0], dict): + return matches[0] + try: + parsed = json.loads(content.strip()) + except (json.JSONDecodeError, ValueError): + return None + return parsed if isinstance(parsed, dict) else None + + # Test 1: Valid JSON in code block (find_json_dicts path) + content1 = """```json +{ + "generic": false, + "labels": ["Person", "Organization"], + "focus_terms": ["CEO", "company"], + "rewritten_task": "Extract person and organization names" +} +```""" + result1 = _parse_spec_dict(content1) + assert result1 is not None + assert result1["labels"] == ["Person", "Organization"] + assert result1["generic"] is False + + # Test 2: Valid JSON without code block (json.loads fallback) + content2 = '{"generic": true, "labels": [], "focus_terms": [], "rewritten_task": ""}' + result2 = _parse_spec_dict(content2) + assert result2 is not None + assert result2["generic"] is True + + # Test 3: Invalid JSON + content3 = "{invalid json}" + result3 = _parse_spec_dict(content3) + assert result3 is None + + # Test 4: JSON array (should return None as we expect dict) + content4 = '[{"key": "value"}]' + result4 = _parse_spec_dict(content4) + assert result4 is None + + # Test 5: Multiple JSON blocks (find_json_dicts returns multiple, should fail) + content5 = """```json +{"key1": "value1"} +``` +```json +{"key2": "value2"} +```""" + result5 = _parse_spec_dict(content5) + assert result5 is None + + +def test_infer_entity_targets_with_task(monkeypatch: pytest.MonkeyPatch, enricher: DoclingEnrichingAgent) -> None: + """Test that _infer_entity_targets correctly parses entity specifications from LLM response.""" + + def _fake_generate_content(self, *, m, text, task_prompt, requirement_description, validation_fn, loop_budget=5): + # Simulate LLM returning entity target specification + return """```json +{ + "generic": false, + "labels": ["Person", "Email", "Phone"], + "focus_terms": ["name", "contact", "email address"], + "rewritten_task": "Extract personal information including person names, email addresses, and phone numbers" +} +```""" + + monkeypatch.setattr(DoclingEnrichingAgent, "_generate_content", _fake_generate_content) + + m = enricher._create_extraction_session() + entity_targets = enricher._infer_entity_targets( + m=m, task="Find personal information in the document", loop_budget=3 + ) + + assert entity_targets is not None + assert entity_targets["generic"] is False + assert entity_targets["labels"] == ["Person", "Email", "Phone"] + assert "name" in entity_targets["focus_terms"] + assert "Extract personal information" in entity_targets["rewritten_task"] + + +def test_infer_entity_targets_without_task(enricher: DoclingEnrichingAgent) -> None: + """Test that _infer_entity_targets returns None when no task is provided.""" + m = enricher._create_extraction_session() + entity_targets = enricher._infer_entity_targets(m=m, task=None, loop_budget=3) + assert entity_targets is None + + entity_targets = enricher._infer_entity_targets(m=m, task="", loop_budget=3) + assert entity_targets is None + + +def test_generate_entities_with_label_filtering( + monkeypatch: pytest.MonkeyPatch, enricher: DoclingEnrichingAgent +) -> None: + """Test that _generate_entities correctly filters entities by allowed labels.""" + + def _fake_generate_content(self, *, m, text, task_prompt, requirement_description, validation_fn, loop_budget=5): + # Simulate LLM returning entities with mixed labels (some allowed, some not) + return """```json +[ + {"text": "John Doe", "label": "Person"}, + {"text": "john@example.com", "label": "Email"}, + {"text": "Acme Corp", "label": "Organization"}, + {"text": "New York", "label": "Location"}, + {"text": "GPT-4", "label": "AI-Model"} +] +```""" + + monkeypatch.setattr(DoclingEnrichingAgent, "_generate_content", _fake_generate_content) + + m = enricher._create_extraction_session() + text = "John Doe (john@example.com) works at Acme Corp in New York using GPT-4." + + # Test with allowed labels - should filter out Organization, Location, AI-Model + entity_targets = { + "labels": ["Person", "Email"], + "focus_terms": [], + "rewritten_task": "Extract person names and emails", + "generic": False, + } + + result = enricher._generate_entities( + m=m, text=text, task="Find people and emails", entity_targets=entity_targets, loop_budget=3 + ) + + assert result is not None + assert len(result.mentions) == 2 # Only Person and Email should be kept + labels = {mention.label for mention in result.mentions} + assert labels == {"Person", "Email"} + assert "Organization" not in labels + assert "Location" not in labels + + +def test_generate_entities_case_insensitive_labels( + monkeypatch: pytest.MonkeyPatch, enricher: DoclingEnrichingAgent +) -> None: + """Test that label filtering is case-insensitive.""" + + def _fake_generate_content(self, *, m, text, task_prompt, requirement_description, validation_fn, loop_budget=5): + # LLM returns labels in different cases + return """```json +[ + {"text": "John Doe", "label": "PERSON"}, + {"text": "jane@example.com", "label": "email"}, + {"text": "Acme Corp", "label": "Organization"} +] +```""" + + monkeypatch.setattr(DoclingEnrichingAgent, "_generate_content", _fake_generate_content) + + m = enricher._create_extraction_session() + text = "John Doe and jane@example.com work at Acme Corp." + + # Allowed labels in mixed case + entity_targets = { + "labels": ["Person", "Email"], # lowercase in spec + "focus_terms": [], + "rewritten_task": "", + "generic": False, + } + + result = enricher._generate_entities(m=m, text=text, task="", entity_targets=entity_targets, loop_budget=3) + + assert result is not None + assert len(result.mentions) == 2 # PERSON and email should match Person and Email + labels = {mention.label for mention in result.mentions} + assert "PERSON" in labels or "Person" in labels + assert "email" in labels or "Email" in labels + + +def test_generate_entities_with_dict_wrapped_response( + monkeypatch: pytest.MonkeyPatch, enricher: DoclingEnrichingAgent +) -> None: + """Test that _generate_entities handles LLM responses that wrap entities in a dict.""" + + def _fake_generate_content(self, *, m, text, task_prompt, requirement_description, validation_fn, loop_budget=5): + # LLM returns {"entities": [...]} instead of just [...] + return """```json +{ + "entities": [ + {"text": "John Doe", "label": "Person"}, + {"text": "john@example.com", "label": "Email"} + ] +} +```""" + + monkeypatch.setattr(DoclingEnrichingAgent, "_generate_content", _fake_generate_content) + + m = enricher._create_extraction_session() + text = "John Doe (john@example.com)" + + result = enricher._generate_entities(m=m, text=text, task="", entity_targets=None, loop_budget=3) + + assert result is not None + assert len(result.mentions) == 2 + assert result.mentions[0].text == "John Doe" + assert result.mentions[1].text == "john@example.com" + + +def test_generate_entities_without_code_block(monkeypatch: pytest.MonkeyPatch, enricher: DoclingEnrichingAgent) -> None: + """Test that _generate_entities handles JSON without markdown code blocks.""" + + def _fake_generate_content(self, *, m, text, task_prompt, requirement_description, validation_fn, loop_budget=5): + # LLM returns raw JSON without code block + return '[{"text": "John Doe", "label": "Person"}]' + + monkeypatch.setattr(DoclingEnrichingAgent, "_generate_content", _fake_generate_content) + + m = enricher._create_extraction_session() + text = "John Doe is here." + + result = enricher._generate_entities(m=m, text=text, task="", entity_targets=None, loop_budget=3) + + assert result is not None + assert len(result.mentions) == 1 + assert result.mentions[0].text == "John Doe" + + +def test_generate_entities_hard_constraint_prompt( + monkeypatch: pytest.MonkeyPatch, enricher: DoclingEnrichingAgent +) -> None: + """Test that allowed_labels triggers HARD CONSTRAINT prompt.""" + captured_prompts = [] + + def _fake_generate_content(self, *, m, text, task_prompt, requirement_description, validation_fn, loop_budget=5): + captured_prompts.append(task_prompt) + return "[]" + + monkeypatch.setattr(DoclingEnrichingAgent, "_generate_content", _fake_generate_content) + + m = enricher._create_extraction_session() + + # Test with allowed labels + entity_targets = { + "labels": ["Person", "Email"], + "focus_terms": ["contact"], + "rewritten_task": "Extract contacts", + "generic": False, + } + + enricher._generate_entities(m=m, text="test", task="", entity_targets=entity_targets, loop_budget=3) + + assert len(captured_prompts) == 1 + prompt = captured_prompts[0] + assert "HARD CONSTRAINT" in prompt + assert "Use ONLY these label values" in prompt + assert "['Person', 'Email']" in prompt + assert "do NOT invent new labels" in prompt + + +def test_extract_entities_session_isolation( + monkeypatch: pytest.MonkeyPatch, test_document: DoclingDocument, enricher: DoclingEnrichingAgent +) -> None: + """Test that entity extraction uses separate session from entity target inference. + + This is a regression test for the bug fix where a separate session (m_leaf) + is created for leaf entity extraction to prevent state pollution. + """ + session_creation_count = [0] + original_create_session = enricher._create_extraction_session + + def _counting_create_session(): + session_creation_count[0] += 1 + return original_create_session() + + monkeypatch.setattr(enricher, "_create_extraction_session", _counting_create_session) + + # Mock the methods to avoid actual LLM calls + def _fake_infer_targets(self, *, m, task, loop_budget): + return {"labels": ["Person"], "focus_terms": [], "rewritten_task": "", "generic": False} + + def _fake_extract_from_leaf(self, *, m, document, loop_budget, entity_targets, task): + pass + + monkeypatch.setattr(DoclingEnrichingAgent, "_infer_entity_targets", _fake_infer_targets) + monkeypatch.setattr(DoclingEnrichingAgent, "_extract_entities_from_leaf_items", _fake_extract_from_leaf) + + # Run entity detection (which calls entity extraction internally) + enricher._detect_key_entities(document=test_document, task="Find people", loop_budget=3) + + # Should create 2 sessions: one for inference, one for extraction + assert session_creation_count[0] == 2 + + if __name__ == "__main__": print("=" * 70) print("TEST 1: Demonstrating the heading levels problem") diff --git a/tests/test_orchestrator.py b/tests/test_orchestrator.py new file mode 100644 index 0000000..4198401 --- /dev/null +++ b/tests/test_orchestrator.py @@ -0,0 +1,124 @@ +"""Tests for DoclingOrchestratorAgent. + +This module includes regression tests for the entity constraint bug fix, +specifically testing task propagation from orchestrator to enricher. +""" + +import tempfile +from pathlib import Path + +import pytest +from docling_core.types.doc.document import DoclingDocument + +from docling_agent.agent.enricher import DoclingEnrichingAgent +from docling_agent.agent.library import DoclingLibrary +from docling_agent.agent.orchestrator import DoclingOrchestratorAgent + + +@pytest.fixture +def enricher(mock_backend) -> DoclingEnrichingAgent: + """Fixture providing a DoclingEnrichingAgent instance.""" + return DoclingEnrichingAgent(backend=mock_backend, tools=[]) + + +@pytest.fixture +def orchestrator(mock_backend) -> DoclingOrchestratorAgent: + """Fixture providing a DoclingOrchestratorAgent instance.""" + return DoclingOrchestratorAgent(backend=mock_backend, tools=[]) + + +def test_ensure_enriched_task_propagation( + monkeypatch: pytest.MonkeyPatch, orchestrator: DoclingOrchestratorAgent, test_document: DoclingDocument +) -> None: + """Test that _ensure_enriched propagates task parameter to enricher.run(). + + This is a regression test for the bug fix where task was not being passed + from orchestrator to enricher, causing entity constraints to be ignored. + """ + captured_task = [] + + def _fake_enricher_run(self, task, document, operations): + captured_task.append(task) + return document + + # Mock the enricher's run method + monkeypatch.setattr(DoclingEnrichingAgent, "run", _fake_enricher_run) + + # Create a mock library with temporary directory + with tempfile.TemporaryDirectory() as tmpdir: + library = DoclingLibrary(path=Path(tmpdir)) + library.store(test_document, "test.json") + + # Call _ensure_enriched with a task + source_pairs = [(test_document, "test-doc-id")] + task_query = "Extract person names and email addresses" + + orchestrator._ensure_enriched( + source_pairs=source_pairs, library=library, operations=["entities"], task=task_query + ) + + # Verify task was passed to enricher + assert len(captured_task) == 1 + assert captured_task[0] == task_query + + +def test_ensure_enriched_empty_task_default( + monkeypatch: pytest.MonkeyPatch, orchestrator: DoclingOrchestratorAgent, test_document: DoclingDocument +) -> None: + """Test that _ensure_enriched uses empty string as default task.""" + captured_task = [] + + def _fake_enricher_run(self, task, document, operations): + captured_task.append(task) + return document + + monkeypatch.setattr(DoclingEnrichingAgent, "run", _fake_enricher_run) + + with tempfile.TemporaryDirectory() as tmpdir: + library = DoclingLibrary(path=Path(tmpdir)) + library.store(test_document, "test.json") + + source_pairs = [(test_document, "test-doc-id")] + + # Call without task parameter (should default to "") + orchestrator._ensure_enriched(source_pairs=source_pairs, library=library, operations=["entities"]) + + assert len(captured_task) == 1 + assert captured_task[0] == "" + + +def test_ensure_enriched_multiple_operations( + monkeypatch: pytest.MonkeyPatch, orchestrator: DoclingOrchestratorAgent, test_document: DoclingDocument +) -> None: + """Test that _ensure_enriched handles multiple operations correctly.""" + enricher_calls = [] + + def _fake_enricher_run(self, task, document, operations): + enricher_calls.append((task, operations)) + return document + + monkeypatch.setattr(DoclingEnrichingAgent, "run", _fake_enricher_run) + + with tempfile.TemporaryDirectory() as tmpdir: + library = DoclingLibrary(path=Path(tmpdir)) + library.store(test_document, "test.json") + + source_pairs = [(test_document, "test-doc-id")] + task_query = "Enrich with summaries and entities" + + # Request multiple operations + orchestrator._ensure_enriched( + source_pairs=source_pairs, + library=library, + operations=["summarize", "entities"], + task=task_query, + ) + + # Enricher should be called once with both operations + assert len(enricher_calls) == 1 + assert enricher_calls[0][0] == task_query + assert set(enricher_calls[0][1]) == {"summarize", "entities"} + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])