Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion devolutions-agent/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ mod service;
use std::env;
use std::io::{self, BufRead};
use std::sync::mpsc;
use std::time::Duration;

use anyhow::{Context as _, Result, bail};
use ceviche::Service;
Expand Down Expand Up @@ -277,7 +278,13 @@ fn main() {
&command.enrollment_token,
command.advertise_subnets,
)
.await
.await?;

// Enrollment only proves HTTPS/TCP; fail the install now if the QUIC/UDP tunnel
// path is blocked, while the operator is still here to fix the firewall.
let conf = ConfHandle::init().context("load agent configuration for connectivity probe")?;
devolutions_agent::tunnel::probe_connectivity(&conf.get_conf().tunnel, Duration::from_secs(15))
.await
});

if let Err(error) = result {
Expand Down
258 changes: 176 additions & 82 deletions devolutions-agent/src/tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ impl Task for TunnelTask {
return Ok(());
}
Ok(ConnectionOutcome::CertRenewed) => {
// Renewal is a successful "completion", not a failure — skip
// the backoff and reconnect immediately with the new cert.
// Renewal is a completion, not a failure.
info!("Certificate renewed; reconnecting with new cert immediately");
backoff.reset();
continue;
Expand Down Expand Up @@ -194,14 +193,11 @@ enum ConnectionOutcome {
///
/// - `Ok(Shutdown)`: graceful shutdown, exit the task.
/// - `Ok(CertRenewed)`: certificate renewed; caller should reconnect immediately.
/// - `Err(...)`: connection lost or handshake failed — caller should retry with backoff.
/// - `Err(_)`: connection lost or handshake failed — caller should retry with backoff.
async fn run_single_connection(
conf_handle: &ConfHandle,
shutdown_signal: &mut ShutdownSignal,
) -> anyhow::Result<ConnectionOutcome> {
// Ensure rustls crypto provider is installed (ring).
let _ = rustls::crypto::ring::default_provider().install_default();

let agent_conf = conf_handle.get_conf();
let tunnel_conf = &agent_conf.tunnel;

Expand Down Expand Up @@ -260,23 +256,112 @@ async fn run_single_connection(
"Advertising subnets and domains"
);

let (_endpoint, connection) = connect_to_gateway(tunnel_conf).await?;

// -- Open control stream --

let mut ctrl: ControlStream<_, _> = connection.open_bi().await.context("open control stream")?.into();

// Send initial RouteAdvertise.
let epoch = 1u64;
let msg = ControlMessage::route_advertise(epoch, advertise_subnets.clone(), advertise_domains.clone());

ctrl.send(&msg).await.context("send initial RouteAdvertise")?;

info!(epoch, "Sent initial RouteAdvertise");

// -- Certificate renewal (post-connect, pre-traffic) --
//
// Run once per reconnect rather than on a periodic timer: the QUIC session
// has a 120s idle timeout and 15s keep-alive, so any blip / VPN reconnect
// / host sleep / gateway restart drops the connection within minutes and
// sends us back through this path. With a 1-year cert and a 15-day
// threshold, the renewal window will be hit on the first reconnect after
// T-15d, which is more than often enough in any real deployment.
if let Some(outcome) = try_renew_certificate(&mut ctrl, &connection, cert_path, key_path, ca_path).await? {
return Ok(outcome);
}

// Split: recv half goes to a reader task, send half stays for periodic messages.
let (mut ctrl_send, ctrl_recv) = ctrl.into_split();
let mut task_handles = tokio::task::JoinSet::new();
task_handles.spawn(run_control_reader(ctrl_recv));

// -- Main loop: accept incoming session streams + periodic tasks --

let route_interval = tunnel_conf.route_advertise_interval_secs;
let heartbeat_interval_secs = tunnel_conf.heartbeat_interval_secs;
let mut route_tick = tokio::time::interval(Duration::from_secs(route_interval));
let mut heartbeat_tick = tokio::time::interval(Duration::from_secs(heartbeat_interval_secs));
// Skip the first immediate tick (we already sent the initial RouteAdvertise).
route_tick.tick().await;
heartbeat_tick.tick().await;

loop {
tokio::select! {
biased;

_ = shutdown_signal.wait() => {
info!("Tunnel task shutting down");
connection.close(0u32.into(), b"shutting down");
break;
}

_ = route_tick.tick() => {
let msg = ControlMessage::route_advertise(epoch, advertise_subnets.clone(), advertise_domains.clone());
let _ = ctrl_send.send(&msg).await
.inspect(|_| trace!(epoch, "Sent RouteAdvertise (refresh)"))
.inspect_err(|e| error!(%e, "Failed to send RouteAdvertise"));
}

_ = heartbeat_tick.tick() => {
// TODO: track actual active_stream_count instead of hardcoded 0.
let msg = ControlMessage::heartbeat(current_time_millis(), 0);
let _ = ctrl_send.send(&msg).await
.inspect(|_| trace!("Sent Heartbeat"))
.inspect_err(|e| error!(%e, "Failed to send Heartbeat"));
}

result = connection.accept_bi() => {
let (send, recv) = result.context("accept incoming bidi stream")?;
let subnets = advertise_subnets.clone();
task_handles.spawn(run_session_proxy(subnets, send, recv));
}

// Reap completed session tasks.
Some(_) = task_handles.join_next() => {}
}
}

task_handles.shutdown().await;

Ok(ConnectionOutcome::Shutdown)
}

/// Build the mTLS client config, resolve the gateway endpoint, and perform the
/// QUIC handshake, returning the live endpoint and connection.
async fn connect_to_gateway(
tunnel_conf: &crate::config::TunnelConf,
) -> anyhow::Result<(quinn::Endpoint, quinn::Connection)> {
// Ensure rustls crypto provider is installed (ring).
let _ = rustls::crypto::ring::default_provider().install_default();
// -- Build rustls ClientConfig --

let certs: Vec<rustls_pki_types::CertificateDer<'static>> = rustls_pemfile::certs(&mut std::io::BufReader::new(
std::fs::File::open(cert_path.as_str()).context("open client cert file")?,
std::fs::File::open(tunnel_conf.client_cert_path.as_str()).context("open client cert file")?,
))
.collect::<Result<Vec<_>, _>>()
.context("parse client certificates")?;

let key = rustls_pemfile::private_key(&mut std::io::BufReader::new(
std::fs::File::open(key_path.as_str()).context("open client key file")?,
std::fs::File::open(tunnel_conf.client_key_path.as_str()).context("open client key file")?,
))
.context("parse private key file")?
.context("no private key found in file")?;

let mut roots = rustls::RootCertStore::empty();
let ca_certs: Vec<rustls_pki_types::CertificateDer<'static>> = rustls_pemfile::certs(&mut std::io::BufReader::new(
std::fs::File::open(ca_path.as_str()).context("open CA cert file")?,
std::fs::File::open(tunnel_conf.gateway_ca_cert_path.as_str()).context("open CA cert file")?,
))
.collect::<Result<Vec<_>, _>>()
.context("parse CA certificates")?;
Expand Down Expand Up @@ -363,84 +448,26 @@ async fn run_single_connection(

info!("QUIC connection established");

// -- Open control stream --

let mut ctrl: ControlStream<_, _> = connection.open_bi().await.context("open control stream")?.into();

// Send initial RouteAdvertise.
let epoch = 1u64;
let msg = ControlMessage::route_advertise(epoch, advertise_subnets.clone(), advertise_domains.clone());

ctrl.send(&msg).await.context("send initial RouteAdvertise")?;

info!(epoch, "Sent initial RouteAdvertise");
Ok((endpoint, connection))
}

// -- Certificate renewal (post-connect, pre-traffic) --
//
// Run once per reconnect rather than on a periodic timer: the QUIC session
// has a 120s idle timeout and 15s keep-alive, so any blip / VPN reconnect
// / host sleep / gateway restart drops the connection within minutes and
// sends us back through this path. With a 1-year cert and a 15-day
// threshold, the renewal window will be hit on the first reconnect after
// T-15d, which is more than often enough in any real deployment.
if let Some(outcome) = try_renew_certificate(&mut ctrl, &connection, cert_path, key_path, ca_path).await? {
return Ok(outcome);
/// Confirm the QUIC/UDP path to the gateway is open by completing one mTLS+QUIC handshake, then
/// draining the connection, bounded by `timeout`.
pub async fn probe_connectivity(tunnel_conf: &crate::config::TunnelConf, timeout: Duration) -> anyhow::Result<()> {
if !tunnel_conf.enabled {
bail!("agent tunnel is not enabled");
}

// Split: recv half goes to a reader task, send half stays for periodic messages.
let (mut ctrl_send, ctrl_recv) = ctrl.into_split();
let mut task_handles = tokio::task::JoinSet::new();
task_handles.spawn(run_control_reader(ctrl_recv));

// -- Main loop: accept incoming session streams + periodic tasks --

let route_interval = tunnel_conf.route_advertise_interval_secs;
let heartbeat_interval_secs = tunnel_conf.heartbeat_interval_secs;
let mut route_tick = tokio::time::interval(Duration::from_secs(route_interval));
let mut heartbeat_tick = tokio::time::interval(Duration::from_secs(heartbeat_interval_secs));
// Skip the first immediate tick (we already sent the initial RouteAdvertise).
route_tick.tick().await;
heartbeat_tick.tick().await;

loop {
tokio::select! {
biased;

_ = shutdown_signal.wait() => {
info!("Tunnel task shutting down");
connection.close(0u32.into(), b"shutting down");
break;
}

_ = route_tick.tick() => {
let msg = ControlMessage::route_advertise(epoch, advertise_subnets.clone(), advertise_domains.clone());
let _ = ctrl_send.send(&msg).await
.inspect(|_| trace!(epoch, "Sent RouteAdvertise (refresh)"))
.inspect_err(|e| error!(%e, "Failed to send RouteAdvertise"));
}

_ = heartbeat_tick.tick() => {
// TODO: track actual active_stream_count instead of hardcoded 0.
let msg = ControlMessage::heartbeat(current_time_millis(), 0);
let _ = ctrl_send.send(&msg).await
.inspect(|_| trace!("Sent Heartbeat"))
.inspect_err(|e| error!(%e, "Failed to send Heartbeat"));
}

result = connection.accept_bi() => {
let (send, recv) = result.context("accept incoming bidi stream")?;
let subnets = advertise_subnets.clone();
task_handles.spawn(run_session_proxy(subnets, send, recv));
}

// Reap completed session tasks.
Some(_) = task_handles.join_next() => {}
}
}
let (endpoint, connection) = tokio::time::timeout(timeout, connect_to_gateway(tunnel_conf))
.await
.context("tunnel connectivity probe timed out")??;

task_handles.shutdown().await;
// Flush the CONNECTION_CLOSE so the gateway unregisters this probe's connection promptly
// (keyed by agent_id) rather than after its idle timeout.
connection.close(0u32.into(), b"probe-complete");
let _ = tokio::time::timeout(Duration::from_secs(3), endpoint.wait_idle()).await;

Ok(ConnectionOutcome::Shutdown)
Ok(())
}

// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -638,3 +665,70 @@ async fn run_session_proxy(advertise_subnets: Vec<Ipv4Network>, send: quinn::Sen
.await
.inspect_err(|e| error!(%e, "Session proxy failed"));
}

#[cfg(test)]
mod tests {
use camino::Utf8PathBuf;

use super::*;
use crate::config::TunnelConf;

fn tunnel_conf_template() -> TunnelConf {
TunnelConf {
enabled: true,
gateway_endpoint: String::new(),
client_cert_path: Utf8PathBuf::new(),
client_key_path: Utf8PathBuf::new(),
gateway_ca_cert_path: Utf8PathBuf::new(),
advertise_subnets: Vec::new(),
advertise_domains: Vec::new(),
auto_detect_domain: false,
heartbeat_interval_secs: 15,
route_advertise_interval_secs: 60,
server_spki_sha256: None,
}
}

#[tokio::test]
async fn probe_fails_fast_when_tunnel_disabled() {
let mut conf = tunnel_conf_template();
conf.enabled = false;

let error = probe_connectivity(&conf, Duration::from_secs(5))

@CBenoit Benoît Cortier (CBenoit) Jun 26, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, a unit test should not run more than 1 sec.

.await
.expect_err("probe must fail when the tunnel is disabled");

assert!(
format!("{error:#}").contains("not enabled"),
"unexpected error: {error:#}"
);
}

#[tokio::test]
async fn probe_times_out_when_gateway_unreachable() {
// Throwaway PEMs so the pre-connect file reads succeed; nothing listens on the target
// port, so the handshake never completes and the probe must hit its own timeout.
let cert_key =
rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).expect("generate self-signed cert");

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: Use a pre-generated certificate / key pair instead of generating a new one on each run. It slows down the testsuite.

let dir = tempfile::tempdir().expect("temp dir");
let cert_path = dir.path().join("client.crt");
let key_path = dir.path().join("client.key");
let ca_path = dir.path().join("ca.crt");
std::fs::write(&cert_path, cert_key.cert.pem()).expect("write client cert");
std::fs::write(&key_path, cert_key.key_pair.serialize_pem()).expect("write client key");
std::fs::write(&ca_path, cert_key.cert.pem()).expect("write ca cert");

let mut conf = tunnel_conf_template();
// 127.0.0.1:1 is reserved and unbound; the QUIC handshake cannot complete.
conf.gateway_endpoint = "127.0.0.1:1".to_owned();
conf.client_cert_path = Utf8PathBuf::from_path_buf(cert_path).expect("utf8 cert path");
conf.client_key_path = Utf8PathBuf::from_path_buf(key_path).expect("utf8 key path");
conf.gateway_ca_cert_path = Utf8PathBuf::from_path_buf(ca_path).expect("utf8 ca path");

let started = std::time::Instant::now();
let result = probe_connectivity(&conf, Duration::from_secs(2)).await;

assert!(result.is_err(), "probe must fail when the gateway is unreachable");
assert!(started.elapsed() < Duration::from_secs(15), "probe must fail fast");

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 seconds is overly generous. For a unit test it should not exceed 1 second at most.

}
}
Loading
Loading