From 49c3270b537a44f02df7f612283199511a7e14c0 Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Thu, 28 May 2026 15:23:49 +0200 Subject: [PATCH] Limit the number of redis connections an app can have across instances Signed-off-by: Ryan Levick --- Cargo.lock | 1 + crates/factor-outbound-redis/Cargo.toml | 3 +- crates/factor-outbound-redis/src/host.rs | 58 ++++++++++++++----- crates/factor-outbound-redis/src/lib.rs | 22 +++++-- .../src/runtime_config.rs | 8 +++ .../src/runtime_config/spin.rs | 29 ++++++++++ crates/runtime-config/src/lib.rs | 6 +- 7 files changed, 106 insertions(+), 21 deletions(-) create mode 100644 crates/factor-outbound-redis/src/runtime_config.rs create mode 100644 crates/factor-outbound-redis/src/runtime_config/spin.rs diff --git a/Cargo.lock b/Cargo.lock index 0ae6c9554a..c6cb483d19 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9113,6 +9113,7 @@ version = "4.1.0-pre0" dependencies = [ "anyhow", "redis", + "serde", "spin-core", "spin-factor-otel", "spin-factor-outbound-networking", diff --git a/crates/factor-outbound-redis/Cargo.toml b/crates/factor-outbound-redis/Cargo.toml index 6518459d79..55e641ce6e 100644 --- a/crates/factor-outbound-redis/Cargo.toml +++ b/crates/factor-outbound-redis/Cargo.toml @@ -7,13 +7,14 @@ edition = { workspace = true } [dependencies] anyhow = { workspace = true } redis = { workspace = true , features = ["tokio-comp", "tokio-native-tls-comp", "aio"] } +serde = { workspace = true } spin-core = { path = "../core" } spin-factor-otel = { path = "../factor-otel" } spin-factor-outbound-networking = { path = "../factor-outbound-networking" } spin-factors = { path = "../factors" } spin-resource-table = { path = "../table" } spin-world = { path = "../world" } -tokio = { workspace = true } +tokio = { workspace = true, features = ["sync"] } tracing = { workspace = true } [dev-dependencies] diff --git a/crates/factor-outbound-redis/src/host.rs b/crates/factor-outbound-redis/src/host.rs index 61fd05708b..202584a8a1 100644 --- a/crates/factor-outbound-redis/src/host.rs +++ b/crates/factor-outbound-redis/src/host.rs @@ -1,4 +1,5 @@ use std::net::SocketAddr; +use std::sync::Arc; use anyhow::Result; use redis::AsyncConnectionConfig; @@ -11,6 +12,7 @@ use spin_world::MAX_HOST_BUFFERED_BYTES; use spin_world::spin::redis::redis as v3; use spin_world::v1::{redis as v1, redis_types}; use spin_world::v2::redis as v2; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tracing::field::Empty; use tracing::{Level, instrument}; @@ -19,7 +21,9 @@ use crate::allowed_hosts::AllowedHostChecker; pub struct InstanceState { pub(crate) allowed_host_checker: AllowedHostChecker, pub blocked_networks: BlockedNetworks, - pub connections: spin_resource_table::Table, + pub connections: + spin_resource_table::Table<(MultiplexedConnection, Option)>, + pub connection_semaphore: Option>, pub otel: OtelFactorState, } @@ -32,6 +36,15 @@ impl InstanceState { &mut self, address: String, ) -> Result, v2::Error> { + let permit = match &self.connection_semaphore { + Some(sem) => Some( + Arc::clone(sem) + .acquire_owned() + .await + .map_err(|_| v2::Error::TooManyConnections)?, + ), + None => None, + }; let config = AsyncConnectionConfig::new() .set_dns_resolver(SpinDnsResolver(self.blocked_networks.clone())); let conn = redis::Client::open(address.as_str()) @@ -40,7 +53,7 @@ impl InstanceState { .await .map_err(other_error_v2)?; self.connections - .push(conn) + .push((conn, permit)) .map(Resource::new_own) .map_err(|_| v2::Error::TooManyConnections) } @@ -51,6 +64,7 @@ impl InstanceState { ) -> Result<&mut MultiplexedConnection, v2::Error> { self.connections .get_mut(connection.rep()) + .map(|(conn, _permit)| conn) .ok_or(v2::Error::Other( "could not find connection for resource".into(), )) @@ -62,7 +76,7 @@ impl InstanceState { ) -> Result { self.connections .get(connection.rep()) - .cloned() + .map(|(conn, _permit)| conn.clone()) .ok_or(v3::Error::Other( "could not find connection for resource".into(), )) @@ -229,14 +243,16 @@ impl v3::HostConnectionWithStore for crate::RedisFactorData { accessor: &Accessor, address: String, ) -> Result, v3::Error> { - let (allowed_host_checker, blocked_networks) = accessor.with(|mut access| { - let host = access.get(); - host.otel.reparent_tracing_span(); - ( - host.allowed_host_checker.clone(), - host.blocked_networks.clone(), - ) - }); + let (allowed_host_checker, blocked_networks, connection_semaphore) = + accessor.with(|mut access| { + let host = access.get(); + host.otel.reparent_tracing_span(); + ( + host.allowed_host_checker.clone(), + host.blocked_networks.clone(), + host.connection_semaphore.clone(), + ) + }); if !allowed_host_checker .is_address_allowed(&address) @@ -246,6 +262,15 @@ impl v3::HostConnectionWithStore for crate::RedisFactorData { return Err(v3::Error::InvalidAddress); } + let permit = match connection_semaphore { + Some(sem) => Some( + sem.acquire_owned() + .await + .map_err(|_| v3::Error::TooManyConnections)?, + ), + None => None, + }; + let config = AsyncConnectionConfig::new().set_dns_resolver(SpinDnsResolver(blocked_networks)); let conn = redis::Client::open(address.as_str()) @@ -257,7 +282,7 @@ impl v3::HostConnectionWithStore for crate::RedisFactorData { accessor.with(|mut access| { let host = access.get(); host.connections - .push(conn) + .push((conn, permit)) .map(Resource::new_own) .map_err(|_| v3::Error::TooManyConnections) }) @@ -532,9 +557,14 @@ macro_rules! delegate { Ok(c) => c, Err(_) => return Err(v1::Error::Error), }; - ::$name($self, connection, $($arg),*) + // v1 has no persistent connections, so remove the table entry immediately + // after the call to release the semaphore permit. + let rep = connection.rep(); + let result = ::$name($self, connection, $($arg),*) .await - .map_err(|_| v1::Error::Error) + .map_err(|_| v1::Error::Error); + $self.connections.remove(rep); + result }}; } diff --git a/crates/factor-outbound-redis/src/lib.rs b/crates/factor-outbound-redis/src/lib.rs index 494c5ca800..4f9186e5e2 100644 --- a/crates/factor-outbound-redis/src/lib.rs +++ b/crates/factor-outbound-redis/src/lib.rs @@ -1,7 +1,11 @@ mod allowed_hosts; mod host; +pub mod runtime_config; + +use std::sync::Arc; use host::InstanceState; +use runtime_config::RuntimeConfig; use spin_factor_otel::OtelFactorState; use spin_factor_outbound_networking::OutboundNetworkingFactor; use spin_factors::{ @@ -9,6 +13,7 @@ use spin_factors::{ anyhow, }; use spin_world::spin::redis::redis as v3; +use tokio::sync::Semaphore; use crate::allowed_hosts::AllowedHostChecker; @@ -24,9 +29,14 @@ impl OutboundRedisFactor { } } +pub struct AppState { + /// A semaphore to limit the number of concurrent outbound Redis connections. + pub connection_semaphore: Option>, +} + impl Factor for OutboundRedisFactor { - type RuntimeConfig = (); - type AppState = (); + type RuntimeConfig = RuntimeConfig; + type AppState = AppState; type InstanceBuilder = InstanceState; fn init(&mut self, ctx: &mut impl spin_factors::InitContext) -> anyhow::Result<()> { @@ -38,9 +48,12 @@ impl Factor for OutboundRedisFactor { fn configure_app( &self, - _ctx: ConfigureAppContext, + mut ctx: ConfigureAppContext, ) -> anyhow::Result { - Ok(()) + let config = ctx.take_runtime_config().unwrap_or_default(); + Ok(AppState { + connection_semaphore: config.max_connections.map(|n| Arc::new(Semaphore::new(n))), + }) } fn prepare( @@ -54,6 +67,7 @@ impl Factor for OutboundRedisFactor { allowed_host_checker: AllowedHostChecker::new(outbound_networking.allowed_hosts()), blocked_networks: outbound_networking.blocked_networks(), connections: spin_resource_table::Table::new(1024), + connection_semaphore: ctx.app_state().connection_semaphore.clone(), otel, }) } diff --git a/crates/factor-outbound-redis/src/runtime_config.rs b/crates/factor-outbound-redis/src/runtime_config.rs new file mode 100644 index 0000000000..38d2d7ea7d --- /dev/null +++ b/crates/factor-outbound-redis/src/runtime_config.rs @@ -0,0 +1,8 @@ +pub mod spin; + +/// Runtime configuration for outbound Redis. +#[derive(Default)] +pub struct RuntimeConfig { + /// If set, limits the number of concurrent outbound Redis connections. + pub max_connections: Option, +} diff --git a/crates/factor-outbound-redis/src/runtime_config/spin.rs b/crates/factor-outbound-redis/src/runtime_config/spin.rs new file mode 100644 index 0000000000..82c0efeaff --- /dev/null +++ b/crates/factor-outbound-redis/src/runtime_config/spin.rs @@ -0,0 +1,29 @@ +use serde::Deserialize; +use spin_factors::runtime_config::toml::GetTomlValue; + +/// Get the runtime configuration for outbound Redis from a TOML table. +/// +/// Expects table to be in the format: +/// ```toml +/// [outbound_redis] +/// max_connections = 10 # optional, defaults to unlimited +/// ``` +pub fn config_from_table( + table: &impl GetTomlValue, +) -> anyhow::Result> { + if let Some(outbound_redis) = table.get("outbound_redis") { + let toml = outbound_redis.clone().try_into::()?; + Ok(Some(super::RuntimeConfig { + max_connections: toml.max_connections, + })) + } else { + Ok(None) + } +} + +#[derive(Debug, Default, Deserialize)] +#[serde(deny_unknown_fields)] +struct OutboundRedisToml { + #[serde(default)] + max_connections: Option, +} diff --git a/crates/runtime-config/src/lib.rs b/crates/runtime-config/src/lib.rs index f3169beb6a..6ca9c9954f 100644 --- a/crates/runtime-config/src/lib.rs +++ b/crates/runtime-config/src/lib.rs @@ -368,8 +368,10 @@ impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_, '_> { } impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_, '_> { - fn get_runtime_config(&mut self) -> anyhow::Result> { - Ok(None) + fn get_runtime_config( + &mut self, + ) -> anyhow::Result::RuntimeConfig>> { + spin_factor_outbound_redis::runtime_config::spin::config_from_table(&self.toml.table) } }