Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 53 additions & 11 deletions docling_agent/agent/enricher.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,8 +860,9 @@ def _detect_key_entities(
)

with self._timed_stage("entities: leaf entities"):
m_leaf = self._create_extraction_session()
self._extract_entities_from_leaf_items(
m=m,
m=m_leaf,
document=hier_doc,
loop_budget=loop_budget,
entity_targets=entity_targets,
Expand All @@ -880,11 +881,20 @@ def _infer_entity_targets(
if not task:
return None

def _validate_entity_target_spec(content: str) -> bool:
def _parse_spec_dict(content: str) -> dict | None:
matches = find_json_dicts(text=content)
if len(matches) != 1:
if len(matches) == 1 and isinstance(matches[0], dict):
return matches[0]
try:
parsed = json.loads(content.strip())
except (json.JSONDecodeError, ValueError):
return None
return parsed if isinstance(parsed, dict) else None

def _validate_entity_target_spec(content: str) -> bool:
spec = _parse_spec_dict(content)
if spec is None:
return False
spec = matches[0]
labels = spec.get("labels", [])
focus_terms = spec.get("focus_terms", [])
generic = spec.get("generic", False)
Expand Down Expand Up @@ -923,8 +933,7 @@ def _validate_entity_target_spec(content: str) -> bool:
if not answer:
return None

specs = find_json_dicts(text=answer)
return specs[0] if specs else None
return _parse_spec_dict(answer)

def _walk_and_extract_entities(
self,
Expand Down Expand Up @@ -1027,6 +1036,8 @@ def _validate_entities(content: str) -> bool:
return False
try:
val = json.loads(match.group(1))
if isinstance(val, dict) and isinstance(val.get("entities"), list):
val = val["entities"]
if not isinstance(val, list):
return False
return all(isinstance(item, dict) and "text" in item for item in val)
Expand All @@ -1035,17 +1046,28 @@ def _validate_entities(content: str) -> bool:

target_clause = ""
rewritten_task = task or ""
allowed_labels: list[str] = []
if entity_targets:
labels = entity_targets.get("labels", [])
allowed_labels = [str(lbl).strip() for lbl in entity_targets.get("labels", []) if str(lbl).strip()]
focus_terms = entity_targets.get("focus_terms", [])
rewritten_task = entity_targets.get("rewritten_task", rewritten_task)
generic = entity_targets.get("generic", False)
if not generic or labels or focus_terms or rewritten_task:
if allowed_labels:
target_clause = (
"\nHARD CONSTRAINT — you MUST obey this:\n"
f"Use ONLY these label values: {allowed_labels}.\n"
"Reject any other entity type. If a candidate mention does not fit one of those labels, "
"omit it entirely — do NOT output it under a different label, do NOT invent new labels.\n"
"Apply the same definitions to every paragraph and every section, including reference lists, "
"captions, and tables.\n"
f"\nExtraction brief (for context only, does not expand the label set):\n{rewritten_task}\n"
f"- focus_terms (hints, not exhaustive): {focus_terms}\n"
)
elif not generic or focus_terms or rewritten_task:
target_clause = (
"\nUse this rewritten extraction brief:\n"
f"{rewritten_task}\n"
"Focus on entities that match the brief, including obvious instances even if the wording differs.\n"
f"- labels: {labels}\n"
f"- focus_terms: {focus_terms}\n"
"If an entity is not relevant to that brief, omit it."
)
Expand Down Expand Up @@ -1073,14 +1095,28 @@ def _validate_entities(content: str) -> bool:

if result:
match = re.search(r"```json\s*(.*?)\s*```", result, re.DOTALL)
if match:
json_text = match.group(1) if match else result.strip()
if json_text:
try:
payload = json.loads(match.group(1))
payload = json.loads(json_text)
if isinstance(payload, dict) and isinstance(payload.get("entities"), list):
payload = payload["entities"]
if not isinstance(payload, list):
payload = []
Comment thread
ceberam marked this conversation as resolved.
allowed_lookup = {lbl.casefold() for lbl in allowed_labels}
mentions: list[EntityMention] = []
dropped_by_label: dict[str, int] = {}
search_start = 0
for item in payload:
if not isinstance(item, dict) or not str(item.get("text", "")).strip():
continue
if allowed_lookup:
raw_label = str(item.get("label", "")).strip()
if raw_label.casefold() not in allowed_lookup:
dropped_by_label[raw_label or "<empty>"] = (
dropped_by_label.get(raw_label or "<empty>", 0) + 1
)
continue
mention = self._make_entity_mention(
item=item,
source_text=text,
Expand All @@ -1090,6 +1126,12 @@ def _validate_entities(content: str) -> bool:
if mention.charspan is not None:
search_start = mention.charspan[1]
mentions.append(mention)
if dropped_by_label:
log_info(
"Dropped mentions with disallowed labels",
count=sum(dropped_by_label.values()),
by_label=dropped_by_label,
)
log_debug("Parsed entities", count=len(mentions))
if mentions:
return EntitiesMetaField(mentions=mentions)
Expand Down
7 changes: 4 additions & 3 deletions docling_agent/agent/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def _ensure_enriched(
source_pairs: list[_SourcePair],
library: DoclingLibrary,
operations: list[str],
task: str = "",
) -> list[_SourcePair]:
"""Run enrichment on documents that are missing the requested enrichments.

Expand All @@ -229,7 +230,7 @@ def _ensure_enriched(

if needed:
log_info(f"Enriching {doc.name!r} with operations={needed}")
enriched_doc = enricher.run(task="", document=doc, operations=needed)
enriched_doc = enricher.run(task=task, document=doc, operations=needed)
# Persist enriched document back to library
library.store(enriched_doc, entry.source_path if entry else "in-memory")
# Update status flags
Expand Down Expand Up @@ -278,7 +279,7 @@ def _run_rag(
) -> DoclingDocument:
log_info(f"_run_rag: query={task.query!r}, docs={len(source_pairs)}")
if task.enrich_before_rag:
source_pairs = self._ensure_enriched(source_pairs, library, operations=["summarize"])
source_pairs = self._ensure_enriched(source_pairs, library, operations=["summarize"], task=task.query)

docs: list[DoclingDocument | Path] = [doc for doc, _ in source_pairs]
rag_agent = DoclingRAGAgent(
Expand Down Expand Up @@ -364,7 +365,7 @@ def _run_enrich(
enriched_pairs.append((enriched_doc, doc_id))
else:
ops: list[str] = list(task.operations)
enriched_pairs = self._ensure_enriched(source_pairs, library, operations=ops)
enriched_pairs = self._ensure_enriched(source_pairs, library, operations=ops, task=task.query)

# Return: single doc → return it directly; multiple → a composite summary doc
if len(enriched_pairs) == 1:
Expand Down
28 changes: 28 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Pytest configuration and shared fixtures for docling-agent tests.

This file is automatically discovered by pytest and makes fixtures
available to all test modules without explicit imports.
"""

import json
from pathlib import Path

import pytest
from docling_core.types.doc.document import DoclingDocument

from .test_utils import MockBackend


@pytest.fixture(scope="module")
def mock_backend() -> MockBackend:
"""Fixture providing a mocked backend instance."""
return MockBackend()


@pytest.fixture(scope="module")
def test_document() -> DoclingDocument:
"""Fixture providing the test document loaded from JSON."""
json_path = Path("tests/data/2408.09869v5.json")
with open(json_path) as f:
doc_dict = json.load(f)
return DoclingDocument.model_validate(doc_dict)
Loading
Loading