Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions graphiti_core/driver/neo4j/operations/search_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ async def node_similarity_search(
filter_queries.append('n.group_id IN $group_ids')
filter_params['group_ids'] = group_ids

filter_queries.append(
'n.name_embedding IS NOT NULL AND size(n.name_embedding) = size($search_vector)'
)

filter_query = ''
if filter_queries:
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
Expand Down Expand Up @@ -308,6 +312,10 @@ async def edge_similarity_search(
filter_params['target_uuid'] = target_node_uuid
filter_queries.append('m.uuid = $target_uuid')

filter_queries.append(
'e.fact_embedding IS NOT NULL AND size(e.fact_embedding) = size($search_vector)'
)

filter_query = ''
if filter_queries:
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
Expand Down Expand Up @@ -489,12 +497,16 @@ async def community_similarity_search(
min_score: float = 0.6,
) -> list[CommunityNode]:
query_params: dict[str, Any] = {}
filter_queries = [
'c.name_embedding IS NOT NULL AND size(c.name_embedding) = size($search_vector)'
]

group_filter_query = ''
if group_ids is not None:
group_filter_query += ' WHERE c.group_id IN $group_ids'
filter_queries.insert(0, 'c.group_id IN $group_ids')
query_params['group_ids'] = group_ids

group_filter_query = ' WHERE ' + (' AND '.join(filter_queries))

cypher = (
'MATCH (c:Community)'
+ group_filter_query
Expand Down
51 changes: 50 additions & 1 deletion tests/utils/search/test_search_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from graphiti_core.driver.driver import GraphProvider
from graphiti_core.driver.falkordb.operations.search_ops import _build_falkor_fulltext_query
from graphiti_core.driver.neo4j.operations.search_ops import _build_neo4j_fulltext_query
from graphiti_core.driver.neo4j.operations.search_ops import (
Neo4jSearchOperations,
_build_neo4j_fulltext_query,
)
from graphiti_core.errors import GroupIdValidationError, NodeLabelValidationError
from graphiti_core.helpers import get_default_group_id, validate_group_id
from graphiti_core.search.search import search
Expand All @@ -19,6 +22,15 @@
from graphiti_core.search.search_utils import fulltext_query


class RecordingExecutor:
def __init__(self):
self.cypher = ''

async def execute_query(self, cypher, **kwargs):
self.cypher = cypher
return [], None, None


def test_search_filters_reject_unsafe_node_labels():
with pytest.raises(ValidationError, match='node_labels must start with a letter or underscore'):
SearchFilters(node_labels=['Entity`) WITH n MATCH (x) DETACH DELETE x //'])
Expand Down Expand Up @@ -109,3 +121,40 @@ def test_falkordb_fulltext_query_escapes_default_group_id():
built = _build_falkor_fulltext_query('hello', [default])
assert '@group_id:"\\_"' in built
assert '@group_id:"_"' not in built


@pytest.mark.asyncio
@pytest.mark.parametrize(
('method_name', 'expected_guard'),
[
(
'node_similarity_search',
'n.name_embedding IS NOT NULL AND size(n.name_embedding) = size($search_vector)',
),
(
'edge_similarity_search',
'e.fact_embedding IS NOT NULL AND size(e.fact_embedding) = size($search_vector)',
),
(
'community_similarity_search',
'c.name_embedding IS NOT NULL AND size(c.name_embedding) = size($search_vector)',
),
],
)
async def test_neo4j_similarity_search_filters_invalid_embeddings_before_cosine(
method_name, expected_guard
):
executor = RecordingExecutor()
operations = Neo4jSearchOperations()
method = getattr(operations, method_name)

if method_name == 'edge_similarity_search':
await method(executor, [0.1, 0.2], None, None, SearchFilters(), group_ids=['tenant'])
elif method_name == 'community_similarity_search':
await method(executor, [0.1, 0.2], group_ids=['tenant'])
else:
await method(executor, [0.1, 0.2], SearchFilters(), group_ids=['tenant'])

cosine_index = executor.cypher.index('vector.similarity.cosine')
guard_index = executor.cypher.index(expected_guard)
assert guard_index < cosine_index
Loading