diff --git a/packages/graphrag/graphrag/query/context_builder/dynamic_community_selection.py b/packages/graphrag/graphrag/query/context_builder/dynamic_community_selection.py index 0ef22b9176..8a2ef92445 100644 --- a/packages/graphrag/graphrag/query/context_builder/dynamic_community_selection.py +++ b/packages/graphrag/graphrag/query/context_builder/dynamic_community_selection.py @@ -67,8 +67,8 @@ def __init__( if community.short_id in self.reports: self.levels[community.level].append(community.short_id) - # start from root communities (level 0) - self.starting_communities = self.levels["0"] + # start from root communities (level 0), if any + self.starting_communities = self.levels.get("0", []) async def select(self, query: str) -> tuple[list[CommunityReport], dict[str, Any]]: """ diff --git a/tests/unit/query/context_builder/dynamic_community_selection.py b/tests/unit/query/context_builder/dynamic_community_selection.py index ba63f0c774..ff4e295c31 100644 --- a/tests/unit/query/context_builder/dynamic_community_selection.py +++ b/tests/unit/query/context_builder/dynamic_community_selection.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock +import pytest from graphrag.data_model.community import Community from graphrag.data_model.community_report import CommunityReport from graphrag.query.context_builder.dynamic_community_selection import ( @@ -203,3 +204,49 @@ def test_dynamic_community_selection_handles_str_children(): assert child_id in selector.reports, ( f"Child {child} (as '{child_id}') should be found in reports" ) + + +@pytest.mark.asyncio +async def test_dynamic_community_selection_handles_missing_level_zero(): + """Test that DynamicCommunitySelection does not require a level 0 community.""" + communities = [ + Community( + id="comm-1", + short_id="1", + title="Child Community 1", + level="1", + parent="0", + children=[], + ), + ] + reports = [ + CommunityReport( + id="report-1", + short_id="1", + title="Report 1", + community_id="1", + summary="Child 1 summary", + full_content="Child 1 full content", + rank=1.0, + ), + ] + + model = create_mock_model() + tokenizer = create_mock_tokenizer() + + selector = DynamicCommunitySelection( + community_reports=reports, + communities=communities, + model=model, + tokenizer=tokenizer, + threshold=1, + keep_parent=False, + max_level=2, + ) + + assert selector.starting_communities == [] + selected_reports, llm_info = await selector.select("query") + + assert selected_reports == [] + assert llm_info["llm_calls"] == 0 + model.assert_not_called()