Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backend/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ GOOGLE_EMBEDDINGS=gemini-embedding-001
HF_EMBEDDINGS=thenlper/gte-large
HF_RERANKER=BAAI/bge-reranker-base

# Reranker type: 'HF' for HuggingFace CrossEncoder, 'VERTEX_AI' for Google Vertex AI Ranking API
RERANKER_TYPE=HF
VERTEX_AI_PROJECT_ID=
VERTEX_AI_LOCATION=global

# FAISS database path
FAISS_DB_PATH=./.faissdb/faiss_index

Expand Down
1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies = [
"huggingface-hub[cli]==0.34.4",
"langchain==0.3.27",
"langchain-community==0.3.27",
"langchain-google-community[vertexaisearch]>=2.0.0",
"langchain-google-genai==2.1.9",
"langchain-google-vertexai==2.0.28",
"langchain-huggingface==0.3.1",
Expand Down
34 changes: 30 additions & 4 deletions backend/src/chains/hybrid_retriever_chain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import logging
from typing import Optional, Union, Any

from langchain.retrievers import EnsembleRetriever
Expand Down Expand Up @@ -121,10 +122,35 @@ def create_hybrid_retriever(self) -> None:
)

if self.contextual_rerank:
compressor = CrossEncoderReranker(
model=HuggingFaceCrossEncoder(model_name=self.reranking_model_name),
top_n=self.search_k,
)
reranker_type = os.getenv("RERANKER_TYPE", "HF").upper()

if reranker_type == "VERTEX_AI":
from langchain_google_community.vertex_rank import VertexAIRank

project_id = os.getenv("VERTEX_AI_PROJECT_ID", "")
location_id = os.getenv("VERTEX_AI_LOCATION", "global")

if not project_id:
raise ValueError(
"VERTEX_AI_PROJECT_ID must be set when using RERANKER_TYPE=VERTEX_AI"
)

compressor = VertexAIRank(
project_id=project_id,
location_id=location_id,
ranking_config="default_ranking_config",
top_n=self.search_k,
)
logging.info("Using Vertex AI reranker")
else:
compressor = CrossEncoderReranker(
model=HuggingFaceCrossEncoder(
model_name=self.reranking_model_name
),
top_n=self.search_k,
)
logging.info("Using HuggingFace CrossEncoder reranker")

self.retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
Comment on lines +158 to 162
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ensemble_retriever can be referenced here even when it was never assigned. If self.vector_db is None or processed_docs is empty, bm25_retriever is never set, which prevents the ensemble_retriever = EnsembleRetriever(...) block from running; later this code still uses ensemble_retriever, leading to UnboundLocalError. Initialize bm25_retriever/ensemble_retriever to None and either raise a clear error when the ensemble cannot be constructed or provide a fallback retriever composition before reaching this point.

Copilot uses AI. Check for mistakes.
Expand Down
115 changes: 114 additions & 1 deletion backend/tests/test_hybrid_retriever_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def test_create_hybrid_retriever_with_provided_vector_db(

assert chain.retriever == mock_ensemble_instance

@patch.dict("os.environ", {"RERANKER_TYPE": "HF"})
@patch("src.chains.hybrid_retriever_chain.SimilarityRetrieverChain")
@patch("src.chains.hybrid_retriever_chain.MMRRetrieverChain")
@patch("src.chains.hybrid_retriever_chain.BM25RetrieverChain")
Expand All @@ -158,7 +159,7 @@ def test_create_hybrid_retriever_with_contextual_rerank(
mock_mmr_chain,
mock_sim_chain,
):
"""Test creating hybrid retriever with contextual reranking enabled."""
"""Test creating hybrid retriever with HF contextual reranking enabled."""
mock_vector_db = Mock()
mock_vector_db.processed_docs = [Mock(), Mock()]

Expand Down Expand Up @@ -209,6 +210,118 @@ def test_create_hybrid_retriever_with_contextual_rerank(

assert chain.retriever == mock_compression_instance

@patch.dict(
"os.environ",
{
"RERANKER_TYPE": "VERTEX_AI",
"VERTEX_AI_PROJECT_ID": "test-project",
"VERTEX_AI_LOCATION": "global",
},
)
@patch("src.chains.hybrid_retriever_chain.SimilarityRetrieverChain")
@patch("src.chains.hybrid_retriever_chain.MMRRetrieverChain")
@patch("src.chains.hybrid_retriever_chain.BM25RetrieverChain")
@patch("src.chains.hybrid_retriever_chain.EnsembleRetriever")
@patch("src.chains.hybrid_retriever_chain.ContextualCompressionRetriever")
def test_create_hybrid_retriever_with_vertex_ai_rerank(
self,
mock_compression,
mock_ensemble,
mock_bm25_chain,
mock_mmr_chain,
mock_sim_chain,
):
"""Test creating hybrid retriever with Vertex AI reranking enabled."""
mock_vector_db = Mock()
mock_vector_db.processed_docs = [Mock(), Mock()]

chain = HybridRetrieverChain(
vector_db=mock_vector_db,
contextual_rerank=True,
search_k=5,
)

# Setup mocks
mock_sim_instance = Mock()
mock_sim_instance.retriever = Mock()
mock_sim_chain.return_value = mock_sim_instance

mock_mmr_instance = Mock()
mock_mmr_instance.retriever = Mock()
mock_mmr_chain.return_value = mock_mmr_instance

mock_bm25_instance = Mock()
mock_bm25_instance.retriever = Mock()
mock_bm25_chain.return_value = mock_bm25_instance

mock_ensemble_instance = Mock()
mock_ensemble.return_value = mock_ensemble_instance

mock_compression_instance = Mock()
mock_compression.return_value = mock_compression_instance

with patch(
"langchain_google_community.vertex_rank.VertexAIRank"
) as mock_vertex_rank:
mock_vertex_rank_instance = Mock()
mock_vertex_rank.return_value = mock_vertex_rank_instance

chain.create_hybrid_retriever()

mock_vertex_rank.assert_called_once_with(
project_id="test-project",
location_id="global",
ranking_config="default_ranking_config",
top_n=5,
)
mock_compression.assert_called_once_with(
base_compressor=mock_vertex_rank_instance,
base_retriever=mock_ensemble_instance,
)

assert chain.retriever == mock_compression_instance

@patch.dict(
"os.environ",
{"RERANKER_TYPE": "VERTEX_AI", "VERTEX_AI_PROJECT_ID": ""},
)
@patch("src.chains.hybrid_retriever_chain.SimilarityRetrieverChain")
@patch("src.chains.hybrid_retriever_chain.MMRRetrieverChain")
@patch("src.chains.hybrid_retriever_chain.BM25RetrieverChain")
@patch("src.chains.hybrid_retriever_chain.EnsembleRetriever")
def test_vertex_ai_rerank_raises_without_project_id(
self,
mock_ensemble,
mock_bm25_chain,
mock_mmr_chain,
mock_sim_chain,
):
"""Test that Vertex AI reranker raises error without project ID."""
mock_vector_db = Mock()
mock_vector_db.processed_docs = [Mock(), Mock()]

chain = HybridRetrieverChain(
vector_db=mock_vector_db,
contextual_rerank=True,
)

mock_sim_instance = Mock()
mock_sim_instance.retriever = Mock()
mock_sim_chain.return_value = mock_sim_instance

mock_mmr_instance = Mock()
mock_mmr_instance.retriever = Mock()
mock_mmr_chain.return_value = mock_mmr_instance

mock_bm25_instance = Mock()
mock_bm25_instance.retriever = Mock()
mock_bm25_chain.return_value = mock_bm25_instance

mock_ensemble.return_value = Mock()

with pytest.raises(ValueError, match="VERTEX_AI_PROJECT_ID must be set"):
chain.create_hybrid_retriever()

@patch("src.chains.hybrid_retriever_chain.os.path.isdir")
@patch("src.chains.hybrid_retriever_chain.os.listdir")
@patch("src.chains.hybrid_retriever_chain.SimilarityRetrieverChain")
Expand Down
Loading