From 6ebe91b7c40e311f6d2bab6196981bc77e6b613d Mon Sep 17 00:00:00 2001 From: Ritwij Aryan Parmar Date: Wed, 10 Jun 2026 11:26:14 -0400 Subject: [PATCH] fix: stabilize graph reranking and community propagation --- graphiti_core/search/search_utils.py | 9 +-- .../utils/maintenance/community_operations.py | 24 +++++--- .../maintenance/test_community_operations.py | 32 ++++++++++ tests/utils/search/search_utils_test.py | 58 ++++++++++++++++++- 4 files changed, 109 insertions(+), 14 deletions(-) create mode 100644 tests/utils/maintenance/test_community_operations.py diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 2c9d63882e..98f580d0de 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -1872,7 +1872,7 @@ async def episode_mentions_reranker( sorted_uuids, _ = rrf(node_uuids) scores: dict[str, float] = {} - # Find the shortest path to center node + # Count how often candidate nodes are mentioned by episodes. results, _, _ = await driver.execute_query( """ UNWIND $node_uuids AS node_uuid @@ -1888,10 +1888,11 @@ async def episode_mentions_reranker( for uuid in sorted_uuids: if uuid not in scores: - scores[uuid] = float('inf') + scores[uuid] = 0 - # rerank on shortest distance - sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) + # Higher mention counts should rank first. Python's stable sort preserves the + # preliminary RRF order for ties. + sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid], reverse=True) return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [ scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index 8c96bd79f4..f75806adea 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -18,6 +18,7 @@ from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence MAX_COMMUNITY_BUILD_CONCURRENCY = 10 +LABEL_PROPAGATION_MAX_ITERATION_FACTOR = 10 logger = logging.getLogger(__name__) @@ -98,10 +99,10 @@ def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]: # 4. Continue until no communities change during propagation community_map = {uuid: i for i, uuid in enumerate(projection.keys())} + max_iterations = max(len(projection) * LABEL_PROPAGATION_MAX_ITERATION_FACTOR, 1) - while True: - no_change = True - new_community_map: dict[str, int] = {} + for _ in range(max_iterations): + changed_count = 0 for uuid, neighbors in projection.items(): curr_community = community_map[uuid] @@ -120,15 +121,20 @@ def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]: else: new_community = max(community_candidate, curr_community) - new_community_map[uuid] = new_community - if new_community != curr_community: - no_change = False + changed_count += 1 + community_map[uuid] = new_community - if no_change: + if changed_count == 0: break - - community_map = new_community_map + else: + logger.warning( + 'Label propagation stopped before convergence', + extra={ + 'node_count': len(projection), + 'max_iterations': max_iterations, + }, + ) community_cluster_map = defaultdict(list) for uuid, community in community_map.items(): diff --git a/tests/utils/maintenance/test_community_operations.py b/tests/utils/maintenance/test_community_operations.py new file mode 100644 index 0000000000..6cfff46a22 --- /dev/null +++ b/tests/utils/maintenance/test_community_operations.py @@ -0,0 +1,32 @@ +from graphiti_core.utils.maintenance.community_operations import Neighbor, label_propagation + + +def _cluster_sets(clusters: list[list[str]]) -> set[frozenset[str]]: + return {frozenset(cluster) for cluster in clusters} + + +def test_label_propagation_terminates_for_repeated_entity_edges(): + projection = { + 'node-a': [Neighbor(node_uuid='node-b', edge_count=2)], + 'node-b': [Neighbor(node_uuid='node-a', edge_count=2)], + } + + clusters = label_propagation(projection) + + assert _cluster_sets(clusters) == {frozenset({'node-a', 'node-b'})} + + +def test_label_propagation_keeps_disconnected_components_separate(): + projection = { + 'node-a': [Neighbor(node_uuid='node-b', edge_count=2)], + 'node-b': [Neighbor(node_uuid='node-a', edge_count=2)], + 'node-c': [Neighbor(node_uuid='node-d', edge_count=3)], + 'node-d': [Neighbor(node_uuid='node-c', edge_count=3)], + } + + clusters = label_propagation(projection) + + assert _cluster_sets(clusters) == { + frozenset({'node-a', 'node-b'}), + frozenset({'node-c', 'node-d'}), + } diff --git a/tests/utils/search/search_utils_test.py b/tests/utils/search/search_utils_test.py index 6b97daab1f..70529ed4d4 100644 --- a/tests/utils/search/search_utils_test.py +++ b/tests/utils/search/search_utils_test.py @@ -4,7 +4,26 @@ from graphiti_core.nodes import EntityNode from graphiti_core.search.search_filters import SearchFilters -from graphiti_core.search.search_utils import hybrid_node_search +from graphiti_core.search.search_utils import episode_mentions_reranker, hybrid_node_search + + +class _MentionCountDriver: + search_interface = None + + def __init__(self, mention_counts: dict[str, int]) -> None: + self.mention_counts = mention_counts + + async def execute_query(self, _query: str, **kwargs): + node_uuids = kwargs['node_uuids'] + return ( + [ + {'uuid': node_uuid, 'score': self.mention_counts[node_uuid]} + for node_uuid in node_uuids + if node_uuid in self.mention_counts + ], + None, + None, + ) @pytest.mark.asyncio @@ -161,3 +180,40 @@ async def test_hybrid_node_search_with_limit_and_duplicates(): mock_similarity_search.assert_called_with( mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 4 ) + + +@pytest.mark.asyncio +async def test_episode_mentions_reranker_orders_by_highest_mention_count(): + driver = _MentionCountDriver( + { + 'node-a': 20, + 'node-b': 5, + 'node-c': 1, + } + ) + + ranked_uuids, scores = await episode_mentions_reranker( + driver, + [['node-c', 'node-b', 'node-a']], + ) + + assert ranked_uuids == ['node-a', 'node-b', 'node-c'] + assert scores == [20, 5, 1] + + +@pytest.mark.asyncio +async def test_episode_mentions_reranker_keeps_zero_mention_nodes_last(): + driver = _MentionCountDriver( + { + 'node-a': 2, + 'node-b': 1, + } + ) + + ranked_uuids, scores = await episode_mentions_reranker( + driver, + [['node-c', 'node-b', 'node-a']], + ) + + assert ranked_uuids == ['node-a', 'node-b', 'node-c'] + assert scores == [2, 1, 0]