From 81b9a8de0c38113fe0d7f1eea1505d5a427177d0 Mon Sep 17 00:00:00 2001 From: Ana Daniele Date: Wed, 27 May 2026 12:49:59 +0200 Subject: [PATCH 1/6] fix(enrich): propagate task to enricher and enforce label whitelist MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two related bugs prevented YAML-declared entity-type constraints from reaching NuExtract: 1. `orchestrator._ensure_enriched` hard-coded `task=""` when the YAML used the `operations:` shortcut, dropping `task.query` on the floor. With no task, `_infer_entity_targets` returned None, so the prompt sent to the model never included the allowed labels. Now `_ensure_enriched` accepts `task` and both `_run_enrich` and `_run_rag` pass `task.query` through. 2. `enricher._generate_entities` only applied `entity_targets` as a soft hint, and had no post-filter on returned labels. When the model invented types (Project, Software, Algorithm, IP-Address, …) they flowed through into the rendered output. Now the prompt carries a HARD CONSTRAINT block listing the allowed labels, and the parser drops any mention whose case-folded label is not in the allowed set (with an INFO log line summarising drops per call). Verified on df3 with the `2026-05-22a_enrich_nuextract3_postpatch` run that the enricher-only patch was a no-op without (1); end-to-end re-run with both fixes is queued. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Ana Daniele --- docling_agent/agent/enricher.py | 30 ++++++++++++++++++++++++++--- docling_agent/agent/orchestrator.py | 7 ++++--- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/docling_agent/agent/enricher.py b/docling_agent/agent/enricher.py index 13aa482..ed13726 100644 --- a/docling_agent/agent/enricher.py +++ b/docling_agent/agent/enricher.py @@ -1035,17 +1035,28 @@ def _validate_entities(content: str) -> bool: target_clause = "" rewritten_task = task or "" + allowed_labels: list[str] = [] if entity_targets: - labels = entity_targets.get("labels", []) + allowed_labels = [str(lbl).strip() for lbl in entity_targets.get("labels", []) if str(lbl).strip()] focus_terms = entity_targets.get("focus_terms", []) rewritten_task = entity_targets.get("rewritten_task", rewritten_task) generic = entity_targets.get("generic", False) - if not generic or labels or focus_terms or rewritten_task: + if allowed_labels: + target_clause = ( + "\nHARD CONSTRAINT — you MUST obey this:\n" + f"Use ONLY these label values: {allowed_labels}.\n" + "Reject any other entity type. If a candidate mention does not fit one of those labels, " + "omit it entirely — do NOT output it under a different label, do NOT invent new labels.\n" + "Apply the same definitions to every paragraph and every section, including reference lists, " + "captions, and tables.\n" + f"\nExtraction brief (for context only, does not expand the label set):\n{rewritten_task}\n" + f"- focus_terms (hints, not exhaustive): {focus_terms}\n" + ) + elif not generic or focus_terms or rewritten_task: target_clause = ( "\nUse this rewritten extraction brief:\n" f"{rewritten_task}\n" "Focus on entities that match the brief, including obvious instances even if the wording differs.\n" - f"- labels: {labels}\n" f"- focus_terms: {focus_terms}\n" "If an entity is not relevant to that brief, omit it." ) @@ -1076,11 +1087,18 @@ def _validate_entities(content: str) -> bool: if match: try: payload = json.loads(match.group(1)) + allowed_lookup = {lbl.casefold() for lbl in allowed_labels} mentions: list[EntityMention] = [] + dropped_by_label: dict[str, int] = {} search_start = 0 for item in payload: if not isinstance(item, dict) or not str(item.get("text", "")).strip(): continue + if allowed_lookup: + raw_label = str(item.get("label", "")).strip() + if raw_label.casefold() not in allowed_lookup: + dropped_by_label[raw_label or ""] = dropped_by_label.get(raw_label or "", 0) + 1 + continue mention = self._make_entity_mention( item=item, source_text=text, @@ -1090,6 +1108,12 @@ def _validate_entities(content: str) -> bool: if mention.charspan is not None: search_start = mention.charspan[1] mentions.append(mention) + if dropped_by_label: + log_info( + "Dropped mentions with disallowed labels", + count=sum(dropped_by_label.values()), + by_label=dropped_by_label, + ) log_debug("Parsed entities", count=len(mentions)) if mentions: return EntitiesMetaField(mentions=mentions) diff --git a/docling_agent/agent/orchestrator.py b/docling_agent/agent/orchestrator.py index 145d93a..2083c8f 100644 --- a/docling_agent/agent/orchestrator.py +++ b/docling_agent/agent/orchestrator.py @@ -205,6 +205,7 @@ def _ensure_enriched( source_pairs: list[_SourcePair], library: DoclingLibrary, operations: list[str], + task: str = "", ) -> list[_SourcePair]: """Run enrichment on documents that are missing the requested enrichments. @@ -229,7 +230,7 @@ def _ensure_enriched( if needed: log_info(f"Enriching {doc.name!r} with operations={needed}") - enriched_doc = enricher.run(task="", document=doc, operations=needed) + enriched_doc = enricher.run(task=task, document=doc, operations=needed) # Persist enriched document back to library library.store(enriched_doc, entry.source_path if entry else "in-memory") # Update status flags @@ -278,7 +279,7 @@ def _run_rag( ) -> DoclingDocument: log_info(f"_run_rag: query={task.query!r}, docs={len(source_pairs)}") if task.enrich_before_rag: - source_pairs = self._ensure_enriched(source_pairs, library, operations=["summarize"]) + source_pairs = self._ensure_enriched(source_pairs, library, operations=["summarize"], task=task.query) docs: list[DoclingDocument | Path] = [doc for doc, _ in source_pairs] rag_agent = DoclingRAGAgent( @@ -364,7 +365,7 @@ def _run_enrich( enriched_pairs.append((enriched_doc, doc_id)) else: ops: list[str] = list(task.operations) - enriched_pairs = self._ensure_enriched(source_pairs, library, operations=ops) + enriched_pairs = self._ensure_enriched(source_pairs, library, operations=ops, task=task.query) # Return: single doc → return it directly; multiple → a composite summary doc if len(enriched_pairs) == 1: From 9f8afb1a413e7578c09247e853804046661b81af Mon Sep 17 00:00:00 2001 From: Ana Daniele Date: Mon, 1 Jun 2026 07:37:53 +0200 Subject: [PATCH 2/6] fix(enrich): accept raw JSON from rewrite LLM in _infer_entity_targets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `_validate_entity_target_spec` (and the post-validation parse on line 688) used `find_json_dicts`, which only matches JSON wrapped in a ```json ...``` markdown block. NuExtract3 ignores that part of the prompt and returns a bare JSON object instead — well-formed, just unfenced. Validation then failed and `_infer_entity_targets` returned None, so `entity_targets` reached `_generate_entities` as None and the HARD CONSTRAINT clause from the previous commit was silently skipped. Introduce `_parse_spec_dict` inside `_infer_entity_targets`: try `find_json_dicts` first (preserves existing behaviour for models that do use the fence), and fall back to `json.loads(content.strip())` when no fence is found. Both the validation hook and the final parse use the same helper. Confirmed end-to-end on df3 (`2026-06-01b_..._postpatch_v3`): the brief now parses, `entity_targets["labels"] = ["MODEL", "DATASET", "KPI"]`, and every per-chunk LLM REQUEST carries the HARD CONSTRAINT block. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Ana Daniele --- docling_agent/agent/enricher.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/docling_agent/agent/enricher.py b/docling_agent/agent/enricher.py index ed13726..86dd90e 100644 --- a/docling_agent/agent/enricher.py +++ b/docling_agent/agent/enricher.py @@ -880,11 +880,20 @@ def _infer_entity_targets( if not task: return None - def _validate_entity_target_spec(content: str) -> bool: + def _parse_spec_dict(content: str) -> dict | None: matches = find_json_dicts(text=content) - if len(matches) != 1: + if len(matches) == 1 and isinstance(matches[0], dict): + return matches[0] + try: + parsed = json.loads(content.strip()) + except (json.JSONDecodeError, ValueError): + return None + return parsed if isinstance(parsed, dict) else None + + def _validate_entity_target_spec(content: str) -> bool: + spec = _parse_spec_dict(content) + if spec is None: return False - spec = matches[0] labels = spec.get("labels", []) focus_terms = spec.get("focus_terms", []) generic = spec.get("generic", False) @@ -923,8 +932,7 @@ def _validate_entity_target_spec(content: str) -> bool: if not answer: return None - specs = find_json_dicts(text=answer) - return specs[0] if specs else None + return _parse_spec_dict(answer) def _walk_and_extract_entities( self, From 7fe67be77e54cf71c1febb7f95813fb5e9697bcd Mon Sep 17 00:00:00 2001 From: Ana Daniele Date: Mon, 1 Jun 2026 08:16:52 +0200 Subject: [PATCH 3/6] fix(enrich): isolate sessions, tolerate raw entity responses, use docling-core fields Three issues surfaced once the prompt actually carried the HARD CONSTRAINT block and NuExtract started responding with the expected schema. 1. Session bleed. `_detect_key_entities` reused one LM Studio session for both `_infer_entity_targets` (which primes the model with the spec schema `{"generic":..., "labels":...}`) and the per-chunk extraction calls. NuExtract carried the prior turn's schema forward and answered every chunk with the spec dict instead of an entity array. Now the leaf stage opens its own session via `_create_extraction_session`. 2. Entity parser too strict. The response parser required a ```json ...``` fenced block and assumed the top-level payload was a list. NuExtract returns bare JSON, usually as `{"entities": [...]}`. Fall back to `result.strip()` when no fence is present and unwrap a single `entities` key into the list before iterating; coerce other shapes to an empty list so downstream code keeps its invariants. 3. EntityMention field names. The mention constructor used the old docling-core argument names (`original=`, `span=`). The current docling-core EntityMention expects `orig=` and `charspan=`; the mismatch raised a pydantic validation error on every successful response, which the surrounding `try` swallowed as "Failed to parse entities JSON". End-to-end on df3: `2026-06-01e_..._postpatch_v6` is the first run where parsed entities are non-empty, e.g. `[EntityMention( text='Nougat', label='MODEL', charspan=(0, 6)), ...]`, with zero pydantic warnings. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Ana Daniele --- docling_agent/agent/enricher.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/docling_agent/agent/enricher.py b/docling_agent/agent/enricher.py index 86dd90e..ba5dbba 100644 --- a/docling_agent/agent/enricher.py +++ b/docling_agent/agent/enricher.py @@ -860,8 +860,9 @@ def _detect_key_entities( ) with self._timed_stage("entities: leaf entities"): + m_leaf = self._create_extraction_session() self._extract_entities_from_leaf_items( - m=m, + m=m_leaf, document=hier_doc, loop_budget=loop_budget, entity_targets=entity_targets, @@ -1092,9 +1093,14 @@ def _validate_entities(content: str) -> bool: if result: match = re.search(r"```json\s*(.*?)\s*```", result, re.DOTALL) - if match: + json_text = match.group(1) if match else result.strip() + if json_text: try: - payload = json.loads(match.group(1)) + payload = json.loads(json_text) + if isinstance(payload, dict) and isinstance(payload.get("entities"), list): + payload = payload["entities"] + if not isinstance(payload, list): + payload = [] allowed_lookup = {lbl.casefold() for lbl in allowed_labels} mentions: list[EntityMention] = [] dropped_by_label: dict[str, int] = {} From cbb12a4dcb56a82f2bd704f3e778896e4709d336 Mon Sep 17 00:00:00 2001 From: Cesar Berrospi Ramis Date: Fri, 5 Jun 2026 11:27:07 +0200 Subject: [PATCH 4/6] style: fix formatting Signed-off-by: Cesar Berrospi Ramis --- docling_agent/agent/enricher.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docling_agent/agent/enricher.py b/docling_agent/agent/enricher.py index ba5dbba..157073f 100644 --- a/docling_agent/agent/enricher.py +++ b/docling_agent/agent/enricher.py @@ -1111,7 +1111,9 @@ def _validate_entities(content: str) -> bool: if allowed_lookup: raw_label = str(item.get("label", "")).strip() if raw_label.casefold() not in allowed_lookup: - dropped_by_label[raw_label or ""] = dropped_by_label.get(raw_label or "", 0) + 1 + dropped_by_label[raw_label or ""] = ( + dropped_by_label.get(raw_label or "", 0) + 1 + ) continue mention = self._make_entity_mention( item=item, From 1574ee1bb7f6383eb8bbf8dba3100383ce88bf4b Mon Sep 17 00:00:00 2001 From: Cesar Berrospi Ramis Date: Fri, 5 Jun 2026 11:58:51 +0200 Subject: [PATCH 5/6] chore(enricher): add defensive handling in '_validate_entities' Add defensive handling in '_validate_entities()' to unwrap responses ensuring validation consistency with the main parsing logic. Signed-off-by: Cesar Berrospi Ramis --- docling_agent/agent/enricher.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docling_agent/agent/enricher.py b/docling_agent/agent/enricher.py index 157073f..c118430 100644 --- a/docling_agent/agent/enricher.py +++ b/docling_agent/agent/enricher.py @@ -1036,6 +1036,8 @@ def _validate_entities(content: str) -> bool: return False try: val = json.loads(match.group(1)) + if isinstance(val, dict) and isinstance(val.get("entities"), list): + val = val["entities"] if not isinstance(val, list): return False return all(isinstance(item, dict) and "text" in item for item in val) From d4350a2e38e667d98ec32f2b4c0dda9ef08b7117 Mon Sep 17 00:00:00 2001 From: Cesar Berrospi Ramis Date: Fri, 5 Jun 2026 12:00:30 +0200 Subject: [PATCH 6/6] test: add regression tests for entity constraints Signed-off-by: Cesar Berrospi Ramis --- tests/conftest.py | 28 ++++ tests/test_enricher.py | 289 +++++++++++++++++++++++++++++++++++++ tests/test_orchestrator.py | 124 ++++++++++++++++ 3 files changed, 441 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/test_orchestrator.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1667149 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,28 @@ +"""Pytest configuration and shared fixtures for docling-agent tests. + +This file is automatically discovered by pytest and makes fixtures +available to all test modules without explicit imports. +""" + +import json +from pathlib import Path + +import pytest +from docling_core.types.doc.document import DoclingDocument + +from .test_utils import MockBackend + + +@pytest.fixture(scope="module") +def mock_backend() -> MockBackend: + """Fixture providing a mocked backend instance.""" + return MockBackend() + + +@pytest.fixture(scope="module") +def test_document() -> DoclingDocument: + """Fixture providing the test document loaded from JSON.""" + json_path = Path("tests/data/2408.09869v5.json") + with open(json_path) as f: + doc_dict = json.load(f) + return DoclingDocument.model_validate(doc_dict) diff --git a/tests/test_enricher.py b/tests/test_enricher.py index ffb5671..21118c5 100644 --- a/tests/test_enricher.py +++ b/tests/test_enricher.py @@ -335,6 +335,295 @@ def _fake_generate_summary(self, *, m, text, loop_budget=5, style="sentences", s assert pages_with_summaries > 0 +def test_parse_spec_dict_with_find_json_dicts(enricher: DoclingEnrichingAgent) -> None: + """Test _parse_spec_dict helper function with various JSON formats.""" + # Simulate the _parse_spec_dict function behavior + import json + + from docling_agent.agent.base_functions import find_json_dicts + + def _parse_spec_dict(content: str) -> dict | None: + matches = find_json_dicts(text=content) + if len(matches) == 1 and isinstance(matches[0], dict): + return matches[0] + try: + parsed = json.loads(content.strip()) + except (json.JSONDecodeError, ValueError): + return None + return parsed if isinstance(parsed, dict) else None + + # Test 1: Valid JSON in code block (find_json_dicts path) + content1 = """```json +{ + "generic": false, + "labels": ["Person", "Organization"], + "focus_terms": ["CEO", "company"], + "rewritten_task": "Extract person and organization names" +} +```""" + result1 = _parse_spec_dict(content1) + assert result1 is not None + assert result1["labels"] == ["Person", "Organization"] + assert result1["generic"] is False + + # Test 2: Valid JSON without code block (json.loads fallback) + content2 = '{"generic": true, "labels": [], "focus_terms": [], "rewritten_task": ""}' + result2 = _parse_spec_dict(content2) + assert result2 is not None + assert result2["generic"] is True + + # Test 3: Invalid JSON + content3 = "{invalid json}" + result3 = _parse_spec_dict(content3) + assert result3 is None + + # Test 4: JSON array (should return None as we expect dict) + content4 = '[{"key": "value"}]' + result4 = _parse_spec_dict(content4) + assert result4 is None + + # Test 5: Multiple JSON blocks (find_json_dicts returns multiple, should fail) + content5 = """```json +{"key1": "value1"} +``` +```json +{"key2": "value2"} +```""" + result5 = _parse_spec_dict(content5) + assert result5 is None + + +def test_infer_entity_targets_with_task(monkeypatch: pytest.MonkeyPatch, enricher: DoclingEnrichingAgent) -> None: + """Test that _infer_entity_targets correctly parses entity specifications from LLM response.""" + + def _fake_generate_content(self, *, m, text, task_prompt, requirement_description, validation_fn, loop_budget=5): + # Simulate LLM returning entity target specification + return """```json +{ + "generic": false, + "labels": ["Person", "Email", "Phone"], + "focus_terms": ["name", "contact", "email address"], + "rewritten_task": "Extract personal information including person names, email addresses, and phone numbers" +} +```""" + + monkeypatch.setattr(DoclingEnrichingAgent, "_generate_content", _fake_generate_content) + + m = enricher._create_extraction_session() + entity_targets = enricher._infer_entity_targets( + m=m, task="Find personal information in the document", loop_budget=3 + ) + + assert entity_targets is not None + assert entity_targets["generic"] is False + assert entity_targets["labels"] == ["Person", "Email", "Phone"] + assert "name" in entity_targets["focus_terms"] + assert "Extract personal information" in entity_targets["rewritten_task"] + + +def test_infer_entity_targets_without_task(enricher: DoclingEnrichingAgent) -> None: + """Test that _infer_entity_targets returns None when no task is provided.""" + m = enricher._create_extraction_session() + entity_targets = enricher._infer_entity_targets(m=m, task=None, loop_budget=3) + assert entity_targets is None + + entity_targets = enricher._infer_entity_targets(m=m, task="", loop_budget=3) + assert entity_targets is None + + +def test_generate_entities_with_label_filtering( + monkeypatch: pytest.MonkeyPatch, enricher: DoclingEnrichingAgent +) -> None: + """Test that _generate_entities correctly filters entities by allowed labels.""" + + def _fake_generate_content(self, *, m, text, task_prompt, requirement_description, validation_fn, loop_budget=5): + # Simulate LLM returning entities with mixed labels (some allowed, some not) + return """```json +[ + {"text": "John Doe", "label": "Person"}, + {"text": "john@example.com", "label": "Email"}, + {"text": "Acme Corp", "label": "Organization"}, + {"text": "New York", "label": "Location"}, + {"text": "GPT-4", "label": "AI-Model"} +] +```""" + + monkeypatch.setattr(DoclingEnrichingAgent, "_generate_content", _fake_generate_content) + + m = enricher._create_extraction_session() + text = "John Doe (john@example.com) works at Acme Corp in New York using GPT-4." + + # Test with allowed labels - should filter out Organization, Location, AI-Model + entity_targets = { + "labels": ["Person", "Email"], + "focus_terms": [], + "rewritten_task": "Extract person names and emails", + "generic": False, + } + + result = enricher._generate_entities( + m=m, text=text, task="Find people and emails", entity_targets=entity_targets, loop_budget=3 + ) + + assert result is not None + assert len(result.mentions) == 2 # Only Person and Email should be kept + labels = {mention.label for mention in result.mentions} + assert labels == {"Person", "Email"} + assert "Organization" not in labels + assert "Location" not in labels + + +def test_generate_entities_case_insensitive_labels( + monkeypatch: pytest.MonkeyPatch, enricher: DoclingEnrichingAgent +) -> None: + """Test that label filtering is case-insensitive.""" + + def _fake_generate_content(self, *, m, text, task_prompt, requirement_description, validation_fn, loop_budget=5): + # LLM returns labels in different cases + return """```json +[ + {"text": "John Doe", "label": "PERSON"}, + {"text": "jane@example.com", "label": "email"}, + {"text": "Acme Corp", "label": "Organization"} +] +```""" + + monkeypatch.setattr(DoclingEnrichingAgent, "_generate_content", _fake_generate_content) + + m = enricher._create_extraction_session() + text = "John Doe and jane@example.com work at Acme Corp." + + # Allowed labels in mixed case + entity_targets = { + "labels": ["Person", "Email"], # lowercase in spec + "focus_terms": [], + "rewritten_task": "", + "generic": False, + } + + result = enricher._generate_entities(m=m, text=text, task="", entity_targets=entity_targets, loop_budget=3) + + assert result is not None + assert len(result.mentions) == 2 # PERSON and email should match Person and Email + labels = {mention.label for mention in result.mentions} + assert "PERSON" in labels or "Person" in labels + assert "email" in labels or "Email" in labels + + +def test_generate_entities_with_dict_wrapped_response( + monkeypatch: pytest.MonkeyPatch, enricher: DoclingEnrichingAgent +) -> None: + """Test that _generate_entities handles LLM responses that wrap entities in a dict.""" + + def _fake_generate_content(self, *, m, text, task_prompt, requirement_description, validation_fn, loop_budget=5): + # LLM returns {"entities": [...]} instead of just [...] + return """```json +{ + "entities": [ + {"text": "John Doe", "label": "Person"}, + {"text": "john@example.com", "label": "Email"} + ] +} +```""" + + monkeypatch.setattr(DoclingEnrichingAgent, "_generate_content", _fake_generate_content) + + m = enricher._create_extraction_session() + text = "John Doe (john@example.com)" + + result = enricher._generate_entities(m=m, text=text, task="", entity_targets=None, loop_budget=3) + + assert result is not None + assert len(result.mentions) == 2 + assert result.mentions[0].text == "John Doe" + assert result.mentions[1].text == "john@example.com" + + +def test_generate_entities_without_code_block(monkeypatch: pytest.MonkeyPatch, enricher: DoclingEnrichingAgent) -> None: + """Test that _generate_entities handles JSON without markdown code blocks.""" + + def _fake_generate_content(self, *, m, text, task_prompt, requirement_description, validation_fn, loop_budget=5): + # LLM returns raw JSON without code block + return '[{"text": "John Doe", "label": "Person"}]' + + monkeypatch.setattr(DoclingEnrichingAgent, "_generate_content", _fake_generate_content) + + m = enricher._create_extraction_session() + text = "John Doe is here." + + result = enricher._generate_entities(m=m, text=text, task="", entity_targets=None, loop_budget=3) + + assert result is not None + assert len(result.mentions) == 1 + assert result.mentions[0].text == "John Doe" + + +def test_generate_entities_hard_constraint_prompt( + monkeypatch: pytest.MonkeyPatch, enricher: DoclingEnrichingAgent +) -> None: + """Test that allowed_labels triggers HARD CONSTRAINT prompt.""" + captured_prompts = [] + + def _fake_generate_content(self, *, m, text, task_prompt, requirement_description, validation_fn, loop_budget=5): + captured_prompts.append(task_prompt) + return "[]" + + monkeypatch.setattr(DoclingEnrichingAgent, "_generate_content", _fake_generate_content) + + m = enricher._create_extraction_session() + + # Test with allowed labels + entity_targets = { + "labels": ["Person", "Email"], + "focus_terms": ["contact"], + "rewritten_task": "Extract contacts", + "generic": False, + } + + enricher._generate_entities(m=m, text="test", task="", entity_targets=entity_targets, loop_budget=3) + + assert len(captured_prompts) == 1 + prompt = captured_prompts[0] + assert "HARD CONSTRAINT" in prompt + assert "Use ONLY these label values" in prompt + assert "['Person', 'Email']" in prompt + assert "do NOT invent new labels" in prompt + + +def test_extract_entities_session_isolation( + monkeypatch: pytest.MonkeyPatch, test_document: DoclingDocument, enricher: DoclingEnrichingAgent +) -> None: + """Test that entity extraction uses separate session from entity target inference. + + This is a regression test for the bug fix where a separate session (m_leaf) + is created for leaf entity extraction to prevent state pollution. + """ + session_creation_count = [0] + original_create_session = enricher._create_extraction_session + + def _counting_create_session(): + session_creation_count[0] += 1 + return original_create_session() + + monkeypatch.setattr(enricher, "_create_extraction_session", _counting_create_session) + + # Mock the methods to avoid actual LLM calls + def _fake_infer_targets(self, *, m, task, loop_budget): + return {"labels": ["Person"], "focus_terms": [], "rewritten_task": "", "generic": False} + + def _fake_extract_from_leaf(self, *, m, document, loop_budget, entity_targets, task): + pass + + monkeypatch.setattr(DoclingEnrichingAgent, "_infer_entity_targets", _fake_infer_targets) + monkeypatch.setattr(DoclingEnrichingAgent, "_extract_entities_from_leaf_items", _fake_extract_from_leaf) + + # Run entity detection (which calls entity extraction internally) + enricher._detect_key_entities(document=test_document, task="Find people", loop_budget=3) + + # Should create 2 sessions: one for inference, one for extraction + assert session_creation_count[0] == 2 + + if __name__ == "__main__": print("=" * 70) print("TEST 1: Demonstrating the heading levels problem") diff --git a/tests/test_orchestrator.py b/tests/test_orchestrator.py new file mode 100644 index 0000000..4198401 --- /dev/null +++ b/tests/test_orchestrator.py @@ -0,0 +1,124 @@ +"""Tests for DoclingOrchestratorAgent. + +This module includes regression tests for the entity constraint bug fix, +specifically testing task propagation from orchestrator to enricher. +""" + +import tempfile +from pathlib import Path + +import pytest +from docling_core.types.doc.document import DoclingDocument + +from docling_agent.agent.enricher import DoclingEnrichingAgent +from docling_agent.agent.library import DoclingLibrary +from docling_agent.agent.orchestrator import DoclingOrchestratorAgent + + +@pytest.fixture +def enricher(mock_backend) -> DoclingEnrichingAgent: + """Fixture providing a DoclingEnrichingAgent instance.""" + return DoclingEnrichingAgent(backend=mock_backend, tools=[]) + + +@pytest.fixture +def orchestrator(mock_backend) -> DoclingOrchestratorAgent: + """Fixture providing a DoclingOrchestratorAgent instance.""" + return DoclingOrchestratorAgent(backend=mock_backend, tools=[]) + + +def test_ensure_enriched_task_propagation( + monkeypatch: pytest.MonkeyPatch, orchestrator: DoclingOrchestratorAgent, test_document: DoclingDocument +) -> None: + """Test that _ensure_enriched propagates task parameter to enricher.run(). + + This is a regression test for the bug fix where task was not being passed + from orchestrator to enricher, causing entity constraints to be ignored. + """ + captured_task = [] + + def _fake_enricher_run(self, task, document, operations): + captured_task.append(task) + return document + + # Mock the enricher's run method + monkeypatch.setattr(DoclingEnrichingAgent, "run", _fake_enricher_run) + + # Create a mock library with temporary directory + with tempfile.TemporaryDirectory() as tmpdir: + library = DoclingLibrary(path=Path(tmpdir)) + library.store(test_document, "test.json") + + # Call _ensure_enriched with a task + source_pairs = [(test_document, "test-doc-id")] + task_query = "Extract person names and email addresses" + + orchestrator._ensure_enriched( + source_pairs=source_pairs, library=library, operations=["entities"], task=task_query + ) + + # Verify task was passed to enricher + assert len(captured_task) == 1 + assert captured_task[0] == task_query + + +def test_ensure_enriched_empty_task_default( + monkeypatch: pytest.MonkeyPatch, orchestrator: DoclingOrchestratorAgent, test_document: DoclingDocument +) -> None: + """Test that _ensure_enriched uses empty string as default task.""" + captured_task = [] + + def _fake_enricher_run(self, task, document, operations): + captured_task.append(task) + return document + + monkeypatch.setattr(DoclingEnrichingAgent, "run", _fake_enricher_run) + + with tempfile.TemporaryDirectory() as tmpdir: + library = DoclingLibrary(path=Path(tmpdir)) + library.store(test_document, "test.json") + + source_pairs = [(test_document, "test-doc-id")] + + # Call without task parameter (should default to "") + orchestrator._ensure_enriched(source_pairs=source_pairs, library=library, operations=["entities"]) + + assert len(captured_task) == 1 + assert captured_task[0] == "" + + +def test_ensure_enriched_multiple_operations( + monkeypatch: pytest.MonkeyPatch, orchestrator: DoclingOrchestratorAgent, test_document: DoclingDocument +) -> None: + """Test that _ensure_enriched handles multiple operations correctly.""" + enricher_calls = [] + + def _fake_enricher_run(self, task, document, operations): + enricher_calls.append((task, operations)) + return document + + monkeypatch.setattr(DoclingEnrichingAgent, "run", _fake_enricher_run) + + with tempfile.TemporaryDirectory() as tmpdir: + library = DoclingLibrary(path=Path(tmpdir)) + library.store(test_document, "test.json") + + source_pairs = [(test_document, "test-doc-id")] + task_query = "Enrich with summaries and entities" + + # Request multiple operations + orchestrator._ensure_enriched( + source_pairs=source_pairs, + library=library, + operations=["summarize", "entities"], + task=task_query, + ) + + # Enricher should be called once with both operations + assert len(enricher_calls) == 1 + assert enricher_calls[0][0] == task_query + assert set(enricher_calls[0][1]) == {"summarize", "entities"} + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])