From 6d0ff2fb51f121accdce8b89a6f28f906b17ad3f Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Thu, 21 May 2026 20:08:40 +0200 Subject: [PATCH 01/15] Cap max number of open sockets Signed-off-by: Ryan Levick --- Cargo.lock | 1 + crates/factor-outbound-networking/Cargo.toml | 2 + crates/factor-outbound-networking/src/lib.rs | 21 +- .../src/runtime_config.rs | 3 + .../src/runtime_config/spin.rs | 45 +- .../tests/factor_test.rs | 232 ++++++++++- crates/factor-wasi/Cargo.toml | 2 +- crates/factor-wasi/src/lib.rs | 383 +++++++++++++++++- 8 files changed, 650 insertions(+), 39 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3348380929..71f0e3b4e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9069,6 +9069,7 @@ dependencies = [ "toml 0.8.19", "tracing", "url", + "wasmtime", "wasmtime-wasi", "webpki-root-certs", ] diff --git a/crates/factor-outbound-networking/Cargo.toml b/crates/factor-outbound-networking/Cargo.toml index f340c27093..42e362f3cc 100644 --- a/crates/factor-outbound-networking/Cargo.toml +++ b/crates/factor-outbound-networking/Cargo.toml @@ -20,6 +20,7 @@ spin-locked-app = { path = "../locked-app" } spin-manifest = { path = "../manifest" } spin-outbound-networking-config = { path = "../outbound-networking-config" } spin-serde = { path = "../serde" } +tokio = { workspace = true, features = ["sync"] } tracing = { workspace = true } opentelemetry-semantic-conventions = { workspace = true } url = { workspace = true } @@ -30,6 +31,7 @@ spin-factors-test = { path = "../factors-test" } tempfile = { workspace = true } tokio = { workspace = true, features = ["macros", "rt"] } toml = { workspace = true } +wasmtime = { workspace = true } wasmtime-wasi = { workspace = true } [features] diff --git a/crates/factor-outbound-networking/src/lib.rs b/crates/factor-outbound-networking/src/lib.rs index 5b20c46be3..590818a802 100644 --- a/crates/factor-outbound-networking/src/lib.rs +++ b/crates/factor-outbound-networking/src/lib.rs @@ -7,12 +7,13 @@ use std::{collections::HashMap, sync::Arc}; use futures_util::FutureExt as _; use opentelemetry_semantic_conventions::attribute::SERVER_PORT; use spin_factor_variables::VariablesFactor; -use spin_factor_wasi::{SocketAddrUse, WasiFactor}; +use spin_factor_wasi::{SocketAddrUse, SocketPermitState, WasiFactor}; use spin_factors::{ ConfigureAppContext, Error, Factor, FactorInstanceBuilder, PrepareContext, RuntimeFactors, anyhow::{self, Context}, }; use spin_outbound_networking_config::allowed_hosts::{DisallowedHostHandler, OutboundAllowedHosts}; +use tokio::sync::Semaphore; use url::Url; use crate::{ @@ -69,15 +70,18 @@ impl Factor for OutboundNetworkingFactor { client_tls_configs, blocked_ip_networks: block_networks, block_private_networks, + max_sockets_per_app, } = ctx.take_runtime_config().unwrap_or_default(); let blocked_networks = BlockedNetworks::new(block_networks, block_private_networks); let tls_client_configs = TlsClientConfigs::new(client_tls_configs)?; + let socket_quota = max_sockets_per_app.map(|n| Arc::new(Semaphore::new(n))); Ok(AppState { component_allowed_hosts, blocked_networks, tls_client_configs, + socket_quota, }) } @@ -123,10 +127,18 @@ impl Factor for OutboundNetworkingFactor { self.disallowed_host_handler.clone(), ); let blocked_networks = ctx.app_state().blocked_networks.clone(); + let permit_state = ctx + .app_state() + .socket_quota + .as_ref() + .map(|sem| SocketPermitState::new(Arc::clone(sem))); match ctx.instance_builder::() { Ok(wasi_builder) => { - // Update Wasi socket allowed ports + if let Some(state) = permit_state { + wasi_builder.set_socket_permit_state(state); + } + let allowed_hosts = allowed_hosts.clone(); wasi_builder.outbound_socket_addr_check(move |addr, addr_use| { let allowed_hosts = allowed_hosts.clone(); @@ -185,6 +197,10 @@ pub struct AppState { blocked_networks: BlockedNetworks, /// TLS client configs tls_client_configs: TlsClientConfigs, + /// App-wide semaphore capping total concurrent outbound socket connections + /// + /// `None` means unlimited. + socket_quota: Option>, } pub struct InstanceBuilder { @@ -193,6 +209,7 @@ pub struct InstanceBuilder { component_tls_client_configs: ComponentTlsClientConfigs, } + impl InstanceBuilder { pub fn allowed_hosts(&self) -> OutboundAllowedHosts { self.allowed_hosts.clone() diff --git a/crates/factor-outbound-networking/src/runtime_config.rs b/crates/factor-outbound-networking/src/runtime_config.rs index 887742febb..1520d7ba4a 100644 --- a/crates/factor-outbound-networking/src/runtime_config.rs +++ b/crates/factor-outbound-networking/src/runtime_config.rs @@ -12,6 +12,9 @@ pub struct RuntimeConfig { pub block_private_networks: bool, /// TLS client configs pub client_tls_configs: Vec, + /// Maximum number of outbound socket connections across all instances of this app. + /// `None` means unlimited (default). + pub max_sockets_per_app: Option, } /// TLS configuration for one or more component(s) and host(s). diff --git a/crates/factor-outbound-networking/src/runtime_config/spin.rs b/crates/factor-outbound-networking/src/runtime_config/spin.rs index f41c4a0d75..2e8824ad43 100644 --- a/crates/factor-outbound-networking/src/runtime_config/spin.rs +++ b/crates/factor-outbound-networking/src/runtime_config/spin.rs @@ -46,52 +46,48 @@ impl SpinRuntimeConfig { &self, table: &impl GetTomlValue, ) -> anyhow::Result> { - let maybe_blocked_networks = self - .blocked_networks_from_table(table) + let maybe_outbound_networking = self + .outbound_networking_from_table(table) .context("failed to parse [outbound_networking] table")?; let maybe_tls_configs = self .tls_configs_from_table(table) .context("failed to parse [[client_tls]] table")?; - if maybe_blocked_networks.is_none() && maybe_tls_configs.is_none() { + if maybe_outbound_networking.is_none() && maybe_tls_configs.is_none() { return Ok(None); } - let (blocked_ip_networks, block_private_networks) = - maybe_blocked_networks.unwrap_or_default(); - - let client_tls_configs = maybe_tls_configs.unwrap_or_default(); + let outbound_networking = maybe_outbound_networking.unwrap_or_default(); + let mut blocked_ip_networks = vec![]; + let mut block_private_networks = false; + for block_network in outbound_networking.block_networks { + match block_network { + CidrOrPrivate::Cidr(ip_network) => blocked_ip_networks.push(ip_network), + CidrOrPrivate::Private => { + block_private_networks = true; + } + } + } let runtime_config = super::RuntimeConfig { blocked_ip_networks, block_private_networks, - client_tls_configs, + client_tls_configs: maybe_tls_configs.unwrap_or_default(), + max_sockets_per_app: outbound_networking.max_sockets, }; Ok(Some(runtime_config)) } - /// Attempts to parse (blocked_ip_networks, block_private_networks) from a - /// `[outbound_networking]` table. - fn blocked_networks_from_table( + /// Attempts to parse the `[outbound_networking]` table. + fn outbound_networking_from_table( &self, table: &impl GetTomlValue, - ) -> anyhow::Result, bool)>> { + ) -> anyhow::Result> { let Some(value) = table.get("outbound_networking") else { return Ok(None); }; let outbound_networking: OutboundNetworkingToml = value.clone().try_into()?; - - let mut ip_networks = vec![]; - let mut private_networks = false; - for block_network in outbound_networking.block_networks { - match block_network { - CidrOrPrivate::Cidr(ip_network) => ip_networks.push(ip_network), - CidrOrPrivate::Private => { - private_networks = true; - } - } - } - Ok(Some((ip_networks, private_networks))) + Ok(Some(outbound_networking)) } fn tls_configs_from_table( @@ -225,6 +221,7 @@ fn deserialize_hosts<'de, D: Deserializer<'de>>(deserializer: D) -> Result, + max_sockets: Option, } #[derive(Debug)] diff --git a/crates/factor-outbound-networking/tests/factor_test.rs b/crates/factor-outbound-networking/tests/factor_test.rs index ce7f0bd479..15273f922d 100644 --- a/crates/factor-outbound-networking/tests/factor_test.rs +++ b/crates/factor-outbound-networking/tests/factor_test.rs @@ -1,10 +1,15 @@ use spin_factor_outbound_networking::OutboundNetworkingFactor; +use spin_factor_outbound_networking::runtime_config::RuntimeConfig; use spin_factor_outbound_networking::runtime_config::spin::SpinRuntimeConfig; use spin_factor_variables::VariablesFactor; use spin_factor_wasi::{DummyFilesMounter, WasiFactor}; -use spin_factors::{RuntimeFactors, anyhow}; +use spin_factors::anyhow::Context as _; +use spin_factors::{App, RuntimeFactors, anyhow}; use spin_factors_test::{TestEnvironment, toml}; use wasmtime_wasi::p2::bindings::sockets::instance_network::Host; +use wasmtime_wasi::p2::bindings::sockets::network::{ErrorCode, IpAddressFamily}; +use wasmtime_wasi::p2::bindings::sockets::tcp as p2_tcp; +use wasmtime_wasi::p2::bindings::sockets::tcp_create_socket as p2_tcp_create; use wasmtime_wasi::sockets::SocketAddrUse; #[derive(RuntimeFactors)] @@ -81,3 +86,228 @@ async fn wasi_factor_is_optional() -> anyhow::Result<()> { .await?; Ok(()) } + +#[tokio::test] +async fn socket_quota_blocks_excess_connections() -> anyhow::Result<()> { + let factors = TestFactors { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor::new(), + }; + let env = TestEnvironment::new(factors) + .extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["*://123.0.2.1:12345"] + }) + .runtime_config(TestFactorsRuntimeConfig { + networking: Some(RuntimeConfig { + max_sockets_per_app: Some(2), + ..Default::default() + }), + ..Default::default() + })?; + + let mut state = env.build_instance_state().await?; + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + let addr: std::net::SocketAddr = "123.0.2.1:12345".parse().unwrap(); + + // First two connections should be accepted (non-blocking connect initiated) + let net1 = sockets.instance_network()?; + let sock1 = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock1, net1, addr.into()).await?; + + let net2 = sockets.instance_network()?; + let sock2 = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock2, net2, addr.into()).await?; + + // Third should fail — quota exhausted + let net3 = sockets.instance_network()?; + let sock3 = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + let err = p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock3, net3, addr.into()) + .await + .unwrap_err(); + assert_eq!(err.downcast_ref(), Some(&ErrorCode::ConnectionRefused)); + Ok(()) +} + +#[tokio::test] +async fn socket_quota_releases_on_instance_drop() -> anyhow::Result<()> { + let factors = TestFactors { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor::new(), + }; + let env = TestEnvironment::new(factors) + .extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["*://123.0.2.1:12345"] + }) + .runtime_config(TestFactorsRuntimeConfig { + networking: Some(RuntimeConfig { + max_sockets_per_app: Some(1), + ..Default::default() + }), + ..Default::default() + })?; + + let locked_app = env.build_locked_app().await?; + let TestEnvironment { + factors, + runtime_config, + .. + } = env; + let app = App::new("test-app", locked_app); + let configured_app = factors.configure_app(app, runtime_config)?; + let component_id = configured_app + .app() + .components() + .last() + .context("no components")? + .id() + .to_string(); + + let addr: std::net::SocketAddr = "123.0.2.1:12345".parse().unwrap(); + + // First instance: fill the quota (1 socket) + { + let builders = factors.prepare(&configured_app, &component_id)?; + let mut state = factors.build_instance_state(builders)?; + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + let net = sockets.instance_network()?; + let sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock, net, addr.into()).await?; + // sockets then state drop here, releasing the permit back to the semaphore + } + + // Second instance: quota should be fully available again + let builders = factors.prepare(&configured_app, &component_id)?; + let mut state = factors.build_instance_state(builders)?; + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + let net = sockets.instance_network()?; + let sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock, net, addr.into()).await?; + Ok(()) +} + +#[tokio::test] +async fn no_socket_quota_allows_unlimited() -> anyhow::Result<()> { + let factors = TestFactors { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor::new(), + }; + let env = TestEnvironment::new(factors).extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["*://123.0.2.1:12345"] + }); + + let mut state = env.build_instance_state().await?; + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + let addr: std::net::SocketAddr = "123.0.2.1:12345".parse().unwrap(); + + for _ in 0..10 { + let net = sockets.instance_network()?; + let sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock, net, addr.into()).await?; + } + Ok(()) +} + +#[tokio::test] +async fn socket_quota_still_enforces_allowed_hosts() -> anyhow::Result<()> { + let factors = TestFactors { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor::new(), + }; + let env = TestEnvironment::new(factors) + .extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["*://123.0.2.1:12345"] + }) + .runtime_config(TestFactorsRuntimeConfig { + networking: Some(RuntimeConfig { + max_sockets_per_app: Some(10), + ..Default::default() + }), + ..Default::default() + })?; + + let mut state = env.build_instance_state().await?; + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + + // Allowed host succeeds + let net = sockets.instance_network()?; + let sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + let allowed_addr: std::net::SocketAddr = "123.0.2.1:12345".parse().unwrap(); + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock, net, allowed_addr.into()).await?; + + // Disallowed host is rejected even with quota available + let net = sockets.instance_network()?; + let sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + let disallowed_addr: std::net::SocketAddr = "1.2.3.4:80".parse().unwrap(); + assert!( + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock, net, disallowed_addr.into()) + .await + .is_err() + ); + Ok(()) +} + +#[tokio::test] +async fn socket_quota_releases_on_socket_drop() -> anyhow::Result<()> { + let factors = TestFactors { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor::new(), + }; + let env = TestEnvironment::new(factors) + .extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["*://123.0.2.1:12345"] + }) + .runtime_config(TestFactorsRuntimeConfig { + networking: Some(RuntimeConfig { + max_sockets_per_app: Some(1), + ..Default::default() + }), + ..Default::default() + })?; + + let mut state = env.build_instance_state().await?; + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + let addr: std::net::SocketAddr = "123.0.2.1:12345".parse().unwrap(); + + // Acquire the only permit via start_connect. Save the rep so we can reconstruct + // a handle afterwards — start_connect consumes the Resource but leaves the socket + // alive in the ResourceTable. + let net1 = sockets.instance_network()?; + let sock1 = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + let sock1_rep = sock1.rep(); + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock1, net1, addr.into()).await?; + + // A second start_connect should fail while the permit is held. + let net2 = sockets.instance_network()?; + let sock2 = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + let err = p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock2, net2, addr.into()) + .await + .unwrap_err(); + assert_eq!(err.downcast_ref(), Some(&ErrorCode::ConnectionRefused)); + + // Explicitly drop sock1 before finish_connect — this should release the permit. + let sock1_handle = + wasmtime::component::Resource::::new_own(sock1_rep); + p2_tcp::HostTcpSocket::drop(&mut sockets, sock1_handle)?; + + // After the drop the quota is free again, so a new start_connect must succeed. + let net3 = sockets.instance_network()?; + let sock3 = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock3, net3, addr.into()).await?; + + Ok(()) +} diff --git a/crates/factor-wasi/Cargo.toml b/crates/factor-wasi/Cargo.toml index 3647670beb..277a8b8b0a 100644 --- a/crates/factor-wasi/Cargo.toml +++ b/crates/factor-wasi/Cargo.toml @@ -9,7 +9,7 @@ async-trait = { workspace = true } bytes = { workspace = true } spin-common = { path = "../common" } spin-factors = { path = "../factors" } -tokio = { workspace = true } +tokio = { workspace = true, features = ["sync"] } wasmtime = { workspace = true } wasmtime-wasi = { workspace = true } diff --git a/crates/factor-wasi/src/lib.rs b/crates/factor-wasi/src/lib.rs index f86332764b..b3602b7275 100644 --- a/crates/factor-wasi/src/lib.rs +++ b/crates/factor-wasi/src/lib.rs @@ -4,10 +4,12 @@ mod wasi_2023_10_18; mod wasi_2023_11_10; use std::{ + collections::HashMap, future::Future, io::{Read, Write}, net::SocketAddr, path::Path, + sync::{Arc, Mutex}, }; use io::{PipeReadStream, PipedWriteStream}; @@ -15,16 +17,335 @@ use spin_factors::{ AppComponent, Factor, FactorInstanceBuilder, InitContext, PrepareContext, RuntimeFactors, RuntimeFactorsInstanceState, anyhow, }; -use wasmtime::component::HasData; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use wasmtime::component::{HasData, Resource}; use wasmtime_wasi::cli::{StdinStream, StdoutStream, WasiCli, WasiCliCtxView}; use wasmtime_wasi::clocks::{WasiClocks, WasiClocksCtxView}; use wasmtime_wasi::filesystem::{WasiFilesystem, WasiFilesystemCtxView}; +use wasmtime_wasi::p2::bindings::sockets::network::{ + ErrorCode as SocketErrorCode, Host as NetworkHost, Network, +}; +use wasmtime_wasi::p2::bindings::sockets::tcp::{self as p2_tcp, IpSocketAddress, ShutdownType}; +use wasmtime_wasi::p2::bindings::sockets::tcp_create_socket as p2_tcp_create; +use wasmtime_wasi::p2::{DynInputStream, DynOutputStream, DynPollable}; use wasmtime_wasi::random::{WasiRandom, WasiRandomCtx}; -use wasmtime_wasi::sockets::{WasiSockets, WasiSocketsCtxView}; +use wasmtime_wasi::sockets::{TcpSocket, WasiSockets, WasiSocketsCtxView}; use wasmtime_wasi::{DirPerms, FilePerms, ResourceTable, WasiCtx, WasiCtxBuilder, WasiCtxView}; pub use wasmtime_wasi::sockets::SocketAddrUse; +/// Shared state for tracking per-socket semaphore permits. Permits are +/// acquired in `start_connect` and released when the socket resource is dropped. +pub struct SocketPermitState { + semaphore: Arc, + /// Active permits keyed by socket resource rep (u32). Removed (and the + /// permit dropped/released) when the WASI socket resource is dropped. + active: Mutex>, +} + +impl SocketPermitState { + pub fn new(semaphore: Arc) -> Arc { + Arc::new(Self { + semaphore, + active: Mutex::new(HashMap::new()), + }) + } +} + +/// A view over WASI socket state that carries an optional per-instance socket +/// permit store, enabling per-connection quota tracking. +pub struct SpinSocketsView<'a> { + pub(crate) inner: WasiSocketsCtxView<'a>, + pub(crate) permit_state: Option>, +} + +impl<'a> std::ops::Deref for SpinSocketsView<'a> { + type Target = WasiSocketsCtxView<'a>; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl std::ops::DerefMut for SpinSocketsView<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +/// [`HasData`] accessor for [`SpinSocketsView`], used in place of [`WasiSockets`] +/// when registering TCP socket bindings so that `start_connect` and `drop` can +/// participate in socket quota tracking. +pub struct SpinSockets; + +impl HasData for SpinSockets { + type Data<'a> = SpinSocketsView<'a>; +} + +impl p2_tcp::Host for SpinSocketsView<'_> {} + +impl p2_tcp::HostTcpSocket for SpinSocketsView<'_> { + async fn start_bind( + &mut self, + this: Resource, + network: Resource, + local_address: IpSocketAddress, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::start_bind(&mut self.inner, this, network, local_address).await + } + + fn finish_bind(&mut self, this: Resource) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::finish_bind(&mut self.inner, this) + } + + async fn start_connect( + &mut self, + this: Resource, + network: Resource, + remote_address: IpSocketAddress, + ) -> wasmtime_wasi::p2::SocketResult<()> { + let socket_rep = this.rep(); + let permit = if let Some(state) = &self.permit_state { + let state = Arc::clone(state); + match state.semaphore.clone().try_acquire_owned() { + Ok(permit) => Some((state, permit)), + // wasi has no "quota exceeded" error code; ConnectionRefused is the closest available. + Err(_) => return Err(SocketErrorCode::ConnectionRefused.into()), + } + } else { + None + }; + let result = + p2_tcp::HostTcpSocket::start_connect(&mut self.inner, this, network, remote_address) + .await; + if let (Some((state, permit)), Ok(())) = (permit, &result) { + state + .active + .lock() + .unwrap_or_else(|e| e.into_inner()) + .insert(socket_rep, permit); + } + // On Err, any acquired permit is dropped here, returning it to the semaphore. + result + } + + fn finish_connect( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult<(Resource, Resource)> + { + p2_tcp::HostTcpSocket::finish_connect(&mut self.inner, this) + } + + fn start_listen(&mut self, this: Resource) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::start_listen(&mut self.inner, this) + } + + fn finish_listen(&mut self, this: Resource) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::finish_listen(&mut self.inner, this) + } + + fn accept( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult<( + Resource, + Resource, + Resource, + )> { + p2_tcp::HostTcpSocket::accept(&mut self.inner, this) + } + + fn local_address( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::local_address(&mut self.inner, this) + } + + fn remote_address( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::remote_address(&mut self.inner, this) + } + + fn is_listening(&mut self, this: Resource) -> wasmtime::Result { + p2_tcp::HostTcpSocket::is_listening(&mut self.inner, this) + } + + fn address_family( + &mut self, + this: Resource, + ) -> wasmtime::Result { + p2_tcp::HostTcpSocket::address_family(&mut self.inner, this) + } + + fn set_listen_backlog_size( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_listen_backlog_size(&mut self.inner, this, value) + } + + fn keep_alive_enabled( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::keep_alive_enabled(&mut self.inner, this) + } + + fn set_keep_alive_enabled( + &mut self, + this: Resource, + value: bool, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_keep_alive_enabled(&mut self.inner, this, value) + } + + fn keep_alive_idle_time( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::keep_alive_idle_time(&mut self.inner, this) + } + + fn set_keep_alive_idle_time( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_keep_alive_idle_time(&mut self.inner, this, value) + } + + fn keep_alive_interval( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::keep_alive_interval(&mut self.inner, this) + } + + fn set_keep_alive_interval( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_keep_alive_interval(&mut self.inner, this, value) + } + + fn keep_alive_count( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::keep_alive_count(&mut self.inner, this) + } + + fn set_keep_alive_count( + &mut self, + this: Resource, + value: u32, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_keep_alive_count(&mut self.inner, this, value) + } + + fn hop_limit(&mut self, this: Resource) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::hop_limit(&mut self.inner, this) + } + + fn set_hop_limit( + &mut self, + this: Resource, + value: u8, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_hop_limit(&mut self.inner, this, value) + } + + fn receive_buffer_size( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::receive_buffer_size(&mut self.inner, this) + } + + fn set_receive_buffer_size( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_receive_buffer_size(&mut self.inner, this, value) + } + + fn send_buffer_size( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::send_buffer_size(&mut self.inner, this) + } + + fn set_send_buffer_size( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_send_buffer_size(&mut self.inner, this, value) + } + + fn subscribe(&mut self, this: Resource) -> wasmtime::Result> { + p2_tcp::HostTcpSocket::subscribe(&mut self.inner, this) + } + + fn shutdown( + &mut self, + this: Resource, + shutdown_type: ShutdownType, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::shutdown(&mut self.inner, this, shutdown_type) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + // Release the permit before dropping the socket resource. + if let Some(state) = &self.permit_state { + state + .active + .lock() + .unwrap_or_else(|e| e.into_inner()) + .remove(&this.rep()); + } + p2_tcp::HostTcpSocket::drop(&mut self.inner, this) + } +} + +impl NetworkHost for SpinSocketsView<'_> { + fn convert_error_code( + &mut self, + error: wasmtime_wasi::p2::SocketError, + ) -> wasmtime::Result { + NetworkHost::convert_error_code(&mut self.inner, error) + } + + fn network_error_code( + &mut self, + err: Resource, + ) -> wasmtime::Result> { + NetworkHost::network_error_code(&mut self.inner, err) + } +} + +impl wasmtime_wasi::p2::bindings::sockets::network::HostNetwork for SpinSocketsView<'_> { + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + wasmtime_wasi::p2::bindings::sockets::network::HostNetwork::drop(&mut self.inner, this) + } +} + +impl p2_tcp_create::Host for SpinSocketsView<'_> { + fn create_tcp_socket( + &mut self, + address_family: wasmtime_wasi::p2::bindings::sockets::network::IpAddressFamily, + ) -> wasmtime_wasi::p2::SocketResult> { + p2_tcp_create::Host::create_tcp_socket(&mut self.inner, address_family) + } +} + pub struct WasiFactor { files_mounter: Box, } @@ -58,11 +379,14 @@ impl WasiFactor { pub fn get_sockets_impl( runtime_instance_state: &mut impl RuntimeFactorsInstanceState, - ) -> Option> { + ) -> Option> { let (state, table) = runtime_instance_state.get_with_table::()?; - Some(WasiSocketsCtxView { - ctx: state.ctx.sockets(), - table, + Some(SpinSocketsView { + inner: WasiSocketsCtxView { + ctx: state.ctx.sockets(), + table, + }, + permit_state: state.socket_permit_state.clone(), }) } } @@ -176,6 +500,27 @@ trait InitContextExt: InitContext { add_to_linker(self.linker(), &O::default(), Self::get_sockets) } + fn get_spin_sockets(data: &mut Self::StoreData) -> SpinSocketsView<'_> { + let (state, table) = Self::get_data_with_table(data); + SpinSocketsView { + inner: WasiSocketsCtxView { + ctx: state.ctx.sockets(), + table, + }, + permit_state: state.socket_permit_state.clone(), + } + } + + fn link_tcp_bindings( + &mut self, + add_to_linker: fn( + &mut wasmtime::component::Linker, + fn(&mut Self::StoreData) -> SpinSocketsView<'_>, + ) -> wasmtime::Result<()>, + ) -> wasmtime::Result<()> { + add_to_linker(self.linker(), Self::get_spin_sockets) + } + fn link_io_bindings( &mut self, add_to_linker: fn( @@ -294,10 +639,11 @@ impl Factor for WasiFactor { ctx.link_cli_bindings(p3::bindings::cli::terminal_stdout::add_to_linker::<_, WasiCli>)?; ctx.link_cli_bindings(p2::bindings::cli::terminal_stderr::add_to_linker::<_, WasiCli>)?; ctx.link_cli_bindings(p3::bindings::cli::terminal_stderr::add_to_linker::<_, WasiCli>)?; - ctx.link_sockets_bindings(p2::bindings::sockets::tcp::add_to_linker::<_, WasiSockets>)?; - ctx.link_sockets_bindings( - p2::bindings::sockets::tcp_create_socket::add_to_linker::<_, WasiSockets>, + ctx.link_tcp_bindings(p2::bindings::sockets::tcp::add_to_linker::<_, SpinSockets>)?; + ctx.link_tcp_bindings( + p2::bindings::sockets::tcp_create_socket::add_to_linker::<_, SpinSockets>, )?; + // UDP sockets are not subject to the max_sockets_per_app quota — enforcement is TCP-only. ctx.link_sockets_bindings(p2::bindings::sockets::udp::add_to_linker::<_, WasiSockets>)?; ctx.link_sockets_bindings( p2::bindings::sockets::udp_create_socket::add_to_linker::<_, WasiSockets>, @@ -339,7 +685,10 @@ impl Factor for WasiFactor { self.files_mounter .mount_files(ctx.app_component(), mount_ctx)?; - let mut builder = InstanceBuilder { ctx: wasi_ctx }; + let mut builder = InstanceBuilder { + ctx: wasi_ctx, + socket_permit_state: None, + }; // Apply environment variables builder.env(ctx.app_component().environment()); @@ -396,6 +745,7 @@ impl MountFilesContext<'_> { pub struct InstanceBuilder { ctx: WasiCtxBuilder, + socket_permit_state: Option>, } impl InstanceBuilder { @@ -466,14 +816,24 @@ impl FactorInstanceBuilder for InstanceBuilder { type InstanceState = InstanceState; fn build(self) -> anyhow::Result { - let InstanceBuilder { ctx: mut wasi_ctx } = self; + let InstanceBuilder { + ctx: mut wasi_ctx, + socket_permit_state, + } = self; Ok(InstanceState { ctx: wasi_ctx.build(), + socket_permit_state, }) } } impl InstanceBuilder { + /// Sets the socket permit state for per-connection quota tracking. + /// Called by `OutboundNetworkingFactor` when `max_sockets_per_app` is configured. + pub fn set_socket_permit_state(&mut self, state: Arc) { + self.socket_permit_state = Some(state); + } + pub fn outbound_socket_addr_check(&mut self, check: F) where F: Fn(SocketAddr, SocketAddrUse) -> Fut + Send + Sync + Clone + 'static, @@ -496,4 +856,5 @@ impl InstanceBuilder { pub struct InstanceState { ctx: WasiCtx, + socket_permit_state: Option>, } From f6ec7feb0db1241ed0721633f37919f192621b15 Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Wed, 27 May 2026 12:55:30 +0200 Subject: [PATCH 02/15] Make sure UDP sockets are also capped Signed-off-by: Ryan Levick --- crates/factor-outbound-networking/src/lib.rs | 1 - .../tests/factor_test.rs | 86 +++++- crates/factor-wasi/src/lib.rs | 249 ++++++++++++++++-- 3 files changed, 306 insertions(+), 30 deletions(-) diff --git a/crates/factor-outbound-networking/src/lib.rs b/crates/factor-outbound-networking/src/lib.rs index 590818a802..04bc8d4c1c 100644 --- a/crates/factor-outbound-networking/src/lib.rs +++ b/crates/factor-outbound-networking/src/lib.rs @@ -209,7 +209,6 @@ pub struct InstanceBuilder { component_tls_client_configs: ComponentTlsClientConfigs, } - impl InstanceBuilder { pub fn allowed_hosts(&self) -> OutboundAllowedHosts { self.allowed_hosts.clone() diff --git a/crates/factor-outbound-networking/tests/factor_test.rs b/crates/factor-outbound-networking/tests/factor_test.rs index 15273f922d..476932e999 100644 --- a/crates/factor-outbound-networking/tests/factor_test.rs +++ b/crates/factor-outbound-networking/tests/factor_test.rs @@ -10,6 +10,7 @@ use wasmtime_wasi::p2::bindings::sockets::instance_network::Host; use wasmtime_wasi::p2::bindings::sockets::network::{ErrorCode, IpAddressFamily}; use wasmtime_wasi::p2::bindings::sockets::tcp as p2_tcp; use wasmtime_wasi::p2::bindings::sockets::tcp_create_socket as p2_tcp_create; +use wasmtime_wasi::p2::bindings::sockets::udp_create_socket as p2_udp_create; use wasmtime_wasi::sockets::SocketAddrUse; #[derive(RuntimeFactors)] @@ -178,7 +179,7 @@ async fn socket_quota_releases_on_instance_drop() -> anyhow::Result<()> { let net = sockets.instance_network()?; let sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock, net, addr.into()).await?; - // sockets then state drop here, releasing the permit back to the semaphore + // sockets state dropped here releasing the permit back to the semaphore } // Second instance: quota should be fully available again @@ -311,3 +312,86 @@ async fn socket_quota_releases_on_socket_drop() -> anyhow::Result<()> { Ok(()) } + +#[tokio::test] +async fn socket_quota_blocks_excess_udp_sockets() -> anyhow::Result<()> { + let factors = TestFactors { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor::new(), + }; + let env = TestEnvironment::new(factors) + .extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["*://123.0.2.1:12345"] + }) + .runtime_config(TestFactorsRuntimeConfig { + networking: Some(RuntimeConfig { + max_sockets_per_app: Some(2), + ..Default::default() + }), + ..Default::default() + })?; + + let mut state = env.build_instance_state().await?; + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + + // First two UDP socket creations should succeed. + p2_udp_create::Host::create_udp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_udp_create::Host::create_udp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + + // Third should fail — quota exhausted. + let err = + p2_udp_create::Host::create_udp_socket(&mut sockets, IpAddressFamily::Ipv4).unwrap_err(); + assert_eq!(err.downcast_ref(), Some(&ErrorCode::ConnectionRefused)); + Ok(()) +} + +#[tokio::test] +async fn socket_quota_shared_between_tcp_and_udp() -> anyhow::Result<()> { + let factors = TestFactors { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor::new(), + }; + let env = TestEnvironment::new(factors) + .extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["*://123.0.2.1:12345"] + }) + .runtime_config(TestFactorsRuntimeConfig { + networking: Some(RuntimeConfig { + max_sockets_per_app: Some(2), + ..Default::default() + }), + ..Default::default() + })?; + + let mut state = env.build_instance_state().await?; + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + let addr: std::net::SocketAddr = "123.0.2.1:12345".parse().unwrap(); + + // Consume one permit with a TCP connection. + let net = sockets.instance_network()?; + let tcp_sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_tcp::HostTcpSocket::start_connect(&mut sockets, tcp_sock, net, addr.into()).await?; + + // Consume the second permit with a UDP socket — quota now full. + p2_udp_create::Host::create_udp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + + // Any further allocation must fail — shared quota exhausted. + // UDP: + let err = + p2_udp_create::Host::create_udp_socket(&mut sockets, IpAddressFamily::Ipv4).unwrap_err(); + assert_eq!(err.downcast_ref(), Some(&ErrorCode::ConnectionRefused)); + // TCP: + let net = sockets.instance_network()?; + let tcp_sock2 = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + let err = p2_tcp::HostTcpSocket::start_connect(&mut sockets, tcp_sock2, net, addr.into()) + .await + .unwrap_err(); + assert_eq!(err.downcast_ref(), Some(&ErrorCode::ConnectionRefused)); + Ok(()) +} diff --git a/crates/factor-wasi/src/lib.rs b/crates/factor-wasi/src/lib.rs index b3602b7275..d8e3377c32 100644 --- a/crates/factor-wasi/src/lib.rs +++ b/crates/factor-wasi/src/lib.rs @@ -27,19 +27,23 @@ use wasmtime_wasi::p2::bindings::sockets::network::{ }; use wasmtime_wasi::p2::bindings::sockets::tcp::{self as p2_tcp, IpSocketAddress, ShutdownType}; use wasmtime_wasi::p2::bindings::sockets::tcp_create_socket as p2_tcp_create; +use wasmtime_wasi::p2::bindings::sockets::udp as p2_udp; +use wasmtime_wasi::p2::bindings::sockets::udp_create_socket as p2_udp_create; use wasmtime_wasi::p2::{DynInputStream, DynOutputStream, DynPollable}; use wasmtime_wasi::random::{WasiRandom, WasiRandomCtx}; -use wasmtime_wasi::sockets::{TcpSocket, WasiSockets, WasiSocketsCtxView}; +use wasmtime_wasi::sockets::{TcpSocket, UdpSocket, WasiSockets, WasiSocketsCtxView}; use wasmtime_wasi::{DirPerms, FilePerms, ResourceTable, WasiCtx, WasiCtxBuilder, WasiCtxView}; pub use wasmtime_wasi::sockets::SocketAddrUse; /// Shared state for tracking per-socket semaphore permits. Permits are -/// acquired in `start_connect` and released when the socket resource is dropped. +/// acquired when a socket is allocated (at `start_connect` for TCP, at +/// `create_udp_socket` for UDP) and released when the socket resource is dropped. pub struct SocketPermitState { semaphore: Arc, - /// Active permits keyed by socket resource rep (u32). Removed (and the - /// permit dropped/released) when the WASI socket resource is dropped. + /// Active permits keyed by socket resource rep. + /// + /// Permits are removed (and the permit released) when the WASI socket resource is dropped. active: Mutex>, } @@ -103,29 +107,26 @@ impl p2_tcp::HostTcpSocket for SpinSocketsView<'_> { network: Resource, remote_address: IpSocketAddress, ) -> wasmtime_wasi::p2::SocketResult<()> { - let socket_rep = this.rep(); - let permit = if let Some(state) = &self.permit_state { - let state = Arc::clone(state); - match state.semaphore.clone().try_acquire_owned() { - Ok(permit) => Some((state, permit)), - // wasi has no "quota exceeded" error code; ConnectionRefused is the closest available. - Err(_) => return Err(SocketErrorCode::ConnectionRefused.into()), - } - } else { - None - }; - let result = + if let Some(state) = &self.permit_state { + // If we have a permit state, we need to acquire a permit before allowing the connection to proceed. + let socket_rep = this.rep(); + let Ok(permit) = Arc::clone(&state.semaphore).try_acquire_owned() else { + // wasi has no "quota exceeded" error code. ConnectionRefused is the closest available. + return Err(SocketErrorCode::ConnectionRefused.into()); + }; p2_tcp::HostTcpSocket::start_connect(&mut self.inner, this, network, remote_address) - .await; - if let (Some((state, permit)), Ok(())) = (permit, &result) { + .await?; + // If the connection was successfully initiated, store the permit so it can be released when the socket is dropped. state .active .lock() .unwrap_or_else(|e| e.into_inner()) .insert(socket_rep, permit); + Ok(()) + } else { + p2_tcp::HostTcpSocket::start_connect(&mut self.inner, this, network, remote_address) + .await } - // On Err, any acquired permit is dropped here, returning it to the semaphore. - result } fn finish_connect( @@ -346,6 +347,196 @@ impl p2_tcp_create::Host for SpinSocketsView<'_> { } } +impl p2_udp::Host for SpinSocketsView<'_> {} + +impl p2_udp::HostUdpSocket for SpinSocketsView<'_> { + async fn start_bind( + &mut self, + this: Resource, + network: Resource, + local_address: p2_udp::IpSocketAddress, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_udp::HostUdpSocket::start_bind(&mut self.inner, this, network, local_address).await + } + + fn finish_bind( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_udp::HostUdpSocket::finish_bind(&mut self.inner, this) + } + + async fn stream( + &mut self, + this: Resource, + remote_address: Option, + ) -> wasmtime_wasi::p2::SocketResult<( + Resource, + Resource, + )> { + p2_udp::HostUdpSocket::stream(&mut self.inner, this, remote_address).await + } + + fn local_address( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostUdpSocket::local_address(&mut self.inner, this) + } + + fn remote_address( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostUdpSocket::remote_address(&mut self.inner, this) + } + + fn address_family( + &mut self, + this: Resource, + ) -> wasmtime::Result { + p2_udp::HostUdpSocket::address_family(&mut self.inner, this) + } + + fn unicast_hop_limit( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostUdpSocket::unicast_hop_limit(&mut self.inner, this) + } + + fn set_unicast_hop_limit( + &mut self, + this: Resource, + value: u8, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_udp::HostUdpSocket::set_unicast_hop_limit(&mut self.inner, this, value) + } + + fn receive_buffer_size( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostUdpSocket::receive_buffer_size(&mut self.inner, this) + } + + fn set_receive_buffer_size( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_udp::HostUdpSocket::set_receive_buffer_size(&mut self.inner, this, value) + } + + fn send_buffer_size( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostUdpSocket::send_buffer_size(&mut self.inner, this) + } + + fn set_send_buffer_size( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_udp::HostUdpSocket::set_send_buffer_size(&mut self.inner, this, value) + } + + fn subscribe( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + p2_udp::HostUdpSocket::subscribe(&mut self.inner, this) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + // Release the permit before dropping the socket resource. + if let Some(state) = &self.permit_state { + state + .active + .lock() + .unwrap_or_else(|e| e.into_inner()) + .remove(&this.rep()); + } + p2_udp::HostUdpSocket::drop(&mut self.inner, this) + } +} + +impl p2_udp::HostIncomingDatagramStream for SpinSocketsView<'_> { + fn receive( + &mut self, + this: Resource, + max_results: u64, + ) -> wasmtime_wasi::p2::SocketResult> { + p2_udp::HostIncomingDatagramStream::receive(&mut self.inner, this, max_results) + } + + fn subscribe( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + p2_udp::HostIncomingDatagramStream::subscribe(&mut self.inner, this) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + p2_udp::HostIncomingDatagramStream::drop(&mut self.inner, this) + } +} + +impl p2_udp::HostOutgoingDatagramStream for SpinSocketsView<'_> { + fn check_send( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostOutgoingDatagramStream::check_send(&mut self.inner, this) + } + + async fn send( + &mut self, + this: Resource, + datagrams: Vec, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostOutgoingDatagramStream::send(&mut self.inner, this, datagrams).await + } + + fn subscribe( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + p2_udp::HostOutgoingDatagramStream::subscribe(&mut self.inner, this) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + p2_udp::HostOutgoingDatagramStream::drop(&mut self.inner, this) + } +} + +impl p2_udp_create::Host for SpinSocketsView<'_> { + fn create_udp_socket( + &mut self, + address_family: wasmtime_wasi::p2::bindings::sockets::network::IpAddressFamily, + ) -> wasmtime_wasi::p2::SocketResult> { + if let Some(state) = &self.permit_state { + // If we have a permit state, we need to acquire a permit before allowing the socket creation to proceed. + let state = Arc::clone(state); + let permit = Arc::clone(&state.semaphore) + .try_acquire_owned() + .map_err(|_| SocketErrorCode::ConnectionRefused)?; + let sock = p2_udp_create::Host::create_udp_socket(&mut self.inner, address_family)?; + // If the socket was successfully created, store the permit so it can be released when the socket is dropped. + state + .active + .lock() + .unwrap_or_else(|e| e.into_inner()) + .insert(sock.rep(), permit); + Ok(sock) + } else { + p2_udp_create::Host::create_udp_socket(&mut self.inner, address_family) + } + } +} + pub struct WasiFactor { files_mounter: Box, } @@ -511,7 +702,7 @@ trait InitContextExt: InitContext { } } - fn link_tcp_bindings( + fn link_spin_sockets_bindings( &mut self, add_to_linker: fn( &mut wasmtime::component::Linker, @@ -639,14 +830,17 @@ impl Factor for WasiFactor { ctx.link_cli_bindings(p3::bindings::cli::terminal_stdout::add_to_linker::<_, WasiCli>)?; ctx.link_cli_bindings(p2::bindings::cli::terminal_stderr::add_to_linker::<_, WasiCli>)?; ctx.link_cli_bindings(p3::bindings::cli::terminal_stderr::add_to_linker::<_, WasiCli>)?; - ctx.link_tcp_bindings(p2::bindings::sockets::tcp::add_to_linker::<_, SpinSockets>)?; - ctx.link_tcp_bindings( + ctx.link_spin_sockets_bindings( + p2::bindings::sockets::tcp::add_to_linker::<_, SpinSockets>, + )?; + ctx.link_spin_sockets_bindings( p2::bindings::sockets::tcp_create_socket::add_to_linker::<_, SpinSockets>, )?; - // UDP sockets are not subject to the max_sockets_per_app quota — enforcement is TCP-only. - ctx.link_sockets_bindings(p2::bindings::sockets::udp::add_to_linker::<_, WasiSockets>)?; - ctx.link_sockets_bindings( - p2::bindings::sockets::udp_create_socket::add_to_linker::<_, WasiSockets>, + ctx.link_spin_sockets_bindings( + p2::bindings::sockets::udp::add_to_linker::<_, SpinSockets>, + )?; + ctx.link_spin_sockets_bindings( + p2::bindings::sockets::udp_create_socket::add_to_linker::<_, SpinSockets>, )?; ctx.link_sockets_bindings( p2::bindings::sockets::instance_network::add_to_linker::<_, WasiSockets>, @@ -829,7 +1023,6 @@ impl FactorInstanceBuilder for InstanceBuilder { impl InstanceBuilder { /// Sets the socket permit state for per-connection quota tracking. - /// Called by `OutboundNetworkingFactor` when `max_sockets_per_app` is configured. pub fn set_socket_permit_state(&mut self, state: Arc) { self.socket_permit_state = Some(state); } From d60ee5313442b894da5a27a9f0ef305080f08bed Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Thu, 28 May 2026 12:05:32 +0200 Subject: [PATCH 03/15] PR nits Signed-off-by: Ryan Levick --- Cargo.lock | 1 + .../tests/factor_test.rs | 10 +- crates/factor-wasi/Cargo.toml | 1 + crates/factor-wasi/src/lib.rs | 519 +---------------- crates/factor-wasi/src/sockets.rs | 536 ++++++++++++++++++ 5 files changed, 548 insertions(+), 519 deletions(-) create mode 100644 crates/factor-wasi/src/sockets.rs diff --git a/Cargo.lock b/Cargo.lock index 71f0e3b4e9..e0322526cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9169,6 +9169,7 @@ dependencies = [ "spin-factors", "spin-factors-test", "tokio", + "tracing", "wasmtime", "wasmtime-wasi", ] diff --git a/crates/factor-outbound-networking/tests/factor_test.rs b/crates/factor-outbound-networking/tests/factor_test.rs index 476932e999..19c3b25296 100644 --- a/crates/factor-outbound-networking/tests/factor_test.rs +++ b/crates/factor-outbound-networking/tests/factor_test.rs @@ -128,7 +128,7 @@ async fn socket_quota_blocks_excess_connections() -> anyhow::Result<()> { let err = p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock3, net3, addr.into()) .await .unwrap_err(); - assert_eq!(err.downcast_ref(), Some(&ErrorCode::ConnectionRefused)); + assert_eq!(err.downcast_ref(), Some(&ErrorCode::NewSocketLimit)); Ok(()) } @@ -298,7 +298,7 @@ async fn socket_quota_releases_on_socket_drop() -> anyhow::Result<()> { let err = p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock2, net2, addr.into()) .await .unwrap_err(); - assert_eq!(err.downcast_ref(), Some(&ErrorCode::ConnectionRefused)); + assert_eq!(err.downcast_ref(), Some(&ErrorCode::NewSocketLimit)); // Explicitly drop sock1 before finish_connect — this should release the permit. let sock1_handle = @@ -344,7 +344,7 @@ async fn socket_quota_blocks_excess_udp_sockets() -> anyhow::Result<()> { // Third should fail — quota exhausted. let err = p2_udp_create::Host::create_udp_socket(&mut sockets, IpAddressFamily::Ipv4).unwrap_err(); - assert_eq!(err.downcast_ref(), Some(&ErrorCode::ConnectionRefused)); + assert_eq!(err.downcast_ref(), Some(&ErrorCode::NewSocketLimit)); Ok(()) } @@ -385,13 +385,13 @@ async fn socket_quota_shared_between_tcp_and_udp() -> anyhow::Result<()> { // UDP: let err = p2_udp_create::Host::create_udp_socket(&mut sockets, IpAddressFamily::Ipv4).unwrap_err(); - assert_eq!(err.downcast_ref(), Some(&ErrorCode::ConnectionRefused)); + assert_eq!(err.downcast_ref(), Some(&ErrorCode::NewSocketLimit)); // TCP: let net = sockets.instance_network()?; let tcp_sock2 = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; let err = p2_tcp::HostTcpSocket::start_connect(&mut sockets, tcp_sock2, net, addr.into()) .await .unwrap_err(); - assert_eq!(err.downcast_ref(), Some(&ErrorCode::ConnectionRefused)); + assert_eq!(err.downcast_ref(), Some(&ErrorCode::NewSocketLimit)); Ok(()) } diff --git a/crates/factor-wasi/Cargo.toml b/crates/factor-wasi/Cargo.toml index 277a8b8b0a..17acba531b 100644 --- a/crates/factor-wasi/Cargo.toml +++ b/crates/factor-wasi/Cargo.toml @@ -10,6 +10,7 @@ bytes = { workspace = true } spin-common = { path = "../common" } spin-factors = { path = "../factors" } tokio = { workspace = true, features = ["sync"] } +tracing = { workspace = true } wasmtime = { workspace = true } wasmtime-wasi = { workspace = true } diff --git a/crates/factor-wasi/src/lib.rs b/crates/factor-wasi/src/lib.rs index d8e3377c32..02c6ba876e 100644 --- a/crates/factor-wasi/src/lib.rs +++ b/crates/factor-wasi/src/lib.rs @@ -1,15 +1,15 @@ mod io; +pub mod sockets; pub mod spin; mod wasi_2023_10_18; mod wasi_2023_11_10; use std::{ - collections::HashMap, future::Future, io::{Read, Write}, net::SocketAddr, path::Path, - sync::{Arc, Mutex}, + sync::Arc, }; use io::{PipeReadStream, PipedWriteStream}; @@ -17,526 +17,17 @@ use spin_factors::{ AppComponent, Factor, FactorInstanceBuilder, InitContext, PrepareContext, RuntimeFactors, RuntimeFactorsInstanceState, anyhow, }; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; -use wasmtime::component::{HasData, Resource}; +use wasmtime::component::HasData; use wasmtime_wasi::cli::{StdinStream, StdoutStream, WasiCli, WasiCliCtxView}; use wasmtime_wasi::clocks::{WasiClocks, WasiClocksCtxView}; use wasmtime_wasi::filesystem::{WasiFilesystem, WasiFilesystemCtxView}; -use wasmtime_wasi::p2::bindings::sockets::network::{ - ErrorCode as SocketErrorCode, Host as NetworkHost, Network, -}; -use wasmtime_wasi::p2::bindings::sockets::tcp::{self as p2_tcp, IpSocketAddress, ShutdownType}; -use wasmtime_wasi::p2::bindings::sockets::tcp_create_socket as p2_tcp_create; -use wasmtime_wasi::p2::bindings::sockets::udp as p2_udp; -use wasmtime_wasi::p2::bindings::sockets::udp_create_socket as p2_udp_create; -use wasmtime_wasi::p2::{DynInputStream, DynOutputStream, DynPollable}; use wasmtime_wasi::random::{WasiRandom, WasiRandomCtx}; -use wasmtime_wasi::sockets::{TcpSocket, UdpSocket, WasiSockets, WasiSocketsCtxView}; +use wasmtime_wasi::sockets::{WasiSockets, WasiSocketsCtxView}; use wasmtime_wasi::{DirPerms, FilePerms, ResourceTable, WasiCtx, WasiCtxBuilder, WasiCtxView}; +pub use sockets::{SocketPermitState, SpinSockets, SpinSocketsView}; pub use wasmtime_wasi::sockets::SocketAddrUse; -/// Shared state for tracking per-socket semaphore permits. Permits are -/// acquired when a socket is allocated (at `start_connect` for TCP, at -/// `create_udp_socket` for UDP) and released when the socket resource is dropped. -pub struct SocketPermitState { - semaphore: Arc, - /// Active permits keyed by socket resource rep. - /// - /// Permits are removed (and the permit released) when the WASI socket resource is dropped. - active: Mutex>, -} - -impl SocketPermitState { - pub fn new(semaphore: Arc) -> Arc { - Arc::new(Self { - semaphore, - active: Mutex::new(HashMap::new()), - }) - } -} - -/// A view over WASI socket state that carries an optional per-instance socket -/// permit store, enabling per-connection quota tracking. -pub struct SpinSocketsView<'a> { - pub(crate) inner: WasiSocketsCtxView<'a>, - pub(crate) permit_state: Option>, -} - -impl<'a> std::ops::Deref for SpinSocketsView<'a> { - type Target = WasiSocketsCtxView<'a>; - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl std::ops::DerefMut for SpinSocketsView<'_> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner - } -} - -/// [`HasData`] accessor for [`SpinSocketsView`], used in place of [`WasiSockets`] -/// when registering TCP socket bindings so that `start_connect` and `drop` can -/// participate in socket quota tracking. -pub struct SpinSockets; - -impl HasData for SpinSockets { - type Data<'a> = SpinSocketsView<'a>; -} - -impl p2_tcp::Host for SpinSocketsView<'_> {} - -impl p2_tcp::HostTcpSocket for SpinSocketsView<'_> { - async fn start_bind( - &mut self, - this: Resource, - network: Resource, - local_address: IpSocketAddress, - ) -> wasmtime_wasi::p2::SocketResult<()> { - p2_tcp::HostTcpSocket::start_bind(&mut self.inner, this, network, local_address).await - } - - fn finish_bind(&mut self, this: Resource) -> wasmtime_wasi::p2::SocketResult<()> { - p2_tcp::HostTcpSocket::finish_bind(&mut self.inner, this) - } - - async fn start_connect( - &mut self, - this: Resource, - network: Resource, - remote_address: IpSocketAddress, - ) -> wasmtime_wasi::p2::SocketResult<()> { - if let Some(state) = &self.permit_state { - // If we have a permit state, we need to acquire a permit before allowing the connection to proceed. - let socket_rep = this.rep(); - let Ok(permit) = Arc::clone(&state.semaphore).try_acquire_owned() else { - // wasi has no "quota exceeded" error code. ConnectionRefused is the closest available. - return Err(SocketErrorCode::ConnectionRefused.into()); - }; - p2_tcp::HostTcpSocket::start_connect(&mut self.inner, this, network, remote_address) - .await?; - // If the connection was successfully initiated, store the permit so it can be released when the socket is dropped. - state - .active - .lock() - .unwrap_or_else(|e| e.into_inner()) - .insert(socket_rep, permit); - Ok(()) - } else { - p2_tcp::HostTcpSocket::start_connect(&mut self.inner, this, network, remote_address) - .await - } - } - - fn finish_connect( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult<(Resource, Resource)> - { - p2_tcp::HostTcpSocket::finish_connect(&mut self.inner, this) - } - - fn start_listen(&mut self, this: Resource) -> wasmtime_wasi::p2::SocketResult<()> { - p2_tcp::HostTcpSocket::start_listen(&mut self.inner, this) - } - - fn finish_listen(&mut self, this: Resource) -> wasmtime_wasi::p2::SocketResult<()> { - p2_tcp::HostTcpSocket::finish_listen(&mut self.inner, this) - } - - fn accept( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult<( - Resource, - Resource, - Resource, - )> { - p2_tcp::HostTcpSocket::accept(&mut self.inner, this) - } - - fn local_address( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult { - p2_tcp::HostTcpSocket::local_address(&mut self.inner, this) - } - - fn remote_address( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult { - p2_tcp::HostTcpSocket::remote_address(&mut self.inner, this) - } - - fn is_listening(&mut self, this: Resource) -> wasmtime::Result { - p2_tcp::HostTcpSocket::is_listening(&mut self.inner, this) - } - - fn address_family( - &mut self, - this: Resource, - ) -> wasmtime::Result { - p2_tcp::HostTcpSocket::address_family(&mut self.inner, this) - } - - fn set_listen_backlog_size( - &mut self, - this: Resource, - value: u64, - ) -> wasmtime_wasi::p2::SocketResult<()> { - p2_tcp::HostTcpSocket::set_listen_backlog_size(&mut self.inner, this, value) - } - - fn keep_alive_enabled( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult { - p2_tcp::HostTcpSocket::keep_alive_enabled(&mut self.inner, this) - } - - fn set_keep_alive_enabled( - &mut self, - this: Resource, - value: bool, - ) -> wasmtime_wasi::p2::SocketResult<()> { - p2_tcp::HostTcpSocket::set_keep_alive_enabled(&mut self.inner, this, value) - } - - fn keep_alive_idle_time( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult { - p2_tcp::HostTcpSocket::keep_alive_idle_time(&mut self.inner, this) - } - - fn set_keep_alive_idle_time( - &mut self, - this: Resource, - value: u64, - ) -> wasmtime_wasi::p2::SocketResult<()> { - p2_tcp::HostTcpSocket::set_keep_alive_idle_time(&mut self.inner, this, value) - } - - fn keep_alive_interval( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult { - p2_tcp::HostTcpSocket::keep_alive_interval(&mut self.inner, this) - } - - fn set_keep_alive_interval( - &mut self, - this: Resource, - value: u64, - ) -> wasmtime_wasi::p2::SocketResult<()> { - p2_tcp::HostTcpSocket::set_keep_alive_interval(&mut self.inner, this, value) - } - - fn keep_alive_count( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult { - p2_tcp::HostTcpSocket::keep_alive_count(&mut self.inner, this) - } - - fn set_keep_alive_count( - &mut self, - this: Resource, - value: u32, - ) -> wasmtime_wasi::p2::SocketResult<()> { - p2_tcp::HostTcpSocket::set_keep_alive_count(&mut self.inner, this, value) - } - - fn hop_limit(&mut self, this: Resource) -> wasmtime_wasi::p2::SocketResult { - p2_tcp::HostTcpSocket::hop_limit(&mut self.inner, this) - } - - fn set_hop_limit( - &mut self, - this: Resource, - value: u8, - ) -> wasmtime_wasi::p2::SocketResult<()> { - p2_tcp::HostTcpSocket::set_hop_limit(&mut self.inner, this, value) - } - - fn receive_buffer_size( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult { - p2_tcp::HostTcpSocket::receive_buffer_size(&mut self.inner, this) - } - - fn set_receive_buffer_size( - &mut self, - this: Resource, - value: u64, - ) -> wasmtime_wasi::p2::SocketResult<()> { - p2_tcp::HostTcpSocket::set_receive_buffer_size(&mut self.inner, this, value) - } - - fn send_buffer_size( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult { - p2_tcp::HostTcpSocket::send_buffer_size(&mut self.inner, this) - } - - fn set_send_buffer_size( - &mut self, - this: Resource, - value: u64, - ) -> wasmtime_wasi::p2::SocketResult<()> { - p2_tcp::HostTcpSocket::set_send_buffer_size(&mut self.inner, this, value) - } - - fn subscribe(&mut self, this: Resource) -> wasmtime::Result> { - p2_tcp::HostTcpSocket::subscribe(&mut self.inner, this) - } - - fn shutdown( - &mut self, - this: Resource, - shutdown_type: ShutdownType, - ) -> wasmtime_wasi::p2::SocketResult<()> { - p2_tcp::HostTcpSocket::shutdown(&mut self.inner, this, shutdown_type) - } - - fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { - // Release the permit before dropping the socket resource. - if let Some(state) = &self.permit_state { - state - .active - .lock() - .unwrap_or_else(|e| e.into_inner()) - .remove(&this.rep()); - } - p2_tcp::HostTcpSocket::drop(&mut self.inner, this) - } -} - -impl NetworkHost for SpinSocketsView<'_> { - fn convert_error_code( - &mut self, - error: wasmtime_wasi::p2::SocketError, - ) -> wasmtime::Result { - NetworkHost::convert_error_code(&mut self.inner, error) - } - - fn network_error_code( - &mut self, - err: Resource, - ) -> wasmtime::Result> { - NetworkHost::network_error_code(&mut self.inner, err) - } -} - -impl wasmtime_wasi::p2::bindings::sockets::network::HostNetwork for SpinSocketsView<'_> { - fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { - wasmtime_wasi::p2::bindings::sockets::network::HostNetwork::drop(&mut self.inner, this) - } -} - -impl p2_tcp_create::Host for SpinSocketsView<'_> { - fn create_tcp_socket( - &mut self, - address_family: wasmtime_wasi::p2::bindings::sockets::network::IpAddressFamily, - ) -> wasmtime_wasi::p2::SocketResult> { - p2_tcp_create::Host::create_tcp_socket(&mut self.inner, address_family) - } -} - -impl p2_udp::Host for SpinSocketsView<'_> {} - -impl p2_udp::HostUdpSocket for SpinSocketsView<'_> { - async fn start_bind( - &mut self, - this: Resource, - network: Resource, - local_address: p2_udp::IpSocketAddress, - ) -> wasmtime_wasi::p2::SocketResult<()> { - p2_udp::HostUdpSocket::start_bind(&mut self.inner, this, network, local_address).await - } - - fn finish_bind( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult<()> { - p2_udp::HostUdpSocket::finish_bind(&mut self.inner, this) - } - - async fn stream( - &mut self, - this: Resource, - remote_address: Option, - ) -> wasmtime_wasi::p2::SocketResult<( - Resource, - Resource, - )> { - p2_udp::HostUdpSocket::stream(&mut self.inner, this, remote_address).await - } - - fn local_address( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult { - p2_udp::HostUdpSocket::local_address(&mut self.inner, this) - } - - fn remote_address( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult { - p2_udp::HostUdpSocket::remote_address(&mut self.inner, this) - } - - fn address_family( - &mut self, - this: Resource, - ) -> wasmtime::Result { - p2_udp::HostUdpSocket::address_family(&mut self.inner, this) - } - - fn unicast_hop_limit( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult { - p2_udp::HostUdpSocket::unicast_hop_limit(&mut self.inner, this) - } - - fn set_unicast_hop_limit( - &mut self, - this: Resource, - value: u8, - ) -> wasmtime_wasi::p2::SocketResult<()> { - p2_udp::HostUdpSocket::set_unicast_hop_limit(&mut self.inner, this, value) - } - - fn receive_buffer_size( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult { - p2_udp::HostUdpSocket::receive_buffer_size(&mut self.inner, this) - } - - fn set_receive_buffer_size( - &mut self, - this: Resource, - value: u64, - ) -> wasmtime_wasi::p2::SocketResult<()> { - p2_udp::HostUdpSocket::set_receive_buffer_size(&mut self.inner, this, value) - } - - fn send_buffer_size( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult { - p2_udp::HostUdpSocket::send_buffer_size(&mut self.inner, this) - } - - fn set_send_buffer_size( - &mut self, - this: Resource, - value: u64, - ) -> wasmtime_wasi::p2::SocketResult<()> { - p2_udp::HostUdpSocket::set_send_buffer_size(&mut self.inner, this, value) - } - - fn subscribe( - &mut self, - this: Resource, - ) -> wasmtime::Result> { - p2_udp::HostUdpSocket::subscribe(&mut self.inner, this) - } - - fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { - // Release the permit before dropping the socket resource. - if let Some(state) = &self.permit_state { - state - .active - .lock() - .unwrap_or_else(|e| e.into_inner()) - .remove(&this.rep()); - } - p2_udp::HostUdpSocket::drop(&mut self.inner, this) - } -} - -impl p2_udp::HostIncomingDatagramStream for SpinSocketsView<'_> { - fn receive( - &mut self, - this: Resource, - max_results: u64, - ) -> wasmtime_wasi::p2::SocketResult> { - p2_udp::HostIncomingDatagramStream::receive(&mut self.inner, this, max_results) - } - - fn subscribe( - &mut self, - this: Resource, - ) -> wasmtime::Result> { - p2_udp::HostIncomingDatagramStream::subscribe(&mut self.inner, this) - } - - fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { - p2_udp::HostIncomingDatagramStream::drop(&mut self.inner, this) - } -} - -impl p2_udp::HostOutgoingDatagramStream for SpinSocketsView<'_> { - fn check_send( - &mut self, - this: Resource, - ) -> wasmtime_wasi::p2::SocketResult { - p2_udp::HostOutgoingDatagramStream::check_send(&mut self.inner, this) - } - - async fn send( - &mut self, - this: Resource, - datagrams: Vec, - ) -> wasmtime_wasi::p2::SocketResult { - p2_udp::HostOutgoingDatagramStream::send(&mut self.inner, this, datagrams).await - } - - fn subscribe( - &mut self, - this: Resource, - ) -> wasmtime::Result> { - p2_udp::HostOutgoingDatagramStream::subscribe(&mut self.inner, this) - } - - fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { - p2_udp::HostOutgoingDatagramStream::drop(&mut self.inner, this) - } -} - -impl p2_udp_create::Host for SpinSocketsView<'_> { - fn create_udp_socket( - &mut self, - address_family: wasmtime_wasi::p2::bindings::sockets::network::IpAddressFamily, - ) -> wasmtime_wasi::p2::SocketResult> { - if let Some(state) = &self.permit_state { - // If we have a permit state, we need to acquire a permit before allowing the socket creation to proceed. - let state = Arc::clone(state); - let permit = Arc::clone(&state.semaphore) - .try_acquire_owned() - .map_err(|_| SocketErrorCode::ConnectionRefused)?; - let sock = p2_udp_create::Host::create_udp_socket(&mut self.inner, address_family)?; - // If the socket was successfully created, store the permit so it can be released when the socket is dropped. - state - .active - .lock() - .unwrap_or_else(|e| e.into_inner()) - .insert(sock.rep(), permit); - Ok(sock) - } else { - p2_udp_create::Host::create_udp_socket(&mut self.inner, address_family) - } - } -} - pub struct WasiFactor { files_mounter: Box, } diff --git a/crates/factor-wasi/src/sockets.rs b/crates/factor-wasi/src/sockets.rs new file mode 100644 index 0000000000..e00a1ee4dc --- /dev/null +++ b/crates/factor-wasi/src/sockets.rs @@ -0,0 +1,536 @@ +//! Socket quota tracking and WASI socket host implementations. +//! +//! This module provides [`SocketPermitState`], [`SpinSocketsView`], and +//! [`SpinSockets`] — the types needed to intercept WASI TCP/UDP socket +//! creation and enforce a per-app cap on the number of concurrently open +//! sockets. + +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use wasmtime::component::{HasData, Resource}; +use wasmtime_wasi::p2::bindings::sockets::network::{ + ErrorCode as SocketErrorCode, Host as NetworkHost, Network, +}; +use wasmtime_wasi::p2::bindings::sockets::tcp::{self as p2_tcp, IpSocketAddress, ShutdownType}; +use wasmtime_wasi::p2::bindings::sockets::tcp_create_socket as p2_tcp_create; +use wasmtime_wasi::p2::bindings::sockets::udp as p2_udp; +use wasmtime_wasi::p2::bindings::sockets::udp_create_socket as p2_udp_create; +use wasmtime_wasi::p2::{DynInputStream, DynOutputStream, DynPollable}; +use wasmtime_wasi::sockets::{TcpSocket, UdpSocket, WasiSocketsCtxView}; + +/// Shared state for tracking per-socket semaphore permits. Permits are +/// acquired when a socket is allocated (at `start_connect` for TCP, at +/// `create_udp_socket` for UDP) and released when the socket resource is dropped. +pub struct SocketPermitState { + semaphore: Arc, + /// Active permits keyed by socket resource rep. + /// + /// Permits are removed (and the permit released) when the WASI socket resource is dropped. + active: Mutex>, +} + +impl SocketPermitState { + pub fn new(semaphore: Arc) -> Arc { + Arc::new(Self { + semaphore, + active: Mutex::new(HashMap::new()), + }) + } +} + +/// A view over WASI socket state that carries an optional per-instance socket +/// permit store, enabling per-connection quota tracking. +pub struct SpinSocketsView<'a> { + pub(crate) inner: WasiSocketsCtxView<'a>, + pub(crate) permit_state: Option>, +} + +impl<'a> std::ops::Deref for SpinSocketsView<'a> { + type Target = WasiSocketsCtxView<'a>; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl std::ops::DerefMut for SpinSocketsView<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +/// [`HasData`] accessor for [`SpinSocketsView`], used in place of [`WasiSockets`] +/// when registering TCP socket bindings so that `start_connect` and `drop` can +/// participate in socket quota tracking. +pub struct SpinSockets; + +impl HasData for SpinSockets { + type Data<'a> = SpinSocketsView<'a>; +} + +impl p2_tcp::Host for SpinSocketsView<'_> {} + +impl p2_tcp::HostTcpSocket for SpinSocketsView<'_> { + async fn start_bind( + &mut self, + this: Resource, + network: Resource, + local_address: IpSocketAddress, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::start_bind(&mut self.inner, this, network, local_address).await + } + + fn finish_bind(&mut self, this: Resource) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::finish_bind(&mut self.inner, this) + } + + async fn start_connect( + &mut self, + this: Resource, + network: Resource, + remote_address: IpSocketAddress, + ) -> wasmtime_wasi::p2::SocketResult<()> { + if let Some(state) = &self.permit_state { + let socket_rep = this.rep(); + // Unlike outbound HTTP (which queues when its permit pool is exhausted), + // sockets fail immediately. Waiting would risk deadlock if a component + // holds sockets open across async yield points, and raw-socket callers + // are better positioned to implement their own retry logic. The two + // limits are also configured separately, so different semantics are fine. + let Ok(permit) = Arc::clone(&state.semaphore).try_acquire_owned() else { + tracing::warn!("TCP socket connection refused: socket quota exhausted"); + // `new-socket-limit` maps to POSIX EMFILE/ENFILE: "a new socket + // resource could not be created because of a system limit." + return Err(SocketErrorCode::NewSocketLimit.into()); + }; + p2_tcp::HostTcpSocket::start_connect(&mut self.inner, this, network, remote_address) + .await?; + // If the connection was successfully initiated, store the permit so it can be released when the socket is dropped. + state + .active + .lock() + .unwrap_or_else(|e| e.into_inner()) + .insert(socket_rep, permit); + Ok(()) + } else { + p2_tcp::HostTcpSocket::start_connect(&mut self.inner, this, network, remote_address) + .await + } + } + + fn finish_connect( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult<(Resource, Resource)> + { + p2_tcp::HostTcpSocket::finish_connect(&mut self.inner, this) + } + + fn start_listen(&mut self, this: Resource) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::start_listen(&mut self.inner, this) + } + + fn finish_listen(&mut self, this: Resource) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::finish_listen(&mut self.inner, this) + } + + fn accept( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult<( + Resource, + Resource, + Resource, + )> { + p2_tcp::HostTcpSocket::accept(&mut self.inner, this) + } + + fn local_address( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::local_address(&mut self.inner, this) + } + + fn remote_address( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::remote_address(&mut self.inner, this) + } + + fn is_listening(&mut self, this: Resource) -> wasmtime::Result { + p2_tcp::HostTcpSocket::is_listening(&mut self.inner, this) + } + + fn address_family( + &mut self, + this: Resource, + ) -> wasmtime::Result { + p2_tcp::HostTcpSocket::address_family(&mut self.inner, this) + } + + fn set_listen_backlog_size( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_listen_backlog_size(&mut self.inner, this, value) + } + + fn keep_alive_enabled( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::keep_alive_enabled(&mut self.inner, this) + } + + fn set_keep_alive_enabled( + &mut self, + this: Resource, + value: bool, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_keep_alive_enabled(&mut self.inner, this, value) + } + + fn keep_alive_idle_time( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::keep_alive_idle_time(&mut self.inner, this) + } + + fn set_keep_alive_idle_time( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_keep_alive_idle_time(&mut self.inner, this, value) + } + + fn keep_alive_interval( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::keep_alive_interval(&mut self.inner, this) + } + + fn set_keep_alive_interval( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_keep_alive_interval(&mut self.inner, this, value) + } + + fn keep_alive_count( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::keep_alive_count(&mut self.inner, this) + } + + fn set_keep_alive_count( + &mut self, + this: Resource, + value: u32, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_keep_alive_count(&mut self.inner, this, value) + } + + fn hop_limit(&mut self, this: Resource) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::hop_limit(&mut self.inner, this) + } + + fn set_hop_limit( + &mut self, + this: Resource, + value: u8, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_hop_limit(&mut self.inner, this, value) + } + + fn receive_buffer_size( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::receive_buffer_size(&mut self.inner, this) + } + + fn set_receive_buffer_size( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_receive_buffer_size(&mut self.inner, this, value) + } + + fn send_buffer_size( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::send_buffer_size(&mut self.inner, this) + } + + fn set_send_buffer_size( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_send_buffer_size(&mut self.inner, this, value) + } + + fn subscribe(&mut self, this: Resource) -> wasmtime::Result> { + p2_tcp::HostTcpSocket::subscribe(&mut self.inner, this) + } + + fn shutdown( + &mut self, + this: Resource, + shutdown_type: ShutdownType, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::shutdown(&mut self.inner, this, shutdown_type) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + // Release the permit before dropping the socket resource. + if let Some(state) = &self.permit_state { + state + .active + .lock() + .unwrap_or_else(|e| e.into_inner()) + .remove(&this.rep()); + } + p2_tcp::HostTcpSocket::drop(&mut self.inner, this) + } +} + +impl NetworkHost for SpinSocketsView<'_> { + fn convert_error_code( + &mut self, + error: wasmtime_wasi::p2::SocketError, + ) -> wasmtime::Result { + NetworkHost::convert_error_code(&mut self.inner, error) + } + + fn network_error_code( + &mut self, + err: Resource, + ) -> wasmtime::Result> { + NetworkHost::network_error_code(&mut self.inner, err) + } +} + +impl wasmtime_wasi::p2::bindings::sockets::network::HostNetwork for SpinSocketsView<'_> { + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + wasmtime_wasi::p2::bindings::sockets::network::HostNetwork::drop(&mut self.inner, this) + } +} + +impl p2_tcp_create::Host for SpinSocketsView<'_> { + fn create_tcp_socket( + &mut self, + address_family: wasmtime_wasi::p2::bindings::sockets::network::IpAddressFamily, + ) -> wasmtime_wasi::p2::SocketResult> { + p2_tcp_create::Host::create_tcp_socket(&mut self.inner, address_family) + } +} + +impl p2_udp::Host for SpinSocketsView<'_> {} + +impl p2_udp::HostUdpSocket for SpinSocketsView<'_> { + async fn start_bind( + &mut self, + this: Resource, + network: Resource, + local_address: p2_udp::IpSocketAddress, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_udp::HostUdpSocket::start_bind(&mut self.inner, this, network, local_address).await + } + + fn finish_bind( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_udp::HostUdpSocket::finish_bind(&mut self.inner, this) + } + + async fn stream( + &mut self, + this: Resource, + remote_address: Option, + ) -> wasmtime_wasi::p2::SocketResult<( + Resource, + Resource, + )> { + p2_udp::HostUdpSocket::stream(&mut self.inner, this, remote_address).await + } + + fn local_address( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostUdpSocket::local_address(&mut self.inner, this) + } + + fn remote_address( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostUdpSocket::remote_address(&mut self.inner, this) + } + + fn address_family( + &mut self, + this: Resource, + ) -> wasmtime::Result { + p2_udp::HostUdpSocket::address_family(&mut self.inner, this) + } + + fn unicast_hop_limit( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostUdpSocket::unicast_hop_limit(&mut self.inner, this) + } + + fn set_unicast_hop_limit( + &mut self, + this: Resource, + value: u8, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_udp::HostUdpSocket::set_unicast_hop_limit(&mut self.inner, this, value) + } + + fn receive_buffer_size( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostUdpSocket::receive_buffer_size(&mut self.inner, this) + } + + fn set_receive_buffer_size( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_udp::HostUdpSocket::set_receive_buffer_size(&mut self.inner, this, value) + } + + fn send_buffer_size( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostUdpSocket::send_buffer_size(&mut self.inner, this) + } + + fn set_send_buffer_size( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_udp::HostUdpSocket::set_send_buffer_size(&mut self.inner, this, value) + } + + fn subscribe( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + p2_udp::HostUdpSocket::subscribe(&mut self.inner, this) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + // Release the permit before dropping the socket resource. + if let Some(state) = &self.permit_state { + state + .active + .lock() + .unwrap_or_else(|e| e.into_inner()) + .remove(&this.rep()); + } + p2_udp::HostUdpSocket::drop(&mut self.inner, this) + } +} + +impl p2_udp::HostIncomingDatagramStream for SpinSocketsView<'_> { + fn receive( + &mut self, + this: Resource, + max_results: u64, + ) -> wasmtime_wasi::p2::SocketResult> { + p2_udp::HostIncomingDatagramStream::receive(&mut self.inner, this, max_results) + } + + fn subscribe( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + p2_udp::HostIncomingDatagramStream::subscribe(&mut self.inner, this) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + p2_udp::HostIncomingDatagramStream::drop(&mut self.inner, this) + } +} + +impl p2_udp::HostOutgoingDatagramStream for SpinSocketsView<'_> { + fn check_send( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostOutgoingDatagramStream::check_send(&mut self.inner, this) + } + + async fn send( + &mut self, + this: Resource, + datagrams: Vec, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostOutgoingDatagramStream::send(&mut self.inner, this, datagrams).await + } + + fn subscribe( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + p2_udp::HostOutgoingDatagramStream::subscribe(&mut self.inner, this) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + p2_udp::HostOutgoingDatagramStream::drop(&mut self.inner, this) + } +} + +impl p2_udp_create::Host for SpinSocketsView<'_> { + fn create_udp_socket( + &mut self, + address_family: wasmtime_wasi::p2::bindings::sockets::network::IpAddressFamily, + ) -> wasmtime_wasi::p2::SocketResult> { + if let Some(state) = &self.permit_state { + let state = Arc::clone(state); + // See the analogous comment in `start_connect` for why we fail + // immediately rather than waiting (as outbound HTTP does). + let permit = Arc::clone(&state.semaphore) + .try_acquire_owned() + .map_err(|_| { + tracing::warn!("UDP socket creation refused: socket quota exhausted"); + // `new-socket-limit` maps to POSIX EMFILE/ENFILE: "a new socket + // resource could not be created because of a system limit." + SocketErrorCode::NewSocketLimit + })?; + let sock = p2_udp_create::Host::create_udp_socket(&mut self.inner, address_family)?; + // If the socket was successfully created, store the permit so it can be released when the socket is dropped. + state + .active + .lock() + .unwrap_or_else(|e| e.into_inner()) + .insert(sock.rep(), permit); + Ok(sock) + } else { + p2_udp_create::Host::create_udp_socket(&mut self.inner, address_family) + } + } +} From d988e1cc4138db26ac72305c3323ca6a36b7786a Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Thu, 28 May 2026 15:23:49 +0200 Subject: [PATCH 04/15] Limit the number of redis connections an app can have across instances Signed-off-by: Ryan Levick --- Cargo.lock | 1 + crates/factor-outbound-redis/Cargo.toml | 3 +- crates/factor-outbound-redis/src/host.rs | 58 ++++++++++++++----- crates/factor-outbound-redis/src/lib.rs | 22 +++++-- .../src/runtime_config.rs | 8 +++ .../src/runtime_config/spin.rs | 29 ++++++++++ crates/runtime-config/src/lib.rs | 6 +- 7 files changed, 106 insertions(+), 21 deletions(-) create mode 100644 crates/factor-outbound-redis/src/runtime_config.rs create mode 100644 crates/factor-outbound-redis/src/runtime_config/spin.rs diff --git a/Cargo.lock b/Cargo.lock index e0322526cb..350e6badd1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9115,6 +9115,7 @@ version = "4.1.0-pre0" dependencies = [ "anyhow", "redis", + "serde", "spin-core", "spin-factor-otel", "spin-factor-outbound-networking", diff --git a/crates/factor-outbound-redis/Cargo.toml b/crates/factor-outbound-redis/Cargo.toml index 6518459d79..55e641ce6e 100644 --- a/crates/factor-outbound-redis/Cargo.toml +++ b/crates/factor-outbound-redis/Cargo.toml @@ -7,13 +7,14 @@ edition = { workspace = true } [dependencies] anyhow = { workspace = true } redis = { workspace = true , features = ["tokio-comp", "tokio-native-tls-comp", "aio"] } +serde = { workspace = true } spin-core = { path = "../core" } spin-factor-otel = { path = "../factor-otel" } spin-factor-outbound-networking = { path = "../factor-outbound-networking" } spin-factors = { path = "../factors" } spin-resource-table = { path = "../table" } spin-world = { path = "../world" } -tokio = { workspace = true } +tokio = { workspace = true, features = ["sync"] } tracing = { workspace = true } [dev-dependencies] diff --git a/crates/factor-outbound-redis/src/host.rs b/crates/factor-outbound-redis/src/host.rs index 61fd05708b..202584a8a1 100644 --- a/crates/factor-outbound-redis/src/host.rs +++ b/crates/factor-outbound-redis/src/host.rs @@ -1,4 +1,5 @@ use std::net::SocketAddr; +use std::sync::Arc; use anyhow::Result; use redis::AsyncConnectionConfig; @@ -11,6 +12,7 @@ use spin_world::MAX_HOST_BUFFERED_BYTES; use spin_world::spin::redis::redis as v3; use spin_world::v1::{redis as v1, redis_types}; use spin_world::v2::redis as v2; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tracing::field::Empty; use tracing::{Level, instrument}; @@ -19,7 +21,9 @@ use crate::allowed_hosts::AllowedHostChecker; pub struct InstanceState { pub(crate) allowed_host_checker: AllowedHostChecker, pub blocked_networks: BlockedNetworks, - pub connections: spin_resource_table::Table, + pub connections: + spin_resource_table::Table<(MultiplexedConnection, Option)>, + pub connection_semaphore: Option>, pub otel: OtelFactorState, } @@ -32,6 +36,15 @@ impl InstanceState { &mut self, address: String, ) -> Result, v2::Error> { + let permit = match &self.connection_semaphore { + Some(sem) => Some( + Arc::clone(sem) + .acquire_owned() + .await + .map_err(|_| v2::Error::TooManyConnections)?, + ), + None => None, + }; let config = AsyncConnectionConfig::new() .set_dns_resolver(SpinDnsResolver(self.blocked_networks.clone())); let conn = redis::Client::open(address.as_str()) @@ -40,7 +53,7 @@ impl InstanceState { .await .map_err(other_error_v2)?; self.connections - .push(conn) + .push((conn, permit)) .map(Resource::new_own) .map_err(|_| v2::Error::TooManyConnections) } @@ -51,6 +64,7 @@ impl InstanceState { ) -> Result<&mut MultiplexedConnection, v2::Error> { self.connections .get_mut(connection.rep()) + .map(|(conn, _permit)| conn) .ok_or(v2::Error::Other( "could not find connection for resource".into(), )) @@ -62,7 +76,7 @@ impl InstanceState { ) -> Result { self.connections .get(connection.rep()) - .cloned() + .map(|(conn, _permit)| conn.clone()) .ok_or(v3::Error::Other( "could not find connection for resource".into(), )) @@ -229,14 +243,16 @@ impl v3::HostConnectionWithStore for crate::RedisFactorData { accessor: &Accessor, address: String, ) -> Result, v3::Error> { - let (allowed_host_checker, blocked_networks) = accessor.with(|mut access| { - let host = access.get(); - host.otel.reparent_tracing_span(); - ( - host.allowed_host_checker.clone(), - host.blocked_networks.clone(), - ) - }); + let (allowed_host_checker, blocked_networks, connection_semaphore) = + accessor.with(|mut access| { + let host = access.get(); + host.otel.reparent_tracing_span(); + ( + host.allowed_host_checker.clone(), + host.blocked_networks.clone(), + host.connection_semaphore.clone(), + ) + }); if !allowed_host_checker .is_address_allowed(&address) @@ -246,6 +262,15 @@ impl v3::HostConnectionWithStore for crate::RedisFactorData { return Err(v3::Error::InvalidAddress); } + let permit = match connection_semaphore { + Some(sem) => Some( + sem.acquire_owned() + .await + .map_err(|_| v3::Error::TooManyConnections)?, + ), + None => None, + }; + let config = AsyncConnectionConfig::new().set_dns_resolver(SpinDnsResolver(blocked_networks)); let conn = redis::Client::open(address.as_str()) @@ -257,7 +282,7 @@ impl v3::HostConnectionWithStore for crate::RedisFactorData { accessor.with(|mut access| { let host = access.get(); host.connections - .push(conn) + .push((conn, permit)) .map(Resource::new_own) .map_err(|_| v3::Error::TooManyConnections) }) @@ -532,9 +557,14 @@ macro_rules! delegate { Ok(c) => c, Err(_) => return Err(v1::Error::Error), }; - ::$name($self, connection, $($arg),*) + // v1 has no persistent connections, so remove the table entry immediately + // after the call to release the semaphore permit. + let rep = connection.rep(); + let result = ::$name($self, connection, $($arg),*) .await - .map_err(|_| v1::Error::Error) + .map_err(|_| v1::Error::Error); + $self.connections.remove(rep); + result }}; } diff --git a/crates/factor-outbound-redis/src/lib.rs b/crates/factor-outbound-redis/src/lib.rs index 494c5ca800..4f9186e5e2 100644 --- a/crates/factor-outbound-redis/src/lib.rs +++ b/crates/factor-outbound-redis/src/lib.rs @@ -1,7 +1,11 @@ mod allowed_hosts; mod host; +pub mod runtime_config; + +use std::sync::Arc; use host::InstanceState; +use runtime_config::RuntimeConfig; use spin_factor_otel::OtelFactorState; use spin_factor_outbound_networking::OutboundNetworkingFactor; use spin_factors::{ @@ -9,6 +13,7 @@ use spin_factors::{ anyhow, }; use spin_world::spin::redis::redis as v3; +use tokio::sync::Semaphore; use crate::allowed_hosts::AllowedHostChecker; @@ -24,9 +29,14 @@ impl OutboundRedisFactor { } } +pub struct AppState { + /// A semaphore to limit the number of concurrent outbound Redis connections. + pub connection_semaphore: Option>, +} + impl Factor for OutboundRedisFactor { - type RuntimeConfig = (); - type AppState = (); + type RuntimeConfig = RuntimeConfig; + type AppState = AppState; type InstanceBuilder = InstanceState; fn init(&mut self, ctx: &mut impl spin_factors::InitContext) -> anyhow::Result<()> { @@ -38,9 +48,12 @@ impl Factor for OutboundRedisFactor { fn configure_app( &self, - _ctx: ConfigureAppContext, + mut ctx: ConfigureAppContext, ) -> anyhow::Result { - Ok(()) + let config = ctx.take_runtime_config().unwrap_or_default(); + Ok(AppState { + connection_semaphore: config.max_connections.map(|n| Arc::new(Semaphore::new(n))), + }) } fn prepare( @@ -54,6 +67,7 @@ impl Factor for OutboundRedisFactor { allowed_host_checker: AllowedHostChecker::new(outbound_networking.allowed_hosts()), blocked_networks: outbound_networking.blocked_networks(), connections: spin_resource_table::Table::new(1024), + connection_semaphore: ctx.app_state().connection_semaphore.clone(), otel, }) } diff --git a/crates/factor-outbound-redis/src/runtime_config.rs b/crates/factor-outbound-redis/src/runtime_config.rs new file mode 100644 index 0000000000..38d2d7ea7d --- /dev/null +++ b/crates/factor-outbound-redis/src/runtime_config.rs @@ -0,0 +1,8 @@ +pub mod spin; + +/// Runtime configuration for outbound Redis. +#[derive(Default)] +pub struct RuntimeConfig { + /// If set, limits the number of concurrent outbound Redis connections. + pub max_connections: Option, +} diff --git a/crates/factor-outbound-redis/src/runtime_config/spin.rs b/crates/factor-outbound-redis/src/runtime_config/spin.rs new file mode 100644 index 0000000000..82c0efeaff --- /dev/null +++ b/crates/factor-outbound-redis/src/runtime_config/spin.rs @@ -0,0 +1,29 @@ +use serde::Deserialize; +use spin_factors::runtime_config::toml::GetTomlValue; + +/// Get the runtime configuration for outbound Redis from a TOML table. +/// +/// Expects table to be in the format: +/// ```toml +/// [outbound_redis] +/// max_connections = 10 # optional, defaults to unlimited +/// ``` +pub fn config_from_table( + table: &impl GetTomlValue, +) -> anyhow::Result> { + if let Some(outbound_redis) = table.get("outbound_redis") { + let toml = outbound_redis.clone().try_into::()?; + Ok(Some(super::RuntimeConfig { + max_connections: toml.max_connections, + })) + } else { + Ok(None) + } +} + +#[derive(Debug, Default, Deserialize)] +#[serde(deny_unknown_fields)] +struct OutboundRedisToml { + #[serde(default)] + max_connections: Option, +} diff --git a/crates/runtime-config/src/lib.rs b/crates/runtime-config/src/lib.rs index b5b2f6b1b8..4f73e7cba6 100644 --- a/crates/runtime-config/src/lib.rs +++ b/crates/runtime-config/src/lib.rs @@ -368,8 +368,10 @@ impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_, '_> { } impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_, '_> { - fn get_runtime_config(&mut self) -> anyhow::Result> { - Ok(None) + fn get_runtime_config( + &mut self, + ) -> anyhow::Result::RuntimeConfig>> { + spin_factor_outbound_redis::runtime_config::spin::config_from_table(&self.toml.table) } } From fc9c62ee31319e7890ddf19510c01ea0b189c078 Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Thu, 28 May 2026 16:40:33 +0200 Subject: [PATCH 05/15] Limit the number of pg/mysql connections an app can have across instances Signed-off-by: Ryan Levick --- Cargo.lock | 2 + crates/factor-outbound-mysql/Cargo.toml | 3 +- crates/factor-outbound-mysql/src/host.rs | 97 ++++++++++++++----- crates/factor-outbound-mysql/src/lib.rs | 43 +++++--- .../src/runtime_config.rs | 8 ++ .../src/runtime_config/spin.rs | 29 ++++++ crates/factor-outbound-pg/Cargo.toml | 3 +- crates/factor-outbound-pg/src/host.rs | 68 ++++++++++--- crates/factor-outbound-pg/src/lib.rs | 31 ++++-- .../factor-outbound-pg/src/runtime_config.rs | 8 ++ .../src/runtime_config/spin.rs | 29 ++++++ crates/runtime-config/src/lib.rs | 20 +++- 12 files changed, 274 insertions(+), 67 deletions(-) create mode 100644 crates/factor-outbound-mysql/src/runtime_config.rs create mode 100644 crates/factor-outbound-mysql/src/runtime_config/spin.rs create mode 100644 crates/factor-outbound-pg/src/runtime_config.rs create mode 100644 crates/factor-outbound-pg/src/runtime_config/spin.rs diff --git a/Cargo.lock b/Cargo.lock index 350e6badd1..5943a3f4a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9028,6 +9028,7 @@ dependencies = [ "anyhow", "futures", "mysql_async", + "serde", "spin-core", "spin-factor-otel", "spin-factor-outbound-networking", @@ -9088,6 +9089,7 @@ dependencies = [ "postgres-native-tls", "postgres_range", "rust_decimal", + "serde", "serde_json", "spin-common", "spin-core", diff --git a/crates/factor-outbound-mysql/Cargo.toml b/crates/factor-outbound-mysql/Cargo.toml index 64f51e3db0..b99020589e 100644 --- a/crates/factor-outbound-mysql/Cargo.toml +++ b/crates/factor-outbound-mysql/Cargo.toml @@ -10,6 +10,7 @@ doctest = false [dependencies] anyhow = { workspace = true } futures = { workspace = true } +serde = { workspace = true } # Removing default features for mysql_async to remove flate2/zlib feature mysql_async = { version = "0.35", default-features = false, features = [ "minimal-rust", @@ -23,7 +24,7 @@ spin-resource-table = { path = "../table" } spin-telemetry = { path = "../telemetry" } spin-world = { path = "../world" } spin-wasi-async = { path = "../wasi-async" } -tokio = { workspace = true, features = ["rt-multi-thread"] } +tokio = { workspace = true, features = ["rt-multi-thread", "sync"] } tracing = { workspace = true } url = { workspace = true } diff --git a/crates/factor-outbound-mysql/src/host.rs b/crates/factor-outbound-mysql/src/host.rs index 6f5577d80f..41900a395d 100644 --- a/crates/factor-outbound-mysql/src/host.rs +++ b/crates/factor-outbound-mysql/src/host.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use anyhow::Result; use spin_core::wasmtime::component::{Accessor, FutureReader, Resource, StreamReader}; use spin_telemetry::traces::{self, Blame}; @@ -6,8 +8,7 @@ use spin_world::spin::mysql::mysql as v3; use spin_world::v1::mysql as v1; use spin_world::v2::mysql as v2; use spin_world::v2::rdbms_types as v2_types; -use std::sync::Arc; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, OwnedSemaphorePermit}; use tracing::field::Empty; use tracing::{Level, instrument}; @@ -15,7 +16,11 @@ use crate::client::Client; use crate::{InstanceState, InstanceStateInner, MysqlFactorData}; impl InstanceStateInner { - async fn open_connection(&mut self, address: &str) -> Result { + async fn open_connection( + &mut self, + address: &str, + permit: Option, + ) -> Result { spin_factor_outbound_networking::record_address_fields(address); if !self.is_address_allowed(address).await.map_err(|e| { @@ -40,7 +45,7 @@ impl InstanceStateInner { err })?; self.connections - .push(Arc::new(Mutex::new(client))) + .push((Arc::new(Mutex::new(client)), permit)) .map_err(|_| { // The guest exceeded the host-imposed connection limit. let err = v2::Error::ConnectionFailed("too many connections".into()); @@ -50,13 +55,16 @@ impl InstanceStateInner { } fn get_client(&mut self, connection: u32) -> Result>, v2::Error> { - self.connections.get(connection).cloned().ok_or_else(|| { - // The connection table is managed entirely by the host, so a - // missing handle indicates a host-side bug, not a guest mistake. - let err = v2::Error::ConnectionFailed("no connection found".into()); - traces::mark_as_error(&err, Some(Blame::Host)); - err - }) + self.connections + .get(connection) + .map(|(conn, _permit)| conn.clone()) + .ok_or_else(|| { + // The connection table is managed entirely by the host, so a + // missing handle indicates a host-side bug, not a guest mistake. + let err = v2::Error::ConnectionFailed("no connection found".into()); + traces::mark_as_error(&err, Some(Blame::Host)); + err + }) } async fn is_address_allowed(&self, address: &str) -> Result { @@ -72,7 +80,7 @@ impl v3::Host for InstanceState { impl v3::HostConnection for InstanceState { async fn drop(&mut self, connection: Resource) -> Result<()> { - let mut state = self.0.lock().await; + let mut state = self.inner.lock().await; state.connections.remove(connection.rep()); Ok(()) } @@ -90,10 +98,21 @@ impl v3::HostConnectionWithStore for MysqlFactorData { accessor: &Accessor, address: String, ) -> Result, v3::Error> { - let state = accessor.with(|mut access| access.get().0.clone()); - let mut state = state.lock().await; + let (state_arc, connection_semaphore) = accessor.with(|mut access| { + let host = access.get(); + (host.inner.clone(), host.connection_semaphore.clone()) + }); + let permit = match connection_semaphore { + Some(sem) => Some(sem.acquire_owned().await.map_err(|_| { + v3::Error::from(v2::Error::ConnectionFailed("too many connections".into())) + })?), + None => None, + }; + let mut state = state_arc.lock().await; state.otel.reparent_tracing_span(); - Ok(Resource::new_own(state.open_connection(&address).await?)) + Ok(Resource::new_own( + state.open_connection(&address, permit).await?, + )) } #[instrument(name = "spin_outbound_mysql.execute", skip(accessor, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))] @@ -103,7 +122,7 @@ impl v3::HostConnectionWithStore for MysqlFactorData { statement: String, params: Vec, ) -> Result<(), v3::Error> { - let state = accessor.with(|mut access| access.get().0.clone()); + let state = accessor.with(|mut access| access.get().inner.clone()); let client = { let mut state = state.lock().await; state.otel.reparent_tracing_span(); @@ -125,7 +144,7 @@ impl v3::HostConnectionWithStore for MysqlFactorData { statement: String, params: Vec, ) -> Result { - let state = accessor.with(|mut access| access.get().0.clone()); + let state = accessor.with(|mut access| access.get().inner.clone()); let client = { let mut state = state.lock().await; state.otel.reparent_tracing_span(); @@ -161,9 +180,21 @@ impl v2::Host for InstanceState {} impl v2::HostConnection for InstanceState { #[instrument(name = "spin_outbound_mysql.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", db.address = Empty, server.port = Empty, db.namespace = Empty))] async fn open(&mut self, address: String) -> Result, v2::Error> { - let mut state = self.0.lock().await; + let permit = match &self.connection_semaphore { + Some(sem) => Some( + Arc::clone(sem) + .acquire_owned() + .await + .map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))?, + ), + None => None, + }; + let mut state = self.inner.lock().await; state.otel.reparent_tracing_span(); - state.open_connection(&address).await.map(Resource::new_own) + state + .open_connection(&address, permit) + .await + .map(Resource::new_own) } #[instrument(name = "spin_outbound_mysql.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))] @@ -173,7 +204,7 @@ impl v2::HostConnection for InstanceState { statement: String, params: Vec, ) -> Result<(), v2::Error> { - let mut state = self.0.lock().await; + let mut state = self.inner.lock().await; state.otel.reparent_tracing_span(); state .get_client(connection.rep())? @@ -191,7 +222,7 @@ impl v2::HostConnection for InstanceState { statement: String, params: Vec, ) -> Result { - let mut state = self.0.lock().await; + let mut state = self.inner.lock().await; state.otel.reparent_tracing_span(); state .get_client(connection.rep())? @@ -203,7 +234,7 @@ impl v2::HostConnection for InstanceState { } async fn drop(&mut self, connection: Resource) -> Result<()> { - let mut state = self.0.lock().await; + let mut state = self.inner.lock().await; state.connections.remove(connection.rep()); Ok(()) } @@ -218,13 +249,27 @@ impl v2_types::Host for InstanceState { /// Delegate a function call to the v2::HostConnection implementation macro_rules! delegate { ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{ + let permit = match &$self.connection_semaphore { + Some(sem) => Some( + Arc::clone(sem) + .acquire_owned() + .await + .map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))?, + ), + None => None, + }; let connection = { - let mut state = $self.0.lock().await; - Resource::new_own(state.open_connection(&$address).await?) + let mut state = $self.inner.lock().await; + Resource::new_own(state.open_connection(&$address, permit).await?) }; - ::$name($self, connection, $($arg),*) + // v1 has no persistent connections, so remove the table entry immediately + // after the call to release the semaphore permit. + let rep = connection.rep(); + let result = ::$name($self, connection, $($arg),*) .await - .map_err(Into::into) + .map_err(Into::into); + $self.inner.lock().await.connections.remove(rep); + result }}; } diff --git a/crates/factor-outbound-mysql/src/lib.rs b/crates/factor-outbound-mysql/src/lib.rs index 10a51f1e5f..34d08ac7e1 100644 --- a/crates/factor-outbound-mysql/src/lib.rs +++ b/crates/factor-outbound-mysql/src/lib.rs @@ -1,8 +1,12 @@ pub mod client; mod host; +pub mod runtime_config; + +use std::sync::Arc; use client::Client; use mysql_async::Conn as MysqlClient; +use runtime_config::RuntimeConfig; use spin_factor_otel::OtelFactorState; use spin_factor_outbound_networking::{ OutboundNetworkingFactor, config::allowed_hosts::OutboundAllowedHosts, @@ -11,16 +15,20 @@ use spin_factors::{Factor, FactorData, InitContext, RuntimeFactors, SelfInstance use spin_world::spin::mysql::mysql as v3; use spin_world::v1::mysql as v1; use spin_world::v2::mysql as v2; -use std::sync::Arc; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore}; pub struct OutboundMysqlFactor { _phantom: std::marker::PhantomData, } +pub struct AppState { + /// A semaphore to limit the number of concurrent outbound MySQL connections. + pub connection_semaphore: Option>, +} + impl Factor for OutboundMysqlFactor { - type RuntimeConfig = (); - type AppState = (); + type RuntimeConfig = RuntimeConfig; + type AppState = AppState; type InstanceBuilder = InstanceState; fn init(&mut self, ctx: &mut impl InitContext) -> anyhow::Result<()> { @@ -32,9 +40,12 @@ impl Factor for OutboundMysqlFactor { fn configure_app( &self, - _ctx: spin_factors::ConfigureAppContext, + mut ctx: spin_factors::ConfigureAppContext, ) -> anyhow::Result { - Ok(()) + let config = ctx.take_runtime_config().unwrap_or_default(); + Ok(AppState { + connection_semaphore: config.max_connections.map(|n| Arc::new(Semaphore::new(n))), + }) } fn prepare( @@ -46,11 +57,14 @@ impl Factor for OutboundMysqlFactor { .allowed_hosts(); let otel = OtelFactorState::from_prepare_context(&mut ctx)?; - Ok(InstanceState(Arc::new(Mutex::new(InstanceStateInner { - allowed_hosts, - connections: Default::default(), - otel, - })))) + Ok(InstanceState { + inner: Arc::new(Mutex::new(InstanceStateInner { + allowed_hosts, + connections: Default::default(), + otel, + })), + connection_semaphore: ctx.app_state().connection_semaphore.clone(), + }) } } @@ -70,11 +84,14 @@ impl OutboundMysqlFactor { pub struct InstanceStateInner { allowed_hosts: OutboundAllowedHosts, - connections: spin_resource_table::Table>>, + connections: spin_resource_table::Table<(Arc>, Option)>, otel: OtelFactorState, } -pub struct InstanceState(Arc>>); +pub struct InstanceState { + pub(crate) inner: Arc>>, + pub connection_semaphore: Option>, +} impl SelfInstanceBuilder for InstanceState {} diff --git a/crates/factor-outbound-mysql/src/runtime_config.rs b/crates/factor-outbound-mysql/src/runtime_config.rs new file mode 100644 index 0000000000..5a96047a1f --- /dev/null +++ b/crates/factor-outbound-mysql/src/runtime_config.rs @@ -0,0 +1,8 @@ +pub mod spin; + +/// Runtime configuration for outbound MySQL. +#[derive(Default)] +pub struct RuntimeConfig { + /// If set, limits the number of concurrent outbound MySQL connections. + pub max_connections: Option, +} diff --git a/crates/factor-outbound-mysql/src/runtime_config/spin.rs b/crates/factor-outbound-mysql/src/runtime_config/spin.rs new file mode 100644 index 0000000000..85c38253cc --- /dev/null +++ b/crates/factor-outbound-mysql/src/runtime_config/spin.rs @@ -0,0 +1,29 @@ +use serde::Deserialize; +use spin_factors::runtime_config::toml::GetTomlValue; + +/// Get the runtime configuration for outbound MySQL from a TOML table. +/// +/// Expects table to be in the format: +/// ```toml +/// [outbound_mysql] +/// max_connections = 10 # optional, defaults to unlimited +/// ``` +pub fn config_from_table( + table: &impl GetTomlValue, +) -> anyhow::Result> { + if let Some(outbound_mysql) = table.get("outbound_mysql") { + let toml = outbound_mysql.clone().try_into::()?; + Ok(Some(super::RuntimeConfig { + max_connections: toml.max_connections, + })) + } else { + Ok(None) + } +} + +#[derive(Debug, Default, Deserialize)] +#[serde(deny_unknown_fields)] +struct OutboundMysqlToml { + #[serde(default)] + max_connections: Option, +} diff --git a/crates/factor-outbound-pg/Cargo.toml b/crates/factor-outbound-pg/Cargo.toml index 45dcc7a22f..b7cac624b8 100644 --- a/crates/factor-outbound-pg/Cargo.toml +++ b/crates/factor-outbound-pg/Cargo.toml @@ -7,6 +7,7 @@ edition = { workspace = true } [dependencies] anyhow = { workspace = true } bytes = {workspace = true } +serde = { workspace = true } chrono = { workspace = true } deadpool-postgres = { version = "0.14", features = ["rt_tokio_1"] } futures = { workspace = true } @@ -27,7 +28,7 @@ spin-resource-table = { path = "../table" } spin-telemetry = { path = "../telemetry" } spin-wasi-async = { path = "../wasi-async" } spin-world = { path = "../world" } -tokio = { workspace = true, features = ["rt-multi-thread"] } +tokio = { workspace = true, features = ["rt-multi-thread", "sync"] } tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_json-1", "with-uuid-1"] } tracing = { workspace = true } url = { workspace = true } diff --git a/crates/factor-outbound-pg/src/host.rs b/crates/factor-outbound-pg/src/host.rs index 5faf7663a8..81ab16d677 100644 --- a/crates/factor-outbound-pg/src/host.rs +++ b/crates/factor-outbound-pg/src/host.rs @@ -1,5 +1,7 @@ #![allow(clippy::result_large_err)] +use std::sync::Arc; + use anyhow::Result; use spin_core::wasmtime::component::{Accessor, FutureReader, Resource, StreamReader}; use spin_telemetry::traces::{self, Blame}; @@ -24,6 +26,14 @@ impl InstanceState { address: &str, root_ca: Option, ) -> Result, v4::Error> { + let permit = match &self.connection_semaphore { + Some(sem) => Some(Arc::clone(sem).acquire_owned().await.map_err(|_| { + let err = v4::Error::ConnectionFailed("too many connections".into()); + traces::mark_as_error(&err, Some(Blame::Guest)); + err + })?), + None => None, + }; let client = self .client_factory .get_client(address, root_ca) @@ -37,7 +47,7 @@ impl InstanceState { err })?; self.connections - .push(client) + .push((client, permit)) .map_err(|_| { // The guest exceeded the host-imposed connection limit. let err = v4::Error::ConnectionFailed("too many connections".into()); @@ -51,13 +61,16 @@ impl InstanceState { &self, connection: Resource, ) -> Result<&CF::Client, v4::Error> { - self.connections.get(connection.rep()).ok_or_else(|| { - // The connection table is managed entirely by the host, so a - // missing handle indicates a host-side bug, not a guest mistake. - let err = v4::Error::ConnectionFailed("no connection found".into()); - traces::mark_as_error(&err, Some(Blame::Host)); - err - }) + self.connections + .get(connection.rep()) + .map(|(client, _permit)| client) + .ok_or_else(|| { + // The connection table is managed entirely by the host, so a + // missing handle indicates a host-side bug, not a guest mistake. + let err = v4::Error::ConnectionFailed("no connection found".into()); + traces::mark_as_error(&err, Some(Blame::Host)); + err + }) } fn allowed_host_checker(&self) -> AllowedHostChecker { @@ -260,7 +273,10 @@ impl spin_world::spin::postgres4_2_0::postgres::HostConnectio ) -> Result { let client = accessor.with(|mut access| { let host = access.get(); - host.connections.get(connection.rep()).unwrap().clone() + host.connections + .get(connection.rep()) + .map(|(client, _permit)| client.clone()) + .unwrap() }); client @@ -286,7 +302,10 @@ impl spin_world::spin::postgres4_2_0::postgres::HostConnectio > { let client = accessor.with(|mut access| { let host = access.get(); - host.connections.get(connection.rep()).unwrap().clone() + host.connections + .get(connection.rep()) + .map(|(client, _permit)| client.clone()) + .unwrap() }); let QueryAsyncResult { @@ -368,11 +387,23 @@ impl crate::PgFactorData { address: &str, root_ca: Option, ) -> Result, v4::Error> { - let cf = accessor.with(|mut access| { + let (cf, connection_semaphore) = accessor.with(|mut access| { let host = access.get(); - host.client_factory.clone() + ( + host.client_factory.clone(), + host.connection_semaphore.clone(), + ) }); + let permit = match connection_semaphore { + Some(sem) => Some(sem.acquire_owned().await.map_err(|_| { + let err = v4::Error::ConnectionFailed("too many connections".into()); + traces::mark_as_error(&err, Some(Blame::Guest)); + err + })?), + None => None, + }; + let client = cf.get_client(address, root_ca).await.map_err(|e| { let err = v4::Error::ConnectionFailed(format!("{e:?}")); traces::mark_as_error(&err, Some(Blame::Guest)); @@ -382,7 +413,7 @@ impl crate::PgFactorData { accessor.with(|mut access| { let host = access.get(); host.connections - .push(client) + .push((client, permit)) .map_err(|_| { let err = v4::Error::ConnectionFailed("too many connections".into()); traces::mark_as_error(&err, Some(Blame::Guest)); @@ -429,7 +460,7 @@ impl v4::Host for InstanceState { } } -/// Delegate a function call to the v3::HostConnection implementation +/// Delegate a function call to the v4::HostConnection implementation macro_rules! delegate { ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{ $self.ensure_address_allowed(&$address).await?; @@ -437,9 +468,14 @@ macro_rules! delegate { Ok(c) => c, Err(e) => return Err(e.into()), }; - ::$name($self, connection, $($arg),*) + // v1 has no persistent connections, so remove the table entry immediately + // after the call to release the semaphore permit. + let rep = connection.rep(); + let result = ::$name($self, connection, $($arg),*) .await - .map_err(|e| e.into()) + .map_err(|e| e.into()); + $self.connections.remove(rep); + result }}; } diff --git a/crates/factor-outbound-pg/src/lib.rs b/crates/factor-outbound-pg/src/lib.rs index d20cfe492f..62714e1954 100644 --- a/crates/factor-outbound-pg/src/lib.rs +++ b/crates/factor-outbound-pg/src/lib.rs @@ -1,25 +1,34 @@ mod allowed_hosts; pub mod client; mod host; +pub mod runtime_config; mod types; use std::{collections::HashMap, sync::Arc}; use allowed_hosts::AllowedHostChecker; use client::ClientFactory; +use runtime_config::RuntimeConfig; use spin_factor_otel::OtelFactorState; use spin_factor_outbound_networking::OutboundNetworkingFactor; use spin_factors::{ ConfigureAppContext, Factor, PrepareContext, RuntimeFactors, SelfInstanceBuilder, anyhow, }; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; pub struct OutboundPgFactor { _phantom: std::marker::PhantomData, } +pub struct AppState { + pub client_factories: HashMap>, + /// A semaphore to limit the number of concurrent outbound PostgreSQL connections. + pub connection_semaphore: Option>, +} + impl Factor for OutboundPgFactor { - type RuntimeConfig = (); - type AppState = HashMap>; + type RuntimeConfig = RuntimeConfig; + type AppState = AppState; type InstanceBuilder = InstanceState; fn init(&mut self, ctx: &mut impl spin_factors::InitContext) -> anyhow::Result<()> { @@ -36,13 +45,17 @@ impl Factor for OutboundPgFactor { fn configure_app( &self, - ctx: ConfigureAppContext, + mut ctx: ConfigureAppContext, ) -> anyhow::Result { + let config = ctx.take_runtime_config().unwrap_or_default(); let mut client_factories = HashMap::new(); for comp in ctx.app().components() { client_factories.insert(comp.id().to_string(), Arc::new(CF::default())); } - Ok(client_factories) + Ok(AppState { + client_factories, + connection_semaphore: config.max_connections.map(|n| Arc::new(Semaphore::new(n))), + }) } fn prepare( @@ -53,7 +66,11 @@ impl Factor for OutboundPgFactor { .instance_builder::()? .allowed_hosts(); let otel = OtelFactorState::from_prepare_context(&mut ctx)?; - let cf = ctx.app_state().get(ctx.app_component().id()).unwrap(); + let cf = ctx + .app_state() + .client_factories + .get(ctx.app_component().id()) + .unwrap(); Ok(InstanceState { allowed_host_checker: AllowedHostChecker::new(allowed_hosts), @@ -61,6 +78,7 @@ impl Factor for OutboundPgFactor { connections: Default::default(), otel, builders: Default::default(), + connection_semaphore: ctx.app_state().connection_semaphore.clone(), }) } } @@ -82,9 +100,10 @@ impl OutboundPgFactor { pub struct InstanceState { allowed_host_checker: AllowedHostChecker, client_factory: Arc, - connections: spin_resource_table::Table, + connections: spin_resource_table::Table<(CF::Client, Option)>, otel: OtelFactorState, builders: spin_resource_table::Table, + pub connection_semaphore: Option>, } impl SelfInstanceBuilder for InstanceState {} diff --git a/crates/factor-outbound-pg/src/runtime_config.rs b/crates/factor-outbound-pg/src/runtime_config.rs new file mode 100644 index 0000000000..7cf9745c0e --- /dev/null +++ b/crates/factor-outbound-pg/src/runtime_config.rs @@ -0,0 +1,8 @@ +pub mod spin; + +/// Runtime configuration for outbound PostgreSQL. +#[derive(Default)] +pub struct RuntimeConfig { + /// If set, limits the number of concurrent outbound PostgreSQL connections. + pub max_connections: Option, +} diff --git a/crates/factor-outbound-pg/src/runtime_config/spin.rs b/crates/factor-outbound-pg/src/runtime_config/spin.rs new file mode 100644 index 0000000000..b82c60ea43 --- /dev/null +++ b/crates/factor-outbound-pg/src/runtime_config/spin.rs @@ -0,0 +1,29 @@ +use serde::Deserialize; +use spin_factors::runtime_config::toml::GetTomlValue; + +/// Get the runtime configuration for outbound PostgreSQL from a TOML table. +/// +/// Expects table to be in the format: +/// ```toml +/// [outbound_pg] +/// max_connections = 10 # optional, defaults to unlimited +/// ``` +pub fn config_from_table( + table: &impl GetTomlValue, +) -> anyhow::Result> { + if let Some(outbound_pg) = table.get("outbound_pg") { + let toml = outbound_pg.clone().try_into::()?; + Ok(Some(super::RuntimeConfig { + max_connections: toml.max_connections, + })) + } else { + Ok(None) + } +} + +#[derive(Debug, Default, Deserialize)] +#[serde(deny_unknown_fields)] +struct OutboundPgToml { + #[serde(default)] + max_connections: Option, +} diff --git a/crates/runtime-config/src/lib.rs b/crates/runtime-config/src/lib.rs index 4f73e7cba6..c9433a1781 100644 --- a/crates/runtime-config/src/lib.rs +++ b/crates/runtime-config/src/lib.rs @@ -79,6 +79,14 @@ impl ResolvedRuntimeConfig { summaries.push(format!("[llm_compute: {ty}")); } } + // [outbound_redis: max_connections=N], [outbound_pg: max_connections=N], [outbound_mysql: max_connections=N] + for key in ["outbound_redis", "outbound_pg", "outbound_mysql"] { + if let Some(table) = self.toml.get(key).and_then(Value::as_table) { + if let Some(max) = table.get("max_connections").and_then(Value::as_integer) { + summaries.push(format!("[{key}: max_connections={max}]")); + } + } + } if !summaries.is_empty() { let summaries = summaries.join(", "); let from_path = runtime_config_path @@ -350,14 +358,18 @@ impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_, } impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_, '_> { - fn get_runtime_config(&mut self) -> anyhow::Result> { - Ok(None) + fn get_runtime_config( + &mut self, + ) -> anyhow::Result::RuntimeConfig>> { + spin_factor_outbound_pg::runtime_config::spin::config_from_table(&self.toml.table) } } impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_, '_> { - fn get_runtime_config(&mut self) -> anyhow::Result> { - Ok(None) + fn get_runtime_config( + &mut self, + ) -> anyhow::Result::RuntimeConfig>> { + spin_factor_outbound_mysql::runtime_config::spin::config_from_table(&self.toml.table) } } From 9a70b0c63cb37f8a47ed0b8ea59717df300d008a Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Fri, 29 May 2026 13:52:22 +0200 Subject: [PATCH 06/15] Introduce global connection limit Signed-off-by: Ryan Levick --- Cargo.lock | 12 + crates/connection-semaphore/Cargo.toml | 17 + crates/connection-semaphore/src/lib.rs | 290 ++++++++++++++++++ crates/factor-outbound-http/src/lib.rs | 102 ++---- crates/factor-outbound-http/src/spin.rs | 11 +- crates/factor-outbound-http/src/wasi.rs | 39 ++- crates/factor-outbound-mysql/src/host.rs | 46 ++- crates/factor-outbound-mysql/src/lib.rs | 32 +- crates/factor-outbound-networking/Cargo.toml | 1 + .../src/connection_semaphore.rs | 1 + crates/factor-outbound-networking/src/lib.rs | 44 ++- .../src/runtime_config.rs | 7 +- .../src/runtime_config/spin.rs | 6 +- .../tests/factor_test.rs | 12 +- crates/factor-outbound-pg/src/host.rs | 35 +-- crates/factor-outbound-pg/src/lib.rs | 36 ++- crates/factor-outbound-redis/src/host.rs | 56 ++-- crates/factor-outbound-redis/src/lib.rs | 28 +- crates/factor-wasi/Cargo.toml | 1 + crates/factor-wasi/src/sockets.rs | 39 +-- crates/runtime-config/src/lib.rs | 15 + 21 files changed, 578 insertions(+), 252 deletions(-) create mode 100644 crates/connection-semaphore/Cargo.toml create mode 100644 crates/connection-semaphore/src/lib.rs create mode 100644 crates/factor-outbound-networking/src/connection_semaphore.rs diff --git a/Cargo.lock b/Cargo.lock index 5943a3f4a9..831186f5ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8794,6 +8794,16 @@ dependencies = [ "wac-graph 0.10.0", ] +[[package]] +name = "spin-connection-semaphore" +version = "4.1.0-pre0" +dependencies = [ + "anyhow", + "spin-telemetry", + "tokio", + "tracing", +] + [[package]] name = "spin-core" version = "4.1.0-pre0" @@ -9057,6 +9067,7 @@ dependencies = [ "rustls-pki-types", "rustls-platform-verifier", "serde", + "spin-connection-semaphore", "spin-factor-variables", "spin-factor-wasi", "spin-factors", @@ -9169,6 +9180,7 @@ dependencies = [ "async-trait", "bytes", "spin-common", + "spin-connection-semaphore", "spin-factors", "spin-factors-test", "tokio", diff --git a/crates/connection-semaphore/Cargo.toml b/crates/connection-semaphore/Cargo.toml new file mode 100644 index 0000000000..eeb77a7103 --- /dev/null +++ b/crates/connection-semaphore/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "spin-connection-semaphore" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } + +[dependencies] +anyhow = { workspace = true } +spin-telemetry = { path = "../telemetry" } +tokio = { workspace = true, features = ["sync"] } +tracing = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "rt"] } + +[lints] +workspace = true diff --git a/crates/connection-semaphore/src/lib.rs b/crates/connection-semaphore/src/lib.rs new file mode 100644 index 0000000000..183a04d9a5 --- /dev/null +++ b/crates/connection-semaphore/src/lib.rs @@ -0,0 +1,290 @@ +use std::sync::Arc; + +use anyhow::anyhow; +use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; + +/// Wraps an optional global and an optional factor-specific semaphore. +#[derive(Clone)] +pub struct ConnectionSemaphore { + global: Option>, + factor_specific: Option>, + factor: &'static str, +} + +impl ConnectionSemaphore { + /// Creates a new `ConnectionSemaphore`. + pub fn new( + global: Option>, + factor_specific_limit: Option, + factor: &'static str, + ) -> Self { + Self { + global, + factor_specific: factor_specific_limit.map(|n| Arc::new(Semaphore::new(n))), + factor, + } + } + + /// Creates a `ConnectionSemaphore` from pre-existing semaphore handles. + /// + /// This is intended for testing and internal use where an already-constructed + /// (and possibly partially acquired) semaphore must be used directly. + #[doc(hidden)] + pub fn from_raw( + global: Option>, + factor_specific: Option>, + factor: &'static str, + ) -> Self { + Self { + global, + factor_specific, + factor, + } + } + + /// Acquire both configured semaphore slots, returning a permit that holds + /// them until dropped. + /// + /// When both a global and a factor-specific semaphore are configured, this + /// method never holds one permit while blocking on the other, preventing global + /// permits from being tied up while waiting on a factor-specific backlog. + pub async fn acquire(&self) -> anyhow::Result { + /// Acquires a single permit from `sem`, trying non-blocking first. + /// + /// Sets `*waited = true` if a blocking wait was required. + async fn acquire_one( + sem: &Arc, + waited: &mut bool, + label: &str, + ) -> anyhow::Result { + match sem.clone().try_acquire_owned() { + Ok(p) => Ok(p), + Err(TryAcquireError::NoPermits) => { + *waited = true; + sem.clone() + .acquire_owned() + .await + .map_err(|_| anyhow!("{label} connection semaphore closed")) + } + Err(_) => Err(anyhow!("{label} connection semaphore closed")), + } + } + let mut waited = false; + + let (global, factor_specific) = match (&self.global, &self.factor_specific) { + (None, None) => (None, None), + (Some(g), None) => (Some(acquire_one(g, &mut waited, "global").await?), None), + (None, Some(f)) => (None, Some(acquire_one(f, &mut waited, "factor").await?)), + // Loop until we acquire both. We have to be careful to avoid holding one permit while waiting for the other. + (Some(g), Some(f)) => loop { + let global = acquire_one(g, &mut waited, "global").await?; + match f.clone().try_acquire_owned() { + Ok(factor) => break (Some(global), Some(factor)), + Err(TryAcquireError::NoPermits) => {} + Err(_) => anyhow::bail!("factor connection semaphore closed"), + } + // Factor specific has no free permits: release global so other connection types aren't blocked, + // then wait for factor-specific before trying global again. + drop(global); + waited = true; + let factor = acquire_one(f, &mut waited, "factor").await?; + match g.clone().try_acquire_owned() { + Ok(global) => break (Some(global), Some(factor)), + Err(TryAcquireError::NoPermits) => {} + Err(_) => anyhow::bail!("global connection semaphore closed"), + } + // Global has no free permits: release factor specific and retry from the top of the loop. + drop(factor); + }, + }; + + let factor = self.factor; + spin_telemetry::monotonic_counter!( + outbound_connection_permits_acquired = 1, + factor = factor, + waited = waited + ); + + Ok(ConnectionPermit { + _global: global, + _factor_specific: factor_specific, + }) + } + + /// Attempt to acquire both configured slots without waiting. + /// Returns `None` if either semaphore is exhausted. + /// + /// If the global permit is acquired but the factor-specific permit is not + /// available, the global permit is released before returning `None`. + pub fn try_acquire(&self) -> Option { + // Acquire global first. If it fails, nothing is consumed — return None. + let global = match &self.global { + Some(s) => Some(s.clone().try_acquire_owned().ok()?), + None => None, + }; + // Now attempt the factor-specific permit. + // If it fails, the global OwnedSemaphorePermit is dropped here, releasing + // the global slot before we return None. + let factor_specific = match &self.factor_specific { + Some(s) => Some(s.clone().try_acquire_owned().ok()?), + None => None, + }; + + let factor = self.factor; + spin_telemetry::monotonic_counter!( + outbound_connection_permits_acquired = 1, + factor = factor, + waited = false + ); + + Some(ConnectionPermit { + _global: global, + _factor_specific: factor_specific, + }) + } +} + +/// Holds up to two semaphore permits (global + factor-specific). +/// Both permits are released when this value is dropped. +/// All-`None` fields are valid and represent the no-limits case. +/// +/// Fields are intentionally prefixed with `_` — they exist solely to be dropped. +pub struct ConnectionPermit { + _global: Option, + _factor_specific: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn no_limits_acquire_always_succeeds() { + let sem = ConnectionSemaphore::new(None, None, "test"); + let permit = sem.acquire().await.expect("should succeed"); + drop(permit); + let _permit2 = sem.acquire().await.expect("should succeed again"); + } + + #[test] + fn no_limits_try_acquire_always_succeeds() { + let sem = ConnectionSemaphore::new(None, None, "test"); + let permit = sem.try_acquire().expect("should succeed"); + drop(permit); + let _permit2 = sem.try_acquire().expect("should succeed again"); + } + + #[test] + fn global_limit_only_exhausted() { + let global = Arc::new(Semaphore::new(1)); + let sem = ConnectionSemaphore::new(Some(global.clone()), None, "test"); + let permit1 = sem.try_acquire().expect("first should succeed"); + assert!( + sem.try_acquire().is_none(), + "second should fail: global exhausted" + ); + drop(permit1); + assert_eq!(global.available_permits(), 1); + let _permit3 = sem.try_acquire().expect("after release should succeed"); + } + + #[test] + fn factor_limit_only_exhausted() { + let sem = ConnectionSemaphore::new(None, Some(1), "test"); + let permit1 = sem.try_acquire().expect("first should succeed"); + assert!( + sem.try_acquire().is_none(), + "second should fail: factor exhausted" + ); + drop(permit1); + let _permit3 = sem.try_acquire().expect("after release should succeed"); + } + + #[test] + fn both_limits_global_exhausted_first() { + let global = Arc::new(Semaphore::new(1)); + let factor = Arc::new(Semaphore::new(2)); + let sem = ConnectionSemaphore::from_raw(Some(global.clone()), Some(factor.clone()), "test"); + + let permit1 = sem.try_acquire().expect("first should succeed"); + // After permit1: global=0, factor=1 + let factor_before = factor.available_permits(); + + // Second try_acquire should fail because global is exhausted. + assert!(sem.try_acquire().is_none(), "should fail: global exhausted"); + // Factor must NOT have been consumed by the failed attempt. + assert_eq!( + factor.available_permits(), + factor_before, + "factor permits should not be consumed when global is exhausted" + ); + drop(permit1); + } + + #[test] + fn both_limits_factor_exhausted_global_released() { + let global = Arc::new(Semaphore::new(2)); + let factor = Arc::new(Semaphore::new(1)); + let sem = ConnectionSemaphore::from_raw(Some(global.clone()), Some(factor.clone()), "test"); + + let permit1 = sem.try_acquire().expect("first should succeed"); + // Global still has 1, factor exhausted + let result = sem.try_acquire(); + assert!(result.is_none(), "should fail: factor exhausted"); + // Global slot must have been released (back to 1) + assert_eq!(global.available_permits(), 1); + drop(permit1); + assert_eq!(global.available_permits(), 2); + } + + #[tokio::test] + async fn acquire_waits_for_release() { + let global = Arc::new(Semaphore::new(1)); + let sem = ConnectionSemaphore::new(Some(global.clone()), None, "test"); + + let permit = sem.try_acquire().expect("first should succeed"); + + let sem2 = sem.clone(); + let handle = tokio::spawn(async move { + let _p = sem2.acquire().await.expect("should eventually acquire"); + }); + + drop(permit); // release so the spawned task can proceed + handle.await.expect("task should complete"); + } + + /// Verifies that when factor-specific is exhausted, acquire() releases + /// the global permit while waiting — so other connection types aren't blocked. + #[tokio::test] + async fn acquire_releases_global_while_waiting_for_factor() { + let global = Arc::new(Semaphore::new(1)); + let factor = Arc::new(Semaphore::new(1)); + let sem = ConnectionSemaphore::from_raw(Some(global.clone()), Some(factor.clone()), "test"); + + // Exhaust factor-specific from outside. + let _factor_hold = factor.clone().acquire_owned().await.unwrap(); + + let global_clone = global.clone(); + let sem_clone = sem.clone(); + let handle = tokio::spawn(async move { + sem_clone + .acquire() + .await + .expect("should succeed after factor is released") + }); + + // Yield twice: first to let the spawned task run until it blocks waiting + // for factor-specific; second to confirm it has released the global permit. + tokio::task::yield_now().await; + tokio::task::yield_now().await; + + assert_eq!( + global_clone.available_permits(), + 1, + "global should be free while acquire() waits for factor-specific" + ); + + drop(_factor_hold); + handle.await.expect("task should complete"); + } +} diff --git a/crates/factor-outbound-http/src/lib.rs b/crates/factor-outbound-http/src/lib.rs index e7b43e0bed..fb0b42eed9 100644 --- a/crates/factor-outbound-http/src/lib.rs +++ b/crates/factor-outbound-http/src/lib.rs @@ -16,14 +16,13 @@ use intercept::OutboundHttpInterceptor; use runtime_config::RuntimeConfig; use spin_factor_otel::OtelFactorState; use spin_factor_outbound_networking::{ - ComponentTlsClientConfigs, OutboundNetworkingFactor, + ComponentTlsClientConfigs, ConnectionSemaphore, OutboundNetworkingFactor, config::{allowed_hosts::OutboundAllowedHosts, blocked_networks::BlockedNetworks}, }; use spin_factors::{ ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, SelfInstanceBuilder, anyhow, }; -use tokio::sync::Semaphore; use wasmtime_wasi_http::WasiHttpCtx; pub use wasmtime_wasi_http::p2::{ @@ -56,14 +55,30 @@ impl Factor for OutboundHttpFactor { mut ctx: ConfigureAppContext, ) -> anyhow::Result { let config = ctx.take_runtime_config().unwrap_or_default(); + + let networking = ctx.app_state::().ok(); + let global = networking.and_then(|s| s.global_connection_semaphore.clone()); + let global_total_limit = networking.and_then(|s| s.max_total_connections); + + if let (Some(per_factor), Some(global_limit)) = + (config.max_concurrent_connections, global_total_limit) + && per_factor > global_limit + { + tracing::warn!( + "outbound_http max_concurrent_requests ({per_factor}) exceeds global \ + max_total_connections ({global_limit}); the global limit will be the \ + effective cap" + ); + } + + // Permit count is the max concurrent connections + 1. + // i.e., 0 concurrent connections means 1 total connection. + let factor_specific_limit = config.max_concurrent_connections.map(|n| n + 1); + Ok(AppState { wasi_http_clients: wasi::HttpClients::new(config.connection_pooling_enabled), connection_pooling_enabled: config.connection_pooling_enabled, - concurrent_outbound_connections_semaphore: config - .max_concurrent_connections - // Permit count is the max concurrent connections + 1. - // i.e., 0 concurrent connections means 1 total connection. - .map(|n| Arc::new(Semaphore::new(n + 1))), + semaphore: ConnectionSemaphore::new(global, factor_specific_limit, "http"), }) } @@ -87,10 +102,7 @@ impl Factor for OutboundHttpFactor { spin_http_client: None, wasi_http_clients: ctx.app_state().wasi_http_clients.clone(), connection_pooling_enabled: ctx.app_state().connection_pooling_enabled, - concurrent_outbound_connections_semaphore: ctx - .app_state() - .concurrent_outbound_connections_semaphore - .clone(), + semaphore: ctx.app_state().semaphore.clone(), otel, }, }) @@ -121,8 +133,8 @@ struct InstanceHttpHooks { wasi_http_clients: wasi::HttpClients, /// Whether connection pooling is enabled for this instance. connection_pooling_enabled: bool, - /// A semaphore to limit the number of concurrent outbound connections. - concurrent_outbound_connections_semaphore: Option>, + /// Semaphore to limit concurrent outbound connections. + semaphore: ConnectionSemaphore, /// Manages access to the OtelFactor state. otel: OtelFactorState, } @@ -153,66 +165,6 @@ impl InstanceState { impl SelfInstanceBuilder for InstanceState {} -/// Helper module for acquiring permits from the outbound connections semaphore. -/// -/// This is used by the outbound HTTP implementations to limit concurrent outbound connections. -mod concurrent_outbound_connections { - use super::*; - - /// Acquires a semaphore permit for the given interface, if a semaphore is configured. - pub async fn acquire_semaphore<'a>( - interface: &str, - semaphore: &'a Option>, - ) -> Option> { - let s = semaphore.as_ref()?; - acquire(interface, || s.try_acquire(), async || s.acquire().await).await - } - - /// Acquires an owned semaphore permit for the given interface, if a semaphore is configured. - pub async fn acquire_owned_semaphore( - interface: &str, - semaphore: &Option>, - ) -> Option { - let s = semaphore.as_ref()?; - acquire( - interface, - || s.clone().try_acquire_owned(), - async || s.clone().acquire_owned().await, - ) - .await - } - - /// Helper function to acquire a semaphore permit, either immediately or by waiting. - /// - /// Allows getting either a borrowed or owned permit. - async fn acquire( - interface: &str, - try_acquire: impl Fn() -> Result, - acquire: impl AsyncFnOnce() -> Result, - ) -> Option { - // Try to acquire a permit without waiting first - // Keep track of whether we had to wait for metrics purposes. - let mut waited = false; - let permit = match try_acquire() { - Ok(p) => Ok(p), - // No available permits right now; wait for one - Err(tokio::sync::TryAcquireError::NoPermits) => { - waited = true; - acquire().await.map_err(|_| ()) - } - Err(_) => Err(()), - }; - if permit.is_ok() { - spin_telemetry::monotonic_counter!( - outbound_http.concurrent_connection_permits_acquired = 1, - interface = interface, - waited = waited - ); - } - permit.ok() - } -} - pub type Request = http::Request; pub type Response = http::Response; @@ -268,8 +220,8 @@ pub struct AppState { wasi_http_clients: wasi::HttpClients, /// Whether connection pooling is enabled for this app. connection_pooling_enabled: bool, - /// A semaphore to limit the number of concurrent outbound connections. - concurrent_outbound_connections_semaphore: Option>, + /// Semaphore to limit concurrent outbound connections. + semaphore: ConnectionSemaphore, } /// Removes IPs in the given [`BlockedNetworks`]. diff --git a/crates/factor-outbound-http/src/spin.rs b/crates/factor-outbound-http/src/spin.rs index 5c47204bfb..ef352d7e85 100644 --- a/crates/factor-outbound-http/src/spin.rs +++ b/crates/factor-outbound-http/src/spin.rs @@ -112,11 +112,12 @@ impl spin_http::Host for crate::InstanceState { // If we're limiting concurrent outbound requests, acquire a permit // Note: since we don't have access to the underlying connection, we can only // limit the number of concurrent requests, not connections. - let permit = crate::concurrent_outbound_connections::acquire_semaphore( - "spin", - &self.hooks.concurrent_outbound_connections_semaphore, - ) - .await; + let permit = self + .hooks + .semaphore + .acquire() + .await + .map_err(|_| HttpError::RuntimeError)?; let resp = client.execute(req).await.map_err(log_reqwest_error)?; drop(permit); diff --git a/crates/factor-outbound-http/src/wasi.rs b/crates/factor-outbound-http/src/wasi.rs index 2fc151562a..40957b6008 100644 --- a/crates/factor-outbound-http/src/wasi.rs +++ b/crates/factor-outbound-http/src/wasi.rs @@ -35,7 +35,6 @@ use spin_factors::RuntimeFactorsInstanceState; use tokio::{ io::{AsyncRead, AsyncWrite, ReadBuf}, net::TcpStream, - sync::{OwnedSemaphorePermit, Semaphore}, time::timeout, }; use tokio_rustls::client::TlsStream; @@ -53,6 +52,8 @@ use wasmtime_wasi_http::{ p3::{self, bindings::http::types as p3_types}, }; +use spin_factor_outbound_networking::{ConnectionPermit, ConnectionSemaphore}; + use crate::{ InstanceHttpHooks, OutboundHttpFactor, SelfRequestOrigin, intercept::{InterceptOutcome, OutboundHttpInterceptor}, @@ -184,9 +185,7 @@ impl p3::WasiHttpHooks for InstanceHttpHooks { self_request_origin: self.self_request_origin.clone(), blocked_networks: self.blocked_networks.clone(), http_clients: self.wasi_http_clients.clone(), - concurrent_outbound_connections_semaphore: self - .concurrent_outbound_connections_semaphore - .clone(), + semaphore: self.semaphore.clone(), }; let config = OutgoingRequestConfig { use_tls: request.uri().scheme() == Some(&Scheme::HTTPS), @@ -442,9 +441,7 @@ impl p2::WasiHttpHooks for InstanceHttpHooks { self_request_origin: self.self_request_origin.clone(), blocked_networks: self.blocked_networks.clone(), http_clients: self.wasi_http_clients.clone(), - concurrent_outbound_connections_semaphore: self - .concurrent_outbound_connections_semaphore - .clone(), + semaphore: self.semaphore.clone(), }; Ok(HostFutureIncomingResponse::Pending( wasmtime_wasi::runtime::spawn( @@ -470,7 +467,7 @@ struct RequestSender { self_request_origin: Option, request_interceptor: Option>, http_clients: HttpClients, - concurrent_outbound_connections_semaphore: Option>, + semaphore: ConnectionSemaphore, } impl RequestSender { @@ -624,8 +621,7 @@ impl RequestSender { connect_timeout, tls_client_config, override_connect_addr, - concurrent_outbound_connections_semaphore: self - .concurrent_outbound_connections_semaphore, + semaphore: self.semaphore, }, async move { if use_tls { @@ -719,8 +715,8 @@ struct ConnectOptions { tls_client_config: Option, /// If set, override the address to connect to instead of using the given `uri`'s authority. override_connect_addr: Option, - /// A semaphore to limit the number of concurrent outbound connections. - concurrent_outbound_connections_semaphore: Option>, + /// Semaphore to limit concurrent outbound connections. + semaphore: ConnectionSemaphore, } impl ConnectOptions { @@ -758,11 +754,7 @@ impl ConnectOptions { let connect = async { // If we're limiting concurrent outbound requests, acquire a permit - let permit = crate::concurrent_outbound_connections::acquire_owned_semaphore( - "wasi", - &self.concurrent_outbound_connections_semaphore, - ) - .await; + let permit = self.semaphore.acquire().await; (TcpStream::connect(&*socket_addrs).await, permit) }; @@ -771,6 +763,7 @@ impl ConnectOptions { let (stream, permit) = timeout(self.connect_timeout, connect) .await .map_err(|_| ErrorCode::ConnectionTimeout)?; + let permit = permit.map_err(|_| ErrorCode::ConnectionRefused)?; let stream = stream.map_err(|err| match err.kind() { std::io::ErrorKind::AddrNotAvailable => dns_error("address not available".into(), 0), _ => ErrorCode::ConnectionRefused, @@ -912,7 +905,7 @@ impl AsyncWrite for RustlsStream { } } -/// A TCP stream that holds an optional permit indicating that it is allowed to exist. +/// A TCP stream that holds a permit indicating that it is allowed to exist. struct PermittedTcpStream { /// The wrapped TCP stream. inner: TcpStream, @@ -920,7 +913,7 @@ struct PermittedTcpStream { /// /// When this stream is dropped, the permit is also dropped, allowing another /// connection to be established. - _permit: Option, + _permit: ConnectionPermit, } impl Connection for PermittedTcpStream { @@ -1219,12 +1212,18 @@ mod tests { /// `ConnectionTimeout` within the configured deadline. #[tokio::test] async fn connect_timeout_applies_to_permit_acquisition() { + use std::sync::Arc; + use tokio::sync::Semaphore; + // Create a semaphore with exactly 1 permit and hold it immediately, // leaving 0 permits available. This simulates all outbound-connection // slots being occupied. let semaphore = Arc::new(Semaphore::new(1)); let _held = semaphore.clone().try_acquire_owned().unwrap(); + // Build a ConnectionSemaphore with the exhausted semaphore as the factor-specific limit. + let conn_semaphore = ConnectionSemaphore::from_raw(None, Some(semaphore), "test"); + let options = ConnectOptions { // No blocked networks; we want the address to pass the filter. blocked_networks: BlockedNetworks::default(), @@ -1233,7 +1232,7 @@ mod tests { tls_client_config: None, // Skip DNS by supplying the address directly. override_connect_addr: Some("127.0.0.1:1".parse().unwrap()), - concurrent_outbound_connections_semaphore: Some(semaphore), + semaphore: conn_semaphore, }; // `connect_tcp` must time out while waiting for a permit rather than diff --git a/crates/factor-outbound-mysql/src/host.rs b/crates/factor-outbound-mysql/src/host.rs index 41900a395d..e86d2ecd9e 100644 --- a/crates/factor-outbound-mysql/src/host.rs +++ b/crates/factor-outbound-mysql/src/host.rs @@ -2,13 +2,14 @@ use std::sync::Arc; use anyhow::Result; use spin_core::wasmtime::component::{Accessor, FutureReader, Resource, StreamReader}; +use spin_factor_outbound_networking::ConnectionPermit; use spin_telemetry::traces::{self, Blame}; use spin_world::MAX_HOST_BUFFERED_BYTES; use spin_world::spin::mysql::mysql as v3; use spin_world::v1::mysql as v1; use spin_world::v2::mysql as v2; use spin_world::v2::rdbms_types as v2_types; -use tokio::sync::{Mutex, OwnedSemaphorePermit}; +use tokio::sync::Mutex; use tracing::field::Empty; use tracing::{Level, instrument}; @@ -19,7 +20,7 @@ impl InstanceStateInner { async fn open_connection( &mut self, address: &str, - permit: Option, + permit: ConnectionPermit, ) -> Result { spin_factor_outbound_networking::record_address_fields(address); @@ -98,16 +99,13 @@ impl v3::HostConnectionWithStore for MysqlFactorData { accessor: &Accessor, address: String, ) -> Result, v3::Error> { - let (state_arc, connection_semaphore) = accessor.with(|mut access| { + let (state_arc, semaphore) = accessor.with(|mut access| { let host = access.get(); - (host.inner.clone(), host.connection_semaphore.clone()) + (host.inner.clone(), host.semaphore.clone()) }); - let permit = match connection_semaphore { - Some(sem) => Some(sem.acquire_owned().await.map_err(|_| { - v3::Error::from(v2::Error::ConnectionFailed("too many connections".into())) - })?), - None => None, - }; + let permit = semaphore.acquire().await.map_err(|_| { + v3::Error::from(v2::Error::ConnectionFailed("too many connections".into())) + })?; let mut state = state_arc.lock().await; state.otel.reparent_tracing_span(); Ok(Resource::new_own( @@ -180,15 +178,11 @@ impl v2::Host for InstanceState {} impl v2::HostConnection for InstanceState { #[instrument(name = "spin_outbound_mysql.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", db.address = Empty, server.port = Empty, db.namespace = Empty))] async fn open(&mut self, address: String) -> Result, v2::Error> { - let permit = match &self.connection_semaphore { - Some(sem) => Some( - Arc::clone(sem) - .acquire_owned() - .await - .map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))?, - ), - None => None, - }; + let permit = self + .semaphore + .acquire() + .await + .map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))?; let mut state = self.inner.lock().await; state.otel.reparent_tracing_span(); state @@ -249,15 +243,11 @@ impl v2_types::Host for InstanceState { /// Delegate a function call to the v2::HostConnection implementation macro_rules! delegate { ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{ - let permit = match &$self.connection_semaphore { - Some(sem) => Some( - Arc::clone(sem) - .acquire_owned() - .await - .map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))?, - ), - None => None, - }; + let permit = $self + .semaphore + .acquire() + .await + .map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))?; let connection = { let mut state = $self.inner.lock().await; Resource::new_own(state.open_connection(&$address, permit).await?) diff --git a/crates/factor-outbound-mysql/src/lib.rs b/crates/factor-outbound-mysql/src/lib.rs index 34d08ac7e1..7a10809e7a 100644 --- a/crates/factor-outbound-mysql/src/lib.rs +++ b/crates/factor-outbound-mysql/src/lib.rs @@ -9,21 +9,22 @@ use mysql_async::Conn as MysqlClient; use runtime_config::RuntimeConfig; use spin_factor_otel::OtelFactorState; use spin_factor_outbound_networking::{ - OutboundNetworkingFactor, config::allowed_hosts::OutboundAllowedHosts, + ConnectionPermit, ConnectionSemaphore, OutboundNetworkingFactor, + config::allowed_hosts::OutboundAllowedHosts, }; use spin_factors::{Factor, FactorData, InitContext, RuntimeFactors, SelfInstanceBuilder}; use spin_world::spin::mysql::mysql as v3; use spin_world::v1::mysql as v1; use spin_world::v2::mysql as v2; -use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore}; +use tokio::sync::Mutex; pub struct OutboundMysqlFactor { _phantom: std::marker::PhantomData, } pub struct AppState { - /// A semaphore to limit the number of concurrent outbound MySQL connections. - pub connection_semaphore: Option>, + /// Semaphore(s) to limit concurrent outbound MySQL connections. + pub semaphore: ConnectionSemaphore, } impl Factor for OutboundMysqlFactor { @@ -43,8 +44,23 @@ impl Factor for OutboundMysqlFactor { mut ctx: spin_factors::ConfigureAppContext, ) -> anyhow::Result { let config = ctx.take_runtime_config().unwrap_or_default(); + + let networking = ctx.app_state::().ok(); + let global = networking.and_then(|s| s.global_connection_semaphore.clone()); + let global_total_limit = networking.and_then(|s| s.max_total_connections); + + if let (Some(per_factor), Some(global_limit)) = (config.max_connections, global_total_limit) + && per_factor > global_limit + { + tracing::warn!( + "outbound_mysql max_connections ({per_factor}) exceeds global \ + max_total_connections ({global_limit}); the global limit will be the \ + effective cap" + ); + } + Ok(AppState { - connection_semaphore: config.max_connections.map(|n| Arc::new(Semaphore::new(n))), + semaphore: ConnectionSemaphore::new(global, config.max_connections, "mysql"), }) } @@ -63,7 +79,7 @@ impl Factor for OutboundMysqlFactor { connections: Default::default(), otel, })), - connection_semaphore: ctx.app_state().connection_semaphore.clone(), + semaphore: ctx.app_state().semaphore.clone(), }) } } @@ -84,13 +100,13 @@ impl OutboundMysqlFactor { pub struct InstanceStateInner { allowed_hosts: OutboundAllowedHosts, - connections: spin_resource_table::Table<(Arc>, Option)>, + connections: spin_resource_table::Table<(Arc>, ConnectionPermit)>, otel: OtelFactorState, } pub struct InstanceState { pub(crate) inner: Arc>>, - pub connection_semaphore: Option>, + pub semaphore: ConnectionSemaphore, } impl SelfInstanceBuilder for InstanceState {} diff --git a/crates/factor-outbound-networking/Cargo.toml b/crates/factor-outbound-networking/Cargo.toml index 42e362f3cc..71dfc9865a 100644 --- a/crates/factor-outbound-networking/Cargo.toml +++ b/crates/factor-outbound-networking/Cargo.toml @@ -13,6 +13,7 @@ rustls = { workspace = true } rustls-pki-types = { workspace = true } rustls-platform-verifier = { workspace = true } serde = { workspace = true } +spin-connection-semaphore = { path = "../connection-semaphore" } spin-factor-variables = { path = "../factor-variables" } spin-factor-wasi = { path = "../factor-wasi" } spin-factors = { path = "../factors" } diff --git a/crates/factor-outbound-networking/src/connection_semaphore.rs b/crates/factor-outbound-networking/src/connection_semaphore.rs new file mode 100644 index 0000000000..f91c0f8583 --- /dev/null +++ b/crates/factor-outbound-networking/src/connection_semaphore.rs @@ -0,0 +1 @@ +pub use spin_connection_semaphore::{ConnectionPermit, ConnectionSemaphore}; diff --git a/crates/factor-outbound-networking/src/lib.rs b/crates/factor-outbound-networking/src/lib.rs index 04bc8d4c1c..9553a61052 100644 --- a/crates/factor-outbound-networking/src/lib.rs +++ b/crates/factor-outbound-networking/src/lib.rs @@ -1,4 +1,5 @@ mod allowed_hosts; +pub mod connection_semaphore; pub mod runtime_config; mod tls; @@ -20,6 +21,7 @@ use crate::{ allowed_hosts::allowed_outbound_hosts, runtime_config::RuntimeConfig, tls::TlsClientConfigs, }; pub use allowed_hosts::validate_service_chaining_for_components; +pub use connection_semaphore::{ConnectionPermit, ConnectionSemaphore}; pub use crate::tls::{ComponentTlsClientConfigs, TlsClientConfig}; use config::allowed_hosts::AllowedHostsConfig; @@ -70,18 +72,34 @@ impl Factor for OutboundNetworkingFactor { client_tls_configs, blocked_ip_networks: block_networks, block_private_networks, - max_sockets_per_app, + max_socket_connections, + max_total_connections, } = ctx.take_runtime_config().unwrap_or_default(); let blocked_networks = BlockedNetworks::new(block_networks, block_private_networks); let tls_client_configs = TlsClientConfigs::new(client_tls_configs)?; - let socket_quota = max_sockets_per_app.map(|n| Arc::new(Semaphore::new(n))); + let socket_quota = max_socket_connections.map(|n| Arc::new(Semaphore::new(n))); + let global_connection_semaphore = + max_total_connections.map(|n| Arc::new(Semaphore::new(n))); + + if let (Some(socket_cap), Some(global_cap)) = + (max_socket_connections, max_total_connections) + && socket_cap > global_cap + { + tracing::warn!( + "outbound_networking max_socket_connections ({socket_cap}) exceeds \ + max_total_connections ({global_cap}); the global limit will be the effective \ + cap for TCP/UDP sockets" + ); + } Ok(AppState { component_allowed_hosts, blocked_networks, tls_client_configs, socket_quota, + global_connection_semaphore, + max_total_connections, }) } @@ -127,11 +145,15 @@ impl Factor for OutboundNetworkingFactor { self.disallowed_host_handler.clone(), ); let blocked_networks = ctx.app_state().blocked_networks.clone(); - let permit_state = ctx - .app_state() - .socket_quota - .as_ref() - .map(|sem| SocketPermitState::new(Arc::clone(sem))); + let global_semaphore = ctx.app_state().global_connection_semaphore.clone(); + let socket_semaphore = ctx.app_state().socket_quota.clone(); + let permit_state = if global_semaphore.is_some() || socket_semaphore.is_some() { + let sem = + ConnectionSemaphore::from_raw(global_semaphore, socket_semaphore, "wasi-sockets"); + Some(SocketPermitState::new(sem)) + } else { + None + }; match ctx.instance_builder::() { Ok(wasi_builder) => { @@ -197,10 +219,14 @@ pub struct AppState { blocked_networks: BlockedNetworks, /// TLS client configs tls_client_configs: TlsClientConfigs, - /// App-wide semaphore capping total concurrent outbound socket connections - /// + /// App-wide semaphore capping concurrent outbound TCP/UDP socket connections. /// `None` means unlimited. socket_quota: Option>, + /// App-wide semaphore capping total concurrent outbound connections across ALL types. + /// `None` means unlimited. + pub global_connection_semaphore: Option>, + /// The configured global connection limit (for warning comparisons in other factors). + pub max_total_connections: Option, } pub struct InstanceBuilder { diff --git a/crates/factor-outbound-networking/src/runtime_config.rs b/crates/factor-outbound-networking/src/runtime_config.rs index 1520d7ba4a..278882a3e2 100644 --- a/crates/factor-outbound-networking/src/runtime_config.rs +++ b/crates/factor-outbound-networking/src/runtime_config.rs @@ -12,9 +12,12 @@ pub struct RuntimeConfig { pub block_private_networks: bool, /// TLS client configs pub client_tls_configs: Vec, - /// Maximum number of outbound socket connections across all instances of this app. + /// Maximum number of outbound TCP/UDP socket connections across all instances of this app. /// `None` means unlimited (default). - pub max_sockets_per_app: Option, + pub max_socket_connections: Option, + /// Maximum number of outbound connections across ALL connection types (global cap). + /// `None` means unlimited (default). + pub max_total_connections: Option, } /// TLS configuration for one or more component(s) and host(s). diff --git a/crates/factor-outbound-networking/src/runtime_config/spin.rs b/crates/factor-outbound-networking/src/runtime_config/spin.rs index 2e8824ad43..e97fba28cd 100644 --- a/crates/factor-outbound-networking/src/runtime_config/spin.rs +++ b/crates/factor-outbound-networking/src/runtime_config/spin.rs @@ -73,7 +73,8 @@ impl SpinRuntimeConfig { blocked_ip_networks, block_private_networks, client_tls_configs: maybe_tls_configs.unwrap_or_default(), - max_sockets_per_app: outbound_networking.max_sockets, + max_socket_connections: outbound_networking.max_socket_connections, + max_total_connections: outbound_networking.max_total_connections, }; Ok(Some(runtime_config)) } @@ -221,7 +222,8 @@ fn deserialize_hosts<'de, D: Deserializer<'de>>(deserializer: D) -> Result, - max_sockets: Option, + max_socket_connections: Option, + max_total_connections: Option, } #[derive(Debug)] diff --git a/crates/factor-outbound-networking/tests/factor_test.rs b/crates/factor-outbound-networking/tests/factor_test.rs index 19c3b25296..84d9be9db5 100644 --- a/crates/factor-outbound-networking/tests/factor_test.rs +++ b/crates/factor-outbound-networking/tests/factor_test.rs @@ -103,7 +103,7 @@ async fn socket_quota_blocks_excess_connections() -> anyhow::Result<()> { }) .runtime_config(TestFactorsRuntimeConfig { networking: Some(RuntimeConfig { - max_sockets_per_app: Some(2), + max_socket_connections: Some(2), ..Default::default() }), ..Default::default() @@ -147,7 +147,7 @@ async fn socket_quota_releases_on_instance_drop() -> anyhow::Result<()> { }) .runtime_config(TestFactorsRuntimeConfig { networking: Some(RuntimeConfig { - max_sockets_per_app: Some(1), + max_socket_connections: Some(1), ..Default::default() }), ..Default::default() @@ -232,7 +232,7 @@ async fn socket_quota_still_enforces_allowed_hosts() -> anyhow::Result<()> { }) .runtime_config(TestFactorsRuntimeConfig { networking: Some(RuntimeConfig { - max_sockets_per_app: Some(10), + max_socket_connections: Some(10), ..Default::default() }), ..Default::default() @@ -274,7 +274,7 @@ async fn socket_quota_releases_on_socket_drop() -> anyhow::Result<()> { }) .runtime_config(TestFactorsRuntimeConfig { networking: Some(RuntimeConfig { - max_sockets_per_app: Some(1), + max_socket_connections: Some(1), ..Default::default() }), ..Default::default() @@ -328,7 +328,7 @@ async fn socket_quota_blocks_excess_udp_sockets() -> anyhow::Result<()> { }) .runtime_config(TestFactorsRuntimeConfig { networking: Some(RuntimeConfig { - max_sockets_per_app: Some(2), + max_socket_connections: Some(2), ..Default::default() }), ..Default::default() @@ -363,7 +363,7 @@ async fn socket_quota_shared_between_tcp_and_udp() -> anyhow::Result<()> { }) .runtime_config(TestFactorsRuntimeConfig { networking: Some(RuntimeConfig { - max_sockets_per_app: Some(2), + max_socket_connections: Some(2), ..Default::default() }), ..Default::default() diff --git a/crates/factor-outbound-pg/src/host.rs b/crates/factor-outbound-pg/src/host.rs index 81ab16d677..f18a721b31 100644 --- a/crates/factor-outbound-pg/src/host.rs +++ b/crates/factor-outbound-pg/src/host.rs @@ -1,7 +1,5 @@ #![allow(clippy::result_large_err)] -use std::sync::Arc; - use anyhow::Result; use spin_core::wasmtime::component::{Accessor, FutureReader, Resource, StreamReader}; use spin_telemetry::traces::{self, Blame}; @@ -26,14 +24,11 @@ impl InstanceState { address: &str, root_ca: Option, ) -> Result, v4::Error> { - let permit = match &self.connection_semaphore { - Some(sem) => Some(Arc::clone(sem).acquire_owned().await.map_err(|_| { - let err = v4::Error::ConnectionFailed("too many connections".into()); - traces::mark_as_error(&err, Some(Blame::Guest)); - err - })?), - None => None, - }; + let permit = self.semaphore.acquire().await.map_err(|_| { + let err = v4::Error::ConnectionFailed("too many connections".into()); + traces::mark_as_error(&err, Some(Blame::Guest)); + err + })?; let client = self .client_factory .get_client(address, root_ca) @@ -387,22 +382,16 @@ impl crate::PgFactorData { address: &str, root_ca: Option, ) -> Result, v4::Error> { - let (cf, connection_semaphore) = accessor.with(|mut access| { + let (cf, semaphore) = accessor.with(|mut access| { let host = access.get(); - ( - host.client_factory.clone(), - host.connection_semaphore.clone(), - ) + (host.client_factory.clone(), host.semaphore.clone()) }); - let permit = match connection_semaphore { - Some(sem) => Some(sem.acquire_owned().await.map_err(|_| { - let err = v4::Error::ConnectionFailed("too many connections".into()); - traces::mark_as_error(&err, Some(Blame::Guest)); - err - })?), - None => None, - }; + let permit = semaphore.acquire().await.map_err(|_| { + let err = v4::Error::ConnectionFailed("too many connections".into()); + traces::mark_as_error(&err, Some(Blame::Guest)); + err + })?; let client = cf.get_client(address, root_ca).await.map_err(|e| { let err = v4::Error::ConnectionFailed(format!("{e:?}")); diff --git a/crates/factor-outbound-pg/src/lib.rs b/crates/factor-outbound-pg/src/lib.rs index 62714e1954..cb218e3525 100644 --- a/crates/factor-outbound-pg/src/lib.rs +++ b/crates/factor-outbound-pg/src/lib.rs @@ -4,17 +4,17 @@ mod host; pub mod runtime_config; mod types; -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; +use std::sync::Arc; use allowed_hosts::AllowedHostChecker; use client::ClientFactory; use runtime_config::RuntimeConfig; use spin_factor_otel::OtelFactorState; -use spin_factor_outbound_networking::OutboundNetworkingFactor; +use spin_factor_outbound_networking::{ConnectionSemaphore, OutboundNetworkingFactor}; use spin_factors::{ ConfigureAppContext, Factor, PrepareContext, RuntimeFactors, SelfInstanceBuilder, anyhow, }; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; pub struct OutboundPgFactor { _phantom: std::marker::PhantomData, @@ -22,8 +22,8 @@ pub struct OutboundPgFactor { pub struct AppState { pub client_factories: HashMap>, - /// A semaphore to limit the number of concurrent outbound PostgreSQL connections. - pub connection_semaphore: Option>, + /// Semaphore to limit concurrent outbound PostgreSQL connections. + pub semaphore: ConnectionSemaphore, } impl Factor for OutboundPgFactor { @@ -52,9 +52,24 @@ impl Factor for OutboundPgFactor { for comp in ctx.app().components() { client_factories.insert(comp.id().to_string(), Arc::new(CF::default())); } + + let networking = ctx.app_state::().ok(); + let global = networking.and_then(|s| s.global_connection_semaphore.clone()); + let global_total_limit = networking.and_then(|s| s.max_total_connections); + + if let (Some(per_factor), Some(global_limit)) = (config.max_connections, global_total_limit) + && per_factor > global_limit + { + tracing::warn!( + "outbound_pg max_connections ({per_factor}) exceeds global \ + max_total_connections ({global_limit}); the global limit will be the \ + effective cap" + ); + } + Ok(AppState { client_factories, - connection_semaphore: config.max_connections.map(|n| Arc::new(Semaphore::new(n))), + semaphore: ConnectionSemaphore::new(global, config.max_connections, "pg"), }) } @@ -78,7 +93,7 @@ impl Factor for OutboundPgFactor { connections: Default::default(), otel, builders: Default::default(), - connection_semaphore: ctx.app_state().connection_semaphore.clone(), + semaphore: ctx.app_state().semaphore.clone(), }) } } @@ -100,10 +115,13 @@ impl OutboundPgFactor { pub struct InstanceState { allowed_host_checker: AllowedHostChecker, client_factory: Arc, - connections: spin_resource_table::Table<(CF::Client, Option)>, + connections: spin_resource_table::Table<( + CF::Client, + spin_factor_outbound_networking::ConnectionPermit, + )>, otel: OtelFactorState, builders: spin_resource_table::Table, - pub connection_semaphore: Option>, + pub semaphore: ConnectionSemaphore, } impl SelfInstanceBuilder for InstanceState {} diff --git a/crates/factor-outbound-redis/src/host.rs b/crates/factor-outbound-redis/src/host.rs index 202584a8a1..0454e98537 100644 --- a/crates/factor-outbound-redis/src/host.rs +++ b/crates/factor-outbound-redis/src/host.rs @@ -1,5 +1,4 @@ use std::net::SocketAddr; -use std::sync::Arc; use anyhow::Result; use redis::AsyncConnectionConfig; @@ -7,12 +6,12 @@ use redis::io::AsyncDNSResolver; use redis::{AsyncCommands, FromRedisValue, Value, aio::MultiplexedConnection}; use spin_core::wasmtime::component::{Accessor, Resource}; use spin_factor_otel::OtelFactorState; +use spin_factor_outbound_networking::ConnectionSemaphore; use spin_factor_outbound_networking::config::blocked_networks::BlockedNetworks; use spin_world::MAX_HOST_BUFFERED_BYTES; use spin_world::spin::redis::redis as v3; use spin_world::v1::{redis as v1, redis_types}; use spin_world::v2::redis as v2; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tracing::field::Empty; use tracing::{Level, instrument}; @@ -21,9 +20,11 @@ use crate::allowed_hosts::AllowedHostChecker; pub struct InstanceState { pub(crate) allowed_host_checker: AllowedHostChecker, pub blocked_networks: BlockedNetworks, - pub connections: - spin_resource_table::Table<(MultiplexedConnection, Option)>, - pub connection_semaphore: Option>, + pub connections: spin_resource_table::Table<( + MultiplexedConnection, + spin_factor_outbound_networking::ConnectionPermit, + )>, + pub semaphore: ConnectionSemaphore, pub otel: OtelFactorState, } @@ -36,15 +37,11 @@ impl InstanceState { &mut self, address: String, ) -> Result, v2::Error> { - let permit = match &self.connection_semaphore { - Some(sem) => Some( - Arc::clone(sem) - .acquire_owned() - .await - .map_err(|_| v2::Error::TooManyConnections)?, - ), - None => None, - }; + let permit = self + .semaphore + .acquire() + .await + .map_err(|_| v2::Error::TooManyConnections)?; let config = AsyncConnectionConfig::new() .set_dns_resolver(SpinDnsResolver(self.blocked_networks.clone())); let conn = redis::Client::open(address.as_str()) @@ -243,16 +240,15 @@ impl v3::HostConnectionWithStore for crate::RedisFactorData { accessor: &Accessor, address: String, ) -> Result, v3::Error> { - let (allowed_host_checker, blocked_networks, connection_semaphore) = - accessor.with(|mut access| { - let host = access.get(); - host.otel.reparent_tracing_span(); - ( - host.allowed_host_checker.clone(), - host.blocked_networks.clone(), - host.connection_semaphore.clone(), - ) - }); + let (allowed_host_checker, blocked_networks, semaphore) = accessor.with(|mut access| { + let host = access.get(); + host.otel.reparent_tracing_span(); + ( + host.allowed_host_checker.clone(), + host.blocked_networks.clone(), + host.semaphore.clone(), + ) + }); if !allowed_host_checker .is_address_allowed(&address) @@ -262,14 +258,10 @@ impl v3::HostConnectionWithStore for crate::RedisFactorData { return Err(v3::Error::InvalidAddress); } - let permit = match connection_semaphore { - Some(sem) => Some( - sem.acquire_owned() - .await - .map_err(|_| v3::Error::TooManyConnections)?, - ), - None => None, - }; + let permit = semaphore + .acquire() + .await + .map_err(|_| v3::Error::TooManyConnections)?; let config = AsyncConnectionConfig::new().set_dns_resolver(SpinDnsResolver(blocked_networks)); diff --git a/crates/factor-outbound-redis/src/lib.rs b/crates/factor-outbound-redis/src/lib.rs index 4f9186e5e2..059729e7d1 100644 --- a/crates/factor-outbound-redis/src/lib.rs +++ b/crates/factor-outbound-redis/src/lib.rs @@ -2,18 +2,15 @@ mod allowed_hosts; mod host; pub mod runtime_config; -use std::sync::Arc; - use host::InstanceState; use runtime_config::RuntimeConfig; use spin_factor_otel::OtelFactorState; -use spin_factor_outbound_networking::OutboundNetworkingFactor; +use spin_factor_outbound_networking::{ConnectionSemaphore, OutboundNetworkingFactor}; use spin_factors::{ ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, SelfInstanceBuilder, anyhow, }; use spin_world::spin::redis::redis as v3; -use tokio::sync::Semaphore; use crate::allowed_hosts::AllowedHostChecker; @@ -30,8 +27,8 @@ impl OutboundRedisFactor { } pub struct AppState { - /// A semaphore to limit the number of concurrent outbound Redis connections. - pub connection_semaphore: Option>, + /// Semaphore(s) to limit concurrent outbound Redis connections. + pub semaphore: ConnectionSemaphore, } impl Factor for OutboundRedisFactor { @@ -51,8 +48,23 @@ impl Factor for OutboundRedisFactor { mut ctx: ConfigureAppContext, ) -> anyhow::Result { let config = ctx.take_runtime_config().unwrap_or_default(); + + let networking = ctx.app_state::().ok(); + let global = networking.and_then(|s| s.global_connection_semaphore.clone()); + let global_total_limit = networking.and_then(|s| s.max_total_connections); + + if let (Some(per_factor), Some(global_limit)) = (config.max_connections, global_total_limit) + && per_factor > global_limit + { + tracing::warn!( + "outbound_redis max_connections ({per_factor}) exceeds global \ + max_total_connections ({global_limit}); the global limit will be the \ + effective cap" + ); + } + Ok(AppState { - connection_semaphore: config.max_connections.map(|n| Arc::new(Semaphore::new(n))), + semaphore: ConnectionSemaphore::new(global, config.max_connections, "redis"), }) } @@ -67,7 +79,7 @@ impl Factor for OutboundRedisFactor { allowed_host_checker: AllowedHostChecker::new(outbound_networking.allowed_hosts()), blocked_networks: outbound_networking.blocked_networks(), connections: spin_resource_table::Table::new(1024), - connection_semaphore: ctx.app_state().connection_semaphore.clone(), + semaphore: ctx.app_state().semaphore.clone(), otel, }) } diff --git a/crates/factor-wasi/Cargo.toml b/crates/factor-wasi/Cargo.toml index 17acba531b..e957b22ba7 100644 --- a/crates/factor-wasi/Cargo.toml +++ b/crates/factor-wasi/Cargo.toml @@ -8,6 +8,7 @@ edition = { workspace = true } async-trait = { workspace = true } bytes = { workspace = true } spin-common = { path = "../common" } +spin-connection-semaphore = { path = "../connection-semaphore" } spin-factors = { path = "../factors" } tokio = { workspace = true, features = ["sync"] } tracing = { workspace = true } diff --git a/crates/factor-wasi/src/sockets.rs b/crates/factor-wasi/src/sockets.rs index e00a1ee4dc..efffa5cbae 100644 --- a/crates/factor-wasi/src/sockets.rs +++ b/crates/factor-wasi/src/sockets.rs @@ -10,7 +10,7 @@ use std::{ sync::{Arc, Mutex}, }; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use spin_connection_semaphore::{ConnectionPermit, ConnectionSemaphore}; use wasmtime::component::{HasData, Resource}; use wasmtime_wasi::p2::bindings::sockets::network::{ ErrorCode as SocketErrorCode, Host as NetworkHost, Network, @@ -26,15 +26,13 @@ use wasmtime_wasi::sockets::{TcpSocket, UdpSocket, WasiSocketsCtxView}; /// acquired when a socket is allocated (at `start_connect` for TCP, at /// `create_udp_socket` for UDP) and released when the socket resource is dropped. pub struct SocketPermitState { - semaphore: Arc, - /// Active permits keyed by socket resource rep. - /// - /// Permits are removed (and the permit released) when the WASI socket resource is dropped. - active: Mutex>, + semaphore: ConnectionSemaphore, + /// Active permits keyed by socket resource rep, released when the resource is dropped. + active: Mutex>, } impl SocketPermitState { - pub fn new(semaphore: Arc) -> Arc { + pub fn new(semaphore: ConnectionSemaphore) -> Arc { Arc::new(Self { semaphore, active: Mutex::new(HashMap::new()), @@ -98,17 +96,13 @@ impl p2_tcp::HostTcpSocket for SpinSocketsView<'_> { // Unlike outbound HTTP (which queues when its permit pool is exhausted), // sockets fail immediately. Waiting would risk deadlock if a component // holds sockets open across async yield points, and raw-socket callers - // are better positioned to implement their own retry logic. The two - // limits are also configured separately, so different semantics are fine. - let Ok(permit) = Arc::clone(&state.semaphore).try_acquire_owned() else { - tracing::warn!("TCP socket connection refused: socket quota exhausted"); - // `new-socket-limit` maps to POSIX EMFILE/ENFILE: "a new socket - // resource could not be created because of a system limit." + // are better positioned to implement their own retry logic. + let Some(permit) = state.semaphore.try_acquire() else { + tracing::warn!("TCP socket connection refused: connection quota exhausted"); return Err(SocketErrorCode::NewSocketLimit.into()); }; p2_tcp::HostTcpSocket::start_connect(&mut self.inner, this, network, remote_address) .await?; - // If the connection was successfully initiated, store the permit so it can be released when the socket is dropped. state .active .lock() @@ -296,7 +290,7 @@ impl p2_tcp::HostTcpSocket for SpinSocketsView<'_> { } fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { - // Release the permit before dropping the socket resource. + // Release both permits before dropping the socket resource. if let Some(state) = &self.permit_state { state .active @@ -443,7 +437,7 @@ impl p2_udp::HostUdpSocket for SpinSocketsView<'_> { } fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { - // Release the permit before dropping the socket resource. + // Release both permits before dropping the socket resource. if let Some(state) = &self.permit_state { state .active @@ -513,16 +507,11 @@ impl p2_udp_create::Host for SpinSocketsView<'_> { let state = Arc::clone(state); // See the analogous comment in `start_connect` for why we fail // immediately rather than waiting (as outbound HTTP does). - let permit = Arc::clone(&state.semaphore) - .try_acquire_owned() - .map_err(|_| { - tracing::warn!("UDP socket creation refused: socket quota exhausted"); - // `new-socket-limit` maps to POSIX EMFILE/ENFILE: "a new socket - // resource could not be created because of a system limit." - SocketErrorCode::NewSocketLimit - })?; + let Some(permit) = state.semaphore.try_acquire() else { + tracing::warn!("UDP socket creation refused: connection quota exhausted"); + return Err(SocketErrorCode::NewSocketLimit.into()); + }; let sock = p2_udp_create::Host::create_udp_socket(&mut self.inner, address_family)?; - // If the socket was successfully created, store the permit so it can be released when the socket is dropped. state .active .lock() diff --git a/crates/runtime-config/src/lib.rs b/crates/runtime-config/src/lib.rs index c9433a1781..4ebe9e5a91 100644 --- a/crates/runtime-config/src/lib.rs +++ b/crates/runtime-config/src/lib.rs @@ -79,6 +79,21 @@ impl ResolvedRuntimeConfig { summaries.push(format!("[llm_compute: {ty}")); } } + // [outbound_networking: max_total_connections=N] + if let Some(table) = self + .toml + .get("outbound_networking") + .and_then(Value::as_table) + { + if let Some(max) = table + .get("max_total_connections") + .and_then(Value::as_integer) + { + summaries.push(format!( + "[outbound_networking: max_total_connections={max}]" + )); + } + } // [outbound_redis: max_connections=N], [outbound_pg: max_connections=N], [outbound_mysql: max_connections=N] for key in ["outbound_redis", "outbound_pg", "outbound_mysql"] { if let Some(table) = self.toml.get(key).and_then(Value::as_table) { From 298ca3cd97f1e05ccba1aaea0324f7a937b07505 Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Fri, 29 May 2026 14:46:54 +0200 Subject: [PATCH 07/15] Deprecate max_concurrent_requests Signed-off-by: Ryan Levick --- Cargo.lock | 1 + crates/factor-outbound-http/Cargo.toml | 1 + crates/factor-outbound-http/src/lib.rs | 4 +-- .../src/runtime_config/spin.rs | 31 ++++++++++++++++--- 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 831186f5ea..b8e3261a01 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9001,6 +9001,7 @@ dependencies = [ "spin-factors-test", "spin-telemetry", "spin-world", + "terminal", "tokio", "tokio-rustls 0.26.4", "tower-service", diff --git a/crates/factor-outbound-http/Cargo.toml b/crates/factor-outbound-http/Cargo.toml index 7ccfd3f63b..619a6cc382 100644 --- a/crates/factor-outbound-http/Cargo.toml +++ b/crates/factor-outbound-http/Cargo.toml @@ -22,6 +22,7 @@ spin-factor-outbound-networking = { path = "../factor-outbound-networking" } spin-factors = { path = "../factors" } spin-telemetry = { path = "../telemetry" } spin-world = { path = "../world" } +terminal = { path = "../terminal" } tokio = { workspace = true, features = ["macros", "rt", "net"] } tokio-rustls = { workspace = true } tower-service = { workspace = true } diff --git a/crates/factor-outbound-http/src/lib.rs b/crates/factor-outbound-http/src/lib.rs index fb0b42eed9..366ef9c908 100644 --- a/crates/factor-outbound-http/src/lib.rs +++ b/crates/factor-outbound-http/src/lib.rs @@ -71,9 +71,7 @@ impl Factor for OutboundHttpFactor { ); } - // Permit count is the max concurrent connections + 1. - // i.e., 0 concurrent connections means 1 total connection. - let factor_specific_limit = config.max_concurrent_connections.map(|n| n + 1); + let factor_specific_limit = config.max_concurrent_connections; Ok(AppState { wasi_http_clients: wasi::HttpClients::new(config.connection_pooling_enabled), diff --git a/crates/factor-outbound-http/src/runtime_config/spin.rs b/crates/factor-outbound-http/src/runtime_config/spin.rs index fc32c2cdcc..77e5c0800d 100644 --- a/crates/factor-outbound-http/src/runtime_config/spin.rs +++ b/crates/factor-outbound-http/src/runtime_config/spin.rs @@ -7,16 +7,36 @@ use spin_factors::runtime_config::toml::GetTomlValue; /// ```toml /// [outbound_http] /// connection_pooling = true # optional, defaults to true -/// max_concurrent_requests = 10 # optional, defaults to unlimited +/// max_connections = 10 # optional, defaults to unlimited; 0 = no connections allowed +/// # max_concurrent_requests is deprecated, use max_connections instead /// ``` pub fn config_from_table( table: &impl GetTomlValue, ) -> anyhow::Result> { if let Some(outbound_http) = table.get("outbound_http") { - let outbound_http_toml = outbound_http.clone().try_into::()?; + let toml = outbound_http.clone().try_into::()?; + + let max_connections = match (toml.max_connections, toml.max_concurrent_requests) { + (Some(_), Some(_)) => anyhow::bail!( + "cannot set both `max_connections` and `max_concurrent_requests` in \ + `[outbound_http]`; use `max_connections` only" + ), + (Some(n), None) => Some(n), + (None, Some(n)) => { + terminal::warn!( + "`max_concurrent_requests` in `[outbound_http]` is deprecated; \ + use `max_connections` instead (note: `max_connections = 0` blocks all \ + connections, whereas `max_concurrent_requests = 0` allowed 1 connection)" + ); + // Preserve old semaphore semantics: n+1 permits so that 0 allowed 1 connection + Some(n + 1) + } + (None, None) => None, + }; + Ok(Some(super::RuntimeConfig { - connection_pooling_enabled: outbound_http_toml.connection_pooling, - max_concurrent_connections: outbound_http_toml.max_concurrent_requests, + connection_pooling_enabled: toml.connection_pooling, + max_concurrent_connections: max_connections, })) } else { Ok(None) @@ -29,5 +49,8 @@ struct OutboundHttpToml { #[serde(default)] connection_pooling: bool, #[serde(default)] + max_connections: Option, + /// Deprecated. Use `max_connections` instead. + #[serde(default)] max_concurrent_requests: Option, } From 4b1b2ac657f9d95510f46a54b9dba50ebfe4b640 Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Mon, 1 Jun 2026 14:20:47 +0200 Subject: [PATCH 08/15] Wrap legacy wasi sockets implementations in connections quota. Signed-off-by: Ryan Levick --- crates/factor-wasi/src/lib.rs | 5 +- crates/factor-wasi/src/sockets.rs | 121 ++++++++++-------- crates/factor-wasi/src/wasi_2023_10_18.rs | 143 ++++++++++++++-------- crates/factor-wasi/src/wasi_2023_11_10.rs | 63 +++++----- 4 files changed, 200 insertions(+), 132 deletions(-) diff --git a/crates/factor-wasi/src/lib.rs b/crates/factor-wasi/src/lib.rs index 02c6ba876e..ff7e5fc1ea 100644 --- a/crates/factor-wasi/src/lib.rs +++ b/crates/factor-wasi/src/lib.rs @@ -235,7 +235,7 @@ trait InitContextExt: InitContext { fn(&mut Self::StoreData) -> WasiClocksCtxView<'_>, fn(&mut Self::StoreData) -> WasiCliCtxView<'_>, fn(&mut Self::StoreData) -> WasiFilesystemCtxView<'_>, - fn(&mut Self::StoreData) -> WasiSocketsCtxView<'_>, + fn(&mut Self::StoreData) -> SpinSocketsView<'_>, ) -> anyhow::Result<()>, ) -> anyhow::Result<()> { add_to_linker( @@ -245,7 +245,7 @@ trait InitContextExt: InitContext { Self::get_clocks, Self::get_cli, Self::get_filesystem, - Self::get_sockets, + Self::get_spin_sockets, ) } } @@ -345,6 +345,7 @@ impl Factor for WasiFactor { ctx.link_sockets_bindings( p3::bindings::sockets::ip_name_lookup::add_to_linker::<_, WasiSockets>, )?; + // TODO(rylev): switch to SpinSockets once possible ctx.link_sockets_bindings(p3::bindings::sockets::types::add_to_linker::<_, WasiSockets>)?; ctx.link_all_bindings(wasi_2023_10_18::add_to_linker)?; diff --git a/crates/factor-wasi/src/sockets.rs b/crates/factor-wasi/src/sockets.rs index efffa5cbae..9980c9c2ff 100644 --- a/crates/factor-wasi/src/sockets.rs +++ b/crates/factor-wasi/src/sockets.rs @@ -69,6 +69,46 @@ impl HasData for SpinSockets { type Data<'a> = SpinSocketsView<'a>; } +impl SpinSocketsView<'_> { + /// Attempts to acquire a connection permit from the semaphore. + /// + /// Returns `Ok(None)` when no quota is configured, `Ok(Some(permit))` on + /// success, or `Err(())` when the quota is exhausted. + /// + /// The returned permit is unregistered — call [`Self::register_permit`] once + /// the socket resource rep is known to tie its lifetime to the socket. + pub(crate) fn try_acquire(&self) -> Result, ()> { + let Some(state) = &self.permit_state else { + return Ok(None); + }; + state.semaphore.try_acquire().map(Some).ok_or(()) + } + + /// Registers `permit` under `socket_rep` so it is held until the socket is + /// dropped. No-op when `permit` is `None` (no quota configured). + pub(crate) fn register_permit(&self, socket_rep: u32, permit: Option) { + let (Some(state), Some(permit)) = (&self.permit_state, permit) else { + return; + }; + state + .active + .lock() + .unwrap_or_else(|e| e.into_inner()) + .insert(socket_rep, permit); + } + + /// Releases the connection permit for `socket_rep`, if any. + pub(crate) fn release_permit(&self, socket_rep: u32) { + if let Some(state) = &self.permit_state { + state + .active + .lock() + .unwrap_or_else(|e| e.into_inner()) + .remove(&socket_rep); + } + } +} + impl p2_tcp::Host for SpinSocketsView<'_> {} impl p2_tcp::HostTcpSocket for SpinSocketsView<'_> { @@ -91,28 +131,23 @@ impl p2_tcp::HostTcpSocket for SpinSocketsView<'_> { network: Resource, remote_address: IpSocketAddress, ) -> wasmtime_wasi::p2::SocketResult<()> { - if let Some(state) = &self.permit_state { - let socket_rep = this.rep(); - // Unlike outbound HTTP (which queues when its permit pool is exhausted), - // sockets fail immediately. Waiting would risk deadlock if a component - // holds sockets open across async yield points, and raw-socket callers - // are better positioned to implement their own retry logic. - let Some(permit) = state.semaphore.try_acquire() else { - tracing::warn!("TCP socket connection refused: connection quota exhausted"); - return Err(SocketErrorCode::NewSocketLimit.into()); - }; - p2_tcp::HostTcpSocket::start_connect(&mut self.inner, this, network, remote_address) - .await?; - state - .active - .lock() - .unwrap_or_else(|e| e.into_inner()) - .insert(socket_rep, permit); - Ok(()) - } else { + let socket_rep = this.rep(); + // Unlike outbound HTTP (which queues when its permit pool is exhausted), + // sockets fail immediately. Waiting would risk deadlock if a component + // holds sockets open across async yield points, and raw-socket callers + // are better positioned to implement their own retry logic. + let Ok(permit) = self.try_acquire() else { + tracing::warn!("TCP socket connection refused: connection quota exhausted"); + return Err(SocketErrorCode::NewSocketLimit.into()); + }; + let result = p2_tcp::HostTcpSocket::start_connect(&mut self.inner, this, network, remote_address) - .await + .await; + if result.is_ok() { + self.register_permit(socket_rep, permit); } + // On error, `permit` is dropped here, automatically releasing the semaphore slot. + result } fn finish_connect( @@ -290,14 +325,7 @@ impl p2_tcp::HostTcpSocket for SpinSocketsView<'_> { } fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { - // Release both permits before dropping the socket resource. - if let Some(state) = &self.permit_state { - state - .active - .lock() - .unwrap_or_else(|e| e.into_inner()) - .remove(&this.rep()); - } + self.release_permit(this.rep()); p2_tcp::HostTcpSocket::drop(&mut self.inner, this) } } @@ -437,14 +465,7 @@ impl p2_udp::HostUdpSocket for SpinSocketsView<'_> { } fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { - // Release both permits before dropping the socket resource. - if let Some(state) = &self.permit_state { - state - .active - .lock() - .unwrap_or_else(|e| e.into_inner()) - .remove(&this.rep()); - } + self.release_permit(this.rep()); p2_udp::HostUdpSocket::drop(&mut self.inner, this) } } @@ -503,23 +524,15 @@ impl p2_udp_create::Host for SpinSocketsView<'_> { &mut self, address_family: wasmtime_wasi::p2::bindings::sockets::network::IpAddressFamily, ) -> wasmtime_wasi::p2::SocketResult> { - if let Some(state) = &self.permit_state { - let state = Arc::clone(state); - // See the analogous comment in `start_connect` for why we fail - // immediately rather than waiting (as outbound HTTP does). - let Some(permit) = state.semaphore.try_acquire() else { - tracing::warn!("UDP socket creation refused: connection quota exhausted"); - return Err(SocketErrorCode::NewSocketLimit.into()); - }; - let sock = p2_udp_create::Host::create_udp_socket(&mut self.inner, address_family)?; - state - .active - .lock() - .unwrap_or_else(|e| e.into_inner()) - .insert(sock.rep(), permit); - Ok(sock) - } else { - p2_udp_create::Host::create_udp_socket(&mut self.inner, address_family) - } + // Check quota before allocating the socket resource. + // See the analogous comment in `start_connect` for why we fail + // immediately rather than waiting (as outbound HTTP does). + let Ok(permit) = self.try_acquire() else { + tracing::warn!("UDP socket creation refused: connection quota exhausted"); + return Err(SocketErrorCode::NewSocketLimit.into()); + }; + let sock = p2_udp_create::Host::create_udp_socket(&mut self.inner, address_family)?; + self.register_permit(sock.rep(), permit); + Ok(sock) } } diff --git a/crates/factor-wasi/src/wasi_2023_10_18.rs b/crates/factor-wasi/src/wasi_2023_10_18.rs index cc0a86111d..79e2c4373b 100644 --- a/crates/factor-wasi/src/wasi_2023_10_18.rs +++ b/crates/factor-wasi/src/wasi_2023_10_18.rs @@ -1,3 +1,4 @@ +use crate::sockets::{SpinSockets, SpinSocketsView}; use spin_factors::anyhow::Result; use std::mem; use wasmtime::component::{Linker, Resource, ResourceTable}; @@ -6,8 +7,8 @@ use wasmtime_wasi::cli::{WasiCli, WasiCliCtxView}; use wasmtime_wasi::clocks::{WasiClocks, WasiClocksCtxView}; use wasmtime_wasi::filesystem::{WasiFilesystem, WasiFilesystemCtxView}; use wasmtime_wasi::p2::DynPollable; +use wasmtime_wasi::p2::bindings::sockets::udp_create_socket as p2_udp_create; use wasmtime_wasi::random::{WasiRandom, WasiRandomCtx}; -use wasmtime_wasi::sockets::{WasiSockets, WasiSocketsCtxView}; mod latest { pub use wasmtime_wasi::p2::bindings::*; @@ -126,7 +127,7 @@ pub fn add_to_linker( clocks_closure: fn(&mut T) -> WasiClocksCtxView<'_>, cli_closure: fn(&mut T) -> WasiCliCtxView<'_>, filesystem_closure: fn(&mut T) -> WasiFilesystemCtxView<'_>, - sockets_closure: fn(&mut T) -> WasiSocketsCtxView<'_>, + sockets_closure: fn(&mut T) -> SpinSocketsView<'_>, ) -> Result<()> where T: Send + 'static, @@ -150,13 +151,13 @@ where wasi::cli::terminal_stdin::add_to_linker::<_, WasiCli>(linker, cli_closure)?; wasi::cli::terminal_stdout::add_to_linker::<_, WasiCli>(linker, cli_closure)?; wasi::cli::terminal_stderr::add_to_linker::<_, WasiCli>(linker, cli_closure)?; - wasi::sockets::tcp::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::tcp_create_socket::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::udp::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::udp_create_socket::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::instance_network::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::network::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::ip_name_lookup::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; + wasi::sockets::tcp::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::tcp_create_socket::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::udp::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::udp_create_socket::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::instance_network::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::network::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::ip_name_lookup::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; Ok(()) } @@ -900,9 +901,9 @@ impl wasi::cli::terminal_output::HostTerminalOutput for WasiCliCtxView<'_> { } } -impl wasi::sockets::tcp::Host for WasiSocketsCtxView<'_> {} +impl wasi::sockets::tcp::Host for SpinSocketsView<'_> {} -impl wasi::sockets::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { +impl wasi::sockets::tcp::HostTcpSocket for SpinSocketsView<'_> { async fn start_bind( &mut self, self_: Resource, @@ -935,6 +936,10 @@ impl wasi::sockets::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { network: Resource, remote_address: IpSocketAddress, ) -> wasmtime::Result> { + // Delegate to the P2 SpinSocketsView impl (passing `self`, not `&mut self.inner`). + // This snapshot uses the raw P2 TcpSocket type — the resource rep is the same at + // start_connect and drop time — so the P2 impl's quota acquire/register/release + // logic round-trips correctly without any wrapper-level bookkeeping here. convert_result( latest::sockets::tcp::HostTcpSocket::start_connect( self, @@ -1147,7 +1152,7 @@ impl wasi::sockets::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { } } -impl wasi::sockets::tcp_create_socket::Host for WasiSocketsCtxView<'_> { +impl wasi::sockets::tcp_create_socket::Host for SpinSocketsView<'_> { fn create_tcp_socket( &mut self, address_family: IpAddressFamily, @@ -1159,7 +1164,7 @@ impl wasi::sockets::tcp_create_socket::Host for WasiSocketsCtxView<'_> { } } -impl wasi::sockets::udp::Host for WasiSocketsCtxView<'_> {} +impl wasi::sockets::udp::Host for SpinSocketsView<'_> {} /// Between the snapshot of WASI that this file is implementing and the current /// implementation of WASI UDP sockets were redesigned slightly to deal with @@ -1180,7 +1185,7 @@ pub enum UdpSocket { impl UdpSocket { async fn finish_connect( - table: &mut WasiSocketsCtxView<'_>, + table: &mut SpinSocketsView<'_>, socket: &Resource, explicit: bool, ) -> wasmtime::Result> { @@ -1197,8 +1202,12 @@ impl UdpSocket { }; let borrow = Resource::new_borrow(new_socket.rep()); let result = convert_result( - latest::sockets::udp::HostUdpSocket::stream(table, borrow, addr.map(|a| a.into())) - .await, + latest::sockets::udp::HostUdpSocket::stream( + &mut table.inner, + borrow, + addr.map(|a| a.into()), + ) + .await, )?; let (incoming, outgoing) = match result { Ok(pair) => pair, @@ -1223,7 +1232,7 @@ impl UdpSocket { } } -impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { +impl wasi::sockets::udp::HostUdpSocket for SpinSocketsView<'_> { async fn start_bind( &mut self, self_: Resource, @@ -1233,7 +1242,7 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { let socket = self.table.get(&self_)?.inner()?; convert_result( latest::sockets::udp::HostUdpSocket::start_bind( - self, + &mut self.inner, socket, network, local_address.into(), @@ -1248,7 +1257,8 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result(latest::sockets::udp::HostUdpSocket::finish_bind( - self, socket, + &mut self.inner, + socket, )) } @@ -1358,7 +1368,8 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result(latest::sockets::udp::HostUdpSocket::local_address( - self, socket, + &mut self.inner, + socket, )) } @@ -1368,13 +1379,15 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result(latest::sockets::udp::HostUdpSocket::remote_address( - self, socket, + &mut self.inner, + socket, )) } fn address_family(&mut self, self_: Resource) -> wasmtime::Result { let socket = self.table.get(&self_)?.inner()?; - latest::sockets::udp::HostUdpSocket::address_family(self, socket).map(|e| e.into()) + latest::sockets::udp::HostUdpSocket::address_family(&mut self.inner, socket) + .map(|e| e.into()) } fn ipv6_only( @@ -1398,7 +1411,8 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result(latest::sockets::udp::HostUdpSocket::unicast_hop_limit( - self, socket, + &mut self.inner, + socket, )) } @@ -1409,7 +1423,9 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result(latest::sockets::udp::HostUdpSocket::set_unicast_hop_limit( - self, socket, value, + &mut self.inner, + socket, + value, )) } @@ -1419,7 +1435,8 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result(latest::sockets::udp::HostUdpSocket::receive_buffer_size( - self, socket, + &mut self.inner, + socket, )) } @@ -1430,7 +1447,11 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result( - latest::sockets::udp::HostUdpSocket::set_receive_buffer_size(self, socket, value), + latest::sockets::udp::HostUdpSocket::set_receive_buffer_size( + &mut self.inner, + socket, + value, + ), ) } @@ -1440,7 +1461,8 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result(latest::sockets::udp::HostUdpSocket::send_buffer_size( - self, socket, + &mut self.inner, + socket, )) } @@ -1451,17 +1473,25 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result(latest::sockets::udp::HostUdpSocket::set_send_buffer_size( - self, socket, value, + &mut self.inner, + socket, + value, )) } fn subscribe(&mut self, self_: Resource) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; - latest::sockets::udp::HostUdpSocket::subscribe(self, socket) + latest::sockets::udp::HostUdpSocket::subscribe(&mut self.inner, socket) } fn drop(&mut self, rep: Resource) -> wasmtime::Result<()> { + let socket_rep = rep.rep(); + // Delete before releasing: the only error case that matters is `HasChildren`, + // where the socket still exists and the permit must stay held. `NotPresent` + // (double-drop) is unreachable from a guest, and `release_permit` is idempotent + // anyway since `HashMap::remove` is a no-op for absent keys. let me = self.table.delete(rep)?; + self.release_permit(socket_rep); let socket = match me { UdpSocket::Initial(s) => s, UdpSocket::Connecting(s, _) => s, @@ -1470,49 +1500,63 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { incoming, outgoing, } => { - latest::sockets::udp::HostIncomingDatagramStream::drop(self, incoming)?; - latest::sockets::udp::HostOutgoingDatagramStream::drop(self, outgoing)?; + latest::sockets::udp::HostIncomingDatagramStream::drop(&mut self.inner, incoming)?; + latest::sockets::udp::HostOutgoingDatagramStream::drop(&mut self.inner, outgoing)?; socket } UdpSocket::Dummy => return Ok(()), }; - latest::sockets::udp::HostUdpSocket::drop(self, socket) + // Drop the inner P2 socket directly, bypassing quota tracking for rep R. + latest::sockets::udp::HostUdpSocket::drop(&mut self.inner, socket) } } -impl wasi::sockets::udp_create_socket::Host for WasiSocketsCtxView<'_> { +impl wasi::sockets::udp_create_socket::Host for SpinSocketsView<'_> { fn create_udp_socket( &mut self, address_family: IpAddressFamily, ) -> wasmtime::Result, SocketErrorCode>> { - let result = convert_result(latest::sockets::udp_create_socket::Host::create_udp_socket( - self, + // Cannot delegate to the P2 SpinSocketsView impl here (unlike TCP). This snapshot + // wraps the P2 UdpSocket in a custom UdpSocket enum stored in a separate resource + // table, so the outer wrapper rep (used at drop time) differs from the inner P2 + // socket rep (which the P2 impl would register the permit under). Delegating would + // cause release_permit at drop time to look up the wrong rep and silently leak the + // semaphore slot. Instead, quota is checked explicitly here and the permit is + // registered under the wrapper rep. + let Ok(permit) = self.try_acquire() else { + tracing::warn!("UDP socket creation refused: connection quota exhausted"); + return Ok(Err(SocketErrorCode::NewSocketLimit)); + }; + // Create the inner P2 socket via self.inner to avoid charging quota at the P2 level. + let result = convert_result(p2_udp_create::Host::create_udp_socket( + &mut self.inner, address_family.into(), ))?; let socket = match result { Ok(socket) => socket, Err(e) => return Ok(Err(e)), }; - let socket = self.table.push(UdpSocket::Initial(socket))?; - Ok(Ok(socket)) + let wrapped = self.table.push(UdpSocket::Initial(socket))?; + self.register_permit(wrapped.rep(), permit); + Ok(Ok(wrapped)) } } -impl wasi::sockets::instance_network::Host for WasiSocketsCtxView<'_> { +impl wasi::sockets::instance_network::Host for SpinSocketsView<'_> { fn instance_network(&mut self) -> wasmtime::Result> { - latest::sockets::instance_network::Host::instance_network(self) + latest::sockets::instance_network::Host::instance_network(&mut self.inner) } } -impl wasi::sockets::network::Host for WasiSocketsCtxView<'_> {} +impl wasi::sockets::network::Host for SpinSocketsView<'_> {} -impl wasi::sockets::network::HostNetwork for WasiSocketsCtxView<'_> { +impl wasi::sockets::network::HostNetwork for SpinSocketsView<'_> { fn drop(&mut self, rep: Resource) -> wasmtime::Result<()> { - latest::sockets::network::HostNetwork::drop(self, rep) + latest::sockets::network::HostNetwork::drop(&mut self.inner, rep) } } -impl wasi::sockets::ip_name_lookup::Host for WasiSocketsCtxView<'_> { +impl wasi::sockets::ip_name_lookup::Host for SpinSocketsView<'_> { fn resolve_addresses( &mut self, network: Resource, @@ -1521,19 +1565,22 @@ impl wasi::sockets::ip_name_lookup::Host for WasiSocketsCtxView<'_> { _include_unavailable: bool, ) -> wasmtime::Result, SocketErrorCode>> { convert_result(latest::sockets::ip_name_lookup::Host::resolve_addresses( - self, network, name, + &mut self.inner, + network, + name, )) } } -impl wasi::sockets::ip_name_lookup::HostResolveAddressStream for WasiSocketsCtxView<'_> { +impl wasi::sockets::ip_name_lookup::HostResolveAddressStream for SpinSocketsView<'_> { fn resolve_next_address( &mut self, self_: Resource, ) -> wasmtime::Result, SocketErrorCode>> { convert_result( latest::sockets::ip_name_lookup::HostResolveAddressStream::resolve_next_address( - self, self_, + &mut self.inner, + self_, ) .map(|e| e.map(|e| e.into())), ) @@ -1543,11 +1590,11 @@ impl wasi::sockets::ip_name_lookup::HostResolveAddressStream for WasiSocketsCtxV &mut self, self_: Resource, ) -> wasmtime::Result> { - latest::sockets::ip_name_lookup::HostResolveAddressStream::subscribe(self, self_) + latest::sockets::ip_name_lookup::HostResolveAddressStream::subscribe(&mut self.inner, self_) } fn drop(&mut self, rep: Resource) -> wasmtime::Result<()> { - latest::sockets::ip_name_lookup::HostResolveAddressStream::drop(self, rep) + latest::sockets::ip_name_lookup::HostResolveAddressStream::drop(&mut self.inner, rep) } } diff --git a/crates/factor-wasi/src/wasi_2023_11_10.rs b/crates/factor-wasi/src/wasi_2023_11_10.rs index 81de4c6c74..c296a4dad0 100644 --- a/crates/factor-wasi/src/wasi_2023_11_10.rs +++ b/crates/factor-wasi/src/wasi_2023_11_10.rs @@ -1,11 +1,11 @@ use super::wasi_2023_10_18::{convert, convert_result}; +use crate::sockets::{SpinSockets, SpinSocketsView}; use spin_factors::anyhow::Result; use wasmtime::component::{Linker, Resource, ResourceTable}; use wasmtime_wasi::cli::{WasiCli, WasiCliCtxView}; use wasmtime_wasi::clocks::{WasiClocks, WasiClocksCtxView}; use wasmtime_wasi::filesystem::{WasiFilesystem, WasiFilesystemCtxView}; use wasmtime_wasi::random::{WasiRandom, WasiRandomCtx}; -use wasmtime_wasi::sockets::{WasiSockets, WasiSocketsCtxView}; mod latest { pub use wasmtime_wasi::p2::bindings::*; @@ -119,7 +119,7 @@ pub fn add_to_linker( clocks_closure: fn(&mut T) -> WasiClocksCtxView<'_>, cli_closure: fn(&mut T) -> WasiCliCtxView<'_>, filesystem_closure: fn(&mut T) -> WasiFilesystemCtxView<'_>, - sockets_closure: fn(&mut T) -> WasiSocketsCtxView<'_>, + sockets_closure: fn(&mut T) -> SpinSocketsView<'_>, ) -> Result<()> where T: Send + 'static, @@ -144,13 +144,13 @@ where wasi::cli::terminal_stdin::add_to_linker::<_, WasiCli>(linker, cli_closure)?; wasi::cli::terminal_stdout::add_to_linker::<_, WasiCli>(linker, cli_closure)?; wasi::cli::terminal_stderr::add_to_linker::<_, WasiCli>(linker, cli_closure)?; - wasi::sockets::tcp::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::tcp_create_socket::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::udp::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::udp_create_socket::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::instance_network::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::network::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::ip_name_lookup::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; + wasi::sockets::tcp::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::tcp_create_socket::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::udp::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::udp_create_socket::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::instance_network::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::network::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::ip_name_lookup::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; Ok(()) } @@ -830,9 +830,9 @@ impl wasi::cli::terminal_output::HostTerminalOutput for WasiCliCtxView<'_> { } } -impl wasi::sockets::tcp::Host for WasiSocketsCtxView<'_> {} +impl wasi::sockets::tcp::Host for SpinSocketsView<'_> {} -impl wasi::sockets::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { +impl wasi::sockets::tcp::HostTcpSocket for SpinSocketsView<'_> { async fn start_bind( &mut self, self_: Resource, @@ -865,6 +865,10 @@ impl wasi::sockets::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { network: Resource, remote_address: IpSocketAddress, ) -> wasmtime::Result> { + // Delegate to the P2 SpinSocketsView impl (passing `self`, not `&mut self.inner`). + // This snapshot uses the raw P2 TcpSocket type — the resource rep is the same at + // start_connect and drop time — so the P2 impl's quota acquire/register/release + // logic round-trips correctly without any wrapper-level bookkeeping here. convert_result( latest::sockets::tcp::HostTcpSocket::start_connect( self, @@ -1123,7 +1127,7 @@ impl wasi::sockets::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { } } -impl wasi::sockets::tcp_create_socket::Host for WasiSocketsCtxView<'_> { +impl wasi::sockets::tcp_create_socket::Host for SpinSocketsView<'_> { fn create_tcp_socket( &mut self, address_family: IpAddressFamily, @@ -1135,9 +1139,9 @@ impl wasi::sockets::tcp_create_socket::Host for WasiSocketsCtxView<'_> { } } -impl wasi::sockets::udp::Host for WasiSocketsCtxView<'_> {} +impl wasi::sockets::udp::Host for SpinSocketsView<'_> {} -impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { +impl wasi::sockets::udp::HostUdpSocket for SpinSocketsView<'_> { async fn start_bind( &mut self, self_: Resource, @@ -1290,7 +1294,7 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { } } -impl wasi::sockets::udp::HostOutgoingDatagramStream for WasiSocketsCtxView<'_> { +impl wasi::sockets::udp::HostOutgoingDatagramStream for SpinSocketsView<'_> { fn check_send( &mut self, self_: Resource, @@ -1325,7 +1329,7 @@ impl wasi::sockets::udp::HostOutgoingDatagramStream for WasiSocketsCtxView<'_> { } } -impl wasi::sockets::udp::HostIncomingDatagramStream for WasiSocketsCtxView<'_> { +impl wasi::sockets::udp::HostIncomingDatagramStream for SpinSocketsView<'_> { fn receive( &mut self, self_: Resource, @@ -1351,7 +1355,7 @@ impl wasi::sockets::udp::HostIncomingDatagramStream for WasiSocketsCtxView<'_> { } } -impl wasi::sockets::udp_create_socket::Host for WasiSocketsCtxView<'_> { +impl wasi::sockets::udp_create_socket::Host for SpinSocketsView<'_> { fn create_udp_socket( &mut self, address_family: IpAddressFamily, @@ -1363,40 +1367,43 @@ impl wasi::sockets::udp_create_socket::Host for WasiSocketsCtxView<'_> { } } -impl wasi::sockets::instance_network::Host for WasiSocketsCtxView<'_> { +impl wasi::sockets::instance_network::Host for SpinSocketsView<'_> { fn instance_network(&mut self) -> wasmtime::Result> { - latest::sockets::instance_network::Host::instance_network(self) + latest::sockets::instance_network::Host::instance_network(&mut self.inner) } } -impl wasi::sockets::network::Host for WasiSocketsCtxView<'_> {} +impl wasi::sockets::network::Host for SpinSocketsView<'_> {} -impl wasi::sockets::network::HostNetwork for WasiSocketsCtxView<'_> { +impl wasi::sockets::network::HostNetwork for SpinSocketsView<'_> { fn drop(&mut self, rep: Resource) -> wasmtime::Result<()> { - latest::sockets::network::HostNetwork::drop(self, rep) + latest::sockets::network::HostNetwork::drop(&mut self.inner, rep) } } -impl wasi::sockets::ip_name_lookup::Host for WasiSocketsCtxView<'_> { +impl wasi::sockets::ip_name_lookup::Host for SpinSocketsView<'_> { fn resolve_addresses( &mut self, network: Resource, name: String, ) -> wasmtime::Result, SocketErrorCode>> { convert_result(latest::sockets::ip_name_lookup::Host::resolve_addresses( - self, network, name, + &mut self.inner, + network, + name, )) } } -impl wasi::sockets::ip_name_lookup::HostResolveAddressStream for WasiSocketsCtxView<'_> { +impl wasi::sockets::ip_name_lookup::HostResolveAddressStream for SpinSocketsView<'_> { fn resolve_next_address( &mut self, self_: Resource, ) -> wasmtime::Result, SocketErrorCode>> { convert_result( latest::sockets::ip_name_lookup::HostResolveAddressStream::resolve_next_address( - self, self_, + &mut self.inner, + self_, ) .map(|e| e.map(|e| e.into())), ) @@ -1406,11 +1413,11 @@ impl wasi::sockets::ip_name_lookup::HostResolveAddressStream for WasiSocketsCtxV &mut self, self_: Resource, ) -> wasmtime::Result> { - latest::sockets::ip_name_lookup::HostResolveAddressStream::subscribe(self, self_) + latest::sockets::ip_name_lookup::HostResolveAddressStream::subscribe(&mut self.inner, self_) } fn drop(&mut self, rep: Resource) -> wasmtime::Result<()> { - latest::sockets::ip_name_lookup::HostResolveAddressStream::drop(self, rep) + latest::sockets::ip_name_lookup::HostResolveAddressStream::drop(&mut self.inner, rep) } } From 66bf989861833bd51c6be4d27458047d48dc2821 Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Tue, 2 Jun 2026 14:41:14 +0200 Subject: [PATCH 09/15] Limit outbound MQTT connections Signed-off-by: Ryan Levick --- crates/factor-outbound-mqtt/Cargo.toml | 2 +- crates/factor-outbound-mqtt/src/host.rs | 48 +++++++--- crates/factor-outbound-mqtt/src/lib.rs | 25 ++++- .../src/runtime_config.rs | 5 + .../src/runtime_config/spin.rs | 6 +- .../factor-outbound-mqtt/tests/factor_test.rs | 92 ++++++++++++++++--- crates/factor-outbound-mysql/src/lib.rs | 2 +- crates/factor-outbound-redis/src/lib.rs | 2 +- crates/runtime-config/src/lib.rs | 9 +- 9 files changed, 156 insertions(+), 35 deletions(-) diff --git a/crates/factor-outbound-mqtt/Cargo.toml b/crates/factor-outbound-mqtt/Cargo.toml index 5561c30744..c72a17691b 100644 --- a/crates/factor-outbound-mqtt/Cargo.toml +++ b/crates/factor-outbound-mqtt/Cargo.toml @@ -6,9 +6,9 @@ edition = { workspace = true } [dependencies] anyhow = { workspace = true } -serde = { workspace = true } # Upstream hasn't been updating dependencies: https://github.com/bytebeamio/rumqtt/issues/1046 rumqttc = { git = "https://github.com/spinframework/rumqtt", rev = "65b7b39a70b12d1781acb61cc07f1f1b680e7643", default-features = false, features = ["use-rustls-no-provider", "url"] } +serde = { workspace = true, features = ["derive"] } spin-core = { path = "../core" } spin-factor-otel = { path = "../factor-otel" } spin-factor-outbound-networking = { path = "../factor-outbound-networking" } diff --git a/crates/factor-outbound-mqtt/src/host.rs b/crates/factor-outbound-mqtt/src/host.rs index efc6a6d229..5a090f3772 100644 --- a/crates/factor-outbound-mqtt/src/host.rs +++ b/crates/factor-outbound-mqtt/src/host.rs @@ -7,6 +7,7 @@ use spin_core::{ }; use spin_factor_otel::OtelFactorState; use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts; +use spin_factor_outbound_networking::{ConnectionPermit, ConnectionSemaphore}; use spin_world::spin::mqtt::mqtt as v3; use spin_world::v2::mqtt as v2; use tracing::{Level, instrument}; @@ -15,8 +16,9 @@ use crate::{ClientCreator, allowed_hosts::AllowedHostChecker}; pub struct InstanceState { allowed_hosts: AllowedHostChecker, - connections: spin_resource_table::Table>, + connections: spin_resource_table::Table<(Arc, ConnectionPermit)>, create_client: Arc, + semaphore: ConnectionSemaphore, otel: OtelFactorState, max_payload_size_bytes: Option, } @@ -25,6 +27,7 @@ impl InstanceState { pub fn new( allowed_hosts: OutboundAllowedHosts, create_client: Arc, + semaphore: ConnectionSemaphore, otel: OtelFactorState, max_payload_size_bytes: Option, ) -> Self { @@ -32,6 +35,7 @@ impl InstanceState { allowed_hosts: AllowedHostChecker::new(allowed_hosts), create_client, connections: spin_resource_table::Table::new(1024), + semaphore, otel, max_payload_size_bytes, } @@ -60,8 +64,15 @@ impl InstanceState { password: String, keep_alive_interval: Duration, ) -> Result, v2::Error> { + let permit = self + .semaphore + .acquire() + .await + .map_err(|_| v2::Error::TooManyConnections)?; + let client = + (self.create_client).create(address, username, password, keep_alive_interval)?; self.connections - .push((self.create_client).create(address, username, password, keep_alive_interval)?) + .push((client, permit)) .map(Resource::new_own) .map_err(|_| v2::Error::TooManyConnections) } @@ -72,7 +83,7 @@ impl InstanceState { .ok_or(v2::Error::Other( "could not find connection for resource".into(), )) - .map(|c| c.as_ref()) + .map(|(c, _permit)| c.as_ref()) } fn get_conn_v3( @@ -81,7 +92,7 @@ impl InstanceState { ) -> Result, v3::Error> { self.connections .get(connection.rep()) - .cloned() + .map(|(c, _permit)| c.clone()) .ok_or(v3::Error::Other( "could not find connection for resource".into(), )) @@ -110,10 +121,14 @@ impl v3::HostConnectionWithStore for crate::MqttFactorData { password: String, keep_alive_interval_in_secs: u64, ) -> Result, v3::Error> { - let (allowed_host_checker, create_client) = accessor.with(|mut access| { + let (allowed_host_checker, create_client, semaphore) = accessor.with(|mut access| { let host = access.get(); host.otel.reparent_tracing_span(); - (host.allowed_hosts.clone(), host.create_client.clone()) + ( + host.allowed_hosts.clone(), + host.create_client.clone(), + host.semaphore.clone(), + ) }); if !allowed_host_checker @@ -126,19 +141,22 @@ impl v3::HostConnectionWithStore for crate::MqttFactorData { ))); } - let client = create_client - .create( - address, - username, - password, - Duration::from_secs(keep_alive_interval_in_secs), - ) - .unwrap(); + let permit = semaphore + .acquire() + .await + .map_err(|_| v3::Error::TooManyConnections)?; + + let client = create_client.create( + address, + username, + password, + Duration::from_secs(keep_alive_interval_in_secs), + )?; accessor.with(|mut access| { let host = access.get(); host.connections - .push(client) + .push((client, permit)) .map(Resource::new_own) .map_err(|_| v3::Error::TooManyConnections) }) diff --git a/crates/factor-outbound-mqtt/src/lib.rs b/crates/factor-outbound-mqtt/src/lib.rs index b42444c788..2913d7bed7 100644 --- a/crates/factor-outbound-mqtt/src/lib.rs +++ b/crates/factor-outbound-mqtt/src/lib.rs @@ -9,9 +9,10 @@ use host::InstanceState; use rumqttc::{AsyncClient, Event, Incoming, Outgoing, QoS}; use spin_core::async_trait; use spin_factor_otel::OtelFactorState; -use spin_factor_outbound_networking::OutboundNetworkingFactor; +use spin_factor_outbound_networking::{ConnectionSemaphore, OutboundNetworkingFactor}; use spin_factors::{ ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, SelfInstanceBuilder, + anyhow, }; use spin_world::spin::mqtt::mqtt as v3; use spin_world::v2::mqtt as v2; @@ -20,6 +21,7 @@ use tokio::sync::Mutex; pub use host::MqttClient; use crate::host::other_error_v3; +use crate::runtime_config::RuntimeConfig; pub struct OutboundMqttFactor { create_client: Arc, @@ -32,11 +34,14 @@ impl OutboundMqttFactor { } pub struct AppState { + /// Optional maximum payload size in bytes for MQTT messages. If `None`, no limit is enforced. max_payload_size_bytes: Option, + /// Semaphore to limit concurrent outbound MQTT connections. + pub semaphore: ConnectionSemaphore, } impl Factor for OutboundMqttFactor { - type RuntimeConfig = runtime_config::RuntimeConfig; + type RuntimeConfig = RuntimeConfig; type AppState = AppState; type InstanceBuilder = InstanceState; @@ -51,7 +56,22 @@ impl Factor for OutboundMqttFactor { mut ctx: ConfigureAppContext, ) -> anyhow::Result { let config = ctx.take_runtime_config().unwrap_or_default(); + let networking = ctx.app_state::().ok(); + let global = networking.and_then(|s| s.global_connection_semaphore.clone()); + let global_total_limit = networking.and_then(|s| s.max_total_connections); + + if let (Some(per_factor), Some(global_limit)) = (config.max_connections, global_total_limit) + && per_factor > global_limit + { + tracing::warn!( + "outbound_mqtt max_connections ({per_factor}) exceeds global \ + max_total_connections ({global_limit}); the global limit will be the \ + effective cap" + ); + } + Ok(AppState { + semaphore: ConnectionSemaphore::new(global, config.max_connections, "mqtt"), max_payload_size_bytes: config.max_payload_size_bytes, }) } @@ -68,6 +88,7 @@ impl Factor for OutboundMqttFactor { Ok(InstanceState::new( allowed_hosts, self.create_client.clone(), + ctx.app_state().semaphore.clone(), otel, ctx.app_state().max_payload_size_bytes, )) diff --git a/crates/factor-outbound-mqtt/src/runtime_config.rs b/crates/factor-outbound-mqtt/src/runtime_config.rs index 786a2eb6f7..702f04233d 100644 --- a/crates/factor-outbound-mqtt/src/runtime_config.rs +++ b/crates/factor-outbound-mqtt/src/runtime_config.rs @@ -9,4 +9,9 @@ pub struct RuntimeConfig { /// should set this to prevent tenants from sending excessively large payloads. /// Configure via `[outbound_mqtt] max_payload_size_bytes` in the runtime config TOML. pub max_payload_size_bytes: Option, + /// If set, limits the number of concurrent outbound MQTT connections. + /// + /// When `None` (the default), no limit is enforced. Operators in multi-tenant deployments + /// should set this to prevent tenants from exhausting connection resources. + pub max_connections: Option, } diff --git a/crates/factor-outbound-mqtt/src/runtime_config/spin.rs b/crates/factor-outbound-mqtt/src/runtime_config/spin.rs index debe8d79e7..b7c3400193 100644 --- a/crates/factor-outbound-mqtt/src/runtime_config/spin.rs +++ b/crates/factor-outbound-mqtt/src/runtime_config/spin.rs @@ -1,4 +1,4 @@ -use anyhow::Context; +use anyhow::Context as _; use serde::Deserialize; use spin_factors::runtime_config::toml::GetTomlValue; @@ -8,6 +8,7 @@ use spin_factors::runtime_config::toml::GetTomlValue; /// ```toml /// [outbound_mqtt] /// max_payload_size_bytes = 65536 # optional, no limit by default +/// max_connections = 10 # optional, defaults to unlimited /// ``` pub fn config_from_table( table: &impl GetTomlValue, @@ -19,6 +20,7 @@ pub fn config_from_table( .context("failed to parse [outbound_mqtt] table")?; Ok(Some(super::RuntimeConfig { max_payload_size_bytes: toml.max_payload_size_bytes, + max_connections: toml.max_connections, })) } else { Ok(None) @@ -28,6 +30,6 @@ pub fn config_from_table( #[derive(Debug, Default, Deserialize)] #[serde(deny_unknown_fields)] struct OutboundMqttToml { - #[serde(default)] max_payload_size_bytes: Option, + max_connections: Option, } diff --git a/crates/factor-outbound-mqtt/tests/factor_test.rs b/crates/factor-outbound-mqtt/tests/factor_test.rs index b532fd5293..dbe4d6a08e 100644 --- a/crates/factor-outbound-mqtt/tests/factor_test.rs +++ b/crates/factor-outbound-mqtt/tests/factor_test.rs @@ -9,7 +9,7 @@ use spin_factor_variables::VariablesFactor; use spin_factors::{RuntimeFactors, anyhow}; use spin_factors_test::{TestEnvironment, toml}; use spin_world::spin::mqtt::mqtt::{Error, Qos}; -use spin_world::v2::mqtt as v2; +use spin_world::v2::mqtt as v2_mqtt; pub struct MockMqttClient {} @@ -62,7 +62,7 @@ fn test_env() -> TestEnvironment { #[tokio::test] async fn disallowed_host_fails() -> anyhow::Result<()> { - use v2::HostConnection; + use v2_mqtt::HostConnection; let env = TestEnvironment::new(factors()).extend_manifest(toml! { [component.test-component] @@ -82,14 +82,14 @@ async fn disallowed_host_fails() -> anyhow::Result<()> { let Err(err) = res else { bail!("expected Err, got Ok"); }; - assert!(matches!(err, v2::Error::ConnectionFailed(_))); + assert!(matches!(err, v2_mqtt::Error::ConnectionFailed(_))); Ok(()) } #[tokio::test] async fn allowed_host_succeeds() -> anyhow::Result<()> { - use v2::HostConnection; + use v2_mqtt::HostConnection; let mut state = test_env().build_instance_state().await?; @@ -111,7 +111,7 @@ async fn allowed_host_succeeds() -> anyhow::Result<()> { #[tokio::test] async fn exercise_publish() -> anyhow::Result<()> { - use v2::HostConnection; + use v2_mqtt::HostConnection; let mut state = test_env().build_instance_state().await?; @@ -131,7 +131,7 @@ async fn exercise_publish() -> anyhow::Result<()> { res, "message".to_string(), b"test message".to_vec(), - v2::Qos::ExactlyOnce, + v2_mqtt::Qos::ExactlyOnce, ) .await?; @@ -140,13 +140,14 @@ async fn exercise_publish() -> anyhow::Result<()> { #[tokio::test] async fn oversized_payload_rejected() -> anyhow::Result<()> { - use v2::HostConnection; + use v2_mqtt::HostConnection; const LIMIT: usize = 10; let env = test_env().runtime_config(TestFactorsRuntimeConfig { mqtt: Some(spin_factor_outbound_mqtt::runtime_config::RuntimeConfig { max_payload_size_bytes: Some(LIMIT), + ..Default::default() }), ..Default::default() })?; @@ -166,10 +167,15 @@ async fn oversized_payload_rejected() -> anyhow::Result<()> { let oversized = vec![0u8; LIMIT + 1]; let err = state .mqtt - .publish(conn, "topic".to_string(), oversized, v2::Qos::AtMostOnce) + .publish( + conn, + "topic".to_string(), + oversized, + v2_mqtt::Qos::AtMostOnce, + ) .await; assert!( - matches!(err, Err(v2::Error::Other(_))), + matches!(err, Err(v2_mqtt::Error::Other(_))), "expected Other error for oversized payload, got {err:?}" ); @@ -178,13 +184,14 @@ async fn oversized_payload_rejected() -> anyhow::Result<()> { #[tokio::test] async fn payload_at_limit_succeeds() -> anyhow::Result<()> { - use v2::HostConnection; + use v2_mqtt::HostConnection; const LIMIT: usize = 10; let env = test_env().runtime_config(TestFactorsRuntimeConfig { mqtt: Some(spin_factor_outbound_mqtt::runtime_config::RuntimeConfig { max_payload_size_bytes: Some(LIMIT), + ..Default::default() }), ..Default::default() })?; @@ -208,9 +215,72 @@ async fn payload_at_limit_succeeds() -> anyhow::Result<()> { conn, "topic".to_string(), exactly_limit, - v2::Qos::AtMostOnce, + v2_mqtt::Qos::AtMostOnce, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn connection_limit_blocks_when_exhausted() -> anyhow::Result<()> { + use v2_mqtt::HostConnection; + + let env = TestEnvironment::new(factors()) + .extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["mqtt://*:*"] + }) + .runtime_config(TestFactorsRuntimeConfig { + mqtt: Some(spin_factor_outbound_mqtt::runtime_config::RuntimeConfig { + max_connections: Some(1), + ..Default::default() + }), + ..Default::default() + })?; + + let mut state = env.build_instance_state().await?; + + // Open first connection - should succeed immediately. + let conn1 = state + .mqtt + .open( + "mqtt://mqtt.test:1883".to_string(), + "username".to_string(), + "password".to_string(), + 1, + ) + .await?; + + // Second open should block (wait for a permit) since the limit is 1. + let timed_out = tokio::time::timeout( + Duration::from_millis(10), + state.mqtt.open( + "mqtt://mqtt.test:1883".to_string(), + "username".to_string(), + "password".to_string(), + 1, + ), + ) + .await + .is_err(); + assert!(timed_out, "expected second open to block when limit is 1"); + + // Releasing the first connection returns its permit to the semaphore. + state.mqtt.drop(conn1).await?; + + // Now a new connection should succeed. + let conn2 = state + .mqtt + .open( + "mqtt://mqtt.test:1883".to_string(), + "username".to_string(), + "password".to_string(), + 1, ) .await?; + state.mqtt.drop(conn2).await?; Ok(()) } diff --git a/crates/factor-outbound-mysql/src/lib.rs b/crates/factor-outbound-mysql/src/lib.rs index 7a10809e7a..2645489a0b 100644 --- a/crates/factor-outbound-mysql/src/lib.rs +++ b/crates/factor-outbound-mysql/src/lib.rs @@ -23,7 +23,7 @@ pub struct OutboundMysqlFactor { } pub struct AppState { - /// Semaphore(s) to limit concurrent outbound MySQL connections. + /// Semaphore to limit concurrent outbound MySQL connections. pub semaphore: ConnectionSemaphore, } diff --git a/crates/factor-outbound-redis/src/lib.rs b/crates/factor-outbound-redis/src/lib.rs index 059729e7d1..f23ec2ee4e 100644 --- a/crates/factor-outbound-redis/src/lib.rs +++ b/crates/factor-outbound-redis/src/lib.rs @@ -27,7 +27,7 @@ impl OutboundRedisFactor { } pub struct AppState { - /// Semaphore(s) to limit concurrent outbound Redis connections. + /// Semaphore to limit concurrent outbound Redis connections. pub semaphore: ConnectionSemaphore, } diff --git a/crates/runtime-config/src/lib.rs b/crates/runtime-config/src/lib.rs index 4ebe9e5a91..e26fc97202 100644 --- a/crates/runtime-config/src/lib.rs +++ b/crates/runtime-config/src/lib.rs @@ -94,8 +94,13 @@ impl ResolvedRuntimeConfig { )); } } - // [outbound_redis: max_connections=N], [outbound_pg: max_connections=N], [outbound_mysql: max_connections=N] - for key in ["outbound_redis", "outbound_pg", "outbound_mysql"] { + // [outbound_redis: max_connections=N], [outbound_pg: max_connections=N], [outbound_mysql: max_connections=N], [outbound_mqtt: max_connections=N] + for key in [ + "outbound_redis", + "outbound_pg", + "outbound_mysql", + "outbound_mqtt", + ] { if let Some(table) = self.toml.get(key).and_then(Value::as_table) { if let Some(max) = table.get("max_connections").and_then(Value::as_integer) { summaries.push(format!("[{key}: max_connections={max}]")); From 04621e220621688a10a852606870ef803ca1dd12 Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Tue, 2 Jun 2026 14:45:32 +0200 Subject: [PATCH 10/15] Print out outbound http max connections summary Signed-off-by: Ryan Levick --- crates/runtime-config/src/lib.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/crates/runtime-config/src/lib.rs b/crates/runtime-config/src/lib.rs index e26fc97202..f164d0dd50 100644 --- a/crates/runtime-config/src/lib.rs +++ b/crates/runtime-config/src/lib.rs @@ -94,12 +94,13 @@ impl ResolvedRuntimeConfig { )); } } - // [outbound_redis: max_connections=N], [outbound_pg: max_connections=N], [outbound_mysql: max_connections=N], [outbound_mqtt: max_connections=N] + // [outbound_redis: max_connections=N], [outbound_pg: max_connections=N], [outbound_mysql: max_connections=N], [outbound_mqtt: max_connections=N], [outbound_http: max_connections=N] for key in [ "outbound_redis", "outbound_pg", "outbound_mysql", "outbound_mqtt", + "outbound_http", ] { if let Some(table) = self.toml.get(key).and_then(Value::as_table) { if let Some(max) = table.get("max_connections").and_then(Value::as_integer) { @@ -107,6 +108,17 @@ impl ResolvedRuntimeConfig { } } } + // [outbound_http: max_concurrent_requests=N (deprecated)] + if let Some(table) = self.toml.get("outbound_http").and_then(Value::as_table) { + if let Some(max) = table + .get("max_concurrent_requests") + .and_then(Value::as_integer) + { + summaries.push(format!( + "[outbound_http: max_concurrent_requests={max} (deprecated, use max_connections)]" + )); + } + } if !summaries.is_empty() { let summaries = summaries.join(", "); let from_path = runtime_config_path From d9dfde6e4ad27ff6a5543fa12d7f6ca4a3a6ace6 Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Tue, 2 Jun 2026 15:08:33 +0200 Subject: [PATCH 11/15] Small refactoring Signed-off-by: Ryan Levick --- crates/connection-semaphore/src/lib.rs | 8 +-- crates/factor-outbound-http/src/lib.rs | 24 ++----- crates/factor-outbound-http/src/wasi.rs | 17 ++--- crates/factor-outbound-mqtt/src/lib.rs | 23 +++---- crates/factor-outbound-mysql/src/lib.rs | 22 ++----- crates/factor-outbound-networking/src/lib.rs | 66 +++++++++++++++----- crates/factor-outbound-pg/src/lib.rs | 24 +++---- crates/factor-outbound-redis/src/lib.rs | 24 +++---- 8 files changed, 94 insertions(+), 114 deletions(-) diff --git a/crates/connection-semaphore/src/lib.rs b/crates/connection-semaphore/src/lib.rs index 183a04d9a5..4395359776 100644 --- a/crates/connection-semaphore/src/lib.rs +++ b/crates/connection-semaphore/src/lib.rs @@ -25,12 +25,8 @@ impl ConnectionSemaphore { } } - /// Creates a `ConnectionSemaphore` from pre-existing semaphore handles. - /// - /// This is intended for testing and internal use where an already-constructed - /// (and possibly partially acquired) semaphore must be used directly. - #[doc(hidden)] - pub fn from_raw( + #[cfg(test)] + pub(crate) fn from_raw( global: Option>, factor_specific: Option>, factor: &'static str, diff --git a/crates/factor-outbound-http/src/lib.rs b/crates/factor-outbound-http/src/lib.rs index 366ef9c908..fc7d9cf609 100644 --- a/crates/factor-outbound-http/src/lib.rs +++ b/crates/factor-outbound-http/src/lib.rs @@ -17,6 +17,7 @@ use runtime_config::RuntimeConfig; use spin_factor_otel::OtelFactorState; use spin_factor_outbound_networking::{ ComponentTlsClientConfigs, ConnectionSemaphore, OutboundNetworkingFactor, + build_connection_semaphore, config::{allowed_hosts::OutboundAllowedHosts, blocked_networks::BlockedNetworks}, }; use spin_factors::{ @@ -56,27 +57,14 @@ impl Factor for OutboundHttpFactor { ) -> anyhow::Result { let config = ctx.take_runtime_config().unwrap_or_default(); - let networking = ctx.app_state::().ok(); - let global = networking.and_then(|s| s.global_connection_semaphore.clone()); - let global_total_limit = networking.and_then(|s| s.max_total_connections); - - if let (Some(per_factor), Some(global_limit)) = - (config.max_concurrent_connections, global_total_limit) - && per_factor > global_limit - { - tracing::warn!( - "outbound_http max_concurrent_requests ({per_factor}) exceeds global \ - max_total_connections ({global_limit}); the global limit will be the \ - effective cap" - ); - } - - let factor_specific_limit = config.max_concurrent_connections; - Ok(AppState { wasi_http_clients: wasi::HttpClients::new(config.connection_pooling_enabled), connection_pooling_enabled: config.connection_pooling_enabled, - semaphore: ConnectionSemaphore::new(global, factor_specific_limit, "http"), + semaphore: build_connection_semaphore( + ctx.app_state::().ok(), + "http", + config.max_concurrent_connections, + ), }) } diff --git a/crates/factor-outbound-http/src/wasi.rs b/crates/factor-outbound-http/src/wasi.rs index 40957b6008..11171e45a1 100644 --- a/crates/factor-outbound-http/src/wasi.rs +++ b/crates/factor-outbound-http/src/wasi.rs @@ -1212,17 +1212,12 @@ mod tests { /// `ConnectionTimeout` within the configured deadline. #[tokio::test] async fn connect_timeout_applies_to_permit_acquisition() { - use std::sync::Arc; - use tokio::sync::Semaphore; - - // Create a semaphore with exactly 1 permit and hold it immediately, - // leaving 0 permits available. This simulates all outbound-connection - // slots being occupied. - let semaphore = Arc::new(Semaphore::new(1)); - let _held = semaphore.clone().try_acquire_owned().unwrap(); - - // Build a ConnectionSemaphore with the exhausted semaphore as the factor-specific limit. - let conn_semaphore = ConnectionSemaphore::from_raw(None, Some(semaphore), "test"); + // Create a semaphore with exactly 1 permit and immediately exhaust it, leaving + // 0 permits available. This simulates all outbound-connection slots being occupied. + let conn_semaphore = ConnectionSemaphore::new(None, Some(1), "test"); + let _held = conn_semaphore + .try_acquire() + .expect("exhausting the single permit"); let options = ConnectOptions { // No blocked networks; we want the address to pass the filter. diff --git a/crates/factor-outbound-mqtt/src/lib.rs b/crates/factor-outbound-mqtt/src/lib.rs index 2913d7bed7..2b74331f1d 100644 --- a/crates/factor-outbound-mqtt/src/lib.rs +++ b/crates/factor-outbound-mqtt/src/lib.rs @@ -9,7 +9,9 @@ use host::InstanceState; use rumqttc::{AsyncClient, Event, Incoming, Outgoing, QoS}; use spin_core::async_trait; use spin_factor_otel::OtelFactorState; -use spin_factor_outbound_networking::{ConnectionSemaphore, OutboundNetworkingFactor}; +use spin_factor_outbound_networking::{ + ConnectionSemaphore, OutboundNetworkingFactor, build_connection_semaphore, +}; use spin_factors::{ ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, SelfInstanceBuilder, anyhow, @@ -56,22 +58,13 @@ impl Factor for OutboundMqttFactor { mut ctx: ConfigureAppContext, ) -> anyhow::Result { let config = ctx.take_runtime_config().unwrap_or_default(); - let networking = ctx.app_state::().ok(); - let global = networking.and_then(|s| s.global_connection_semaphore.clone()); - let global_total_limit = networking.and_then(|s| s.max_total_connections); - - if let (Some(per_factor), Some(global_limit)) = (config.max_connections, global_total_limit) - && per_factor > global_limit - { - tracing::warn!( - "outbound_mqtt max_connections ({per_factor}) exceeds global \ - max_total_connections ({global_limit}); the global limit will be the \ - effective cap" - ); - } Ok(AppState { - semaphore: ConnectionSemaphore::new(global, config.max_connections, "mqtt"), + semaphore: build_connection_semaphore( + ctx.app_state::().ok(), + "mqtt", + config.max_connections, + ), max_payload_size_bytes: config.max_payload_size_bytes, }) } diff --git a/crates/factor-outbound-mysql/src/lib.rs b/crates/factor-outbound-mysql/src/lib.rs index 2645489a0b..72966d610e 100644 --- a/crates/factor-outbound-mysql/src/lib.rs +++ b/crates/factor-outbound-mysql/src/lib.rs @@ -9,7 +9,7 @@ use mysql_async::Conn as MysqlClient; use runtime_config::RuntimeConfig; use spin_factor_otel::OtelFactorState; use spin_factor_outbound_networking::{ - ConnectionPermit, ConnectionSemaphore, OutboundNetworkingFactor, + ConnectionPermit, ConnectionSemaphore, OutboundNetworkingFactor, build_connection_semaphore, config::allowed_hosts::OutboundAllowedHosts, }; use spin_factors::{Factor, FactorData, InitContext, RuntimeFactors, SelfInstanceBuilder}; @@ -45,22 +45,12 @@ impl Factor for OutboundMysqlFactor { ) -> anyhow::Result { let config = ctx.take_runtime_config().unwrap_or_default(); - let networking = ctx.app_state::().ok(); - let global = networking.and_then(|s| s.global_connection_semaphore.clone()); - let global_total_limit = networking.and_then(|s| s.max_total_connections); - - if let (Some(per_factor), Some(global_limit)) = (config.max_connections, global_total_limit) - && per_factor > global_limit - { - tracing::warn!( - "outbound_mysql max_connections ({per_factor}) exceeds global \ - max_total_connections ({global_limit}); the global limit will be the \ - effective cap" - ); - } - Ok(AppState { - semaphore: ConnectionSemaphore::new(global, config.max_connections, "mysql"), + semaphore: build_connection_semaphore( + ctx.app_state::().ok(), + "mysql", + config.max_connections, + ), }) } diff --git a/crates/factor-outbound-networking/src/lib.rs b/crates/factor-outbound-networking/src/lib.rs index 9553a61052..cbf91a40e2 100644 --- a/crates/factor-outbound-networking/src/lib.rs +++ b/crates/factor-outbound-networking/src/lib.rs @@ -78,7 +78,6 @@ impl Factor for OutboundNetworkingFactor { let blocked_networks = BlockedNetworks::new(block_networks, block_private_networks); let tls_client_configs = TlsClientConfigs::new(client_tls_configs)?; - let socket_quota = max_socket_connections.map(|n| Arc::new(Semaphore::new(n))); let global_connection_semaphore = max_total_connections.map(|n| Arc::new(Semaphore::new(n))); @@ -93,11 +92,22 @@ impl Factor for OutboundNetworkingFactor { ); } + let socket_connection_semaphore = + if max_socket_connections.is_some() || global_connection_semaphore.is_some() { + Some(ConnectionSemaphore::new( + global_connection_semaphore.clone(), + max_socket_connections, + "wasi-sockets", + )) + } else { + None + }; + Ok(AppState { component_allowed_hosts, blocked_networks, tls_client_configs, - socket_quota, + socket_connection_semaphore, global_connection_semaphore, max_total_connections, }) @@ -145,15 +155,11 @@ impl Factor for OutboundNetworkingFactor { self.disallowed_host_handler.clone(), ); let blocked_networks = ctx.app_state().blocked_networks.clone(); - let global_semaphore = ctx.app_state().global_connection_semaphore.clone(); - let socket_semaphore = ctx.app_state().socket_quota.clone(); - let permit_state = if global_semaphore.is_some() || socket_semaphore.is_some() { - let sem = - ConnectionSemaphore::from_raw(global_semaphore, socket_semaphore, "wasi-sockets"); - Some(SocketPermitState::new(sem)) - } else { - None - }; + let permit_state = ctx + .app_state() + .socket_connection_semaphore + .clone() + .map(SocketPermitState::new); match ctx.instance_builder::() { Ok(wasi_builder) => { @@ -219,14 +225,42 @@ pub struct AppState { blocked_networks: BlockedNetworks, /// TLS client configs tls_client_configs: TlsClientConfigs, - /// App-wide semaphore capping concurrent outbound TCP/UDP socket connections. - /// `None` means unlimited. - socket_quota: Option>, + /// Pre-built semaphore for TCP/UDP socket quota enforcement (global + socket-specific). + /// `None` means no limits are configured. + socket_connection_semaphore: Option, /// App-wide semaphore capping total concurrent outbound connections across ALL types. /// `None` means unlimited. - pub global_connection_semaphore: Option>, + global_connection_semaphore: Option>, /// The configured global connection limit (for warning comparisons in other factors). - pub max_total_connections: Option, + max_total_connections: Option, +} + +/// Builds a [`ConnectionSemaphore`] for an outbound factor, incorporating the optional global +/// connection limit from the networking factor's app state. +/// +/// Emits a warning when the per-factor limit exceeds the global cap (the global limit would +/// be the effective ceiling in that case). +pub fn build_connection_semaphore( + networking: Option<&AppState>, + factor: &'static str, + factor_limit: Option, +) -> ConnectionSemaphore { + if let (Some(per_factor), Some(global_limit)) = ( + factor_limit, + networking.and_then(|n| n.max_total_connections), + ) && per_factor > global_limit + { + tracing::warn!( + "outbound_{factor} max_connections ({per_factor}) exceeds global \ + max_total_connections ({global_limit}); the global limit will be the \ + effective cap" + ); + } + ConnectionSemaphore::new( + networking.and_then(|n| n.global_connection_semaphore.clone()), + factor_limit, + factor, + ) } pub struct InstanceBuilder { diff --git a/crates/factor-outbound-pg/src/lib.rs b/crates/factor-outbound-pg/src/lib.rs index cb218e3525..b36198d3bc 100644 --- a/crates/factor-outbound-pg/src/lib.rs +++ b/crates/factor-outbound-pg/src/lib.rs @@ -11,7 +11,9 @@ use allowed_hosts::AllowedHostChecker; use client::ClientFactory; use runtime_config::RuntimeConfig; use spin_factor_otel::OtelFactorState; -use spin_factor_outbound_networking::{ConnectionSemaphore, OutboundNetworkingFactor}; +use spin_factor_outbound_networking::{ + ConnectionSemaphore, OutboundNetworkingFactor, build_connection_semaphore, +}; use spin_factors::{ ConfigureAppContext, Factor, PrepareContext, RuntimeFactors, SelfInstanceBuilder, anyhow, }; @@ -53,23 +55,13 @@ impl Factor for OutboundPgFactor { client_factories.insert(comp.id().to_string(), Arc::new(CF::default())); } - let networking = ctx.app_state::().ok(); - let global = networking.and_then(|s| s.global_connection_semaphore.clone()); - let global_total_limit = networking.and_then(|s| s.max_total_connections); - - if let (Some(per_factor), Some(global_limit)) = (config.max_connections, global_total_limit) - && per_factor > global_limit - { - tracing::warn!( - "outbound_pg max_connections ({per_factor}) exceeds global \ - max_total_connections ({global_limit}); the global limit will be the \ - effective cap" - ); - } - Ok(AppState { client_factories, - semaphore: ConnectionSemaphore::new(global, config.max_connections, "pg"), + semaphore: build_connection_semaphore( + ctx.app_state::().ok(), + "pg", + config.max_connections, + ), }) } diff --git a/crates/factor-outbound-redis/src/lib.rs b/crates/factor-outbound-redis/src/lib.rs index f23ec2ee4e..47a05ff745 100644 --- a/crates/factor-outbound-redis/src/lib.rs +++ b/crates/factor-outbound-redis/src/lib.rs @@ -5,7 +5,9 @@ pub mod runtime_config; use host::InstanceState; use runtime_config::RuntimeConfig; use spin_factor_otel::OtelFactorState; -use spin_factor_outbound_networking::{ConnectionSemaphore, OutboundNetworkingFactor}; +use spin_factor_outbound_networking::{ + ConnectionSemaphore, OutboundNetworkingFactor, build_connection_semaphore, +}; use spin_factors::{ ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, SelfInstanceBuilder, anyhow, @@ -49,22 +51,12 @@ impl Factor for OutboundRedisFactor { ) -> anyhow::Result { let config = ctx.take_runtime_config().unwrap_or_default(); - let networking = ctx.app_state::().ok(); - let global = networking.and_then(|s| s.global_connection_semaphore.clone()); - let global_total_limit = networking.and_then(|s| s.max_total_connections); - - if let (Some(per_factor), Some(global_limit)) = (config.max_connections, global_total_limit) - && per_factor > global_limit - { - tracing::warn!( - "outbound_redis max_connections ({per_factor}) exceeds global \ - max_total_connections ({global_limit}); the global limit will be the \ - effective cap" - ); - } - Ok(AppState { - semaphore: ConnectionSemaphore::new(global, config.max_connections, "redis"), + semaphore: build_connection_semaphore( + ctx.app_state::().ok(), + "redis", + config.max_connections, + ), }) } From 45eb731891764339aed73450577b31c782f069c8 Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Tue, 2 Jun 2026 15:28:19 +0200 Subject: [PATCH 12/15] Add cross factor test Signed-off-by: Ryan Levick --- Cargo.lock | 3 + crates/factor-outbound-networking/Cargo.toml | 3 + .../tests/factor_test.rs | 108 ++++++++++++++++++ 3 files changed, 114 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index b8e3261a01..b269512b47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9060,6 +9060,7 @@ name = "spin-factor-outbound-networking" version = "4.1.0-pre0" dependencies = [ "anyhow", + "async-trait", "futures-util", "http 1.3.1", "ip_network", @@ -9069,6 +9070,7 @@ dependencies = [ "rustls-platform-verifier", "serde", "spin-connection-semaphore", + "spin-factor-outbound-mqtt", "spin-factor-variables", "spin-factor-wasi", "spin-factors", @@ -9077,6 +9079,7 @@ dependencies = [ "spin-manifest", "spin-outbound-networking-config", "spin-serde", + "spin-world", "tempfile", "tokio", "toml 0.8.19", diff --git a/crates/factor-outbound-networking/Cargo.toml b/crates/factor-outbound-networking/Cargo.toml index 71dfc9865a..9836e16cf1 100644 --- a/crates/factor-outbound-networking/Cargo.toml +++ b/crates/factor-outbound-networking/Cargo.toml @@ -28,7 +28,10 @@ url = { workspace = true } webpki-root-certs = "1.0.7" [dev-dependencies] +async-trait = { workspace = true } +spin-factor-outbound-mqtt = { path = "../factor-outbound-mqtt" } spin-factors-test = { path = "../factors-test" } +spin-world = { path = "../world" } tempfile = { workspace = true } tokio = { workspace = true, features = ["macros", "rt"] } toml = { workspace = true } diff --git a/crates/factor-outbound-networking/tests/factor_test.rs b/crates/factor-outbound-networking/tests/factor_test.rs index 84d9be9db5..bdeaa8fa79 100644 --- a/crates/factor-outbound-networking/tests/factor_test.rs +++ b/crates/factor-outbound-networking/tests/factor_test.rs @@ -1,3 +1,7 @@ +use std::sync::Arc; +use std::time::Duration; + +use spin_factor_outbound_mqtt::{ClientCreator, MqttClient, OutboundMqttFactor}; use spin_factor_outbound_networking::OutboundNetworkingFactor; use spin_factor_outbound_networking::runtime_config::RuntimeConfig; use spin_factor_outbound_networking::runtime_config::spin::SpinRuntimeConfig; @@ -6,6 +10,8 @@ use spin_factor_wasi::{DummyFilesMounter, WasiFactor}; use spin_factors::anyhow::Context as _; use spin_factors::{App, RuntimeFactors, anyhow}; use spin_factors_test::{TestEnvironment, toml}; +use spin_world::spin::mqtt::mqtt as v3_mqtt; +use spin_world::v2::mqtt as v2_mqtt; use wasmtime_wasi::p2::bindings::sockets::instance_network::Host; use wasmtime_wasi::p2::bindings::sockets::network::{ErrorCode, IpAddressFamily}; use wasmtime_wasi::p2::bindings::sockets::tcp as p2_tcp; @@ -13,6 +19,40 @@ use wasmtime_wasi::p2::bindings::sockets::tcp_create_socket as p2_tcp_create; use wasmtime_wasi::p2::bindings::sockets::udp_create_socket as p2_udp_create; use wasmtime_wasi::sockets::SocketAddrUse; +struct MockMqttClient; + +#[async_trait::async_trait] +impl MqttClient for MockMqttClient { + async fn publish_bytes( + &self, + _topic: String, + _qos: v3_mqtt::Qos, + _payload: Vec, + ) -> anyhow::Result<(), v3_mqtt::Error> { + Ok(()) + } +} + +impl ClientCreator for MockMqttClient { + fn create( + &self, + _address: String, + _username: String, + _password: String, + _keep_alive_interval: Duration, + ) -> anyhow::Result, v3_mqtt::Error> { + Ok(Arc::new(MockMqttClient)) + } +} + +#[derive(RuntimeFactors)] +struct TestFactorsWithMqtt { + wasi: WasiFactor, + variables: VariablesFactor, + networking: OutboundNetworkingFactor, + mqtt: OutboundMqttFactor, +} + #[derive(RuntimeFactors)] struct TestFactors { wasi: WasiFactor, @@ -395,3 +435,71 @@ async fn socket_quota_shared_between_tcp_and_udp() -> anyhow::Result<()> { assert_eq!(err.downcast_ref(), Some(&ErrorCode::NewSocketLimit)); Ok(()) } + +/// Verifies that the global connection limit is shared across factors: a permit +/// held by an MQTT connection blocks a WASI TCP socket (and vice-versa). +#[tokio::test] +async fn global_connection_limit_enforced_across_factors() -> anyhow::Result<()> { + use v2_mqtt::HostConnection as _; + + let factors = TestFactorsWithMqtt { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor::new(), + mqtt: OutboundMqttFactor::new(Arc::new(MockMqttClient)), + }; + let env = TestEnvironment::new(factors) + .extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["mqtt://*:*", "*://123.0.2.1:12345"] + }) + .runtime_config(TestFactorsWithMqttRuntimeConfig { + networking: Some(RuntimeConfig { + max_total_connections: Some(1), + ..Default::default() + }), + ..Default::default() + })?; + + let mut state = env.build_instance_state().await?; + + // Acquire the single global permit via an MQTT connection. + let conn = state + .mqtt + .open( + "mqtt://mqtt.test:1883".to_string(), + "username".to_string(), + "password".to_string(), + 1, + ) + .await?; + + // With the global permit held by MQTT, a TCP socket start_connect must fail immediately. + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + let addr: std::net::SocketAddr = "123.0.2.1:12345".parse().unwrap(); + let net = sockets.instance_network()?; + let sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + let err = p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock, net, addr.into()) + .await + .unwrap_err(); + assert_eq!( + err.downcast_ref(), + Some(&ErrorCode::NewSocketLimit), + "TCP socket should fail while global permit is held by MQTT" + ); + drop(sockets); + + // Releasing the MQTT connection returns the global permit. + state.mqtt.drop(conn).await?; + + // Now the TCP socket start_connect must succeed. + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + let net = sockets.instance_network()?; + let sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock, net, addr.into()) + .await + .expect("TCP socket should succeed after MQTT connection is released"); + + Ok(()) +} From b824b556f7cd985418684a9734b540dfd332e09a Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Tue, 2 Jun 2026 15:59:46 +0200 Subject: [PATCH 13/15] Add telemetry for ConnectionSemaphore::try_acquire and for acquire wait times Signed-off-by: Ryan Levick --- crates/connection-semaphore/src/lib.rs | 58 +++++++++++++++++++------- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/crates/connection-semaphore/src/lib.rs b/crates/connection-semaphore/src/lib.rs index 4395359776..48c186e4a6 100644 --- a/crates/connection-semaphore/src/lib.rs +++ b/crates/connection-semaphore/src/lib.rs @@ -66,6 +66,7 @@ impl ConnectionSemaphore { } } let mut waited = false; + let start = std::time::Instant::now(); let (global, factor_specific) = match (&self.global, &self.factor_specific) { (None, None) => (None, None), @@ -95,6 +96,12 @@ impl ConnectionSemaphore { }; let factor = self.factor; + if waited { + spin_telemetry::histogram!( + outbound_connection_permit_wait_duration_ms = start.elapsed().as_millis() as f64, + factor = factor + ); + } spin_telemetry::monotonic_counter!( outbound_connection_permits_acquired = 1, factor = factor, @@ -113,27 +120,50 @@ impl ConnectionSemaphore { /// If the global permit is acquired but the factor-specific permit is not /// available, the global permit is released before returning `None`. pub fn try_acquire(&self) -> Option { - // Acquire global first. If it fails, nothing is consumed — return None. + match self.try_acquire_permits() { + Ok(permit) => { + spin_telemetry::monotonic_counter!( + outbound_connection_permits_acquired = 1, + factor = self.factor, + waited = false + ); + Some(permit) + } + Err(limit) => { + spin_telemetry::monotonic_counter!( + outbound_connection_permits_rejected = 1, + factor = self.factor, + limit = limit + ); + None + } + } + } + + /// Inner logic for [`Self::try_acquire`], separated so the caller can emit + /// telemetry based on whether a permit was obtained. + /// + /// Returns `Err("global")` or `Err("factor")` to indicate which limit was + /// exhausted, so the caller can tag the rejection metric accordingly. + fn try_acquire_permits(&self) -> Result { + // Acquire global first. If it fails, nothing is consumed. let global = match &self.global { - Some(s) => Some(s.clone().try_acquire_owned().ok()?), + Some(s) => match s.clone().try_acquire_owned() { + Ok(p) => Some(p), + Err(_) => return Err("global"), + }, None => None, }; // Now attempt the factor-specific permit. - // If it fails, the global OwnedSemaphorePermit is dropped here, releasing - // the global slot before we return None. + // On failure, `global` is dropped here, releasing the global slot. let factor_specific = match &self.factor_specific { - Some(s) => Some(s.clone().try_acquire_owned().ok()?), + Some(s) => match s.clone().try_acquire_owned() { + Ok(p) => Some(p), + Err(_) => return Err("factor"), + }, None => None, }; - - let factor = self.factor; - spin_telemetry::monotonic_counter!( - outbound_connection_permits_acquired = 1, - factor = factor, - waited = false - ); - - Some(ConnectionPermit { + Ok(ConnectionPermit { _global: global, _factor_specific: factor_specific, }) From ad2152c0fa8a7dde86de59a443b30f1cb5e076aa Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Tue, 2 Jun 2026 18:19:55 +0200 Subject: [PATCH 14/15] Change metric label from 'factor' to 'kind' Signed-off-by: Ryan Levick --- crates/connection-semaphore/src/lib.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/connection-semaphore/src/lib.rs b/crates/connection-semaphore/src/lib.rs index 48c186e4a6..7ebf37605e 100644 --- a/crates/connection-semaphore/src/lib.rs +++ b/crates/connection-semaphore/src/lib.rs @@ -99,12 +99,12 @@ impl ConnectionSemaphore { if waited { spin_telemetry::histogram!( outbound_connection_permit_wait_duration_ms = start.elapsed().as_millis() as f64, - factor = factor + kind = factor ); } spin_telemetry::monotonic_counter!( outbound_connection_permits_acquired = 1, - factor = factor, + kind = factor, waited = waited ); @@ -124,7 +124,7 @@ impl ConnectionSemaphore { Ok(permit) => { spin_telemetry::monotonic_counter!( outbound_connection_permits_acquired = 1, - factor = self.factor, + kind = self.factor, waited = false ); Some(permit) @@ -132,7 +132,7 @@ impl ConnectionSemaphore { Err(limit) => { spin_telemetry::monotonic_counter!( outbound_connection_permits_rejected = 1, - factor = self.factor, + kind = self.factor, limit = limit ); None From 3e9fb87f5a1b4772de49746c0d02d91bda82cabf Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Wed, 3 Jun 2026 17:15:50 +0200 Subject: [PATCH 15/15] Adjust how we acquire blocking semaphores Signed-off-by: Ryan Levick --- crates/connection-semaphore/src/lib.rs | 40 +++++++++----------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/crates/connection-semaphore/src/lib.rs b/crates/connection-semaphore/src/lib.rs index 7ebf37605e..5b900c6e10 100644 --- a/crates/connection-semaphore/src/lib.rs +++ b/crates/connection-semaphore/src/lib.rs @@ -42,8 +42,8 @@ impl ConnectionSemaphore { /// them until dropped. /// /// When both a global and a factor-specific semaphore are configured, this - /// method never holds one permit while blocking on the other, preventing global - /// permits from being tied up while waiting on a factor-specific backlog. + /// method acquires factor-specific first, then global, ensuring the global + /// permit is never held while blocking on a factor-specific backlog. pub async fn acquire(&self) -> anyhow::Result { /// Acquires a single permit from `sem`, trying non-blocking first. /// @@ -68,31 +68,17 @@ impl ConnectionSemaphore { let mut waited = false; let start = std::time::Instant::now(); - let (global, factor_specific) = match (&self.global, &self.factor_specific) { - (None, None) => (None, None), - (Some(g), None) => (Some(acquire_one(g, &mut waited, "global").await?), None), - (None, Some(f)) => (None, Some(acquire_one(f, &mut waited, "factor").await?)), - // Loop until we acquire both. We have to be careful to avoid holding one permit while waiting for the other. - (Some(g), Some(f)) => loop { - let global = acquire_one(g, &mut waited, "global").await?; - match f.clone().try_acquire_owned() { - Ok(factor) => break (Some(global), Some(factor)), - Err(TryAcquireError::NoPermits) => {} - Err(_) => anyhow::bail!("factor connection semaphore closed"), - } - // Factor specific has no free permits: release global so other connection types aren't blocked, - // then wait for factor-specific before trying global again. - drop(global); - waited = true; - let factor = acquire_one(f, &mut waited, "factor").await?; - match g.clone().try_acquire_owned() { - Ok(global) => break (Some(global), Some(factor)), - Err(TryAcquireError::NoPermits) => {} - Err(_) => anyhow::bail!("global connection semaphore closed"), - } - // Global has no free permits: release factor specific and retry from the top of the loop. - drop(factor); - }, + // Acquire factor-specific first, then global. This ensures we never hold + // the global permit while blocking on factor-specific backlog. + let factor_specific = match &self.factor_specific { + Some(f) => Some(acquire_one(f, &mut waited, "factor").await?), + None => None, + }; + // It's fine to hold the factor-specific permit while waiting for the global slot, since + // other consumers of the factor-specific would also end up waiting for the same global slot. + let global = match &self.global { + Some(g) => Some(acquire_one(g, &mut waited, "global").await?), + None => None, }; let factor = self.factor;