Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/gaia/agents/base/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,17 +636,35 @@ def _get_embedder(self) -> Any:
f"Failed to initialize Lemonade embedding provider: {e}"
) from e

def _get_embedding_cache(self):
"""Lazy-init the content-keyed embedding cache (per-instance)."""
cache = getattr(self, "_embedding_cache", None)
if cache is None:
from gaia.llm.embedding_cache import EmbeddingCache

cache = EmbeddingCache()
self._embedding_cache = cache
return cache

def _embed_text(self, text: str) -> np.ndarray:
"""Embed text via Lemonade (nomic-embed-text-v2-moe-GGUF, 768-dim).

Required, not optional. Raises RuntimeError if embedding fails.

Identical text is served from a content-keyed cache, so repeated
query embeds (same recall query across turns) skip the Lemonade call.

Args:
text: Text to embed.

Returns:
L2-normalized float32 numpy array of shape (768,).
"""
cache = self._get_embedding_cache()
cached = cache.get(EMBEDDING_MODEL, EMBEDDING_DIM, text)
if cached is not None:
return cached

embedder = self._get_embedder()
try:
# LemonadeProvider.embed() returns list[list[float]]
Expand All @@ -658,6 +676,7 @@ def _embed_text(self, text: str) -> np.ndarray:
if norm > 0:
vec = vec / norm

cache.put(EMBEDDING_MODEL, EMBEDDING_DIM, text, vec)
return vec
except Exception as e:
raise RuntimeError(f"Embedding failed: {e}") from e
Expand Down
81 changes: 81 additions & 0 deletions src/gaia/llm/embedding_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
"""Content-keyed LRU cache for text embeddings.

Skips redundant per-turn embeds: the same query text re-embedded across
turns (or by two tool calls in one turn) pays the Lemonade embed cost once.

The cache key *is* the content — ``(model_id, dim, sha256(text))`` — so a hit
is never stale, and swapping the embedding model invalidates by construction
(the model_id component changes). ``dim`` may be ``None`` when the caller does
not track an expected dimensionality; the model_id alone still guarantees
correctness.
"""

import hashlib
import threading
from collections import OrderedDict
from typing import Optional, Tuple

import numpy as np

DEFAULT_MAX_ENTRIES = 512


class EmbeddingCache:
"""Thread-safe, bounded LRU cache mapping text content to its embedding.

Vectors are stored and returned as copies, so callers can mutate the
returned array (e.g. L2-normalize in place) without corrupting the cache.
"""

def __init__(self, max_entries: int = DEFAULT_MAX_ENTRIES):
if max_entries < 1:
raise ValueError(f"max_entries must be >= 1, got {max_entries}")
self._max_entries = max_entries
self._store: "OrderedDict[Tuple[str, Optional[int], str], np.ndarray]" = (
OrderedDict()
)
self._lock = threading.Lock()
self.hits = 0
self.misses = 0

@staticmethod
def _make_key(
model_id: str, dim: Optional[int], text: str
) -> Tuple[str, Optional[int], str]:
text_hash = hashlib.sha256(text.encode("utf-8")).hexdigest()
return (model_id, dim, text_hash)

def get(self, model_id: str, dim: Optional[int], text: str) -> Optional[np.ndarray]:
"""Return a copy of the cached vector, or ``None`` on a miss."""
key = self._make_key(model_id, dim, text)
with self._lock:
vec = self._store.get(key)
if vec is None:
self.misses += 1
return None
self._store.move_to_end(key)
self.hits += 1
return vec.copy()

def put(
self, model_id: str, dim: Optional[int], text: str, vector: np.ndarray
) -> None:
"""Store a copy of ``vector`` under the content key, evicting LRU entries."""
key = self._make_key(model_id, dim, text)
with self._lock:
self._store[key] = np.asarray(vector).copy()
self._store.move_to_end(key)
while len(self._store) > self._max_entries:
self._store.popitem(last=False)

def clear(self) -> None:
with self._lock:
self._store.clear()
self.hits = 0
self.misses = 0

def __len__(self) -> int:
with self._lock:
return len(self._store)
33 changes: 31 additions & 2 deletions src/gaia/rag/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(self, config: Optional[RAGConfig] = None):
self.embedder = None
self.llm_client = None
self.use_lemonade_embeddings = False
self._embedding_cache = None # content-keyed query-embed cache (lazy)
self.index = None
self.chunks = []
self.indexed_files = set()
Expand Down Expand Up @@ -630,6 +631,34 @@ def _encode_texts(
# Convert to numpy array
return np.array(all_embeddings, dtype=np.float32)

def _get_embedding_cache(self):
"""Lazy-init the content-keyed embedding cache (per-instance)."""
cache = getattr(self, "_embedding_cache", None)
if cache is None:
from gaia.llm.embedding_cache import EmbeddingCache

cache = EmbeddingCache()
self._embedding_cache = cache
return cache

def _encode_query(self, query: str) -> "np.ndarray":
"""Encode a single query to a (1, dim) array, served from the
content-keyed cache so identical queries across turns skip the embed.

Falls back to ``_encode_texts`` on a miss. Stored/doc-chunk vectors
are persisted elsewhere; this targets repeated *query* embeds only.
"""
cache = self._get_embedding_cache()
model_id = self.config.embedding_model
cached = cache.get(model_id, None, query)
if cached is not None:
return np.array([cached], dtype=np.float32)

embedding = self._encode_texts([query], show_progress=False)
if embedding.shape[0] > 0:
cache.put(model_id, None, query, embedding[0])
return embedding

def _get_file_type(self, file_path: str) -> str:
"""Detect file type from extension."""
ext = Path(file_path).suffix.lower()
Expand Down Expand Up @@ -2855,7 +2884,7 @@ def _retrieve_chunks_from_file(self, query: str, file_path: str) -> tuple:

# Encode query only once
self._load_embedder()
query_embedding = self._encode_texts([query], show_progress=False)
query_embedding = self._encode_query(query)

# Search in cached file-specific index
k = min(self.config.max_chunks, len(file_chunks))
Expand Down Expand Up @@ -2933,7 +2962,7 @@ def _search_chunks(self, query: str) -> Dict[str, Any]:
if self.config.show_stats:
print(f"🔍 Searching through {len(chunks_snapshot)} chunks...")
self.log.debug(f"Encoding query: {query[:50]}...")
query_embedding = self._encode_texts([query], show_progress=False)
query_embedding = self._encode_query(query)

# Search for similar chunks
k = min(self.config.max_chunks, len(chunks_snapshot))
Expand Down
128 changes: 128 additions & 0 deletions tests/unit/test_embedding_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
"""Unit tests for the content-keyed embedding cache (issue #1743).

Covers the cache itself plus its wiring into MemoryMixin._embed_text and
RAGSDK._encode_query: a second identical embed must make zero backend calls,
and a model/dim change must invalidate (miss).
"""

from unittest.mock import MagicMock

import numpy as np
import pytest

from gaia.llm.embedding_cache import EmbeddingCache


class TestEmbeddingCache:
def test_hit_returns_stored_vector(self):
cache = EmbeddingCache()
vec = np.array([1.0, 2.0, 3.0], dtype=np.float32)
assert cache.get("m", 3, "hello") is None
cache.put("m", 3, "hello", vec)
out = cache.get("m", 3, "hello")
assert out is not None
np.testing.assert_array_equal(out, vec)
assert cache.hits == 1
assert cache.misses == 1 # the first lookup

def test_model_change_invalidates(self):
cache = EmbeddingCache()
cache.put("model-a", 3, "hello", np.zeros(3, dtype=np.float32))
assert cache.get("model-b", 3, "hello") is None

def test_dim_change_invalidates(self):
cache = EmbeddingCache()
cache.put("m", 768, "hello", np.zeros(768, dtype=np.float32))
assert cache.get("m", 384, "hello") is None

def test_distinct_text_misses(self):
cache = EmbeddingCache()
cache.put("m", 3, "hello", np.zeros(3, dtype=np.float32))
assert cache.get("m", 3, "goodbye") is None

def test_returned_vector_is_a_copy(self):
cache = EmbeddingCache()
cache.put("m", 3, "t", np.array([1.0, 2.0, 3.0], dtype=np.float32))
out = cache.get("m", 3, "t")
out[0] = 99.0 # mutate caller's copy
# cache must be unaffected
np.testing.assert_array_equal(
cache.get("m", 3, "t"), np.array([1.0, 2.0, 3.0], dtype=np.float32)
)

def test_lru_eviction(self):
cache = EmbeddingCache(max_entries=2)
cache.put("m", 1, "a", np.zeros(1, dtype=np.float32))
cache.put("m", 1, "b", np.zeros(1, dtype=np.float32))
cache.get("m", 1, "a") # touch 'a' so 'b' is now LRU
cache.put("m", 1, "c", np.zeros(1, dtype=np.float32)) # evicts 'b'
assert cache.get("m", 1, "a") is not None
assert cache.get("m", 1, "c") is not None
assert cache.get("m", 1, "b") is None
assert len(cache) == 2

def test_invalid_max_entries(self):
with pytest.raises(ValueError):
EmbeddingCache(max_entries=0)


class TestMemoryEmbedTextCaching:
"""MemoryMixin._embed_text serves identical text from the cache."""

def _host(self):
from gaia.agents.base.memory import EMBEDDING_DIM, MemoryMixin

host = MemoryMixin()
mock_embedder = MagicMock()
vec = np.random.rand(EMBEDDING_DIM).astype(np.float32).tolist()
mock_embedder.embed.return_value = [vec]
host._embedder = mock_embedder
return host, mock_embedder

def test_second_identical_embed_makes_zero_backend_calls(self):
host, mock_embedder = self._host()

first = host._embed_text("the same query")
second = host._embed_text("the same query")

assert mock_embedder.embed.call_count == 1
np.testing.assert_array_equal(first, second)

def test_distinct_text_calls_backend_again(self):
host, mock_embedder = self._host()
host._embed_text("query one")
host._embed_text("query two")
assert mock_embedder.embed.call_count == 2


class TestRagEncodeQueryCaching:
"""RAGSDK._encode_query serves identical queries from the cache."""

def _sdk(self):
from gaia.rag.sdk import RAGSDK

sdk = RAGSDK.__new__(RAGSDK) # skip heavy __init__
sdk.config = MagicMock()
sdk.config.embedding_model = "nomic-embed-text-v2-moe-GGUF"
sdk._embedding_cache = None
sdk._encode_texts = MagicMock(
return_value=np.random.rand(1, 768).astype(np.float32)
)
return sdk

def test_second_identical_query_makes_zero_encode_calls(self):
sdk = self._sdk()
first = sdk._encode_query("what is gaia?")
second = sdk._encode_query("what is gaia?")

assert sdk._encode_texts.call_count == 1
assert first.shape == (1, 768)
np.testing.assert_array_equal(first, second)

def test_distinct_query_encodes_again(self):
sdk = self._sdk()
sdk._encode_query("query a")
sdk._encode_query("query b")
assert sdk._encode_texts.call_count == 2
Loading