Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion devolutions-agent/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ mod service;
use std::env;
use std::io::{self, BufRead};
use std::sync::mpsc;
use std::time::Duration;

use anyhow::{Context as _, Result, bail};
use ceviche::Service;
Expand Down Expand Up @@ -277,7 +278,12 @@ fn main() {
&command.enrollment_token,
command.advertise_subnets,
)
.await
.await?;

// Enrollment only proves HTTPS/TCP; fail the install now if the QUIC/UDP tunnel
// path is blocked, while the operator is still here to fix the firewall.
let conf = ConfHandle::init().context("load agent configuration for connectivity probe")?;
devolutions_agent::tunnel::probe_connectivity(&conf.get_conf().tunnel, Duration::from_secs(15)).await
});

if let Err(error) = result {
Expand Down
147 changes: 123 additions & 24 deletions devolutions-agent/src/tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ impl Task for TunnelTask {
return Ok(());
}
Ok(ConnectionOutcome::CertRenewed) => {
// Renewal is a successful "completion", not a failure — skip
// the backoff and reconnect immediately with the new cert.
// Renewal is a completion, not a failure.
info!("Certificate renewed; reconnecting with new cert immediately");
backoff.reset();
continue;
Expand Down Expand Up @@ -190,25 +189,10 @@ enum ConnectionOutcome {
CertRenewed,
}

/// Run a single QUIC tunnel connection lifetime: config → connect → event loop.
///
/// - `Ok(Shutdown)`: graceful shutdown, exit the task.
/// - `Ok(CertRenewed)`: certificate renewed; caller should reconnect immediately.
/// - `Err(...)`: connection lost or handshake failed — caller should retry with backoff.
async fn run_single_connection(
conf_handle: &ConfHandle,
shutdown_signal: &mut ShutdownSignal,
) -> anyhow::Result<ConnectionOutcome> {
// Ensure rustls crypto provider is installed (ring).
let _ = rustls::crypto::ring::default_provider().install_default();

let agent_conf = conf_handle.get_conf();
let tunnel_conf = &agent_conf.tunnel;

let cert_path = &tunnel_conf.client_cert_path;
let key_path = &tunnel_conf.client_key_path;
let ca_path = &tunnel_conf.gateway_ca_cert_path;

/// Build the route advertisement payload from the current tunnel configuration.
fn route_advertisements(
tunnel_conf: &crate::config::TunnelConf,
) -> anyhow::Result<(Vec<Ipv4Network>, Vec<agent_tunnel_proto::DomainAdvertisement>)> {
let advertise_subnets: Vec<Ipv4Network> = tunnel_conf
.advertise_subnets
.iter()
Expand Down Expand Up @@ -260,23 +244,33 @@ async fn run_single_connection(
"Advertising subnets and domains"
);

Ok((advertise_subnets, advertise_domains))
}

/// Build the mTLS client config, resolve the gateway endpoint, and perform the
/// QUIC handshake, returning the live endpoint and connection.
async fn connect_to_gateway(
tunnel_conf: &crate::config::TunnelConf,
) -> anyhow::Result<(quinn::Endpoint, quinn::Connection)> {
// Ensure rustls crypto provider is installed (ring).
let _ = rustls::crypto::ring::default_provider().install_default();
// -- Build rustls ClientConfig --

let certs: Vec<rustls_pki_types::CertificateDer<'static>> = rustls_pemfile::certs(&mut std::io::BufReader::new(
std::fs::File::open(cert_path.as_str()).context("open client cert file")?,
std::fs::File::open(tunnel_conf.client_cert_path.as_str()).context("open client cert file")?,
))
.collect::<Result<Vec<_>, _>>()
.context("parse client certificates")?;

let key = rustls_pemfile::private_key(&mut std::io::BufReader::new(
std::fs::File::open(key_path.as_str()).context("open client key file")?,
std::fs::File::open(tunnel_conf.client_key_path.as_str()).context("open client key file")?,
))
.context("parse private key file")?
.context("no private key found in file")?;

let mut roots = rustls::RootCertStore::empty();
let ca_certs: Vec<rustls_pki_types::CertificateDer<'static>> = rustls_pemfile::certs(&mut std::io::BufReader::new(
std::fs::File::open(ca_path.as_str()).context("open CA cert file")?,
std::fs::File::open(tunnel_conf.gateway_ca_cert_path.as_str()).context("open CA cert file")?,
))
.collect::<Result<Vec<_>, _>>()
.context("parse CA certificates")?;
Expand Down Expand Up @@ -363,6 +357,47 @@ async fn run_single_connection(

info!("QUIC connection established");

Ok((endpoint, connection))
}

/// Confirm the QUIC/UDP path to the gateway is open by completing one mTLS+QUIC handshake, then
/// draining the connection, bounded by `timeout`.
pub async fn probe_connectivity(tunnel_conf: &crate::config::TunnelConf, timeout: Duration) -> anyhow::Result<()> {
if !tunnel_conf.enabled {
bail!("agent tunnel is not enabled");
}

let (endpoint, connection) = tokio::time::timeout(timeout, connect_to_gateway(tunnel_conf))
.await
.context("tunnel connectivity probe timed out")??;

// Flush the CONNECTION_CLOSE so the gateway unregisters this probe's connection promptly
// (keyed by agent_id) rather than after its idle timeout.
connection.close(0u32.into(), b"probe-complete");
let _ = tokio::time::timeout(Duration::from_secs(3), endpoint.wait_idle()).await;

Ok(())
}

/// Run a single QUIC tunnel connection lifetime: config → connect → event loop.
///
/// - `Ok(Shutdown)`: graceful shutdown, exit the task.
/// - `Ok(CertRenewed)`: certificate renewed; caller should reconnect immediately.
/// - `Err(_)`: connection lost or handshake failed — caller should retry with backoff.
async fn run_single_connection(
conf_handle: &ConfHandle,
shutdown_signal: &mut ShutdownSignal,
) -> anyhow::Result<ConnectionOutcome> {
let agent_conf = conf_handle.get_conf();
let tunnel_conf = &agent_conf.tunnel;

let cert_path = &tunnel_conf.client_cert_path;
let key_path = &tunnel_conf.client_key_path;
let ca_path = &tunnel_conf.gateway_ca_cert_path;

let (advertise_subnets, advertise_domains) = route_advertisements(tunnel_conf)?;
let (_endpoint, connection) = connect_to_gateway(tunnel_conf).await?;

// -- Open control stream --

let mut ctrl: ControlStream<_, _> = connection.open_bi().await.context("open control stream")?.into();
Expand Down Expand Up @@ -638,3 +673,67 @@ async fn run_session_proxy(advertise_subnets: Vec<Ipv4Network>, send: quinn::Sen
.await
.inspect_err(|e| error!(%e, "Session proxy failed"));
}

#[cfg(test)]
mod tests {
use camino::Utf8PathBuf;

use super::*;
use crate::config::TunnelConf;

fn tunnel_conf_template() -> TunnelConf {
TunnelConf {
enabled: true,
gateway_endpoint: String::new(),
client_cert_path: Utf8PathBuf::new(),
client_key_path: Utf8PathBuf::new(),
gateway_ca_cert_path: Utf8PathBuf::new(),
advertise_subnets: Vec::new(),
advertise_domains: Vec::new(),
auto_detect_domain: false,
heartbeat_interval_secs: 15,
route_advertise_interval_secs: 60,
server_spki_sha256: None,
}
}

#[tokio::test]
async fn probe_fails_fast_when_tunnel_disabled() {
let mut conf = tunnel_conf_template();
conf.enabled = false;

let error = probe_connectivity(&conf, Duration::from_secs(5))
.await
.expect_err("probe must fail when the tunnel is disabled");

assert!(format!("{error:#}").contains("not enabled"), "unexpected error: {error:#}");
}

#[tokio::test]
async fn probe_times_out_when_gateway_unreachable() {
// Throwaway PEMs so the pre-connect file reads succeed; nothing listens on the target
// port, so the handshake never completes and the probe must hit its own timeout.
let cert_key =
rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).expect("generate self-signed cert");
let dir = tempfile::tempdir().expect("temp dir");
let cert_path = dir.path().join("client.crt");
let key_path = dir.path().join("client.key");
let ca_path = dir.path().join("ca.crt");
std::fs::write(&cert_path, cert_key.cert.pem()).expect("write client cert");
std::fs::write(&key_path, cert_key.key_pair.serialize_pem()).expect("write client key");
std::fs::write(&ca_path, cert_key.cert.pem()).expect("write ca cert");

let mut conf = tunnel_conf_template();
// 127.0.0.1:1 is reserved and unbound; the QUIC handshake cannot complete.
conf.gateway_endpoint = "127.0.0.1:1".to_owned();
conf.client_cert_path = Utf8PathBuf::from_path_buf(cert_path).expect("utf8 cert path");
conf.client_key_path = Utf8PathBuf::from_path_buf(key_path).expect("utf8 key path");
conf.gateway_ca_cert_path = Utf8PathBuf::from_path_buf(ca_path).expect("utf8 ca path");

let started = std::time::Instant::now();
let result = probe_connectivity(&conf, Duration::from_secs(2)).await;

assert!(result.is_err(), "probe must fail when the gateway is unreachable");
assert!(started.elapsed() < Duration::from_secs(15), "probe must fail fast");
}
}
46 changes: 35 additions & 11 deletions package/AgentWindowsManaged/Actions/CustomActions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -498,17 +498,10 @@ ActionResult Fail(string msg)
// Observed.
}

// A hard Kill() bypasses BOTH recovery layers: the agent's transactional
// rollback never runs (we killed it, it didn't gracefully error), and no marker
// has been written yet (the marker write happens after the exit-code-0 check
// below). So if `up` wrote agent.json + cert files but then hung, those would be
// orphaned. Mirror the marker-failure path: best-effort read whatever cert paths
// landed in agent.json and clean them up + restore the pre-snapshot state. The
// snapshot locals (originalTunnel/originalGatewayCaB64/originalStateCaptured) were
// captured before `up` started, so they're valid here. ReadTunnelCertPaths can't
// throw, so this can't escape the timeout path.
List<string> timeoutCertPaths = ReadTunnelCertPaths(agentJsonPath);
CleanUpEnrollmentArtifacts(session, timeoutCertPaths, originalTunnel, originalGatewayCaB64, originalStateCaptured);
// A hard Kill() bypasses the agent's own rollback and no marker exists yet, so a hang
// after `up` persisted its enrollment would orphan it; undo it (guarded so an early
// hang that wrote nothing can't delete the prior install's certs).
RollBackFailedEnrollment(session, agentJsonPath, originalTunnel, originalGatewayCaB64, originalStateCaptured);

return Fail("Agent tunnel enrollment timed out. Verify your Devolutions Gateway is reachable from this machine.");
}
Expand All @@ -532,6 +525,11 @@ ActionResult Fail(string msg)
if (process.ExitCode != 0)
{
string detail = !string.IsNullOrWhiteSpace(stderr) ? Redact(stderr).Trim() : $"exit code {process.ExitCode}";

// `up` enrolls then probes, so a non-zero exit can leave a freshly-persisted
// enrollment on disk with no marker yet; undo it (guarded against early failures).
RollBackFailedEnrollment(session, agentJsonPath, originalTunnel, originalGatewayCaB64, originalStateCaptured);

return Fail($"Agent tunnel enrollment failed: {detail}");
}

Expand Down Expand Up @@ -914,6 +912,32 @@ public static ActionResult RollbackConfig(Session session)
/// cleanup when it cannot record the rollback marker. Best-effort: logs and continues past
/// individual failures so it never aborts a rollback.
/// </summary>
// Only undo when `up` actually persisted a NEW enrollment (client cert path changed from the
// pre-`up` snapshot); else an early failure would delete the prior install's still-referenced certs.
private static void RollBackFailedEnrollment(Session session, string agentJsonPath, JToken originalTunnel, string originalGatewayCaB64, bool originalStateCaptured)
{
if (!originalStateCaptured)
{
// Snapshot failed, so we can't tell new artifacts from the prior install's — skip
// rather than risk deleting cert/key we never observed (a harmless orphan beats deletion).
session.Log("skipping enrollment cleanup: pre-enrollment state was not captured");
return;
}

List<string> certPaths = ReadTunnelCertPaths(agentJsonPath);
string originalClientCert = originalTunnel?["ClientCertPath"]?.Value<string>();
string currentClientCert = certPaths.FirstOrDefault(p => p.EndsWith("-cert.pem", StringComparison.OrdinalIgnoreCase));

if (currentClientCert != null && !string.Equals(currentClientCert, originalClientCert, StringComparison.OrdinalIgnoreCase))
{
CleanUpEnrollmentArtifacts(session, certPaths, originalTunnel, originalGatewayCaB64, originalStateCaptured);
}
else
{
session.Log("skipping enrollment cleanup: `up` did not persist a new enrollment (client cert unchanged)");
}
}

private static void CleanUpEnrollmentArtifacts(Session session, List<string> newCertPaths, JToken originalTunnel, string originalGatewayCaB64, bool originalStateCaptured)
{
// The client cert/key are uniquely named per enrollment, so they're always deleted —
Expand Down
Loading