diff --git a/graphiti_core/driver/neo4j/operations/search_ops.py b/graphiti_core/driver/neo4j/operations/search_ops.py index 9f0e308e39..40114c8c26 100644 --- a/graphiti_core/driver/neo4j/operations/search_ops.py +++ b/graphiti_core/driver/neo4j/operations/search_ops.py @@ -142,6 +142,10 @@ async def node_similarity_search( filter_queries.append('n.group_id IN $group_ids') filter_params['group_ids'] = group_ids + filter_queries.append( + 'n.name_embedding IS NOT NULL AND size(n.name_embedding) = size($search_vector)' + ) + filter_query = '' if filter_queries: filter_query = ' WHERE ' + (' AND '.join(filter_queries)) @@ -308,6 +312,10 @@ async def edge_similarity_search( filter_params['target_uuid'] = target_node_uuid filter_queries.append('m.uuid = $target_uuid') + filter_queries.append( + 'e.fact_embedding IS NOT NULL AND size(e.fact_embedding) = size($search_vector)' + ) + filter_query = '' if filter_queries: filter_query = ' WHERE ' + (' AND '.join(filter_queries)) @@ -489,12 +497,16 @@ async def community_similarity_search( min_score: float = 0.6, ) -> list[CommunityNode]: query_params: dict[str, Any] = {} + filter_queries = [ + 'c.name_embedding IS NOT NULL AND size(c.name_embedding) = size($search_vector)' + ] - group_filter_query = '' if group_ids is not None: - group_filter_query += ' WHERE c.group_id IN $group_ids' + filter_queries.insert(0, 'c.group_id IN $group_ids') query_params['group_ids'] = group_ids + group_filter_query = ' WHERE ' + (' AND '.join(filter_queries)) + cypher = ( 'MATCH (c:Community)' + group_filter_query diff --git a/tests/utils/search/test_search_security.py b/tests/utils/search/test_search_security.py index 17fc419db6..76ed20566e 100644 --- a/tests/utils/search/test_search_security.py +++ b/tests/utils/search/test_search_security.py @@ -6,7 +6,10 @@ from graphiti_core.driver.driver import GraphProvider from graphiti_core.driver.falkordb.operations.search_ops import _build_falkor_fulltext_query -from graphiti_core.driver.neo4j.operations.search_ops import _build_neo4j_fulltext_query +from graphiti_core.driver.neo4j.operations.search_ops import ( + Neo4jSearchOperations, + _build_neo4j_fulltext_query, +) from graphiti_core.errors import GroupIdValidationError, NodeLabelValidationError from graphiti_core.helpers import get_default_group_id, validate_group_id from graphiti_core.search.search import search @@ -19,6 +22,15 @@ from graphiti_core.search.search_utils import fulltext_query +class RecordingExecutor: + def __init__(self): + self.cypher = '' + + async def execute_query(self, cypher, **kwargs): + self.cypher = cypher + return [], None, None + + def test_search_filters_reject_unsafe_node_labels(): with pytest.raises(ValidationError, match='node_labels must start with a letter or underscore'): SearchFilters(node_labels=['Entity`) WITH n MATCH (x) DETACH DELETE x //']) @@ -109,3 +121,40 @@ def test_falkordb_fulltext_query_escapes_default_group_id(): built = _build_falkor_fulltext_query('hello', [default]) assert '@group_id:"\\_"' in built assert '@group_id:"_"' not in built + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ('method_name', 'expected_guard'), + [ + ( + 'node_similarity_search', + 'n.name_embedding IS NOT NULL AND size(n.name_embedding) = size($search_vector)', + ), + ( + 'edge_similarity_search', + 'e.fact_embedding IS NOT NULL AND size(e.fact_embedding) = size($search_vector)', + ), + ( + 'community_similarity_search', + 'c.name_embedding IS NOT NULL AND size(c.name_embedding) = size($search_vector)', + ), + ], +) +async def test_neo4j_similarity_search_filters_invalid_embeddings_before_cosine( + method_name, expected_guard +): + executor = RecordingExecutor() + operations = Neo4jSearchOperations() + method = getattr(operations, method_name) + + if method_name == 'edge_similarity_search': + await method(executor, [0.1, 0.2], None, None, SearchFilters(), group_ids=['tenant']) + elif method_name == 'community_similarity_search': + await method(executor, [0.1, 0.2], group_ids=['tenant']) + else: + await method(executor, [0.1, 0.2], SearchFilters(), group_ids=['tenant']) + + cosine_index = executor.cypher.index('vector.similarity.cosine') + guard_index = executor.cypher.index(expected_guard) + assert guard_index < cosine_index