diff --git a/devolutions-agent/src/main.rs b/devolutions-agent/src/main.rs index 485c6db2e..de106a3a1 100644 --- a/devolutions-agent/src/main.rs +++ b/devolutions-agent/src/main.rs @@ -38,6 +38,7 @@ mod service; use std::env; use std::io::{self, BufRead}; use std::sync::mpsc; +use std::time::Duration; use anyhow::{Context as _, Result, bail}; use ceviche::Service; @@ -277,7 +278,13 @@ fn main() { &command.enrollment_token, command.advertise_subnets, ) - .await + .await?; + + // Enrollment only proves HTTPS/TCP; fail the install now if the QUIC/UDP tunnel + // path is blocked, while the operator is still here to fix the firewall. + let conf = ConfHandle::init().context("load agent configuration for connectivity probe")?; + devolutions_agent::tunnel::probe_connectivity(&conf.get_conf().tunnel, Duration::from_secs(15)) + .await }); if let Err(error) = result { diff --git a/devolutions-agent/src/tunnel.rs b/devolutions-agent/src/tunnel.rs index c48d865b5..d43c6a697 100644 --- a/devolutions-agent/src/tunnel.rs +++ b/devolutions-agent/src/tunnel.rs @@ -138,8 +138,7 @@ impl Task for TunnelTask { return Ok(()); } Ok(ConnectionOutcome::CertRenewed) => { - // Renewal is a successful "completion", not a failure — skip - // the backoff and reconnect immediately with the new cert. + // Renewal is a completion, not a failure. info!("Certificate renewed; reconnecting with new cert immediately"); backoff.reset(); continue; @@ -194,14 +193,11 @@ enum ConnectionOutcome { /// /// - `Ok(Shutdown)`: graceful shutdown, exit the task. /// - `Ok(CertRenewed)`: certificate renewed; caller should reconnect immediately. -/// - `Err(...)`: connection lost or handshake failed — caller should retry with backoff. +/// - `Err(_)`: connection lost or handshake failed — caller should retry with backoff. async fn run_single_connection( conf_handle: &ConfHandle, shutdown_signal: &mut ShutdownSignal, ) -> anyhow::Result { - // Ensure rustls crypto provider is installed (ring). - let _ = rustls::crypto::ring::default_provider().install_default(); - let agent_conf = conf_handle.get_conf(); let tunnel_conf = &agent_conf.tunnel; @@ -260,23 +256,112 @@ async fn run_single_connection( "Advertising subnets and domains" ); + let (_endpoint, connection) = connect_to_gateway(tunnel_conf).await?; + + // -- Open control stream -- + + let mut ctrl: ControlStream<_, _> = connection.open_bi().await.context("open control stream")?.into(); + + // Send initial RouteAdvertise. + let epoch = 1u64; + let msg = ControlMessage::route_advertise(epoch, advertise_subnets.clone(), advertise_domains.clone()); + + ctrl.send(&msg).await.context("send initial RouteAdvertise")?; + + info!(epoch, "Sent initial RouteAdvertise"); + + // -- Certificate renewal (post-connect, pre-traffic) -- + // + // Run once per reconnect rather than on a periodic timer: the QUIC session + // has a 120s idle timeout and 15s keep-alive, so any blip / VPN reconnect + // / host sleep / gateway restart drops the connection within minutes and + // sends us back through this path. With a 1-year cert and a 15-day + // threshold, the renewal window will be hit on the first reconnect after + // T-15d, which is more than often enough in any real deployment. + if let Some(outcome) = try_renew_certificate(&mut ctrl, &connection, cert_path, key_path, ca_path).await? { + return Ok(outcome); + } + + // Split: recv half goes to a reader task, send half stays for periodic messages. + let (mut ctrl_send, ctrl_recv) = ctrl.into_split(); + let mut task_handles = tokio::task::JoinSet::new(); + task_handles.spawn(run_control_reader(ctrl_recv)); + + // -- Main loop: accept incoming session streams + periodic tasks -- + + let route_interval = tunnel_conf.route_advertise_interval_secs; + let heartbeat_interval_secs = tunnel_conf.heartbeat_interval_secs; + let mut route_tick = tokio::time::interval(Duration::from_secs(route_interval)); + let mut heartbeat_tick = tokio::time::interval(Duration::from_secs(heartbeat_interval_secs)); + // Skip the first immediate tick (we already sent the initial RouteAdvertise). + route_tick.tick().await; + heartbeat_tick.tick().await; + + loop { + tokio::select! { + biased; + + _ = shutdown_signal.wait() => { + info!("Tunnel task shutting down"); + connection.close(0u32.into(), b"shutting down"); + break; + } + + _ = route_tick.tick() => { + let msg = ControlMessage::route_advertise(epoch, advertise_subnets.clone(), advertise_domains.clone()); + let _ = ctrl_send.send(&msg).await + .inspect(|_| trace!(epoch, "Sent RouteAdvertise (refresh)")) + .inspect_err(|e| error!(%e, "Failed to send RouteAdvertise")); + } + + _ = heartbeat_tick.tick() => { + // TODO: track actual active_stream_count instead of hardcoded 0. + let msg = ControlMessage::heartbeat(current_time_millis(), 0); + let _ = ctrl_send.send(&msg).await + .inspect(|_| trace!("Sent Heartbeat")) + .inspect_err(|e| error!(%e, "Failed to send Heartbeat")); + } + + result = connection.accept_bi() => { + let (send, recv) = result.context("accept incoming bidi stream")?; + let subnets = advertise_subnets.clone(); + task_handles.spawn(run_session_proxy(subnets, send, recv)); + } + + // Reap completed session tasks. + Some(_) = task_handles.join_next() => {} + } + } + + task_handles.shutdown().await; + + Ok(ConnectionOutcome::Shutdown) +} + +/// Build the mTLS client config, resolve the gateway endpoint, and perform the +/// QUIC handshake, returning the live endpoint and connection. +async fn connect_to_gateway( + tunnel_conf: &crate::config::TunnelConf, +) -> anyhow::Result<(quinn::Endpoint, quinn::Connection)> { + // Ensure rustls crypto provider is installed (ring). + let _ = rustls::crypto::ring::default_provider().install_default(); // -- Build rustls ClientConfig -- let certs: Vec> = rustls_pemfile::certs(&mut std::io::BufReader::new( - std::fs::File::open(cert_path.as_str()).context("open client cert file")?, + std::fs::File::open(tunnel_conf.client_cert_path.as_str()).context("open client cert file")?, )) .collect::, _>>() .context("parse client certificates")?; let key = rustls_pemfile::private_key(&mut std::io::BufReader::new( - std::fs::File::open(key_path.as_str()).context("open client key file")?, + std::fs::File::open(tunnel_conf.client_key_path.as_str()).context("open client key file")?, )) .context("parse private key file")? .context("no private key found in file")?; let mut roots = rustls::RootCertStore::empty(); let ca_certs: Vec> = rustls_pemfile::certs(&mut std::io::BufReader::new( - std::fs::File::open(ca_path.as_str()).context("open CA cert file")?, + std::fs::File::open(tunnel_conf.gateway_ca_cert_path.as_str()).context("open CA cert file")?, )) .collect::, _>>() .context("parse CA certificates")?; @@ -363,84 +448,26 @@ async fn run_single_connection( info!("QUIC connection established"); - // -- Open control stream -- - - let mut ctrl: ControlStream<_, _> = connection.open_bi().await.context("open control stream")?.into(); - - // Send initial RouteAdvertise. - let epoch = 1u64; - let msg = ControlMessage::route_advertise(epoch, advertise_subnets.clone(), advertise_domains.clone()); - - ctrl.send(&msg).await.context("send initial RouteAdvertise")?; - - info!(epoch, "Sent initial RouteAdvertise"); + Ok((endpoint, connection)) +} - // -- Certificate renewal (post-connect, pre-traffic) -- - // - // Run once per reconnect rather than on a periodic timer: the QUIC session - // has a 120s idle timeout and 15s keep-alive, so any blip / VPN reconnect - // / host sleep / gateway restart drops the connection within minutes and - // sends us back through this path. With a 1-year cert and a 15-day - // threshold, the renewal window will be hit on the first reconnect after - // T-15d, which is more than often enough in any real deployment. - if let Some(outcome) = try_renew_certificate(&mut ctrl, &connection, cert_path, key_path, ca_path).await? { - return Ok(outcome); +/// Confirm the QUIC/UDP path to the gateway is open by completing one mTLS+QUIC handshake within +/// `timeout`, then draining the connection (a best-effort teardown that adds up to ~3s). +pub async fn probe_connectivity(tunnel_conf: &crate::config::TunnelConf, timeout: Duration) -> anyhow::Result<()> { + if !tunnel_conf.enabled { + bail!("agent tunnel is not enabled"); } - // Split: recv half goes to a reader task, send half stays for periodic messages. - let (mut ctrl_send, ctrl_recv) = ctrl.into_split(); - let mut task_handles = tokio::task::JoinSet::new(); - task_handles.spawn(run_control_reader(ctrl_recv)); - - // -- Main loop: accept incoming session streams + periodic tasks -- - - let route_interval = tunnel_conf.route_advertise_interval_secs; - let heartbeat_interval_secs = tunnel_conf.heartbeat_interval_secs; - let mut route_tick = tokio::time::interval(Duration::from_secs(route_interval)); - let mut heartbeat_tick = tokio::time::interval(Duration::from_secs(heartbeat_interval_secs)); - // Skip the first immediate tick (we already sent the initial RouteAdvertise). - route_tick.tick().await; - heartbeat_tick.tick().await; - - loop { - tokio::select! { - biased; - - _ = shutdown_signal.wait() => { - info!("Tunnel task shutting down"); - connection.close(0u32.into(), b"shutting down"); - break; - } - - _ = route_tick.tick() => { - let msg = ControlMessage::route_advertise(epoch, advertise_subnets.clone(), advertise_domains.clone()); - let _ = ctrl_send.send(&msg).await - .inspect(|_| trace!(epoch, "Sent RouteAdvertise (refresh)")) - .inspect_err(|e| error!(%e, "Failed to send RouteAdvertise")); - } - - _ = heartbeat_tick.tick() => { - // TODO: track actual active_stream_count instead of hardcoded 0. - let msg = ControlMessage::heartbeat(current_time_millis(), 0); - let _ = ctrl_send.send(&msg).await - .inspect(|_| trace!("Sent Heartbeat")) - .inspect_err(|e| error!(%e, "Failed to send Heartbeat")); - } - - result = connection.accept_bi() => { - let (send, recv) = result.context("accept incoming bidi stream")?; - let subnets = advertise_subnets.clone(); - task_handles.spawn(run_session_proxy(subnets, send, recv)); - } - - // Reap completed session tasks. - Some(_) = task_handles.join_next() => {} - } - } + let (endpoint, connection) = tokio::time::timeout(timeout, connect_to_gateway(tunnel_conf)) + .await + .context("tunnel connectivity probe timed out")??; - task_handles.shutdown().await; + // Flush the CONNECTION_CLOSE so the gateway unregisters this probe's connection promptly + // (keyed by agent_id) rather than after its idle timeout. + connection.close(0u32.into(), b"probe-complete"); + let _ = tokio::time::timeout(Duration::from_secs(3), endpoint.wait_idle()).await; - Ok(ConnectionOutcome::Shutdown) + Ok(()) } // --------------------------------------------------------------------------- @@ -638,3 +665,70 @@ async fn run_session_proxy(advertise_subnets: Vec, send: quinn::Sen .await .inspect_err(|e| error!(%e, "Session proxy failed")); } + +#[cfg(test)] +mod tests { + use camino::Utf8PathBuf; + + use super::*; + use crate::config::TunnelConf; + + fn tunnel_conf_template() -> TunnelConf { + TunnelConf { + enabled: true, + gateway_endpoint: String::new(), + client_cert_path: Utf8PathBuf::new(), + client_key_path: Utf8PathBuf::new(), + gateway_ca_cert_path: Utf8PathBuf::new(), + advertise_subnets: Vec::new(), + advertise_domains: Vec::new(), + auto_detect_domain: false, + heartbeat_interval_secs: 15, + route_advertise_interval_secs: 60, + server_spki_sha256: None, + } + } + + #[tokio::test] + async fn probe_fails_fast_when_tunnel_disabled() { + let mut conf = tunnel_conf_template(); + conf.enabled = false; + + let error = probe_connectivity(&conf, Duration::from_secs(5)) + .await + .expect_err("probe must fail when the tunnel is disabled"); + + assert!( + format!("{error:#}").contains("not enabled"), + "unexpected error: {error:#}" + ); + } + + #[tokio::test] + async fn probe_times_out_when_gateway_unreachable() { + // Throwaway PEMs so the pre-connect file reads succeed; nothing listens on the target + // port, so the probe fails quickly — via its own timeout or an immediate connect error. + let cert_key = + rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).expect("generate self-signed cert"); + let dir = tempfile::tempdir().expect("temp dir"); + let cert_path = dir.path().join("client.crt"); + let key_path = dir.path().join("client.key"); + let ca_path = dir.path().join("ca.crt"); + std::fs::write(&cert_path, cert_key.cert.pem()).expect("write client cert"); + std::fs::write(&key_path, cert_key.key_pair.serialize_pem()).expect("write client key"); + std::fs::write(&ca_path, cert_key.cert.pem()).expect("write ca cert"); + + let mut conf = tunnel_conf_template(); + // 127.0.0.1:1 is reserved and unbound; the QUIC handshake cannot complete. + conf.gateway_endpoint = "127.0.0.1:1".to_owned(); + conf.client_cert_path = Utf8PathBuf::from_path_buf(cert_path).expect("utf8 cert path"); + conf.client_key_path = Utf8PathBuf::from_path_buf(key_path).expect("utf8 key path"); + conf.gateway_ca_cert_path = Utf8PathBuf::from_path_buf(ca_path).expect("utf8 ca path"); + + let started = std::time::Instant::now(); + let result = probe_connectivity(&conf, Duration::from_secs(2)).await; + + assert!(result.is_err(), "probe must fail when the gateway is unreachable"); + assert!(started.elapsed() < Duration::from_secs(15), "probe must fail fast"); + } +} diff --git a/package/AgentWindowsManaged/Actions/CustomActions.cs b/package/AgentWindowsManaged/Actions/CustomActions.cs index 6ebe4b33f..3fc034fdd 100644 --- a/package/AgentWindowsManaged/Actions/CustomActions.cs +++ b/package/AgentWindowsManaged/Actions/CustomActions.cs @@ -498,17 +498,10 @@ ActionResult Fail(string msg) // Observed. } - // A hard Kill() bypasses BOTH recovery layers: the agent's transactional - // rollback never runs (we killed it, it didn't gracefully error), and no marker - // has been written yet (the marker write happens after the exit-code-0 check - // below). So if `up` wrote agent.json + cert files but then hung, those would be - // orphaned. Mirror the marker-failure path: best-effort read whatever cert paths - // landed in agent.json and clean them up + restore the pre-snapshot state. The - // snapshot locals (originalTunnel/originalGatewayCaB64/originalStateCaptured) were - // captured before `up` started, so they're valid here. ReadTunnelCertPaths can't - // throw, so this can't escape the timeout path. - List timeoutCertPaths = ReadTunnelCertPaths(agentJsonPath); - CleanUpEnrollmentArtifacts(session, timeoutCertPaths, originalTunnel, originalGatewayCaB64, originalStateCaptured); + // A hard Kill() bypasses the agent's own rollback and no marker exists yet, so a hang + // after `up` persisted its enrollment would orphan it; undo it (guarded so an early + // hang that wrote nothing can't delete the prior install's certs). + RollbackFailedEnrollment(session, agentJsonPath, originalTunnel, originalGatewayCaB64, originalStateCaptured); return Fail("Agent tunnel enrollment timed out. Verify your Devolutions Gateway is reachable from this machine."); } @@ -532,6 +525,11 @@ ActionResult Fail(string msg) if (process.ExitCode != 0) { string detail = !string.IsNullOrWhiteSpace(stderr) ? Redact(stderr).Trim() : $"exit code {process.ExitCode}"; + + // `up` enrolls then probes, so a non-zero exit can leave a freshly-persisted + // enrollment on disk with no marker yet; undo it (guarded against early failures). + RollbackFailedEnrollment(session, agentJsonPath, originalTunnel, originalGatewayCaB64, originalStateCaptured); + return Fail($"Agent tunnel enrollment failed: {detail}"); } @@ -914,6 +912,32 @@ public static ActionResult RollbackConfig(Session session) /// cleanup when it cannot record the rollback marker. Best-effort: logs and continues past /// individual failures so it never aborts a rollback. /// + // Only undo when `up` actually persisted a NEW enrollment (client cert path changed from the + // pre-`up` snapshot); else an early failure would delete the prior install's still-referenced certs. + private static void RollbackFailedEnrollment(Session session, string agentJsonPath, JToken originalTunnel, string originalGatewayCaB64, bool originalStateCaptured) + { + if (!originalStateCaptured) + { + // Snapshot failed, so we can't tell new artifacts from the prior install's — skip + // rather than risk deleting cert/key we never observed (a harmless orphan beats deletion). + session.Log("skipping enrollment cleanup: pre-enrollment state was not captured"); + return; + } + + List certPaths = ReadTunnelCertPaths(agentJsonPath); + string originalClientCert = originalTunnel?["ClientCertPath"]?.Value(); + string currentClientCert = certPaths.FirstOrDefault(p => p.EndsWith("-cert.pem", StringComparison.OrdinalIgnoreCase)); + + if (currentClientCert != null && !string.Equals(currentClientCert, originalClientCert, StringComparison.OrdinalIgnoreCase)) + { + CleanUpEnrollmentArtifacts(session, certPaths, originalTunnel, originalGatewayCaB64, originalStateCaptured); + } + else + { + session.Log("skipping enrollment cleanup: `up` did not persist a new enrollment (client cert unchanged)"); + } + } + private static void CleanUpEnrollmentArtifacts(Session session, List newCertPaths, JToken originalTunnel, string originalGatewayCaB64, bool originalStateCaptured) { // The client cert/key are uniquely named per enrollment, so they're always deleted —