Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion devolutions-agent/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ mod service;
use std::env;
use std::io::{self, BufRead};
use std::sync::mpsc;
use std::time::Duration;

use anyhow::{Context as _, Result, bail};
use ceviche::Service;
Expand Down Expand Up @@ -277,7 +278,13 @@ fn main() {
&command.enrollment_token,
command.advertise_subnets,
)
.await
.await?;

// Enrollment only proves HTTPS/TCP; fail the install now if the QUIC/UDP tunnel
// path is blocked, while the operator is still here to fix the firewall.
let conf = ConfHandle::init().context("load agent configuration for connectivity probe")?;
devolutions_agent::tunnel::probe_connectivity(&conf.get_conf().tunnel, Duration::from_secs(15))
.await
});

if let Err(error) = result {
Expand Down
258 changes: 176 additions & 82 deletions devolutions-agent/src/tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ impl Task for TunnelTask {
return Ok(());
}
Ok(ConnectionOutcome::CertRenewed) => {
// Renewal is a successful "completion", not a failure — skip
// the backoff and reconnect immediately with the new cert.
// Renewal is a completion, not a failure.
info!("Certificate renewed; reconnecting with new cert immediately");
backoff.reset();
continue;
Expand Down Expand Up @@ -194,14 +193,11 @@ enum ConnectionOutcome {
///
/// - `Ok(Shutdown)`: graceful shutdown, exit the task.
/// - `Ok(CertRenewed)`: certificate renewed; caller should reconnect immediately.
/// - `Err(...)`: connection lost or handshake failed — caller should retry with backoff.
/// - `Err(_)`: connection lost or handshake failed — caller should retry with backoff.
async fn run_single_connection(
conf_handle: &ConfHandle,
shutdown_signal: &mut ShutdownSignal,
) -> anyhow::Result<ConnectionOutcome> {
// Ensure rustls crypto provider is installed (ring).
let _ = rustls::crypto::ring::default_provider().install_default();

let agent_conf = conf_handle.get_conf();
let tunnel_conf = &agent_conf.tunnel;

Expand Down Expand Up @@ -260,23 +256,112 @@ async fn run_single_connection(
"Advertising subnets and domains"
);

let (_endpoint, connection) = connect_to_gateway(tunnel_conf).await?;

// -- Open control stream --

let mut ctrl: ControlStream<_, _> = connection.open_bi().await.context("open control stream")?.into();

// Send initial RouteAdvertise.
let epoch = 1u64;
let msg = ControlMessage::route_advertise(epoch, advertise_subnets.clone(), advertise_domains.clone());

ctrl.send(&msg).await.context("send initial RouteAdvertise")?;

info!(epoch, "Sent initial RouteAdvertise");

// -- Certificate renewal (post-connect, pre-traffic) --
//
// Run once per reconnect rather than on a periodic timer: the QUIC session
// has a 120s idle timeout and 15s keep-alive, so any blip / VPN reconnect
// / host sleep / gateway restart drops the connection within minutes and
// sends us back through this path. With a 1-year cert and a 15-day
// threshold, the renewal window will be hit on the first reconnect after
// T-15d, which is more than often enough in any real deployment.
if let Some(outcome) = try_renew_certificate(&mut ctrl, &connection, cert_path, key_path, ca_path).await? {
return Ok(outcome);
}

// Split: recv half goes to a reader task, send half stays for periodic messages.
let (mut ctrl_send, ctrl_recv) = ctrl.into_split();
let mut task_handles = tokio::task::JoinSet::new();
task_handles.spawn(run_control_reader(ctrl_recv));

// -- Main loop: accept incoming session streams + periodic tasks --

let route_interval = tunnel_conf.route_advertise_interval_secs;
let heartbeat_interval_secs = tunnel_conf.heartbeat_interval_secs;
let mut route_tick = tokio::time::interval(Duration::from_secs(route_interval));
let mut heartbeat_tick = tokio::time::interval(Duration::from_secs(heartbeat_interval_secs));
// Skip the first immediate tick (we already sent the initial RouteAdvertise).
route_tick.tick().await;
heartbeat_tick.tick().await;

loop {
tokio::select! {
biased;

_ = shutdown_signal.wait() => {
info!("Tunnel task shutting down");
connection.close(0u32.into(), b"shutting down");
break;
}

_ = route_tick.tick() => {
let msg = ControlMessage::route_advertise(epoch, advertise_subnets.clone(), advertise_domains.clone());
let _ = ctrl_send.send(&msg).await
.inspect(|_| trace!(epoch, "Sent RouteAdvertise (refresh)"))
.inspect_err(|e| error!(%e, "Failed to send RouteAdvertise"));
}

_ = heartbeat_tick.tick() => {
// TODO: track actual active_stream_count instead of hardcoded 0.
let msg = ControlMessage::heartbeat(current_time_millis(), 0);
let _ = ctrl_send.send(&msg).await
.inspect(|_| trace!("Sent Heartbeat"))
.inspect_err(|e| error!(%e, "Failed to send Heartbeat"));
}

result = connection.accept_bi() => {
let (send, recv) = result.context("accept incoming bidi stream")?;
let subnets = advertise_subnets.clone();
task_handles.spawn(run_session_proxy(subnets, send, recv));
}

// Reap completed session tasks.
Some(_) = task_handles.join_next() => {}
}
}

task_handles.shutdown().await;

Ok(ConnectionOutcome::Shutdown)
}

/// Build the mTLS client config, resolve the gateway endpoint, and perform the
/// QUIC handshake, returning the live endpoint and connection.
async fn connect_to_gateway(
tunnel_conf: &crate::config::TunnelConf,
) -> anyhow::Result<(quinn::Endpoint, quinn::Connection)> {
// Ensure rustls crypto provider is installed (ring).
let _ = rustls::crypto::ring::default_provider().install_default();
// -- Build rustls ClientConfig --

let certs: Vec<rustls_pki_types::CertificateDer<'static>> = rustls_pemfile::certs(&mut std::io::BufReader::new(
std::fs::File::open(cert_path.as_str()).context("open client cert file")?,
std::fs::File::open(tunnel_conf.client_cert_path.as_str()).context("open client cert file")?,
))
.collect::<Result<Vec<_>, _>>()
.context("parse client certificates")?;

let key = rustls_pemfile::private_key(&mut std::io::BufReader::new(
std::fs::File::open(key_path.as_str()).context("open client key file")?,
std::fs::File::open(tunnel_conf.client_key_path.as_str()).context("open client key file")?,
))
.context("parse private key file")?
.context("no private key found in file")?;

let mut roots = rustls::RootCertStore::empty();
let ca_certs: Vec<rustls_pki_types::CertificateDer<'static>> = rustls_pemfile::certs(&mut std::io::BufReader::new(
std::fs::File::open(ca_path.as_str()).context("open CA cert file")?,
std::fs::File::open(tunnel_conf.gateway_ca_cert_path.as_str()).context("open CA cert file")?,
))
.collect::<Result<Vec<_>, _>>()
.context("parse CA certificates")?;
Expand Down Expand Up @@ -363,84 +448,26 @@ async fn run_single_connection(

info!("QUIC connection established");

// -- Open control stream --

let mut ctrl: ControlStream<_, _> = connection.open_bi().await.context("open control stream")?.into();

// Send initial RouteAdvertise.
let epoch = 1u64;
let msg = ControlMessage::route_advertise(epoch, advertise_subnets.clone(), advertise_domains.clone());

ctrl.send(&msg).await.context("send initial RouteAdvertise")?;

info!(epoch, "Sent initial RouteAdvertise");
Ok((endpoint, connection))
}

// -- Certificate renewal (post-connect, pre-traffic) --
//
// Run once per reconnect rather than on a periodic timer: the QUIC session
// has a 120s idle timeout and 15s keep-alive, so any blip / VPN reconnect
// / host sleep / gateway restart drops the connection within minutes and
// sends us back through this path. With a 1-year cert and a 15-day
// threshold, the renewal window will be hit on the first reconnect after
// T-15d, which is more than often enough in any real deployment.
if let Some(outcome) = try_renew_certificate(&mut ctrl, &connection, cert_path, key_path, ca_path).await? {
return Ok(outcome);
/// Confirm the QUIC/UDP path to the gateway is open by completing one mTLS+QUIC handshake within
/// `timeout`, then draining the connection (a best-effort teardown that adds up to ~3s).
pub async fn probe_connectivity(tunnel_conf: &crate::config::TunnelConf, timeout: Duration) -> anyhow::Result<()> {
if !tunnel_conf.enabled {
bail!("agent tunnel is not enabled");
}

// Split: recv half goes to a reader task, send half stays for periodic messages.
let (mut ctrl_send, ctrl_recv) = ctrl.into_split();
let mut task_handles = tokio::task::JoinSet::new();
task_handles.spawn(run_control_reader(ctrl_recv));

// -- Main loop: accept incoming session streams + periodic tasks --

let route_interval = tunnel_conf.route_advertise_interval_secs;
let heartbeat_interval_secs = tunnel_conf.heartbeat_interval_secs;
let mut route_tick = tokio::time::interval(Duration::from_secs(route_interval));
let mut heartbeat_tick = tokio::time::interval(Duration::from_secs(heartbeat_interval_secs));
// Skip the first immediate tick (we already sent the initial RouteAdvertise).
route_tick.tick().await;
heartbeat_tick.tick().await;

loop {
tokio::select! {
biased;

_ = shutdown_signal.wait() => {
info!("Tunnel task shutting down");
connection.close(0u32.into(), b"shutting down");
break;
}

_ = route_tick.tick() => {
let msg = ControlMessage::route_advertise(epoch, advertise_subnets.clone(), advertise_domains.clone());
let _ = ctrl_send.send(&msg).await
.inspect(|_| trace!(epoch, "Sent RouteAdvertise (refresh)"))
.inspect_err(|e| error!(%e, "Failed to send RouteAdvertise"));
}

_ = heartbeat_tick.tick() => {
// TODO: track actual active_stream_count instead of hardcoded 0.
let msg = ControlMessage::heartbeat(current_time_millis(), 0);
let _ = ctrl_send.send(&msg).await
.inspect(|_| trace!("Sent Heartbeat"))
.inspect_err(|e| error!(%e, "Failed to send Heartbeat"));
}

result = connection.accept_bi() => {
let (send, recv) = result.context("accept incoming bidi stream")?;
let subnets = advertise_subnets.clone();
task_handles.spawn(run_session_proxy(subnets, send, recv));
}

// Reap completed session tasks.
Some(_) = task_handles.join_next() => {}
}
}
let (endpoint, connection) = tokio::time::timeout(timeout, connect_to_gateway(tunnel_conf))
.await
.context("tunnel connectivity probe timed out")??;

task_handles.shutdown().await;
// Flush the CONNECTION_CLOSE so the gateway unregisters this probe's connection promptly
// (keyed by agent_id) rather than after its idle timeout.
connection.close(0u32.into(), b"probe-complete");
let _ = tokio::time::timeout(Duration::from_secs(3), endpoint.wait_idle()).await;

Ok(ConnectionOutcome::Shutdown)
Ok(())
}

// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -638,3 +665,70 @@ async fn run_session_proxy(advertise_subnets: Vec<Ipv4Network>, send: quinn::Sen
.await
.inspect_err(|e| error!(%e, "Session proxy failed"));
}

#[cfg(test)]
mod tests {
use camino::Utf8PathBuf;

use super::*;
use crate::config::TunnelConf;

fn tunnel_conf_template() -> TunnelConf {
TunnelConf {
enabled: true,
gateway_endpoint: String::new(),
client_cert_path: Utf8PathBuf::new(),
client_key_path: Utf8PathBuf::new(),
gateway_ca_cert_path: Utf8PathBuf::new(),
advertise_subnets: Vec::new(),
advertise_domains: Vec::new(),
auto_detect_domain: false,
heartbeat_interval_secs: 15,
route_advertise_interval_secs: 60,
server_spki_sha256: None,
}
}

#[tokio::test]
async fn probe_fails_fast_when_tunnel_disabled() {
let mut conf = tunnel_conf_template();
conf.enabled = false;

let error = probe_connectivity(&conf, Duration::from_secs(5))

@CBenoit Benoît Cortier (CBenoit) Jun 26, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, a unit test should not run more than 1 sec.

.await
.expect_err("probe must fail when the tunnel is disabled");

assert!(
format!("{error:#}").contains("not enabled"),
"unexpected error: {error:#}"
);
}

#[tokio::test]
async fn probe_times_out_when_gateway_unreachable() {
// Throwaway PEMs so the pre-connect file reads succeed; nothing listens on the target
// port, so the probe fails quickly — via its own timeout or an immediate connect error.
let cert_key =
rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).expect("generate self-signed cert");

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: Use a pre-generated certificate / key pair instead of generating a new one on each run. It slows down the testsuite.

let dir = tempfile::tempdir().expect("temp dir");
let cert_path = dir.path().join("client.crt");
let key_path = dir.path().join("client.key");
let ca_path = dir.path().join("ca.crt");
std::fs::write(&cert_path, cert_key.cert.pem()).expect("write client cert");
std::fs::write(&key_path, cert_key.key_pair.serialize_pem()).expect("write client key");
std::fs::write(&ca_path, cert_key.cert.pem()).expect("write ca cert");

let mut conf = tunnel_conf_template();
// 127.0.0.1:1 is reserved and unbound; the QUIC handshake cannot complete.
conf.gateway_endpoint = "127.0.0.1:1".to_owned();
conf.client_cert_path = Utf8PathBuf::from_path_buf(cert_path).expect("utf8 cert path");
conf.client_key_path = Utf8PathBuf::from_path_buf(key_path).expect("utf8 key path");
conf.gateway_ca_cert_path = Utf8PathBuf::from_path_buf(ca_path).expect("utf8 ca path");

let started = std::time::Instant::now();
let result = probe_connectivity(&conf, Duration::from_secs(2)).await;

assert!(result.is_err(), "probe must fail when the gateway is unreachable");
assert!(started.elapsed() < Duration::from_secs(15), "probe must fail fast");

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 seconds is overly generous. For a unit test it should not exceed 1 second at most.

}
}
Loading
Loading