diff --git a/backend/apps/ai/admin.py b/backend/apps/ai/admin.py index d0852aeb48..2b66a7ff78 100644 --- a/backend/apps/ai/admin.py +++ b/backend/apps/ai/admin.py @@ -4,6 +4,7 @@ from apps.ai.models.chunk import Chunk from apps.ai.models.context import Context +from apps.ai.models.semantic_cache import SemanticCache class ChunkAdmin(admin.ModelAdmin): @@ -32,5 +33,20 @@ class ContextAdmin(admin.ModelAdmin): search_fields = ("content", "source") +class SemanticCacheAdmin(admin.ModelAdmin): + """Admin for SemanticCache model.""" + + list_display = ( + "confidence", + "id", + "intent", + "nest_created_at", + "query_text", + ) + list_filter = ("intent",) + search_fields = ("query_text", "response_text") + + admin.site.register(Chunk, ChunkAdmin) admin.site.register(Context, ContextAdmin) +admin.site.register(SemanticCache, SemanticCacheAdmin) diff --git a/backend/apps/ai/common/crewai_config.py b/backend/apps/ai/common/crewai_config.py new file mode 100644 index 0000000000..7b2beacb81 --- /dev/null +++ b/backend/apps/ai/common/crewai_config.py @@ -0,0 +1,12 @@ +"""CrewAI assistant configuration.""" + +from dataclasses import dataclass + + +@dataclass +class CrewAIConfig: + """CrewAI assistant configuration.""" + + semantic_cache_enabled: bool = True + semantic_cache_similarity_threshold: float = 0.95 + semantic_cache_ttl_seconds: int = 86400 # 24 hours diff --git a/backend/apps/ai/flows/assistant.py b/backend/apps/ai/flows/assistant.py index 6283afb996..e2b09eb136 100644 --- a/backend/apps/ai/flows/assistant.py +++ b/backend/apps/ai/flows/assistant.py @@ -14,6 +14,7 @@ from apps.ai.flows.collaborative import handle_collaborative_query from apps.ai.query_analyzer import analyze_query from apps.ai.router import route +from apps.ai.semantic_cache import get_cached_response, store_cached_response from apps.common.open_ai import OpenAi from apps.slack.constants import ( OWASP_COMMUNITY_CHANNEL_ID, @@ -265,6 +266,13 @@ def process_query( # noqa: PLR0911 is_channel_suggestion=True, ) + # Check semantic cache + try: + if (cached := get_cached_response(query)) is not None: + return cached + except Exception: + logger.exception("Semantic cache lookup failed, proceeding without cache") + # Step 2: Analyze query complexity before routing try: query_analysis = analyze_query(query) @@ -283,7 +291,15 @@ def process_query( # noqa: PLR0911 # Step 3: Use collaborative flow for complex query if not query_analysis["is_simple"] and len(query_analysis["sub_queries"]) > 1: try: - return handle_collaborative_query(query, query_analysis["sub_queries"]) + if response := handle_collaborative_query(query, query_analysis["sub_queries"]): + try: + store_cached_response( + query=query, + response=response, + ) + except Exception: + logger.exception("Failed to store semantic cache entry") + return response except Exception: logger.exception( "Collaborative flow failed, falling back to single agent: %s", query @@ -383,11 +399,22 @@ def process_query( # noqa: PLR0911 agent = agent_factory() # Step 8: Execute task with agent - return execute_task(agent, query) + if response := execute_task(agent, query): + try: + store_cached_response( + query=query, + response=response, + intent=intent, + confidence=confidence, + ) + except Exception: + logger.exception("Failed to store semantic cache entry") except Exception: logger.exception("Failed to process query: %s", query) return get_fallback_response() + else: + return response def execute_task( diff --git a/backend/apps/ai/migrations/0011_semanticcache.py b/backend/apps/ai/migrations/0011_semanticcache.py new file mode 100644 index 0000000000..003b5433ee --- /dev/null +++ b/backend/apps/ai/migrations/0011_semanticcache.py @@ -0,0 +1,43 @@ +# Generated by Django 6.0.3 on 2026-04-13 14:06 + +import pgvector.django.vector +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ai", "0010_alter_context_unique_together"), + ] + + operations = [ + migrations.CreateModel( + name="SemanticCache", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("nest_created_at", models.DateTimeField(auto_now_add=True)), + ("nest_updated_at", models.DateTimeField(auto_now=True)), + ("confidence", models.FloatField(default=0.0, verbose_name="Confidence")), + ( + "intent", + models.CharField(blank=True, default="", max_length=50, verbose_name="Intent"), + ), + ( + "query_embedding", + pgvector.django.vector.VectorField( + dimensions=1536, verbose_name="Query Embedding" + ), + ), + ("query_text", models.TextField(verbose_name="Query Text")), + ("response_text", models.TextField(verbose_name="Response Text")), + ], + options={ + "verbose_name": "Semantic Cache", + "db_table": "ai_semantic_cache", + }, + ), + ] diff --git a/backend/apps/ai/models/__init__.py b/backend/apps/ai/models/__init__.py index 05907f2861..d6a9796d7f 100644 --- a/backend/apps/ai/models/__init__.py +++ b/backend/apps/ai/models/__init__.py @@ -1 +1,2 @@ from .chunk import Chunk +from .semantic_cache import SemanticCache diff --git a/backend/apps/ai/models/chunk.py b/backend/apps/ai/models/chunk.py index 4e7ec48d0b..b005255244 100644 --- a/backend/apps/ai/models/chunk.py +++ b/backend/apps/ai/models/chunk.py @@ -7,6 +7,8 @@ from apps.common.models import BulkSaveModel, TimestampedModel from apps.common.utils import truncate +EMBEDDING_DIMENSIONS = 1536 + class Chunk(TimestampedModel): """AI Chunk model for storing text chunks with embeddings.""" @@ -19,7 +21,7 @@ class Meta: unique_together = ("context", "text") context = models.ForeignKey(Context, on_delete=models.CASCADE, related_name="chunks") - embedding = VectorField(verbose_name="Embedding", dimensions=1536) + embedding = VectorField(verbose_name="Embedding", dimensions=EMBEDDING_DIMENSIONS) text = models.TextField(verbose_name="Text") def __str__(self): diff --git a/backend/apps/ai/models/semantic_cache.py b/backend/apps/ai/models/semantic_cache.py new file mode 100644 index 0000000000..8f475fa03c --- /dev/null +++ b/backend/apps/ai/models/semantic_cache.py @@ -0,0 +1,121 @@ +"""AI app semantic cache model.""" + +import logging +from datetime import UTC, datetime, timedelta + +from django.db import models +from pgvector.django import VectorField +from pgvector.django.functions import CosineDistance + +from apps.ai.models.chunk import EMBEDDING_DIMENSIONS +from apps.common.models import TimestampedModel +from apps.common.utils import truncate + +logger = logging.getLogger(__name__) + + +# TODO(rudransh-shrivastava): Add Cache Invalidation cron job. +class SemanticCache(TimestampedModel): + """Semantic cache model for storing query-response pairs with embeddings.""" + + class Meta: + """Model options.""" + + db_table = "ai_semantic_cache" + verbose_name = "Semantic Cache" + + confidence = models.FloatField(verbose_name="Confidence", default=0.0) + intent = models.CharField(verbose_name="Intent", blank=True, default="", max_length=50) + query_embedding = VectorField(verbose_name="Query Embedding", dimensions=EMBEDDING_DIMENSIONS) + query_text = models.TextField(verbose_name="Query Text") + response_text = models.TextField(verbose_name="Response Text") + + def __str__(self): + """Human readable representation.""" + return f"SemanticCache {self.id}: {truncate(self.query_text, 50)}" + + @staticmethod + def get_cached_response( + query: str, + *, + similarity_threshold: float = 0.95, + ttl_seconds: int = 86400, # 24 hours + ) -> str | None: + """Look up semantically similar cached response. + + Args: + query (str): User query text. + similarity_threshold (float): Minimum cosine similarity (0.0-1.0). + ttl_seconds (int): Maximum age of cached entries in seconds. + + Returns: + Cached response string if found, None otherwise. + + """ + from apps.ai.embeddings.factory import get_embedder # noqa: PLC0415 + + ttl_cutoff = datetime.now(UTC) - timedelta(seconds=ttl_seconds) + max_distance = 1.0 - similarity_threshold + + result = ( + SemanticCache.objects.filter(nest_created_at__gte=ttl_cutoff) + .annotate( + distance=CosineDistance("query_embedding", get_embedder().embed_query(query)) + ) + .filter(distance__lte=max_distance) + .order_by("distance") + .first() + ) + + if result is not None: + logger.info( + "Semantic cache hit", + extra={ + "cache_id": result.id, + "distance": float(result.distance), + "query_preview": query[:100], + }, + ) + return result.response_text + + return None + + @staticmethod + def store_response( + query: str, + response: str, + intent: str = "", + confidence: float = 0.0, + ) -> "SemanticCache": + """Store query-response pair in semantic cache. + + Args: + query (str): Original query text. + response (str): Generated response text. + intent (str): Classified intent for the query. + confidence (float): Router confidence score. + + Returns: + Created SemanticCache instance. + + """ + from apps.ai.embeddings.factory import get_embedder # noqa: PLC0415 + + entry = SemanticCache( + query_text=query, + query_embedding=get_embedder().embed_query(query), + response_text=response, + intent=intent, + confidence=confidence, + ) + entry.save() + + logger.info( + "Semantic cache stored", + extra={ + "cache_id": entry.id, + "intent": intent, + "query_preview": query[:100], + }, + ) + return entry diff --git a/backend/apps/ai/semantic_cache.py b/backend/apps/ai/semantic_cache.py new file mode 100644 index 0000000000..d262793eaf --- /dev/null +++ b/backend/apps/ai/semantic_cache.py @@ -0,0 +1,56 @@ +"""Semantic cache service for AI query responses.""" + +from apps.ai.common.crewai_config import CrewAIConfig +from apps.ai.models.semantic_cache import SemanticCache + +_config = CrewAIConfig() + + +def get_cached_response(query: str) -> str | None: + """Look up semantically similar cached response. + + Args: + query (str): User query text. + + Returns: + str: Cached response string if found within similarity threshold and TTL, + None otherwise. + + """ + if not _config.semantic_cache_enabled: + return None + + return SemanticCache.get_cached_response( + query, + similarity_threshold=_config.semantic_cache_similarity_threshold, + ttl_seconds=_config.semantic_cache_ttl_seconds, + ) + + +def store_cached_response( + query: str, + response: str, + intent: str = "", + confidence: float = 0.0, +) -> SemanticCache | None: + """Store query-response pair in semantic cache. + + Args: + query (str): Original query text. + response (str): Generated response text. + intent (str): Classified intent for the query. + confidence (float): Router confidence score. + + Returns: + Created SemanticCache instance. + + """ + if not _config.semantic_cache_enabled: + return None + + return SemanticCache.store_response( + query=query, + response=response, + intent=intent, + confidence=confidence, + ) diff --git a/backend/tests/unit/apps/ai/common/crewai_config_test.py b/backend/tests/unit/apps/ai/common/crewai_config_test.py new file mode 100644 index 0000000000..d2cdd032c7 --- /dev/null +++ b/backend/tests/unit/apps/ai/common/crewai_config_test.py @@ -0,0 +1,25 @@ +"""Tests for CrewAI configuration.""" + +import math + +from apps.ai.common.crewai_config import CrewAIConfig + + +class TestCrewAIConfig: + def test_default_values(self): + config = CrewAIConfig() + + assert config.semantic_cache_enabled is True + assert math.isclose(config.semantic_cache_similarity_threshold, 0.95) + assert config.semantic_cache_ttl_seconds == 86400 + + def test_custom_values(self): + config = CrewAIConfig( + semantic_cache_enabled=False, + semantic_cache_similarity_threshold=0.8, + semantic_cache_ttl_seconds=3600, + ) + + assert config.semantic_cache_enabled is False + assert math.isclose(config.semantic_cache_similarity_threshold, 0.8) + assert config.semantic_cache_ttl_seconds == 3600 diff --git a/backend/tests/unit/apps/ai/flows/assistant_test.py b/backend/tests/unit/apps/ai/flows/assistant_test.py index ab66d98477..cc4e563c15 100644 --- a/backend/tests/unit/apps/ai/flows/assistant_test.py +++ b/backend/tests/unit/apps/ai/flows/assistant_test.py @@ -2,7 +2,8 @@ from unittest.mock import MagicMock, patch -from apps.ai.flows.assistant import process_query +from apps.ai.flows.assistant import normalize_channel_id, process_query +from apps.slack.constants import OWASP_COMMUNITY_CHANNEL_ID class TestProcessQueryImageEnrichment: @@ -29,7 +30,7 @@ def test_process_query_with_images_enriches_query( images = ["data:image/png;base64,abc123"] - process_query("What is this?", images=images) + process_query("What is this?", images=images) # NOSONAR duplicate string literal mock_openai_instance.set_images.assert_called_once_with(images) mock_openai_instance.complete.assert_called_once() @@ -77,7 +78,7 @@ def test_process_query_without_images_skips_vision( mock_route.return_value = {"intent": "rag", "confidence": 0.9} mock_execute_task.return_value = "Response" - process_query("What is OWASP?") + process_query("What is OWASP?") # NOSONAR duplicate string literal mock_openai_cls.assert_not_called() @@ -138,3 +139,181 @@ def test_low_confidence_rag_triggers_clarification_when_policy_removed( mock_create_clarify.assert_called_once() mock_execute_task.assert_called_once_with(clarification_agent, "Is this covered by OWASP?") assert res == "Clarify: please specify 'this'" + + +class TestCache: + @patch("apps.ai.flows.assistant.execute_task") + @patch("apps.ai.flows.assistant.route") + @patch("apps.ai.flows.assistant.analyze_query") + @patch("apps.ai.flows.assistant.get_cached_response") + def test_cache_hit_returns_early( + self, mock_get_cached_response, mock_analyze_query, mock_route, mock_execute_task + ): + """Semantic cache hit should return cached response without routing.""" + mock_get_cached_response.return_value = "cached answer" + + result = process_query("What is OWASP?") + + assert result == "cached answer" + mock_get_cached_response.assert_called_once_with("What is OWASP?") + mock_analyze_query.assert_not_called() + mock_route.assert_not_called() + mock_execute_task.assert_not_called() + + @patch("apps.ai.flows.assistant.analyze_query") + @patch("apps.ai.flows.assistant.route") + @patch("apps.ai.flows.assistant.execute_task") + @patch("apps.ai.flows.assistant.get_cached_response") + def test_cache_miss_proceeds_to_routing( + self, mock_get_cached_response, mock_execute_task, mock_route, mock_analyze_query + ): + """Cache miss should proceed with normal routing.""" + mock_get_cached_response.return_value = None + mock_analyze_query.return_value = {"is_simple": True, "sub_queries": []} + mock_route.return_value = {"intent": "rag", "confidence": 0.9} + mock_execute_task.return_value = "agent response 1" + + result = process_query("What is OWASP?") + + assert result == "agent response 1" + mock_get_cached_response.assert_called_once() + mock_route.assert_called_once() + + @patch("apps.ai.flows.assistant.analyze_query") + @patch("apps.ai.flows.assistant.route") + @patch("apps.ai.flows.assistant.execute_task") + @patch("apps.ai.flows.assistant.get_cached_response") + def test_cache_lookup_exception_proceeds_normally( + self, mock_get_cached_response, mock_execute_task, mock_route, mock_analyze_query + ): + """Cache lookup failure should log and proceed without crashing.""" + mock_get_cached_response.side_effect = Exception("Redis down") + mock_analyze_query.return_value = {"is_simple": True, "sub_queries": []} + mock_route.return_value = {"intent": "rag", "confidence": 0.9} + mock_execute_task.return_value = "agent response 2" + + result = process_query("What is OWASP?") + + assert result == "agent response 2" + + @patch("apps.ai.flows.assistant.store_cached_response") + @patch("apps.ai.flows.assistant.analyze_query") + @patch("apps.ai.flows.assistant.route") + @patch("apps.ai.flows.assistant.execute_task") + @patch("apps.ai.flows.assistant.get_cached_response") + def test_response_stored_in_cache_after_execution( + self, + mock_get_cached_response, + mock_execute_task, + mock_route, + mock_analyze_query, + mock_store_cached, + ): + """Successful agent response should be stored in semantic cache.""" + mock_get_cached_response.return_value = None + mock_analyze_query.return_value = {"is_simple": True, "sub_queries": []} + mock_route.return_value = {"intent": "rag", "confidence": 0.9} + mock_execute_task.return_value = "agent response 3" + + process_query("What is OWASP?") + + mock_store_cached.assert_called_once_with( + query="What is OWASP?", + response="agent response 3", + intent="rag", + confidence=0.9, + ) + + @patch("apps.ai.flows.assistant.store_cached_response") + @patch("apps.ai.flows.assistant.analyze_query") + @patch("apps.ai.flows.assistant.route") + @patch("apps.ai.flows.assistant.execute_task") + @patch("apps.ai.flows.assistant.get_cached_response") + def test_cache_store_failure_still_returns_response( + self, + mock_get_cached_response, + mock_execute_task, + mock_route, + mock_analyze_query, + mock_store_cached, + ): + """Cache store exception must not prevent response from being returned.""" + mock_get_cached_response.return_value = None + mock_analyze_query.return_value = {"is_simple": True, "sub_queries": []} + mock_route.return_value = {"intent": "rag", "confidence": 0.9} + mock_execute_task.return_value = "agent response" # NOSONAR duplicate string literal + mock_store_cached.side_effect = Exception("DB write failed") + + result = process_query("What is OWASP?") + + assert result == "agent response" + mock_store_cached.assert_called_once_with( + query="What is OWASP?", + response="agent response", + intent="rag", + confidence=0.9, + ) + + @patch("apps.ai.flows.assistant.store_cached_response") + @patch("apps.ai.flows.assistant.handle_collaborative_query") + @patch("apps.ai.flows.assistant.analyze_query") + @patch("apps.ai.flows.assistant.get_cached_response") + def test_collaborative_flow_stores_in_cache( + self, + mock_get_cached_response, + mock_analyze_query, + mock_collab, + mock_store_cached, + ): + """Collaborative flow response should be stored in cache.""" + mock_get_cached_response.return_value = None + mock_analyze_query.return_value = { + "is_simple": False, + "sub_queries": ["sub1", "sub2"], + } + mock_collab.return_value = "collaborative response" # NOSONAR duplicate string literal + + result = process_query("Complex multi-part question") + + assert result == "collaborative response" + mock_store_cached.assert_called_once_with( + query="Complex multi-part question", + response="collaborative response", + ) + + @patch("apps.ai.flows.assistant.logger") + @patch("apps.ai.flows.assistant.get_cached_response") + @patch("apps.ai.flows.assistant.create_channel_agent") + @patch("apps.ai.flows.assistant.execute_task") + def test_cache_skipped_for_community_channel( + self, mock_execute_task, mock_create_channel, mock_get_cached_response, mock_logger + ): + """Cache lookup should be skipped for owasp-community channel queries.""" + mock_create_channel.return_value = MagicMock() + mock_execute_task.return_value = "channel response 5" + + result = process_query( + "Where should I ask about ZAP?", + channel_id=normalize_channel_id(OWASP_COMMUNITY_CHANNEL_ID), + ) + + assert result == "channel response 5" + mock_get_cached_response.assert_not_called() + + @patch("apps.ai.flows.assistant.logger") + @patch("apps.ai.flows.assistant.store_cached_response") + @patch("apps.ai.flows.assistant.create_channel_agent") + @patch("apps.ai.flows.assistant.execute_task") + def test_community_channel_response_not_cached( + self, mock_execute_task, mock_create_channel, mock_store_cached_response, mock_logger + ): + """Community channel responses should not be stored in cache.""" + mock_create_channel.return_value = MagicMock() + mock_execute_task.return_value = "channel response" + + process_query( + "Where should I ask about ZAP?", + channel_id=normalize_channel_id(OWASP_COMMUNITY_CHANNEL_ID), + ) + + mock_store_cached_response.assert_not_called() diff --git a/backend/tests/unit/apps/ai/models/semantic_cache_test.py b/backend/tests/unit/apps/ai/models/semantic_cache_test.py new file mode 100644 index 0000000000..6f7363fbf4 --- /dev/null +++ b/backend/tests/unit/apps/ai/models/semantic_cache_test.py @@ -0,0 +1,155 @@ +"""Tests for SemanticCache model.""" + +import math +from unittest.mock import Mock, patch + +from apps.ai.models.chunk import EMBEDDING_DIMENSIONS +from apps.ai.models.semantic_cache import SemanticCache +from apps.common.models import TimestampedModel + + +class TestSemanticCacheModel: + def test_meta_class_attributes(self): + assert SemanticCache._meta.db_table == "ai_semantic_cache" + assert SemanticCache._meta.verbose_name == "Semantic Cache" + + def test_inheritance_from_timestamped_model(self): + assert issubclass(SemanticCache, TimestampedModel) + + def test_confidence_field_properties(self): + field = SemanticCache._meta.get_field("confidence") + assert math.isclose(field.default, 0.0) + assert field.verbose_name == "Confidence" + + def test_intent_field_properties(self): + field = SemanticCache._meta.get_field("intent") + assert field.max_length == 50 + assert field.blank is True + assert field.default == "" + assert field.verbose_name == "Intent" + + def test_query_embedding_field_properties(self): + field = SemanticCache._meta.get_field("query_embedding") + assert field.verbose_name == "Query Embedding" + assert field.dimensions == EMBEDDING_DIMENSIONS + + def test_query_text_field_properties(self): + field = SemanticCache._meta.get_field("query_text") + assert field.verbose_name == "Query Text" + + def test_response_text_field_properties(self): + field = SemanticCache._meta.get_field("response_text") + assert field.verbose_name == "Response Text" + + def test_str_method(self): + cache = SemanticCache() + cache.id = 42 + cache.query_text = "What is OWASP Top 10?" + result = str(cache) + assert "SemanticCache 42" in result + assert "What is OWASP Top 10?" in result + + def test_str_method_truncates_long_query(self): + cache = SemanticCache() + cache.id = 1 + cache.query_text = "A" * 200 + result = str(cache) + assert "SemanticCache 1" in result + assert len(result) < 200 + + +class TestSemanticCacheGetCachedResponse: + @patch("apps.ai.embeddings.factory.get_embedder") + @patch("apps.ai.models.semantic_cache.SemanticCache.objects") + def test_cache_hit_returns_response_text(self, mock_objects, mock_get_embedder): + mock_embedder = Mock() + mock_embedder.embed_query.return_value = [0.1] * EMBEDDING_DIMENSIONS + mock_get_embedder.return_value = mock_embedder + + mock_result = Mock() + mock_result.id = 1 + mock_result.distance = 0.02 + mock_result.response_text = "Cached response" + + ( + mock_objects.filter.return_value.annotate.return_value.filter + ).return_value.order_by.return_value.first.return_value = mock_result + + result = SemanticCache.get_cached_response("test query 1") + + assert result == "Cached response" + mock_embedder.embed_query.assert_called_once_with("test query 1") + + @patch("apps.ai.embeddings.factory.get_embedder") + @patch("apps.ai.models.semantic_cache.SemanticCache.objects") + def test_cache_miss_returns_none(self, mock_objects, mock_get_embedder): + mock_embedder = Mock() + mock_embedder.embed_query.return_value = [0.1] * EMBEDDING_DIMENSIONS + mock_get_embedder.return_value = mock_embedder + + ( + mock_objects.filter.return_value.annotate.return_value.filter + ).return_value.order_by.return_value.first.return_value = None + + result = SemanticCache.get_cached_response("unknown query") + + assert result is None + + @patch("apps.ai.embeddings.factory.get_embedder") + @patch("apps.ai.models.semantic_cache.SemanticCache.objects") + def test_custom_similarity_threshold(self, mock_objects, mock_get_embedder): + mock_embedder = Mock() + mock_embedder.embed_query.return_value = [0.1] * EMBEDDING_DIMENSIONS + mock_get_embedder.return_value = mock_embedder + + ( + mock_objects.filter.return_value.annotate.return_value.filter.return_value + ).order_by.return_value.first.return_value = None + + SemanticCache.get_cached_response("test", similarity_threshold=0.8, ttl_seconds=3600) + + mock_objects.filter.assert_called_once() + mock_embedder.embed_query.assert_called_once_with("test") + + second_filter = mock_objects.filter.return_value.annotate.return_value.filter + second_filter.assert_called_once() + filter_kwargs = second_filter.call_args[1] + assert math.isclose(filter_kwargs["distance__lte"], 0.2) + + +class TestSemanticCacheStoreResponse: + @patch("apps.ai.embeddings.factory.get_embedder") + @patch("apps.ai.models.semantic_cache.SemanticCache.save") + def test_store_response_creates_and_saves(self, mock_save, mock_get_embedder): + mock_embedder = Mock() + mock_embedder.embed_query.return_value = [0.1] * EMBEDDING_DIMENSIONS + mock_get_embedder.return_value = mock_embedder + + result = SemanticCache.store_response( + query="test query 2", # NOSONAR duplicate string literal + response="test response 1", + intent="rag", + confidence=0.9, + ) + + mock_save.assert_called_once() + mock_embedder.embed_query.assert_called_once_with("test query 2") + assert result.query_text == "test query 2" + assert result.response_text == "test response 1" + assert result.intent == "rag" + assert math.isclose(result.confidence, 0.9) + + @patch("apps.ai.embeddings.factory.get_embedder") + @patch("apps.ai.models.semantic_cache.SemanticCache.save") + def test_store_response_default_intent_and_confidence(self, mock_save, mock_get_embedder): + mock_embedder = Mock() + mock_embedder.embed_query.return_value = [0.1] * EMBEDDING_DIMENSIONS + mock_get_embedder.return_value = mock_embedder + + result = SemanticCache.store_response( + query="test query", + response="test response", + ) + + assert result.intent == "" + assert math.isclose(result.confidence, 0.0) diff --git a/backend/tests/unit/apps/ai/semantic_cache_test.py b/backend/tests/unit/apps/ai/semantic_cache_test.py new file mode 100644 index 0000000000..4592dc1443 --- /dev/null +++ b/backend/tests/unit/apps/ai/semantic_cache_test.py @@ -0,0 +1,87 @@ +"""Tests for semantic cache service.""" + +from unittest.mock import Mock, patch + +from apps.ai.models.semantic_cache import SemanticCache +from apps.ai.semantic_cache import get_cached_response, store_cached_response + + +class TestGetCachedResponse: + @patch("apps.ai.semantic_cache._config") + @patch("apps.ai.semantic_cache.SemanticCache.get_cached_response") + def test_delegates_to_model_when_enabled(self, mock_model_get, mock_config): + mock_config.semantic_cache_enabled = True + mock_config.semantic_cache_similarity_threshold = 0.95 + mock_config.semantic_cache_ttl_seconds = 86400 + mock_model_get.return_value = "cached response" + + result = get_cached_response("test query 1") + + assert result == "cached response" + mock_model_get.assert_called_once_with( + "test query 1", + similarity_threshold=0.95, + ttl_seconds=86400, + ) + + @patch("apps.ai.semantic_cache._config") + @patch("apps.ai.semantic_cache.SemanticCache.get_cached_response") + def test_returns_none_when_disabled(self, mock_model_get, mock_config): + mock_config.semantic_cache_enabled = False + + result = get_cached_response("test query 2") + + assert result is None + mock_model_get.assert_not_called() + + +class TestStoreCachedResponse: + @patch("apps.ai.semantic_cache._config") + @patch("apps.ai.semantic_cache.SemanticCache.store_response") + def test_delegates_to_model_when_enabled(self, mock_model_store, mock_config): + mock_config.semantic_cache_enabled = True + mock_entry = Mock(spec=SemanticCache) + mock_model_store.return_value = mock_entry + + result = store_cached_response( + query="test query 3", + response="test response 1", + intent="rag", + confidence=0.9, + ) + + assert result == mock_entry + mock_model_store.assert_called_once_with( + query="test query 3", + response="test response 1", + intent="rag", + confidence=0.9, + ) + + @patch("apps.ai.semantic_cache._config") + @patch("apps.ai.semantic_cache.SemanticCache.store_response") + def test_returns_none_when_disabled(self, mock_model_store, mock_config): + mock_config.semantic_cache_enabled = False + + result = store_cached_response( + query="test query", + response="test response", + ) + + assert result is None + mock_model_store.assert_not_called() + + @patch("apps.ai.semantic_cache._config") + @patch("apps.ai.semantic_cache.SemanticCache.store_response") + def test_default_intent_and_confidence(self, mock_model_store, mock_config): + mock_config.semantic_cache_enabled = True + mock_model_store.return_value = Mock(spec=SemanticCache) + + store_cached_response(query="test", response="resp") + + mock_model_store.assert_called_once_with( + query="test", + response="resp", + intent="", + confidence=0.0, + ) diff --git a/docker-compose/local/compose.yaml b/docker-compose/local/compose.yaml index 6e4e1cffb1..f7d25eb767 100644 --- a/docker-compose/local/compose.yaml +++ b/docker-compose/local/compose.yaml @@ -70,7 +70,7 @@ services: networks: - nest-network volumes: - - db-data-nestbot:/var/lib/postgresql/data + - db-data-nestbot-test:/var/lib/postgresql/data docs: container_name: nest-docs @@ -186,7 +186,7 @@ networks: volumes: backend-venv-nestbot: cache-data-nestbot: - db-data-nestbot: + db-data-nestbot-test: docs-venv: frontend-next: frontend-node-modules: