Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,7 +1451,10 @@ async def build_communities(
await remove_communities(driver)

community_nodes, community_edges = await build_communities(
driver, self.llm_client, group_ids
driver,
self.llm_client,
group_ids,
max_coroutines=self.max_coroutines,
)

await semaphore_gather(
Expand Down
22 changes: 17 additions & 5 deletions graphiti_core/utils/maintenance/community_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s


async def build_community(
llm_client: LLMClient, community_cluster: list[EntityNode]
llm_client: LLMClient,
community_cluster: list[EntityNode],
max_coroutines: int | None = None,
) -> tuple[CommunityNode, list[CommunityEdge]]:
summaries = [entity.summary for entity in community_cluster]
length = len(summaries)
Expand All @@ -188,7 +190,8 @@ async def build_community(
for left_summary, right_summary in zip(
summaries[: int(length / 2)], summaries[int(length / 2) :], strict=False
)
]
],
max_coroutines=max_coroutines,
)
)
if odd_one_out is not None:
Expand Down Expand Up @@ -217,18 +220,27 @@ async def build_communities(
driver: GraphDriver,
llm_client: LLMClient,
group_ids: list[str] | None,
max_coroutines: int | None = None,
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
community_clusters = await get_community_clusters(driver, group_ids)

semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY)
outer_limit = (
min(max_coroutines, MAX_COMMUNITY_BUILD_CONCURRENCY)
if max_coroutines is not None
else MAX_COMMUNITY_BUILD_CONCURRENCY
)
semaphore = asyncio.Semaphore(outer_limit)

async def limited_build_community(cluster):
async with semaphore:
return await build_community(llm_client, cluster)
return await build_community(
llm_client, cluster, max_coroutines=max_coroutines
)

communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
await semaphore_gather(
*[limited_build_community(cluster) for cluster in community_clusters]
*[limited_build_community(cluster) for cluster in community_clusters],
max_coroutines=outer_limit,
)
)

Expand Down
104 changes: 104 additions & 0 deletions tests/utils/maintenance/test_community_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock

import pytest

from graphiti_core.graphiti import Graphiti
from graphiti_core.utils.maintenance import community_operations as community_ops


@pytest.mark.asyncio
async def test_build_community_passes_max_coroutines_to_summary_gather(monkeypatch):
gather_limits: list[int | None] = []

async def immediate_gather(*aws, max_coroutines=None):
gather_limits.append(max_coroutines)
return [await aw for aw in aws]

monkeypatch.setattr(community_ops, 'semaphore_gather', immediate_gather)
monkeypatch.setattr(community_ops, 'summarize_pair', AsyncMock(side_effect=['ab', 'cd', 'abcd']))
monkeypatch.setattr(
community_ops, 'generate_summary_description', AsyncMock(return_value='community')
)
monkeypatch.setattr(community_ops, 'build_community_edges', MagicMock(return_value=[]))

cluster = [
SimpleNamespace(summary='a', group_id='group-1'),
SimpleNamespace(summary='b', group_id='group-1'),
SimpleNamespace(summary='c', group_id='group-1'),
SimpleNamespace(summary='d', group_id='group-1'),
]

await community_ops.build_community(
MagicMock(),
cluster,
max_coroutines=3,
)

assert gather_limits == [3, 3]


@pytest.mark.asyncio
async def test_build_communities_passes_max_coroutines_to_nested_calls(monkeypatch):
received_limits: list[int | None] = []

async def fake_get_community_clusters(driver, group_ids):
return [[SimpleNamespace(group_id='group-1')], [SimpleNamespace(group_id='group-1')]]

async def fake_build_community(llm_client, cluster, max_coroutines=None):
received_limits.append(max_coroutines)
return (f'community-{len(received_limits)}', [f'edge-{len(received_limits)}'])

async def immediate_gather(*aws, max_coroutines=None):
received_limits.append(max_coroutines)
return [await aw for aw in aws]

monkeypatch.setattr(community_ops, 'get_community_clusters', fake_get_community_clusters)
monkeypatch.setattr(community_ops, 'build_community', fake_build_community)
monkeypatch.setattr(community_ops, 'semaphore_gather', immediate_gather)

community_nodes, community_edges = await community_ops.build_communities(
MagicMock(),
MagicMock(),
None,
max_coroutines=3,
)

assert received_limits[0] == 3
assert received_limits[1:] == [3, 3]
assert community_nodes == ['community-2', 'community-3']
assert community_edges == ['edge-2', 'edge-3']


@pytest.mark.asyncio
async def test_graphiti_build_communities_passes_instance_max_coroutines(monkeypatch):
captured: dict[str, int | None] = {}

async def fake_remove_communities(driver):
return None

async def fake_build_communities(driver, llm_client, group_ids, max_coroutines=None):
captured['max_coroutines'] = max_coroutines
return ([], [])

async def immediate_gather(*aws, max_coroutines=None):
return [await aw for aw in aws]

monkeypatch.setattr('graphiti_core.graphiti.remove_communities', fake_remove_communities)
monkeypatch.setattr('graphiti_core.graphiti.build_communities', fake_build_communities)
monkeypatch.setattr('graphiti_core.graphiti.semaphore_gather', immediate_gather)
monkeypatch.setattr(
'graphiti_core.graphiti.GraphitiClients', lambda **kwargs: SimpleNamespace(**kwargs)
)

graphiti = Graphiti(
graph_driver=MagicMock(),
llm_client=MagicMock(),
embedder=MagicMock(),
cross_encoder=MagicMock(),
max_coroutines=2,
)

await graphiti.build_communities()

assert captured['max_coroutines'] == 2
Loading