Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1872,7 +1872,7 @@ async def episode_mentions_reranker(
sorted_uuids, _ = rrf(node_uuids)
scores: dict[str, float] = {}

# Find the shortest path to center node
# Count how often candidate nodes are mentioned by episodes.
results, _, _ = await driver.execute_query(
"""
UNWIND $node_uuids AS node_uuid
Expand All @@ -1888,10 +1888,11 @@ async def episode_mentions_reranker(

for uuid in sorted_uuids:
if uuid not in scores:
scores[uuid] = float('inf')
scores[uuid] = 0

# rerank on shortest distance
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
# Higher mention counts should rank first. Python's stable sort preserves the
# preliminary RRF order for ties.
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid], reverse=True)

return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
Expand Down
24 changes: 15 additions & 9 deletions graphiti_core/utils/maintenance/community_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence

MAX_COMMUNITY_BUILD_CONCURRENCY = 10
LABEL_PROPAGATION_MAX_ITERATION_FACTOR = 10

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -98,10 +99,10 @@ def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
# 4. Continue until no communities change during propagation

community_map = {uuid: i for i, uuid in enumerate(projection.keys())}
max_iterations = max(len(projection) * LABEL_PROPAGATION_MAX_ITERATION_FACTOR, 1)

while True:
no_change = True
new_community_map: dict[str, int] = {}
for _ in range(max_iterations):
changed_count = 0

for uuid, neighbors in projection.items():
curr_community = community_map[uuid]
Expand All @@ -120,15 +121,20 @@ def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
else:
new_community = max(community_candidate, curr_community)

new_community_map[uuid] = new_community

if new_community != curr_community:
no_change = False
changed_count += 1
community_map[uuid] = new_community

if no_change:
if changed_count == 0:
break

community_map = new_community_map
else:
logger.warning(
'Label propagation stopped before convergence',
extra={
'node_count': len(projection),
'max_iterations': max_iterations,
},
)

community_cluster_map = defaultdict(list)
for uuid, community in community_map.items():
Expand Down
32 changes: 32 additions & 0 deletions tests/utils/maintenance/test_community_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from graphiti_core.utils.maintenance.community_operations import Neighbor, label_propagation


def _cluster_sets(clusters: list[list[str]]) -> set[frozenset[str]]:
return {frozenset(cluster) for cluster in clusters}


def test_label_propagation_terminates_for_repeated_entity_edges():
projection = {
'node-a': [Neighbor(node_uuid='node-b', edge_count=2)],
'node-b': [Neighbor(node_uuid='node-a', edge_count=2)],
}

clusters = label_propagation(projection)

assert _cluster_sets(clusters) == {frozenset({'node-a', 'node-b'})}


def test_label_propagation_keeps_disconnected_components_separate():
projection = {
'node-a': [Neighbor(node_uuid='node-b', edge_count=2)],
'node-b': [Neighbor(node_uuid='node-a', edge_count=2)],
'node-c': [Neighbor(node_uuid='node-d', edge_count=3)],
'node-d': [Neighbor(node_uuid='node-c', edge_count=3)],
}

clusters = label_propagation(projection)

assert _cluster_sets(clusters) == {
frozenset({'node-a', 'node-b'}),
frozenset({'node-c', 'node-d'}),
}
58 changes: 57 additions & 1 deletion tests/utils/search/search_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,26 @@

from graphiti_core.nodes import EntityNode
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import hybrid_node_search
from graphiti_core.search.search_utils import episode_mentions_reranker, hybrid_node_search


class _MentionCountDriver:
search_interface = None

def __init__(self, mention_counts: dict[str, int]) -> None:
self.mention_counts = mention_counts

async def execute_query(self, _query: str, **kwargs):
node_uuids = kwargs['node_uuids']
return (
[
{'uuid': node_uuid, 'score': self.mention_counts[node_uuid]}
for node_uuid in node_uuids
if node_uuid in self.mention_counts
],
None,
None,
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -161,3 +180,40 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
mock_similarity_search.assert_called_with(
mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 4
)


@pytest.mark.asyncio
async def test_episode_mentions_reranker_orders_by_highest_mention_count():
driver = _MentionCountDriver(
{
'node-a': 20,
'node-b': 5,
'node-c': 1,
}
)

ranked_uuids, scores = await episode_mentions_reranker(
driver,
[['node-c', 'node-b', 'node-a']],
)

assert ranked_uuids == ['node-a', 'node-b', 'node-c']
assert scores == [20, 5, 1]


@pytest.mark.asyncio
async def test_episode_mentions_reranker_keeps_zero_mention_nodes_last():
driver = _MentionCountDriver(
{
'node-a': 2,
'node-b': 1,
}
)

ranked_uuids, scores = await episode_mentions_reranker(
driver,
[['node-c', 'node-b', 'node-a']],
)

assert ranked_uuids == ['node-a', 'node-b', 'node-c']
assert scores == [2, 1, 0]
Loading