Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crates/factor-outbound-redis/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ edition = { workspace = true }
[dependencies]
anyhow = { workspace = true }
redis = { workspace = true , features = ["tokio-comp", "tokio-native-tls-comp", "aio"] }
serde = { workspace = true }
spin-core = { path = "../core" }
spin-factor-otel = { path = "../factor-otel" }
spin-factor-outbound-networking = { path = "../factor-outbound-networking" }
spin-factors = { path = "../factors" }
spin-resource-table = { path = "../table" }
spin-world = { path = "../world" }
tokio = { workspace = true }
tokio = { workspace = true, features = ["sync"] }
tracing = { workspace = true }

[dev-dependencies]
Expand Down
58 changes: 44 additions & 14 deletions crates/factor-outbound-redis/src/host.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::net::SocketAddr;
use std::sync::Arc;

use anyhow::Result;
use redis::AsyncConnectionConfig;
Expand All @@ -11,6 +12,7 @@ use spin_world::MAX_HOST_BUFFERED_BYTES;
use spin_world::spin::redis::redis as v3;
use spin_world::v1::{redis as v1, redis_types};
use spin_world::v2::redis as v2;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tracing::field::Empty;
use tracing::{Level, instrument};

Expand All @@ -19,7 +21,9 @@ use crate::allowed_hosts::AllowedHostChecker;
pub struct InstanceState {
pub(crate) allowed_host_checker: AllowedHostChecker,
pub blocked_networks: BlockedNetworks,
pub connections: spin_resource_table::Table<MultiplexedConnection>,
pub connections:
spin_resource_table::Table<(MultiplexedConnection, Option<OwnedSemaphorePermit>)>,
pub connection_semaphore: Option<Arc<Semaphore>>,
pub otel: OtelFactorState,
}

Expand All @@ -32,6 +36,15 @@ impl InstanceState {
&mut self,
address: String,
) -> Result<Resource<v2::Connection>, v2::Error> {
let permit = match &self.connection_semaphore {
Some(sem) => Some(
Arc::clone(sem)
.acquire_owned()
.await
.map_err(|_| v2::Error::TooManyConnections)?,
),
None => None,
};
let config = AsyncConnectionConfig::new()
.set_dns_resolver(SpinDnsResolver(self.blocked_networks.clone()));
let conn = redis::Client::open(address.as_str())
Expand All @@ -40,7 +53,7 @@ impl InstanceState {
.await
.map_err(other_error_v2)?;
self.connections
.push(conn)
.push((conn, permit))
.map(Resource::new_own)
.map_err(|_| v2::Error::TooManyConnections)
}
Expand All @@ -51,6 +64,7 @@ impl InstanceState {
) -> Result<&mut MultiplexedConnection, v2::Error> {
self.connections
.get_mut(connection.rep())
.map(|(conn, _permit)| conn)
.ok_or(v2::Error::Other(
"could not find connection for resource".into(),
))
Expand All @@ -62,7 +76,7 @@ impl InstanceState {
) -> Result<MultiplexedConnection, v3::Error> {
self.connections
.get(connection.rep())
.cloned()
.map(|(conn, _permit)| conn.clone())
.ok_or(v3::Error::Other(
"could not find connection for resource".into(),
))
Expand Down Expand Up @@ -229,14 +243,16 @@ impl v3::HostConnectionWithStore for crate::RedisFactorData {
accessor: &Accessor<T, Self>,
address: String,
) -> Result<Resource<v3::Connection>, v3::Error> {
let (allowed_host_checker, blocked_networks) = accessor.with(|mut access| {
let host = access.get();
host.otel.reparent_tracing_span();
(
host.allowed_host_checker.clone(),
host.blocked_networks.clone(),
)
});
let (allowed_host_checker, blocked_networks, connection_semaphore) =
accessor.with(|mut access| {
let host = access.get();
host.otel.reparent_tracing_span();
(
host.allowed_host_checker.clone(),
host.blocked_networks.clone(),
host.connection_semaphore.clone(),
)
});

if !allowed_host_checker
.is_address_allowed(&address)
Expand All @@ -246,6 +262,15 @@ impl v3::HostConnectionWithStore for crate::RedisFactorData {
return Err(v3::Error::InvalidAddress);
}

let permit = match connection_semaphore {
Some(sem) => Some(
sem.acquire_owned()
.await
.map_err(|_| v3::Error::TooManyConnections)?,
),
None => None,
};

let config =
AsyncConnectionConfig::new().set_dns_resolver(SpinDnsResolver(blocked_networks));
let conn = redis::Client::open(address.as_str())
Expand All @@ -257,7 +282,7 @@ impl v3::HostConnectionWithStore for crate::RedisFactorData {
accessor.with(|mut access| {
let host = access.get();
host.connections
.push(conn)
.push((conn, permit))
.map(Resource::new_own)
.map_err(|_| v3::Error::TooManyConnections)
})
Expand Down Expand Up @@ -532,9 +557,14 @@ macro_rules! delegate {
Ok(c) => c,
Err(_) => return Err(v1::Error::Error),
};
<Self as v2::HostConnection>::$name($self, connection, $($arg),*)
// v1 has no persistent connections, so remove the table entry immediately
// after the call to release the semaphore permit.
let rep = connection.rep();
let result = <Self as v2::HostConnection>::$name($self, connection, $($arg),*)
.await
.map_err(|_| v1::Error::Error)
.map_err(|_| v1::Error::Error);
$self.connections.remove(rep);
result
}};
}

Expand Down
22 changes: 18 additions & 4 deletions crates/factor-outbound-redis/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
mod allowed_hosts;
mod host;
pub mod runtime_config;

use std::sync::Arc;

use host::InstanceState;
use runtime_config::RuntimeConfig;
use spin_factor_otel::OtelFactorState;
use spin_factor_outbound_networking::OutboundNetworkingFactor;
use spin_factors::{
ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, SelfInstanceBuilder,
anyhow,
};
use spin_world::spin::redis::redis as v3;
use tokio::sync::Semaphore;

use crate::allowed_hosts::AllowedHostChecker;

Expand All @@ -24,9 +29,14 @@ impl OutboundRedisFactor {
}
}

pub struct AppState {
/// A semaphore to limit the number of concurrent outbound Redis connections.
pub connection_semaphore: Option<Arc<Semaphore>>,
}

impl Factor for OutboundRedisFactor {
type RuntimeConfig = ();
type AppState = ();
type RuntimeConfig = RuntimeConfig;
type AppState = AppState;
type InstanceBuilder = InstanceState;

fn init(&mut self, ctx: &mut impl spin_factors::InitContext<Self>) -> anyhow::Result<()> {
Expand All @@ -38,9 +48,12 @@ impl Factor for OutboundRedisFactor {

fn configure_app<T: RuntimeFactors>(
&self,
_ctx: ConfigureAppContext<T, Self>,
mut ctx: ConfigureAppContext<T, Self>,
) -> anyhow::Result<Self::AppState> {
Ok(())
let config = ctx.take_runtime_config().unwrap_or_default();
Ok(AppState {
connection_semaphore: config.max_connections.map(|n| Arc::new(Semaphore::new(n))),
})
}

fn prepare<T: RuntimeFactors>(
Expand All @@ -54,6 +67,7 @@ impl Factor for OutboundRedisFactor {
allowed_host_checker: AllowedHostChecker::new(outbound_networking.allowed_hosts()),
blocked_networks: outbound_networking.blocked_networks(),
connections: spin_resource_table::Table::new(1024),
connection_semaphore: ctx.app_state().connection_semaphore.clone(),
otel,
})
}
Expand Down
8 changes: 8 additions & 0 deletions crates/factor-outbound-redis/src/runtime_config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
pub mod spin;

/// Runtime configuration for outbound Redis.
#[derive(Default)]
pub struct RuntimeConfig {
/// If set, limits the number of concurrent outbound Redis connections.
pub max_connections: Option<usize>,
}
29 changes: 29 additions & 0 deletions crates/factor-outbound-redis/src/runtime_config/spin.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use serde::Deserialize;
use spin_factors::runtime_config::toml::GetTomlValue;

/// Get the runtime configuration for outbound Redis from a TOML table.
///
/// Expects table to be in the format:
/// ```toml
/// [outbound_redis]
/// max_connections = 10 # optional, defaults to unlimited
/// ```
pub fn config_from_table(
table: &impl GetTomlValue,
) -> anyhow::Result<Option<super::RuntimeConfig>> {
if let Some(outbound_redis) = table.get("outbound_redis") {
let toml = outbound_redis.clone().try_into::<OutboundRedisToml>()?;
Ok(Some(super::RuntimeConfig {
max_connections: toml.max_connections,
}))
} else {
Ok(None)
}
}

#[derive(Debug, Default, Deserialize)]
#[serde(deny_unknown_fields)]
struct OutboundRedisToml {
#[serde(default)]
max_connections: Option<usize>,
}
6 changes: 4 additions & 2 deletions crates/runtime-config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,10 @@ impl FactorRuntimeConfigSource<LlmFactor> for TomlRuntimeConfigSource<'_, '_> {
}

impl FactorRuntimeConfigSource<OutboundRedisFactor> for TomlRuntimeConfigSource<'_, '_> {
fn get_runtime_config(&mut self) -> anyhow::Result<Option<()>> {
Ok(None)
fn get_runtime_config(
&mut self,
) -> anyhow::Result<Option<<OutboundRedisFactor as spin_factors::Factor>::RuntimeConfig>> {
spin_factor_outbound_redis::runtime_config::spin::config_from_table(&self.toml.table)
}
}

Expand Down
Loading