From d1a82435d0b268fd1bf945bd1471074a28131fbd Mon Sep 17 00:00:00 2001 From: Ming Wen Date: Tue, 5 May 2026 17:15:24 +0800 Subject: [PATCH 1/2] =?UTF-8?q?feat(cache):=20Stage=204b=20=E2=80=94=20pgv?= =?UTF-8?q?ector=20semantic=20cache=20backend=20(DP)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stage 4b data-plane half: route chat requests with a matched `backend = pgvector` policy through dp-manager's `/dp/cache/{lookup,put}` endpoints for cosine-ANN cache hits. - aisix-cache: new `pgvector` module — `PgvectorCache` HTTP client for the dp-manager handler, with fail-open `SemanticCacheError` and `SemanticHit { response, prompt_tokens, completion_tokens, similarity }`. - aisix-cache: new `embed` module — `embed_prompt` reuses the env's first OpenAI Model's provider creds via `Hub::get(Provider::Openai)` and `Bridge::embed`, swapping in the policy's `embedding_model` on the wire. - aisix-core: `CacheBackend::Pgvector` enum variant. - aisix-proxy::state: `pgvector_cache: Option>` field + `with_pgvector_cache` builder. - aisix-proxy::chat: dispatch path — when matched policy's backend is `pgvector`, embed the last user message, lookup against the vector index; on miss continue to upstream and PUT the result on success. Embedding/transport failures fall open with `cache_status = Disabled` per the Stage 4b design note. Streaming requests bypass semantic cache (existing behaviour). The embed call reuses the chat request id so it correlates in upstream logs. Tests: pgvector wiremock coverage (hit, miss, handler-error, put, base-url normalisation), embed snapshot resolver coverage, proxy cache-policy applies_to matcher coverage. All 219 tests green across aisix-cache (20), aisix-core (94), aisix-proxy (105). --- Cargo.lock | 2 + crates/aisix-cache/Cargo.toml | 4 + crates/aisix-cache/src/embed.rs | 141 +++++++ crates/aisix-cache/src/lib.rs | 4 + crates/aisix-cache/src/pgvector.rs | 399 +++++++++++++++++++ crates/aisix-core/src/models/cache_policy.rs | 13 +- crates/aisix-proxy/src/chat.rs | 249 ++++++++++-- crates/aisix-proxy/src/state.rs | 22 +- 8 files changed, 807 insertions(+), 27 deletions(-) create mode 100644 crates/aisix-cache/src/embed.rs create mode 100644 crates/aisix-cache/src/pgvector.rs diff --git a/Cargo.lock b/Cargo.lock index 97aa0e29..6499942c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -79,11 +79,13 @@ dependencies = [ "async-trait", "moka", "redis", + "reqwest", "serde", "serde_json", "thiserror 1.0.69", "tokio", "tracing", + "wiremock", ] [[package]] diff --git a/crates/aisix-cache/Cargo.toml b/crates/aisix-cache/Cargo.toml index 8115c1e7..c4634921 100644 --- a/crates/aisix-cache/Cargo.toml +++ b/crates/aisix-cache/Cargo.toml @@ -18,10 +18,14 @@ serde_json.workspace = true thiserror.workspace = true tracing.workspace = true moka.workspace = true +# pgvector backend talks to dp-manager over HTTP(S) — reuse the +# workspace reqwest so the binary doesn't pull in two TLS stacks. +reqwest = { workspace = true, default-features = false, features = ["json", "rustls-tls"] } redis = { workspace = true, optional = true } [dev-dependencies] tokio = { workspace = true, features = ["macros", "rt", "time"] } +wiremock.workspace = true [features] default = ["memory"] diff --git a/crates/aisix-cache/src/embed.rs b/crates/aisix-cache/src/embed.rs new file mode 100644 index 00000000..a36865ea --- /dev/null +++ b/crates/aisix-cache/src/embed.rs @@ -0,0 +1,141 @@ +//! Embedding helper for the pgvector semantic cache (Stage 4b). +//! +//! When the proxy matches a `cache_policy` with `backend = pgvector`, +//! we need a vector to look up against. The embedding is computed on +//! the data plane (so embedding-provider credentials stay on DP per +//! the Stage 4b design note); this module is the glue between +//! `aisix_proxy::chat` and the existing `Bridge::embed` surface. +//! +//! Provider key resolution: we pick the first OpenAI Model in the +//! current snapshot and reuse its provider_config (api_key + api_base). +//! Per the Stage 4b decision (option C in the design doc), there is +//! no separate `embedding_provider_key_id` field on `CachePolicy` +//! today — that's a follow-up if operators need to bill embedding +//! against a different key. +//! +//! Failure mode is fail-open: any error from this module surfaces to +//! the proxy as `EmbedError`, which the chat handler maps to +//! `CacheStatus::Disabled` + skip-the-lookup. The caller's request +//! still reaches the upstream, just without semantic-cache benefit. + +use aisix_core::models::{Model, Provider}; +use aisix_core::resource::ResourceEntry; +use aisix_core::AisixSnapshot; +use aisix_gateway::{BridgeContext, EmbeddingRequest, Hub}; +use std::sync::Arc; + +/// Errors the embedder can surface. The proxy logs these at +/// `tracing::warn!` and then falls open — the caller's request still +/// reaches the upstream, just without a cache lookup. +#[derive(Debug, thiserror::Error)] +pub enum EmbedError { + /// No Model in the env snapshot has provider == OpenAI. Operator + /// needs to add an OpenAI provider_key + Model to the env before + /// the pgvector backend can produce embeddings. + #[error("no OpenAI model in snapshot to source embedding credentials from")] + NoOpenAiModel, + /// `Bridge::embed` failed (transport error, upstream error, + /// decode failure). Carries the upstream message for the warn + /// log so an operator can debug why semantic caching went dark. + #[error("embedding bridge call failed: {0}")] + Bridge(String), + /// The provider returned a successful response but with no + /// embedding data (e.g. empty `data: []`). Shouldn't happen + /// against OpenAI but the wire allows it. + #[error("embedding response had no data")] + EmptyResponse, +} + +/// Compute an embedding for the prompt text using the env's first +/// OpenAI Model as the credentials source. +/// +/// `embedding_model` is the model name from the policy (e.g. +/// `"text-embedding-3-small"`). The OpenAI Model's api_key + api_base +/// are reused — the policy's embedding model swaps in for the chat +/// model's name on the embeddings endpoint. +/// +/// `request_id` is the chat request's id, threaded through so the +/// embedding call shows up under the same id in upstream logs. +pub async fn embed_prompt( + snapshot: &AisixSnapshot, + hub: &Hub, + embedding_model: &str, + prompt_text: &str, + request_id: &str, +) -> Result, EmbedError> { + let openai_model = first_openai_model(snapshot).ok_or(EmbedError::NoOpenAiModel)?; + let bridge = hub + .get(Provider::Openai) + .ok_or(EmbedError::NoOpenAiModel)?; + + // The chat model carries provider_config (api_key + api_base); + // we override the model name for the embeddings call so the + // bridge sends the policy's embedding_model on the wire instead + // of the chat model name. + let mut model_for_embed = openai_model.value.clone(); + model_for_embed.model = format!("openai/{embedding_model}"); + model_for_embed.name = format!("__embedder__{embedding_model}"); + let model_arc = Arc::new(model_for_embed); + let ctx = BridgeContext::new(request_id, model_arc); + + let req = EmbeddingRequest { + model: embedding_model.to_string(), + input: vec![prompt_text.to_string()], + encoding_format: None, + dimensions: None, + }; + let resp = bridge + .embed(&req, &ctx) + .await + .map_err(|e| EmbedError::Bridge(e.to_string()))?; + let first = resp.data.into_iter().next().ok_or(EmbedError::EmptyResponse)?; + Ok(first.embedding) +} + +/// Pick the first Model in the snapshot whose provider is OpenAI. +/// "First" is by `entries()` order (which is by id) — stable across +/// snapshot rebuilds for the same set of models, so the same key gets +/// charged for embeddings consistently. +fn first_openai_model(snapshot: &AisixSnapshot) -> Option>> { + snapshot + .models + .entries() + .into_iter() + .find(|entry| matches!(entry.value.provider(), Some(Provider::Openai))) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn snap_with(model_json: &str) -> AisixSnapshot { + let s = AisixSnapshot::new(); + let m: Model = serde_json::from_str(model_json).unwrap(); + s.models.insert(ResourceEntry::new("m-1", m, 1)); + s + } + + #[test] + fn picks_openai_model_when_present() { + let s = snap_with( + r#"{"name":"gpt","model":"openai/gpt-4o","provider_config":{"api_key":"sk-x"}}"#, + ); + let picked = first_openai_model(&s); + assert!(picked.is_some()); + assert_eq!(picked.unwrap().value.name, "gpt"); + } + + #[test] + fn returns_none_when_only_non_openai_models() { + let s = snap_with( + r#"{"name":"c","model":"anthropic/claude","provider_config":{"api_key":"k"}}"#, + ); + assert!(first_openai_model(&s).is_none()); + } + + #[test] + fn returns_none_on_empty_snapshot() { + let s = AisixSnapshot::new(); + assert!(first_openai_model(&s).is_none()); + } +} diff --git a/crates/aisix-cache/src/lib.rs b/crates/aisix-cache/src/lib.rs index e757e4c6..4a48a8cb 100644 --- a/crates/aisix-cache/src/lib.rs +++ b/crates/aisix-cache/src/lib.rs @@ -18,14 +18,18 @@ #![deny(rust_2018_idioms)] mod cache; +mod embed; mod key; mod memory; +mod pgvector; #[cfg(feature = "redis")] mod redis; pub use cache::{Cache, CacheError, CacheOutcome}; +pub use embed::{embed_prompt, EmbedError}; pub use key::CacheKey; pub use memory::{MemoryCache, DEFAULT_CAPACITY, DEFAULT_TTL}; +pub use pgvector::{PgvectorCache, SemanticCacheError, SemanticHit}; #[cfg(feature = "redis")] pub use redis::{ RedisCache, DEFAULT_PREFIX as REDIS_DEFAULT_PREFIX, DEFAULT_TTL as REDIS_DEFAULT_TTL, diff --git a/crates/aisix-cache/src/pgvector.rs b/crates/aisix-cache/src/pgvector.rs new file mode 100644 index 00000000..112f896f --- /dev/null +++ b/crates/aisix-cache/src/pgvector.rs @@ -0,0 +1,399 @@ +//! pgvector-backed semantic cache (Stage 4b of cache-policies). +//! +//! Wire path: the DP sends precomputed embeddings to dp-manager's +//! `/dp/cache/{lookup,put}` mTLS endpoints; dp-manager owns the PG +//! connection and runs the cosine-ANN search over +//! `cache_entries_semantic`. Multi-tenant isolation lives at the +//! env_id check inside the cp-api handler — the DP never gets a PG +//! client, which keeps the trust boundary thin. +//! +//! The DP is responsible for: +//! 1. computing the embedding (so embedding-provider credentials +//! stay on the data plane — see `crate::embed`) +//! 2. POSTing it to `/dp/cache/lookup` +//! 3. on miss, calling the upstream model and POSTing the result +//! to `/dp/cache/put` +//! +//! This module is just the HTTP client layer. The chat handler in +//! `aisix-proxy::chat` orchestrates 1 + 2 + 3. + +use aisix_gateway::ChatResponse; +use reqwest::{Client, StatusCode}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use std::time::Duration; + +/// Errors the pgvector cache can surface to the proxy. The proxy's +/// chat handler maps `LookupFailed` / `PutFailed` to a tracing::warn! +/// and cache-miss fallthrough (fail-open per the Stage 4b design +/// note); upstream callers never see these. +#[derive(Debug, thiserror::Error)] +pub enum SemanticCacheError { + #[error("dp-manager /dp/cache/lookup failed: {0}")] + LookupFailed(String), + #[error("dp-manager /dp/cache/put failed: {0}")] + PutFailed(String), + /// The handler returned a non-2xx with the canonical + /// `{error: {code, message}}` envelope. Surfaced as the message + /// for log readability; the proxy still falls open. + #[error("dp-manager error {status}: {code} — {message}")] + HandlerError { + status: u16, + code: String, + message: String, + }, +} + +/// One semantic cache hit, materialised from the lookup envelope. +/// Mirrors the fields the chat handler needs at the cache-hit return +/// site (response body, the original usage so the dashboard's +/// "tokens saved" stats reflect the original event). +#[derive(Debug, Clone)] +pub struct SemanticHit { + pub response: ChatResponse, + pub prompt_tokens: u32, + pub completion_tokens: u32, + /// Cosine similarity of this match (1 - cosine_distance) against + /// the lookup query. Logged on hit but not surfaced to clients. + pub similarity: f32, +} + +/// HTTP client wrapper for the dp-manager-side semantic cache. Cheap +/// to clone (`reqwest::Client` is internally `Arc<…>`, the wrapper +/// just adds a base URL). +#[derive(Debug, Clone)] +pub struct PgvectorCache { + client: Client, + /// Absolute base URL of the dp-manager listener, e.g. + /// `https://dp-manager.aisix.svc:7944`. Trailing slash NOT + /// included — `format_url` joins with `/dp/cache/...` directly. + base_url: Arc, +} + +impl PgvectorCache { + /// Build with an externally-configured mTLS-presenting client. + /// The client must already have the DP's client cert + the + /// dp-manager's CA loaded — same bundle the telemetry sender + /// uses (see `aisix-server::heartbeat::build_mtls_client`). + pub fn new(client: Client, base_url: impl Into) -> Self { + let mut url = base_url.into(); + while url.ends_with('/') { + url.pop(); + } + Self { + client, + base_url: Arc::from(url.as_str()), + } + } + + /// Look up a semantically-similar entry. `threshold` is the + /// cosine-similarity floor (0..=1). Returns: + /// - `Ok(Some(hit))` when the best entry matches above threshold + /// - `Ok(None)` when no entry matches (or all matches are + /// below the threshold) + /// - `Err(...)` on transport / handler failure (proxy maps + /// these to a warn + cache miss fallthrough) + pub async fn lookup( + &self, + policy_id: &str, + embedding: &[f32], + threshold: Option, + ) -> Result, SemanticCacheError> { + let body = LookupRequest { + policy_id, + embedding, + threshold, + }; + let url = self.format_url("/dp/cache/lookup"); + let resp = self + .client + .post(&url) + .json(&body) + .timeout(Duration::from_secs(5)) + .send() + .await + .map_err(|e| SemanticCacheError::LookupFailed(e.to_string()))?; + + let status = resp.status(); + if !status.is_success() { + return Err(handler_error(status, resp).await); + } + let parsed: LookupResponse = resp + .json() + .await + .map_err(|e| SemanticCacheError::LookupFailed(format!("decode: {e}")))?; + if !parsed.hit { + return Ok(None); + } + Ok(Some(SemanticHit { + response: parsed.response.ok_or_else(|| { + SemanticCacheError::LookupFailed("hit=true but response missing".into()) + })?, + prompt_tokens: parsed.prompt_tokens, + completion_tokens: parsed.completion_tokens, + similarity: parsed.similarity, + })) + } + + /// Persist a fresh entry. Called from the proxy's post-success + /// path. `prompt_text` is the canonical text we embedded + /// (typically the last user message); cp-api stores it verbatim + /// for "show me what this entry caches" debugging. + /// + /// `ttl_seconds = None` makes the cp-api handler use the + /// policy's stored TTL — that's the normal case. The override is + /// kept on the wire for future per-request TTL hints. + #[allow(clippy::too_many_arguments)] + pub async fn put( + &self, + policy_id: &str, + prompt_text: &str, + embedding: &[f32], + response: &ChatResponse, + prompt_tokens: u32, + completion_tokens: u32, + ttl_seconds: Option, + ) -> Result<(), SemanticCacheError> { + let body = PutRequest { + policy_id, + prompt_text, + embedding, + response, + prompt_tokens, + completion_tokens, + ttl_seconds, + }; + let url = self.format_url("/dp/cache/put"); + let resp = self + .client + .post(&url) + .json(&body) + .timeout(Duration::from_secs(5)) + .send() + .await + .map_err(|e| SemanticCacheError::PutFailed(e.to_string()))?; + let status = resp.status(); + if !status.is_success() { + return Err(handler_error(status, resp).await); + } + // Body is `{entry_id, expires_at}` — we don't need either on + // the proxy side, so don't decode. + Ok(()) + } + + fn format_url(&self, path: &str) -> String { + format!("{}{path}", self.base_url) + } +} + +// --- Wire types ----------------------------------------------------- + +#[derive(Serialize)] +struct LookupRequest<'a> { + policy_id: &'a str, + embedding: &'a [f32], + #[serde(skip_serializing_if = "Option::is_none")] + threshold: Option, +} + +#[derive(Deserialize)] +struct LookupResponse { + hit: bool, + #[serde(default)] + response: Option, + #[serde(default)] + prompt_tokens: u32, + #[serde(default)] + completion_tokens: u32, + #[serde(default)] + similarity: f32, +} + +#[derive(Serialize)] +struct PutRequest<'a> { + policy_id: &'a str, + prompt_text: &'a str, + embedding: &'a [f32], + response: &'a ChatResponse, + #[serde(skip_serializing_if = "is_zero_u32")] + prompt_tokens: u32, + #[serde(skip_serializing_if = "is_zero_u32")] + completion_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + ttl_seconds: Option, +} + +#[inline] +fn is_zero_u32(n: &u32) -> bool { + *n == 0 +} + +/// Decode a non-2xx response into the canonical +/// `{error: {code, message}}` envelope. Falls back to a generic +/// HandlerError if the body isn't shaped like that — never panics. +async fn handler_error(status: StatusCode, resp: reqwest::Response) -> SemanticCacheError { + #[derive(Deserialize)] + struct Envelope { + error: Option, + } + #[derive(Deserialize)] + struct EnvelopeBody { + code: String, + message: String, + } + match resp.json::().await { + Ok(env) => match env.error { + Some(b) => SemanticCacheError::HandlerError { + status: status.as_u16(), + code: b.code, + message: b.message, + }, + None => SemanticCacheError::HandlerError { + status: status.as_u16(), + code: "UNKNOWN".into(), + message: "no error envelope".into(), + }, + }, + Err(_) => SemanticCacheError::HandlerError { + status: status.as_u16(), + code: "UNKNOWN".into(), + message: "non-JSON response".into(), + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use aisix_gateway::{ChatMessage, ChatResponse, FinishReason, UsageStats}; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + fn sample_response() -> ChatResponse { + ChatResponse { + id: "cmpl-test".into(), + model: "gpt-4o".into(), + message: ChatMessage::assistant("cached"), + finish_reason: FinishReason::Stop, + usage: UsageStats { + prompt_tokens: 1, + completion_tokens: 2, + ..Default::default() + }, + } + } + + #[tokio::test] + async fn lookup_hit_returns_response_and_similarity() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/dp/cache/lookup")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "hit": true, + "response": { + "id": "cmpl-cached", + "model": "gpt-4o", + "message": {"role": "assistant", "content": "from cache"}, + "finish_reason": "stop", + "usage": {"prompt_tokens": 5, "completion_tokens": 6, "total_tokens": 11} + }, + "prompt_tokens": 5, + "completion_tokens": 6, + "similarity": 0.97 + }))) + .mount(&server) + .await; + + let cache = PgvectorCache::new(reqwest::Client::new(), server.uri()); + let hit = cache + .lookup("policy-1", &[0.0_f32; 3], Some(0.92)) + .await + .unwrap() + .unwrap(); + assert_eq!(hit.response.message.content, "from cache"); + assert_eq!(hit.prompt_tokens, 5); + assert_eq!(hit.completion_tokens, 6); + assert!((hit.similarity - 0.97).abs() < 1e-6); + } + + #[tokio::test] + async fn lookup_miss_returns_none() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/dp/cache/lookup")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"hit": false, "similarity": 0.4})), + ) + .mount(&server) + .await; + + let cache = PgvectorCache::new(reqwest::Client::new(), server.uri()); + let res = cache + .lookup("policy-1", &[0.0_f32; 3], None) + .await + .unwrap(); + assert!(res.is_none()); + } + + #[tokio::test] + async fn lookup_handler_error_surfaces_envelope() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/dp/cache/lookup")) + .respond_with( + ResponseTemplate::new(404).set_body_json(serde_json::json!({ + "error": {"code": "NOT_FOUND", "message": "cache_policy not found"} + })), + ) + .mount(&server) + .await; + + let cache = PgvectorCache::new(reqwest::Client::new(), server.uri()); + let err = cache + .lookup("policy-1", &[0.0_f32; 3], None) + .await + .unwrap_err(); + match err { + SemanticCacheError::HandlerError { status, code, .. } => { + assert_eq!(status, 404); + assert_eq!(code, "NOT_FOUND"); + } + other => panic!("expected HandlerError, got {other:?}"), + } + } + + #[tokio::test] + async fn put_returns_ok_on_success() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/dp/cache/put")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "entry_id": "11111111-1111-1111-1111-111111111111", + "expires_at": "2026-01-01T00:00:00Z" + }))) + .expect(1) + .mount(&server) + .await; + + let cache = PgvectorCache::new(reqwest::Client::new(), server.uri()); + cache + .put( + "policy-1", + "the prompt", + &[0.0_f32; 3], + &sample_response(), + 1, + 2, + None, + ) + .await + .unwrap(); + } + + #[test] + fn base_url_strips_trailing_slashes() { + let cache = PgvectorCache::new(reqwest::Client::new(), "https://dpmgr///"); + assert_eq!(cache.format_url("/dp/cache/lookup"), "https://dpmgr/dp/cache/lookup"); + } +} diff --git a/crates/aisix-core/src/models/cache_policy.rs b/crates/aisix-core/src/models/cache_policy.rs index 1fd98e14..c064f5c5 100644 --- a/crates/aisix-core/src/models/cache_policy.rs +++ b/crates/aisix-core/src/models/cache_policy.rs @@ -23,9 +23,15 @@ use serde::{Deserialize, Serialize}; use crate::resource::Resource; -/// Cache backend choice. Stage 2 only enforces `Memory`. The other -/// variants persist in cp-api + ship through kine but the DP falls -/// back to memory until each backend wires up. +/// Cache backend choice. The DP enforces: +/// - `Memory` (Stage 2) — exact-match in-process LRU +/// - `Pgvector` (Stage 4b) — semantic cosine-similarity over PG + +/// pgvector via dp-manager `/dp/cache/{lookup,put}` +/// +/// `Redis` / `RedisSemantic` / `Qdrant` persist in cp-api + ship +/// through kine but the DP falls through to the memory path for +/// them. They stay enum members so a future migration doesn't need +/// a schema change to pick them up. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum CacheBackend { @@ -33,6 +39,7 @@ pub enum CacheBackend { Memory, Redis, RedisSemantic, + Pgvector, Qdrant, } diff --git a/crates/aisix-proxy/src/chat.rs b/crates/aisix-proxy/src/chat.rs index df68e4f3..8e4411da 100644 --- a/crates/aisix-proxy/src/chat.rs +++ b/crates/aisix-proxy/src/chat.rs @@ -16,7 +16,8 @@ //! line. Errors surface via [`ProxyError`] which carries the right //! status, error type, and (for rate-limits) Retry-After. -use aisix_cache::CacheKey; +use aisix_cache::{embed_prompt, CacheKey}; +use aisix_core::models::CacheBackend; use aisix_gateway::{BridgeContext, BridgeError, ChatFormat}; use aisix_guardrails::GuardrailVerdict; use aisix_obs::{AccessLog, LangfuseEvent, Metrics, RequestOutcome, UsageEvent}; @@ -437,26 +438,20 @@ async fn dispatch( }); } - // Policy gate (Stage 3): the cache is only consulted when at - // least one enabled `CachePolicy` in the snapshot has an - // `applies_to` clause that matches THIS request. cp-api owns - // the policy CRUD surface (`/api/environments/:env/cache_policies`, - // see Stage 1); kine fans out the rows; the loader populates - // `snapshot.cache_policies` (see aisix-etcd). Stage 4 will add - // per-policy `ttl_seconds` propagation into the cache backend - // and the semantic backends. + // Policy gate (Stage 3 + 4b): pick the first enabled `cache_policy` + // whose `applies_to` clause accepts THIS request's + // (model, api_key_id) pair. Stage 4b extends the gate to dispatch + // by `backend`: `pgvector` → semantic lookup via dp-manager + // `/dp/cache/{lookup,put}`; everything else (memory / unsupported + // preview backends) → existing moka path. // - // Match order: first enabled policy whose `parsed_applies_to()` - // accepts (req.model, auth.entry.id) wins. Iteration order on a - // ResourceTable isn't stable, but the first-match-wins rule is - // deterministic enough for the cache_status / x-aisix-cache header - // to stay stable across requests; ties only matter when an env - // legitimately has overlapping policies. - let cache_active_by_policy = snapshot + // `applies_to` parsing + match logic lives on `CachePolicy` itself + // (see aisix-core::cache_policy::AppliesTo). + let matched_policy = snapshot .cache_policies .entries() - .iter() - .any(|entry| { + .into_iter() + .find(|entry| { entry.value.enabled && entry .value @@ -464,18 +459,176 @@ async fn dispatch( .matches(&req.model, &auth.entry.id) }); + // Branch: pgvector backend uses the semantic cache; any other + // matched backend falls back to the moka path. The two paths + // share the cache_status emission shape so the dashboard /logs + // column reads identically across both. + let pgvector_match = matched_policy + .as_ref() + .filter(|p| matches!(p.value.backend, CacheBackend::Pgvector)); + + // ─── pgvector semantic path ───────────────────────────────────── + // + // Embeds the last user message via the env's first OpenAI model + // (see aisix-cache::embed_prompt) and looks up against + // dp-manager's /dp/cache/lookup. On hit we return the cached + // response with cache_status=Hit + x-aisix-cache=hit. On miss we + // capture the embedding so the post-success path can call /put + // without re-embedding. + // + // Failure mode is fail-open per the Stage 4b design note: any + // error here logs a warn + falls through to the upstream as if + // the cache were disabled (cache_status stays Disabled in that + // case, NOT Miss — operators reading /logs see "the gate is + // closed" rather than "we tried and missed"). + let pgvector_embedding: Option>; + let pgvector_prompt_text: Option; + let pgvector_policy_id: Option; + let pgvector_ttl_seconds: Option; + if let (Some(matched), Some(pgvector)) = + (pgvector_match, state.pgvector_cache.as_ref()) + { + let embedding_model = matched + .value + .embedding_model + .as_deref() + .unwrap_or("text-embedding-3-small"); + let prompt_text = last_user_message_text(req).unwrap_or_default(); + if prompt_text.is_empty() { + // No user turn to embed — fall through with Disabled + // status. Shouldn't normally happen (the empty-messages + // check at function entry already rejects), but keep the + // graceful path so an exotic system-only request doesn't + // panic. + pgvector_embedding = None; + pgvector_prompt_text = None; + pgvector_policy_id = None; + pgvector_ttl_seconds = None; + } else { + match embed_prompt( + &snapshot, + &state.hub, + embedding_model, + &prompt_text, + request_id, + ) + .await + { + Ok(vec) => { + let policy_id = matched.id.clone(); + let threshold = matched.value.similarity_threshold; + match pgvector.lookup(&policy_id, &vec, threshold).await { + Ok(Some(hit)) => { + reservation.commit_tokens(0); + let prompt = hit.prompt_tokens as u64; + let completion = hit.completion_tokens as u64; + let total = prompt.saturating_add(completion); + let cached_prompt_tokens = hit.response.usage.cached_prompt_tokens; + let reasoning_tokens = hit.response.usage.reasoning_tokens; + let cache_creation_tokens = + hit.response.usage.cache_creation_tokens; + let cache_read_tokens = hit.response.usage.cache_read_tokens; + let provider_label = attempt_models[0] + .provider() + .map(|p| format!("{p:?}").to_lowercase()) + .unwrap_or_else(|| "unknown".into()); + tracing::debug!( + request_id = %request_id, + policy_id = %policy_id, + similarity = hit.similarity, + "pgvector cache hit", + ); + let mut response = + Json(render_response(now, hit.response)).into_response(); + response + .headers_mut() + .insert(CACHE_HEADER, HeaderValue::from_static("hit")); + return Ok(Success { + response, + provider: provider_label, + model_id: model_id.clone(), + prompt_tokens: Some(prompt), + completion_tokens: Some(completion), + total_tokens: Some(total), + cached_prompt_tokens, + reasoning_tokens, + cache_creation_tokens, + cache_read_tokens, + provider_request_id: String::new(), + provider_model_version: String::new(), + finish_reason: String::new(), + cost_usd: 0.0, + bypass_reason: bypass_reason.clone(), + cache_status: CacheStatus::Hit, + }); + } + Ok(None) => { + // Miss — record state for the post-success put. + pgvector_embedding = Some(vec); + pgvector_prompt_text = Some(prompt_text); + pgvector_policy_id = Some(policy_id); + pgvector_ttl_seconds = Some(matched.value.ttl_seconds); + } + Err(err) => { + tracing::warn!( + error = %err, + policy_id = %policy_id, + "pgvector lookup failed; falling open to upstream", + ); + pgvector_embedding = None; + pgvector_prompt_text = None; + pgvector_policy_id = None; + pgvector_ttl_seconds = None; + } + } + } + Err(err) => { + tracing::warn!( + error = %err, + embedding_model = %embedding_model, + "pgvector embedding failed; falling open to upstream", + ); + pgvector_embedding = None; + pgvector_prompt_text = None; + pgvector_policy_id = None; + pgvector_ttl_seconds = None; + } + } + } + } else { + pgvector_embedding = None; + pgvector_prompt_text = None; + pgvector_policy_id = None; + pgvector_ttl_seconds = None; + } + + // For the moka path: only consult the cache when the matched + // policy's backend is Memory (or the matched policy isn't + // pgvector AND the moka cache is configured). Stage 2's + // any-policy gate is replaced here by the per-backend dispatch. + let cache_active_by_policy = matched_policy + .as_ref() + .map(|p| matches!(p.value.backend, CacheBackend::Memory)) + .unwrap_or(false); + // Cache lookup keyed on the *virtual* model name so a re-request // hits the cache regardless of which target served the original. - // Even with `cache_active_by_policy = false` we still build the - // key to keep the cache_status path uniform — `disabled` is the - // outcome when the gate is closed, but the request itself is - // shaped the same way. let cache_key = state .cache .as_ref() .map(|_| CacheKey::from_request(req).fingerprint()); - let cache_status = if cache_active_by_policy && state.cache.is_some() { + let cache_status = if pgvector_policy_id.is_some() { + // We tried the pgvector path and missed (or the embedding + // / lookup errored after the gate was open). Either way the + // request is heading upstream; record Miss when we'll write + // back, Disabled when we won't. + if pgvector_embedding.is_some() { + CacheStatus::Miss + } else { + CacheStatus::Disabled + } + } else if cache_active_by_policy && state.cache.is_some() { CacheStatus::Miss } else { CacheStatus::Disabled @@ -654,6 +807,39 @@ async fn dispatch( } } + // pgvector backend: if the lookup at the top of dispatch missed + // (and the embedding succeeded), persist the upstream response + // so the next semantically-similar request can hit. Errors on + // /dp/cache/put are warned but don't fail the request — the + // caller already got their answer. + if let (Some(pgvector), Some(embedding), Some(prompt_text), Some(policy_id)) = ( + state.pgvector_cache.as_ref(), + pgvector_embedding.as_deref(), + pgvector_prompt_text.as_deref(), + pgvector_policy_id.as_deref(), + ) { + let prompt_tokens = upstream.usage.prompt_tokens; + let completion_tokens = upstream.usage.completion_tokens; + if let Err(err) = pgvector + .put( + policy_id, + prompt_text, + embedding, + &upstream, + prompt_tokens, + completion_tokens, + pgvector_ttl_seconds, + ) + .await + { + tracing::warn!( + error = %err, + policy_id = %policy_id, + "pgvector put failed; entry not stored", + ); + } + } + let mut response = Json(render_response(now, upstream)).into_response(); if matches!(cache_status, CacheStatus::Miss) { // Miss header only when the cache was actually consulted — @@ -698,6 +884,23 @@ fn finish_reason_label(reason: &aisix_gateway::FinishReason) -> String { } } +/// Pick the text we embed for the pgvector semantic cache: the +/// content of the last `user` message in the request. Per the +/// Stage 4b design note (option A), embedding the full conversation +/// is a possible future option; today the last user turn is the +/// stable signal that captures "what the caller is asking right +/// now". Returns `None` when no user message is present so the +/// caller can short-circuit out of the embedding path. +fn last_user_message_text(req: &ChatFormat) -> Option { + use aisix_gateway::Role; + req.messages + .iter() + .rev() + .find(|m| matches!(m.role, Role::User)) + .map(|m| m.content.clone()) + .filter(|s| !s.is_empty()) +} + fn record_success( metrics: &Metrics, provider: &str, diff --git a/crates/aisix-proxy/src/state.rs b/crates/aisix-proxy/src/state.rs index d367488b..df81d227 100644 --- a/crates/aisix-proxy/src/state.rs +++ b/crates/aisix-proxy/src/state.rs @@ -14,7 +14,7 @@ //! //! Cheap to clone: every field is either an `Arc` or a small Copy scalar. -use aisix_cache::{Cache, MemoryCache}; +use aisix_cache::{Cache, MemoryCache, PgvectorCache}; use aisix_core::snapshot::SnapshotHandle; use aisix_core::{AisixSnapshot, ProxyConfig}; use aisix_gateway::Hub; @@ -34,6 +34,13 @@ pub struct ProxyState { pub limiter: Arc, pub metrics: Arc, pub cache: Option>, + /// Optional pgvector-backed semantic cache. When `Some`, the chat + /// handler routes requests with a matched policy of + /// `backend = pgvector` through this client. None disables the + /// pgvector path — matching policies surface as `Disabled` in + /// the cache_status telemetry. Cheap to clone — `PgvectorCache` + /// is `reqwest::Client` + a base URL. + pub pgvector_cache: Option>, pub routing: Arc, /// Content-policy hooks. Default is an empty chain (no-op); the /// server bootstrap loads a real chain from config. @@ -69,6 +76,7 @@ impl ProxyState { limiter: Arc::new(Limiter::new()), metrics: Arc::new(Metrics::new(false)), cache: Some(Arc::new(MemoryCache::with_defaults())), + pgvector_cache: None, routing: Arc::new(RoutingRegistry::new()), guardrails: Arc::new(GuardrailChain::empty()), budgets: Arc::new(BudgetClient::disabled()), @@ -94,6 +102,7 @@ impl ProxyState { limiter, metrics: Arc::new(Metrics::new(false)), cache: Some(Arc::new(MemoryCache::with_defaults())), + pgvector_cache: None, routing: Arc::new(RoutingRegistry::new()), guardrails: Arc::new(GuardrailChain::empty()), budgets: Arc::new(BudgetClient::disabled()), @@ -122,6 +131,7 @@ impl ProxyState { limiter, metrics, cache, + pgvector_cache: None, routing: Arc::new(RoutingRegistry::new()), guardrails: Arc::new(GuardrailChain::empty()), budgets: Arc::new(BudgetClient::disabled()), @@ -168,4 +178,14 @@ impl ProxyState { self.budgets = client; self } + + /// Attach a pgvector semantic cache. The chat handler routes + /// requests with a matched policy of `backend = pgvector` + /// through this client. None disables the pgvector path — + /// matching policies surface as `Disabled` in the cache_status + /// telemetry so operators see the gate is closed. + pub fn with_pgvector_cache(mut self, client: Arc) -> Self { + self.pgvector_cache = Some(client); + self + } } From c000d6f7f7f3eef1636afff2407330ff79adc752 Mon Sep 17 00:00:00 2001 From: Ming Wen Date: Tue, 5 May 2026 17:19:34 +0800 Subject: [PATCH 2/2] feat(server): wire PgvectorCache into proxy bootstrap MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Build a `PgvectorCache` against the dpmgr origin reusing the heartbeat mTLS bundle (same pattern as `BudgetClient`), and attach it to the `ProxyState`. In self-hosted dev (no heartbeat_cfg) and on bundle-build failure the proxy falls back to surfacing matched pgvector policies as `cache_status = Disabled` — no traffic impact. --- crates/aisix-server/src/main.rs | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/crates/aisix-server/src/main.rs b/crates/aisix-server/src/main.rs index d751e99f..d80dea8a 100644 --- a/crates/aisix-server/src/main.rs +++ b/crates/aisix-server/src/main.rs @@ -21,7 +21,7 @@ mod register; mod telemetry; use aisix_admin::{AdminState, ConfigStore, EtcdConfigStore}; -use aisix_cache::{Cache, MemoryCache, RedisCache}; +use aisix_cache::{Cache, MemoryCache, PgvectorCache, RedisCache}; use aisix_core::models::Provider; use aisix_core::{CacheBackend, Config, EtcdConfig, EtcdTlsConfig}; use aisix_etcd::{EtcdConfigProvider, SnapshotCache, Supervisor}; @@ -325,6 +325,27 @@ async fn run(mut cfg: Config) -> anyhow::Result<()> { } } }); + // Pgvector semantic cache client. Reuses the heartbeat mTLS bundle + // and dpmgr origin; piggybacks on the same /dp routes mounted by + // dp-manager (see AISIX-Cloud `internal/dpmgr/api/cache.go`). When + // the bundle build fails or we're outside managed mode the proxy + // falls back to surfacing matched pgvector policies as + // cache_status=Disabled — no traffic impact, just no semantic-cache + // benefit. See `aisix_cache::pgvector` for the wire client. + let pgvector_cache = heartbeat_cfg.as_ref().and_then(|h| { + let dpmgr_base = h + .url + .strip_suffix("/dp/heartbeat") + .unwrap_or(h.url.as_str()) + .to_string(); + match heartbeat::build_mtls_client(&h.mtls) { + Ok(http) => Some(Arc::new(PgvectorCache::new(http, dpmgr_base))), + Err(e) => { + tracing::warn!(error = %e, "pgvector cache disabled: mTLS client build failed"); + None + } + } + }); let heartbeat_task = heartbeat_cfg.map(|h| heartbeat::spawn(h, cancel_rx.clone())); let (usage_sink, telemetry_task) = match telemetry_cfg { Some(cfg) => { @@ -390,6 +411,9 @@ async fn run(mut cfg: Config) -> anyhow::Result<()> { if let Some(client) = budget_client { proxy_state = proxy_state.with_budget_client(client); } + if let Some(client) = pgvector_cache { + proxy_state = proxy_state.with_pgvector_cache(client); + } // Live guardrail chain: rebuilds itself whenever the etcd watch // supervisor stores a fresh snapshot, so dashboard mutations // (`/guardrails` create / enable / delete) take effect within