From d499f08f78028fd02553aacb6d14e416da7a3973 Mon Sep 17 00:00:00 2001 From: Srikrishna Veturi Date: Mon, 13 Apr 2026 13:52:06 -0500 Subject: [PATCH 1/8] Report eBPF service statuses instead of checking installation (#334) * Report eBPF service statuses instead of checking installation --------- Co-authored-by: Srikrishna Veturi --- proxy_agent_extension/src/service_main.rs | 282 ++++++++++++++---- proxy_agent_shared/src/service.rs | 82 +++++ .../src/service/windows_service.rs | 103 ++++++- 3 files changed, 405 insertions(+), 62 deletions(-) diff --git a/proxy_agent_extension/src/service_main.rs b/proxy_agent_extension/src/service_main.rs index b0140d7b..a40ba79c 100644 --- a/proxy_agent_extension/src/service_main.rs +++ b/proxy_agent_extension/src/service_main.rs @@ -341,70 +341,85 @@ fn write_state_event( } #[cfg(windows)] -fn report_ebpf_status(status_obj: &mut StatusObj) { - match service::check_service_installed(constants::EBPF_CORE) { - (true, message) => { - logger::write(message.to_string()); - match service::check_service_installed(constants::EBPF_EXT) { - (true, message) => { - logger::write(message.to_string()); - status_obj.substatus = { - let mut substatus = status_obj.substatus.clone(); - substatus.push(SubStatus { - name: constants::EBPF_SUBSTATUS_NAME.to_string(), - status: constants::SUCCESS_STATUS.to_string(), - code: constants::STATUS_CODE_OK, - formattedMessage: FormattedMessage { - lang: constants::LANG_EN_US.to_string(), - message: "Ebpf Drivers successfully queried.".to_string(), - }, - }); - substatus - }; - } - (false, message) => { - logger::write(message.to_string()); - status_obj.substatus = { - let mut substatus = status_obj.substatus.clone(); - substatus.push(SubStatus { - name: constants::EBPF_SUBSTATUS_NAME.to_string(), - status: constants::ERROR_STATUS.to_string(), - code: constants::STATUS_CODE_NOT_OK, - formattedMessage: FormattedMessage { - lang: constants::LANG_EN_US.to_string(), - message: format!( - "Ebpf Driver: {} unsuccessfully queried.", - constants::EBPF_EXT - ), - }, - }); - substatus - }; - } +fn build_ebpf_substatus( + core: &proxy_agent_shared::service::ServiceStatusInfo, + ext: &proxy_agent_shared::service::ServiceStatusInfo, +) -> SubStatus { + use proxy_agent_shared::service::ServiceState; + + let (status, code, message) = match (&core.state, &ext.state) { + (Some(core_state), Some(ext_state)) => { + let both_running = + *core_state == ServiceState::Running && *ext_state == ServiceState::Running; + if both_running { + ( + constants::SUCCESS_STATUS.to_string(), + constants::STATUS_CODE_OK, + format!( + "EbpfCore: {}, NetEbpfExt: {}", + core.summary(), + ext.summary() + ), + ) + } else { + ( + constants::ERROR_STATUS.to_string(), + constants::STATUS_CODE_NOT_OK, + format!( + "EbpfCore: {}, NetEbpfExt: {}", + core.summary(), + ext.summary() + ), + ) } } - (false, message) => { - logger::write(message.to_string()); - status_obj.substatus = { - let mut substatus = status_obj.substatus.clone(); - substatus.push(SubStatus { - name: constants::EBPF_SUBSTATUS_NAME.to_string(), - status: constants::ERROR_STATUS.to_string(), - code: constants::STATUS_CODE_NOT_OK, - formattedMessage: FormattedMessage { - lang: constants::LANG_EN_US.to_string(), - message: format!( - "Ebpf Driver: {} unsuccessfully queried.", - constants::EBPF_CORE - ), - }, - }); - substatus - }; - } + (None, None) => ( + constants::ERROR_STATUS.to_string(), + constants::STATUS_CODE_NOT_OK, + "EbpfCore: unsuccessfully queried, NetEbpfExt: unsuccessfully queried.".to_string(), + ), + (None, _) => ( + constants::ERROR_STATUS.to_string(), + constants::STATUS_CODE_NOT_OK, + format!( + "EbpfCore: unsuccessfully queried, NetEbpfExt: {}", + ext.summary() + ), + ), + (_, None) => ( + constants::ERROR_STATUS.to_string(), + constants::STATUS_CODE_NOT_OK, + format!( + "EbpfCore: {}, NetEbpfExt: unsuccessfully queried.", + core.summary() + ), + ), + }; + + SubStatus { + name: constants::EBPF_SUBSTATUS_NAME.to_string(), + status, + code, + formattedMessage: FormattedMessage { + lang: constants::LANG_EN_US.to_string(), + message, + }, } } +#[cfg(windows)] +fn report_ebpf_status(status_obj: &mut StatusObj) { + let core_status = service::check_service_status(constants::EBPF_CORE); + logger::write(format!("check_service_status: {}", core_status.message())); + + let ext_status = service::check_service_status(constants::EBPF_EXT); + logger::write(format!("check_service_status: {}", ext_status.message())); + + let mut substatus = status_obj.substatus.clone(); + substatus.push(build_ebpf_substatus(&core_status, &ext_status)); + status_obj.substatus = substatus; +} + fn backup_proxy_agent(setup_tool: &String) { match Command::new(setup_tool).arg("backup").output() { Ok(output) => { @@ -1184,6 +1199,155 @@ mod tests { status.substatus[3].name, constants::EBPF_SUBSTATUS_NAME.to_string() ); + + // Verify the eBPF substatus message includes service status info + let ebpf_substatus = &status.substatus[3]; + let ebpf_message = &ebpf_substatus.formattedMessage.message; + if ebpf_message.contains("unsuccessfully queried") { + // At least one service not installed — status should be Error + assert_eq!( + ebpf_substatus.status, + constants::ERROR_STATUS, + "Expected Error status when a service is not installed" + ); + } else { + // Both services found — message should contain status details for each driver + assert!( + ebpf_message.contains("EbpfCore:"), + "Expected message to contain 'EbpfCore:', got: {ebpf_message}" + ); + assert!( + ebpf_message.contains("NetEbpfExt:"), + "Expected message to contain 'NetEbpfExt:', got: {ebpf_message}" + ); + // Status depends on whether both services are running + if ebpf_message.contains("Running") && !ebpf_message.contains("Stopped") { + assert_eq!( + ebpf_substatus.status, + constants::SUCCESS_STATUS, + "Expected Success when both services are running" + ); + assert_eq!(ebpf_substatus.code, constants::STATUS_CODE_OK); + } else { + assert_eq!( + ebpf_substatus.status, + constants::ERROR_STATUS, + "Expected Error when at least one service is not running" + ); + assert_eq!(ebpf_substatus.code, constants::STATUS_CODE_NOT_OK); + } + } + } + + #[test] + #[cfg(windows)] + fn test_build_ebpf_substatus() { + use proxy_agent_shared::service::{ServiceState, ServiceStatusInfo}; + + fn make_info(name: &str, state: Option) -> ServiceStatusInfo { + let start_type = if state.is_some() { + "AutoStart".to_string() + } else { + "NotInstalled".to_string() + }; + ServiceStatusInfo { + service_name: name.to_string(), + state, + start_type, + } + } + + // 1. Both not installed + let sub = super::build_ebpf_substatus( + &make_info(constants::EBPF_CORE, None), + &make_info(constants::EBPF_EXT, None), + ); + assert_eq!(sub.status, constants::ERROR_STATUS, "Both not installed"); + assert_eq!(sub.code, constants::STATUS_CODE_NOT_OK); + let msg = &sub.formattedMessage.message; + assert!( + msg.contains(constants::EBPF_CORE) && msg.contains(constants::EBPF_EXT), + "Expected both driver names in message, got: {msg}" + ); + + // 2. Core not installed, Ext running + let sub = super::build_ebpf_substatus( + &make_info(constants::EBPF_CORE, None), + &make_info(constants::EBPF_EXT, Some(ServiceState::Running)), + ); + assert_eq!(sub.status, constants::ERROR_STATUS, "Core not installed"); + assert_eq!(sub.code, constants::STATUS_CODE_NOT_OK); + let msg = &sub.formattedMessage.message; + assert!( + msg.contains(constants::EBPF_CORE), + "Expected EbpfCore in message, got: {msg}" + ); + assert!( + msg.contains("Running"), + "Expected Ext summary (Running) in message, got: {msg}" + ); + + // 3. Core running, Ext not installed + let sub = super::build_ebpf_substatus( + &make_info(constants::EBPF_CORE, Some(ServiceState::Running)), + &make_info(constants::EBPF_EXT, None), + ); + assert_eq!(sub.status, constants::ERROR_STATUS, "Ext not installed"); + assert_eq!(sub.code, constants::STATUS_CODE_NOT_OK); + let msg = &sub.formattedMessage.message; + assert!( + msg.contains("Running"), + "Expected Core summary (Running) in message, got: {msg}" + ); + assert!( + msg.contains(constants::EBPF_EXT), + "Expected NetEbpfExt in message, got: {msg}" + ); + + // 4. Both running → success + let sub = super::build_ebpf_substatus( + &make_info(constants::EBPF_CORE, Some(ServiceState::Running)), + &make_info(constants::EBPF_EXT, Some(ServiceState::Running)), + ); + assert_eq!(sub.status, constants::SUCCESS_STATUS, "Both running"); + assert_eq!(sub.code, constants::STATUS_CODE_OK); + let msg = &sub.formattedMessage.message; + assert!( + msg.contains("EbpfCore:") && msg.contains("NetEbpfExt:"), + "Expected both driver labels in message, got: {msg}" + ); + + // 5. Core stopped, Ext running → error + let sub = super::build_ebpf_substatus( + &make_info(constants::EBPF_CORE, Some(ServiceState::Stopped)), + &make_info(constants::EBPF_EXT, Some(ServiceState::Running)), + ); + assert_eq!( + sub.status, + constants::ERROR_STATUS, + "Core stopped, Ext running" + ); + assert_eq!(sub.code, constants::STATUS_CODE_NOT_OK); + + // 6. Core running, Ext stopped → error + let sub = super::build_ebpf_substatus( + &make_info(constants::EBPF_CORE, Some(ServiceState::Running)), + &make_info(constants::EBPF_EXT, Some(ServiceState::Stopped)), + ); + assert_eq!( + sub.status, + constants::ERROR_STATUS, + "Core running, Ext stopped" + ); + assert_eq!(sub.code, constants::STATUS_CODE_NOT_OK); + + // 7. Both stopped → error + let sub = super::build_ebpf_substatus( + &make_info(constants::EBPF_CORE, Some(ServiceState::Stopped)), + &make_info(constants::EBPF_EXT, Some(ServiceState::Stopped)), + ); + assert_eq!(sub.status, constants::ERROR_STATUS, "Both stopped"); + assert_eq!(sub.code, constants::STATUS_CODE_NOT_OK); } #[tokio::test] diff --git a/proxy_agent_shared/src/service.rs b/proxy_agent_shared/src/service.rs index 53c2cd6f..2bda4385 100644 --- a/proxy_agent_shared/src/service.rs +++ b/proxy_agent_shared/src/service.rs @@ -145,8 +145,49 @@ pub fn check_service_installed(service_name: &str) -> (bool, String) { } } +/// Checks whether a Windows service is installed and queries its runtime state and start type. +/// Returns a `ServiceStatusInfo` whose `state` is `Some(ServiceState)` when the service exists, +/// or `None` when the service is not installed. +#[cfg(windows)] +pub fn check_service_status(service_name: &str) -> windows_service::ServiceStatusInfo { + let (state, start_type) = match windows_service::query_service_status(service_name) { + Ok(status) => { + let start_type = match windows_service::query_service_config(service_name) { + Ok(config) => format!("{:?}", config.start_type), + Err(e) => { + log::warn!( + "Failed to query config for service '{}': {}", + service_name, + e + ); + "Unknown".to_string() + } + }; + (Some(status.current_state), start_type) + } + Err(e) => { + log::debug!( + "Failed to query status for service '{}': {}. Treating as not installed.", + service_name, + e + ); + (None, "NotInstalled".to_string()) + } + }; + + windows_service::ServiceStatusInfo { + service_name: service_name.to_string(), + state, + start_type, + } +} + #[cfg(windows)] pub use windows_service::set_default_failure_actions; +#[cfg(windows)] +pub use windows_service::ServiceState; +#[cfg(windows)] +pub use windows_service::ServiceStatusInfo; #[cfg(test)] mod tests { @@ -191,4 +232,45 @@ mod tests { _ = super::stop_and_delete_service(service_name).await.unwrap(); } } + + #[tokio::test] + async fn test_check_service_status() { + #[cfg(windows)] + { + let service_name = "test_check_service_status"; + // try delete the service if it exists + _ = super::stop_and_delete_service(service_name).await; + + // Verify non-existent service returns not installed + let status = super::check_service_status(service_name); + assert_eq!(status.state, None, "Expected None for non-existent service"); + assert!(status.message().contains("query failed")); + assert_eq!(status.summary(), "NotInstalled"); + + // Install a test service and verify status is reported + let exe_path = std::env::current_exe().unwrap(); + let result = super::install_service(service_name, service_name, vec![], exe_path); + assert!(result.is_ok()); + + let status = super::check_service_status(service_name); + assert!(status.state.is_some(), "Expected service to be installed"); + assert!(status.message().contains("status:")); + // Service should be stopped (test exe can't actually run as a service) + assert_eq!( + status.state, + Some(super::ServiceState::Stopped), + "Expected Some(ServiceState::Stopped), got: {:?}", + status.state + ); + // Summary should also contain start type info + let summary = status.summary(); + assert!( + summary.contains("AutoStart"), + "Expected summary to contain 'AutoStart', got: {summary}" + ); + + // clean up + _ = super::stop_and_delete_service(service_name).await.unwrap(); + } + } } diff --git a/proxy_agent_shared/src/service/windows_service.rs b/proxy_agent_shared/src/service/windows_service.rs index c2a86de2..b8f552c5 100644 --- a/proxy_agent_shared/src/service/windows_service.rs +++ b/proxy_agent_shared/src/service/windows_service.rs @@ -7,14 +7,43 @@ use std::ffi::OsString; use std::path::PathBuf; use std::str; use std::time::Duration; +pub use windows_service::service::ServiceState; use windows_service::service::{ ServiceAccess, ServiceAction, ServiceActionType, ServiceConfig, ServiceErrorControl, - ServiceFailureResetPeriod, ServiceInfo, ServiceStartType, ServiceState, ServiceStatus, - ServiceType, + ServiceFailureResetPeriod, ServiceInfo, ServiceStartType, ServiceStatus, ServiceType, }; use windows_service::service::{ServiceDependency, ServiceFailureActions}; use windows_service::service_manager::{ServiceManager, ServiceManagerAccess}; +/// Holds the runtime status of a Windows service. +#[derive(Debug)] +pub struct ServiceStatusInfo { + pub service_name: String, + pub state: Option, + pub start_type: String, +} + +impl ServiceStatusInfo { + /// Human-readable summary, e.g. "Running, AutoStart" or "NotInstalled". + pub fn summary(&self) -> String { + match self.state { + Some(ref state) => format!("{:?}, {}", state, self.start_type), + None => "NotInstalled".to_string(), + } + } + + /// Log-friendly message including the service name and summary. + pub fn message(&self) -> String { + match self.state { + Some(_) => format!("service: {} status: {}", self.service_name, self.summary()), + None => format!( + "service: {} status query failed, service may not be installed", + self.service_name + ), + } + } +} + pub async fn start_service_with_retry( service_name: &str, retry_count: u32, @@ -167,7 +196,7 @@ pub fn install_or_update_service( } } -fn query_service_status(service_name: &str) -> Result { +pub fn query_service_status(service_name: &str) -> Result { let service_manager = ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT) .map_err(|e| Error::WindowsService(e, std::io::Error::last_os_error()))?; @@ -305,6 +334,7 @@ pub fn set_default_failure_actions(service_name: &str) -> Result<()> { #[cfg(test)] mod tests { use std::{path::PathBuf, process::Command}; + use windows_service::service::ServiceState; #[tokio::test] async fn test_install_service() { @@ -413,4 +443,71 @@ mod tests { //Clean up - delete service super::stop_and_delete_service(service_name).await.unwrap(); } + + #[test] + fn test_service_status_info_summary() { + // Not installed + let info = super::ServiceStatusInfo { + service_name: "TestSvc".to_string(), + state: None, + start_type: "NotInstalled".to_string(), + }; + assert_eq!(info.summary(), "NotInstalled"); + + // Running, AutoStart + let info = super::ServiceStatusInfo { + service_name: "TestSvc".to_string(), + state: Some(ServiceState::Running), + start_type: "AutoStart".to_string(), + }; + assert_eq!(info.summary(), "Running, AutoStart"); + + // Stopped, Disabled + let info = super::ServiceStatusInfo { + service_name: "TestSvc".to_string(), + state: Some(ServiceState::Stopped), + start_type: "Disabled".to_string(), + }; + assert_eq!(info.summary(), "Stopped, Disabled"); + } + + #[test] + fn test_service_status_info_message() { + // Not installed — message must mention the service name and "query failed" + let info = super::ServiceStatusInfo { + service_name: "TestSvc".to_string(), + state: None, + start_type: "NotInstalled".to_string(), + }; + let msg = info.message(); + assert!( + msg.contains("TestSvc"), + "Expected service name in message, got: {msg}" + ); + assert!( + msg.contains("query failed"), + "Expected 'query failed' in message, got: {msg}" + ); + + // Installed and running — message must contain "status:" and the summary + let info = super::ServiceStatusInfo { + service_name: "TestSvc".to_string(), + state: Some(ServiceState::Running), + start_type: "AutoStart".to_string(), + }; + let msg = info.message(); + assert!( + msg.contains("TestSvc"), + "Expected service name in message, got: {msg}" + ); + assert!( + msg.contains("status:"), + "Expected 'status:' in message, got: {msg}" + ); + assert!( + msg.contains(&info.summary()), + "Expected summary '{}' in message, got: {msg}", + info.summary() + ); + } } From d81bb0ad66e9e7973d834c664783da29e5482146 Mon Sep 17 00:00:00 2001 From: Zhidong Peng Date: Tue, 21 Apr 2026 11:20:49 -0700 Subject: [PATCH 2/8] Fix clippy::unnecessary_sort_by (#336) --- proxy_agent_extension/src/service_main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/proxy_agent_extension/src/service_main.rs b/proxy_agent_extension/src/service_main.rs index a40ba79c..2ca9fc99 100644 --- a/proxy_agent_extension/src/service_main.rs +++ b/proxy_agent_extension/src/service_main.rs @@ -814,7 +814,7 @@ fn get_top_proxy_connection_summary( mut summary: Vec, max_count: usize, ) -> Vec { - summary.sort_by(|a, b| a.count.cmp(&b.count)); + summary.sort_by_key(|a| a.count); let len = summary.len(); if len > max_count { summary = summary.split_off(len - max_count); From 329c3b05f321eee1a9d498f4b06e320ea6ea03b1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 23 Apr 2026 12:50:28 -0700 Subject: [PATCH 3/8] Bump rand from 0.8.5 to 0.8.6 (#339) Bumps [rand](https://github.com/rust-random/rand) from 0.8.5 to 0.8.6. - [Release notes](https://github.com/rust-random/rand/releases) - [Changelog](https://github.com/rust-random/rand/blob/0.8.6/CHANGELOG.md) - [Commits](https://github.com/rust-random/rand/compare/0.8.5...0.8.6) --- updated-dependencies: - dependency-name: rand dependency-version: 0.8.6 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1df43fca..cec20c17 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -980,9 +980,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a" dependencies = [ "libc", "rand_chacha", From 53a7855f6cb043be6c571c86c28f06b88eba2aab Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 23 Apr 2026 20:06:24 +0000 Subject: [PATCH 4/8] Bump openssl from 0.10.73 to 0.10.78 (#338) Bumps [openssl](https://github.com/rust-openssl/rust-openssl) from 0.10.73 to 0.10.78. - [Release notes](https://github.com/rust-openssl/rust-openssl/releases) - [Commits](https://github.com/rust-openssl/rust-openssl/compare/openssl-v0.10.73...openssl-v0.10.78) --- updated-dependencies: - dependency-name: openssl dependency-version: 0.10.78 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Zhidong Peng --- Cargo.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cec20c17..ce02cf76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -824,9 +824,9 @@ checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "openssl" -version = "0.10.73" +version = "0.10.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8505734d46c8ab1e19a1dce3aef597ad87dcb4c37e7188231769bd6bd51cebf8" +checksum = "f38c4372413cdaaf3cc79dd92d29d7d9f5ab09b51b10dded508fb90bb70b9222" dependencies = [ "bitflags", "cfg-if", @@ -859,9 +859,9 @@ dependencies = [ [[package]] name = "openssl-sys" -version = "0.9.109" +version = "0.9.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90096e2e47630d78b7d1c20952dc621f957103f8bc2c8359ec81290d75238571" +checksum = "13ce1245cd07fcc4cfdb438f7507b0c7e4f3849a69fd84d52374c66d83741bb6" dependencies = [ "cc", "libc", From 2c82ee15e3d498b36ea789c369684c890ac44ce1 Mon Sep 17 00:00:00 2001 From: Zhidong Peng Date: Fri, 24 Apr 2026 12:55:30 -0700 Subject: [PATCH 5/8] GPA service to use host-date-time for signed http requests (#335) * GPA service to use host-date-time for signed http requests * add logging * fix typo * Bump rand from 0.8.5 to 0.8.6 (#339) Bumps [rand](https://github.com/rust-random/rand) from 0.8.5 to 0.8.6. - [Release notes](https://github.com/rust-random/rand/releases) - [Changelog](https://github.com/rust-random/rand/blob/0.8.6/CHANGELOG.md) - [Commits](https://github.com/rust-random/rand/compare/0.8.5...0.8.6) --- updated-dependencies: - dependency-name: rand dependency-version: 0.8.6 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Bump openssl from 0.10.73 to 0.10.78 (#338) Bumps [openssl](https://github.com/rust-openssl/rust-openssl) from 0.10.73 to 0.10.78. - [Release notes](https://github.com/rust-openssl/rust-openssl/releases) - [Commits](https://github.com/rust-openssl/rust-openssl/compare/openssl-v0.10.73...openssl-v0.10.78) --- updated-dependencies: - dependency-name: openssl dependency-version: 0.10.78 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Zhidong Peng * resolve comments Co-authored-by: Copilot * fix spelling --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Zhidong Peng Co-authored-by: Copilot --- .github/actions/spelling/expect.txt | 2 + proxy_agent/src/key_keeper/key.rs | 34 +++++++- proxy_agent_shared/src/hyper_client.rs | 108 ++++++++++++++++++++++++- proxy_agent_shared/src/misc_helpers.rs | 101 ++++++++++++++++++++++- 4 files changed, 240 insertions(+), 5 deletions(-) diff --git a/.github/actions/spelling/expect.txt b/.github/actions/spelling/expect.txt index 8a5c2865..6e010cc7 100644 --- a/.github/actions/spelling/expect.txt +++ b/.github/actions/spelling/expect.txt @@ -206,6 +206,7 @@ openprocess oneshot opencode opensource +parseable PERCPU pgpkey pgrep @@ -301,6 +302,7 @@ testurl tgid THH thiserror +Thu timedout timeup tlsv diff --git a/proxy_agent/src/key_keeper/key.rs b/proxy_agent/src/key_keeper/key.rs index e8c2b61d..cc850e55 100644 --- a/proxy_agent/src/key_keeper/key.rs +++ b/proxy_agent/src/key_keeper/key.rs @@ -36,9 +36,9 @@ use hyper::Uri; use proxy_agent_shared::hyper_client; use proxy_agent_shared::logger::LoggerLevel; use serde_derive::{Deserialize, Serialize}; -use std::ffi::OsString; use std::fmt::{Display, Formatter}; use std::{collections::HashMap, path::PathBuf}; +use std::{ffi::OsString, time::Duration}; const AUDIT_MODE: &str = "audit"; const ENFORCE_MODE: &str = "enforce"; @@ -729,7 +729,11 @@ impl Display for KeyAction { const STATUS_URL: &str = "/secure-channel/status"; const KEY_URL: &str = "/secure-channel/key"; +const HOST_DATE_TIME_DRIFT_MAX_AGE: Duration = Duration::from_secs(60 * 15); +/// Get the current status of the key from the secure channel. +/// This function will perform a single GET request to the secure channel status endpoint. +/// If the host time sync is stale, it will sync the host time from the response headers. pub async fn get_status(host: &str, port: u16) -> Result { let endpoint = hyper_client::HostEndpoint::new(host, port, STATUS_URL); let mut headers = HashMap::new(); @@ -737,8 +741,32 @@ pub async fn get_status(host: &str, port: u16) -> Result { hyper_client::METADATA_HEADER.to_string(), "True ".to_string(), ); - let status: KeyStatus = - hyper_client::get(&endpoint, &headers, None, None, logger::write_warning).await?; + + let request = hyper_client::build_request(Method::GET, &endpoint, &headers, None, None, None)?; + + let response = hyper_client::send_request( + &endpoint.host, + endpoint.port, + request, + logger::write_warning, + ) + .await?; + + if !response.status().is_success() { + return Err(proxy_agent_shared::error::Error::Hyper( + proxy_agent_shared::error::HyperErrorType::ServerError( + endpoint.to_string(), + response.status(), + ), + ))?; + } + + if proxy_agent_shared::misc_helpers::host_time_sync_is_stale(HOST_DATE_TIME_DRIFT_MAX_AGE) { + let host_time_synced = hyper_client::sync_host_time_from_headers(response.headers()); + logger::write(format!("Host time synced: {host_time_synced}")); + } + + let status: KeyStatus = hyper_client::read_response_body(response).await?; status.validate()?; Ok(status) diff --git a/proxy_agent_shared/src/hyper_client.rs b/proxy_agent_shared/src/hyper_client.rs index 8fe933a6..71175448 100644 --- a/proxy_agent_shared/src/hyper_client.rs +++ b/proxy_agent_shared/src/hyper_client.rs @@ -142,6 +142,30 @@ where read_response_body(response).await } +/// Try to sync host time from a response header map. +/// The function first looks for a custom date header (x-ms-azure-host-date) and +/// if not found, falls back to the standard HTTP date header. +/// if custom date header is present but invalid, it will not fall back to standard date header +/// Returns true when sync is updated successfully. +pub fn sync_host_time_from_headers(headers: &hyper::HeaderMap) -> bool { + // first try custom date header + if let Some(host_date) = headers.get(DATE_HEADER) { + if let Ok(host_date_rfc1123) = host_date.to_str() { + return misc_helpers::sync_host_utc_time_from_rfc1123_string(host_date_rfc1123); + } + } + + // fallback to standard HTTP date header + if let Some(host_date) = headers.get(hyper::header::DATE) { + if let Ok(host_date_rfc1123) = host_date.to_str() { + return misc_helpers::sync_host_utc_time_from_rfc1123_string(host_date_rfc1123); + } + } + + // return false if no valid date header found + false +} + pub async fn read_response_body( mut response: hyper::Response, ) -> Result @@ -539,7 +563,7 @@ mod tests { use crate::{ host_clients::{imds_client::ImdsClient, wire_server_client::WireServerClient}, logger::logger_manager, - server_mock, + misc_helpers, server_mock, }; use tokio_util::sync::CancellationToken; @@ -608,4 +632,86 @@ mod tests { cancellation_token.cancel(); } + + #[test] + fn sync_host_time_from_headers_tests() { + // should return false when no date headers are present + let headers = hyper::HeaderMap::new(); + assert!( + !super::sync_host_time_from_headers(&headers), + "should return false when no date headers are present" + ); + + // should return true with valid custom date header + let mut headers = hyper::HeaderMap::new(); + headers.insert( + super::DATE_HEADER, + "Wed, 23 Apr 2025 12:00:00 GMT".parse().unwrap(), + ); + assert!( + super::sync_host_time_from_headers(&headers), + "should return true with valid custom date header" + ); + + // should return true with valid standard Date header + let mut headers = hyper::HeaderMap::new(); + headers.insert( + hyper::header::DATE, + "Wed, 23 Apr 2025 12:00:00 GMT".parse().unwrap(), + ); + assert!( + super::sync_host_time_from_headers(&headers), + "should return true with valid standard Date header" + ); + + // when both headers are present but invalid, should return false without panic (to_str() succeeds but RFC1123 parse fails) + let mut headers = hyper::HeaderMap::new(); + headers.insert(super::DATE_HEADER, "not-a-valid-date".parse().unwrap()); + headers.insert(hyper::header::DATE, "also-not-valid".parse().unwrap()); + assert!( + !super::sync_host_time_from_headers(&headers), + "should return false when both headers have invalid dates" + ); + + // When the custom header is present but has an invalid date string, + // the function returns false immediately (does not fall back to standard Date header) + // because to_str() succeeds but the RFC1123 parse fails. + let mut headers = hyper::HeaderMap::new(); + headers.insert(super::DATE_HEADER, "not-a-valid-date".parse().unwrap()); + headers.insert( + hyper::header::DATE, + "Wed, 23 Apr 2025 12:00:00 GMT".parse().unwrap(), + ); + assert!( + !super::sync_host_time_from_headers(&headers), + "should return false without falling back when custom header has invalid date" + ); + + // when both headers are present and valid, should return true (sync from custom header takes precedence) + // and not fallback to standard Date header + let mut headers = hyper::HeaderMap::new(); + let custom_host_time = "Wed, 23 Apr 2025 12:00:00 GMT"; + headers.insert(super::DATE_HEADER, custom_host_time.parse().unwrap()); + headers.insert( + hyper::header::DATE, + "Thu, 24 Apr 2025 12:00:00 GMT".parse().unwrap(), + ); + // Both are valid; the function should return true + assert!( + super::sync_host_time_from_headers(&headers), + "should return true when both headers are present" + ); + // verify the sync time is from the custom header, not the standard Date header (which is 1 day later) + let sync_time = misc_helpers::parse_rfc1123_to_offset_datetime( + &misc_helpers::get_date_time_rfc1123_string(), + ) + .unwrap(); + let expected_sync_time = + misc_helpers::parse_rfc1123_to_offset_datetime(custom_host_time).unwrap(); + let diff = sync_time - expected_sync_time; + assert!( + diff < time::Duration::seconds(5), + "sync time should be close to the custom header time" + ); + } } diff --git a/proxy_agent_shared/src/misc_helpers.rs b/proxy_agent_shared/src/misc_helpers.rs index 9378c1b0..9d121c0d 100644 --- a/proxy_agent_shared/src/misc_helpers.rs +++ b/proxy_agent_shared/src/misc_helpers.rs @@ -11,6 +11,8 @@ use std::{ fs::{self, File}, path::{Path, PathBuf}, process::Command, + sync::RwLock, + time::Instant, }; use thread_id; use time::{format_description, OffsetDateTime, PrimitiveDateTime}; @@ -48,6 +50,14 @@ static RFC1123_FORMAT: std::sync::LazyLock>> = + std::sync::LazyLock::new(|| RwLock::new(None)); + pub fn get_date_time_string_with_milliseconds() -> String { let time_str = OffsetDateTime::now_utc() .format(&*ISO8601_MILLIS_FORMAT) @@ -63,11 +73,66 @@ pub fn get_date_time_string() -> String { } pub fn get_date_time_rfc1123_string() -> String { - OffsetDateTime::now_utc() + get_current_utc_time_synced() .format(&*RFC1123_FORMAT) .expect("Failed to format RFC1123 date") } +/// Update host-time sync state from a RFC1123 datetime string. +/// Returns true when sync state is updated successfully, false otherwise. +pub fn sync_host_utc_time_from_rfc1123_string(host_utc_rfc1123: &str) -> bool { + let Ok(parsed_host_utc) = parse_rfc1123_to_offset_datetime(host_utc_rfc1123) else { + return false; + }; + + let Ok(mut state) = HOST_TIME_SYNC_STATE.write() else { + return false; + }; + + *state = Some(HostTimeSyncState { + synced_host_utc: parsed_host_utc, + synced_instant: Instant::now(), + }); + true +} + +pub fn parse_rfc1123_to_offset_datetime(rfc1123_str: &str) -> Result { + PrimitiveDateTime::parse(rfc1123_str, &*RFC1123_FORMAT) + .map(|dt| dt.assume_utc()) + .map_err(|e| { + Error::ParseDateTimeStringError(format!( + "Failed to parse RFC1123 datetime string '{rfc1123_str}': {e}" + )) + }) +} + +/// Returns true when current host-time sync state is older than `max_age`. +/// If there is no host-time sync state yet, this returns true. +pub fn host_time_sync_is_stale(max_age: std::time::Duration) -> bool { + let Ok(state) = HOST_TIME_SYNC_STATE.read() else { + return true; + }; + state + .as_ref() + .is_none_or(|synced| synced.synced_instant.elapsed() > max_age) +} + +fn get_current_utc_time_synced() -> OffsetDateTime { + let Ok(state) = HOST_TIME_SYNC_STATE.read() else { + return OffsetDateTime::now_utc(); + }; + + let Some(synced) = state.as_ref() else { + return OffsetDateTime::now_utc(); + }; + + let elapsed = synced.synced_instant.elapsed(); + match time::Duration::try_from(elapsed) { + Ok(elapsed_time) => synced.synced_host_utc + elapsed_time, + Err(_) => OffsetDateTime::now_utc(), + } +} + pub fn get_date_time_unix_nano() -> i128 { OffsetDateTime::now_utc().unix_timestamp_nanos() } @@ -471,6 +536,7 @@ mod tests { use std::env; use std::fs; use std::path::PathBuf; + use std::time::Duration; #[derive(Serialize, Deserialize)] struct TestStruct { @@ -823,4 +889,37 @@ mod tests { "Should fail to parse invalid datetime string" ); } + + #[test] + fn sync_host_utc_time_from_rfc1123_string_test() { + let host_time = "Mon, 01 Jan 2024 00:00:00 GMT"; + assert!( + super::sync_host_utc_time_from_rfc1123_string(host_time), + "Expected valid host RFC1123 time to update sync state" + ); + + assert!( + !super::host_time_sync_is_stale(Duration::from_secs(3600)), + "Sync state should not be stale right after update" + ); + + std::thread::sleep(Duration::from_millis(1)); + assert!( + super::host_time_sync_is_stale(Duration::from_millis(0)), + "Sync state should be stale when max_age is zero" + ); + + // reset sync state to None for other tests + let _ = super::HOST_TIME_SYNC_STATE + .write() + .map(|mut state| *state = None); + } + + #[test] + fn sync_host_utc_time_from_rfc1123_string_invalid_input_test() { + assert!( + !super::sync_host_utc_time_from_rfc1123_string("invalid-rfc1123"), + "Expected invalid host RFC1123 time to fail" + ); + } } From c259110c3efb2d8bf4816499f72547ed8fba6ae6 Mon Sep 17 00:00:00 2001 From: Zhidong Peng Date: Mon, 27 Apr 2026 12:58:10 -0700 Subject: [PATCH 6/8] Add local file-based access-control rule support. (#329) * Add local file-based access-control rule support. * formatting * resolve comments and validate the parsed local rules. * fix formatting. * fix case-insensitive match * prefix_local_rule_names Co-authored-by: Copilot * Display useLocalFileRules. * update log level at attemptting Co-authored-by: Copilot * fix formatting --------- Co-authored-by: Zhidong Peng Co-authored-by: Copilot --- Cargo.lock | 7 + proxy_agent/Cargo.toml | 1 + proxy_agent/src/key_keeper.rs | 188 ++- proxy_agent/src/key_keeper/local_rules.rs | 1328 +++++++++++++++++++++ 4 files changed, 1463 insertions(+), 61 deletions(-) create mode 100644 proxy_agent/src/key_keeper/local_rules.rs diff --git a/Cargo.lock b/Cargo.lock index ce02cf76..109d3775 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -175,6 +175,7 @@ name = "azure-proxy-agent" version = "9.9.9" dependencies = [ "aya", + "base64", "bitflags", "clap", "http", @@ -223,6 +224,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "2.6.0" diff --git a/proxy_agent/Cargo.toml b/proxy_agent/Cargo.toml index 6b0987d6..9bc6b227 100644 --- a/proxy_agent/Cargo.toml +++ b/proxy_agent/Cargo.toml @@ -29,6 +29,7 @@ clap = { version = "4.5.17", features =["derive"] } # Command Line Argument Pars thiserror = "1.0.64" libc = "0.2.147" socket2 = "0.5" # Set socket options without tokio/std conversion +base64 = "0.22" [dependencies.uuid] version = "1.3.0" diff --git a/proxy_agent/src/key_keeper.rs b/proxy_agent/src/key_keeper.rs index 3f289bde..119f2902 100644 --- a/proxy_agent/src/key_keeper.rs +++ b/proxy_agent/src/key_keeper.rs @@ -25,12 +25,16 @@ //! ``` pub mod key; +pub mod local_rules; use self::key::Key; +use self::local_rules::{ + get_rules_dir_from_key_dir, resolve_effective_rules, LocalRuleStateTracker, LocalRuleTarget, +}; use crate::common::error::{Error, KeyErrorType}; use crate::common::result::Result; use crate::common::{constants, helpers, logger}; -use crate::key_keeper::key::KeyStatus; +use crate::key_keeper::key::{AuthorizationRules, KeyStatus}; use crate::provision; use crate::proxy::authorization_rules::{AuthorizationRulesForLogging, ComputedAuthorizationRules}; use crate::shared_state::access_control_wrapper::AccessControlSharedState; @@ -70,6 +74,8 @@ pub struct KeyKeeper { port: u16, /// key_dir: the folder to save the key details key_dir: PathBuf, + /// rules_dir: the folder to save customer-managed local access control rules + rules_dir: PathBuf, /// status_dir: the folder to log the access control rule details status_dir: PathBuf, /// interval: the interval to poll the secure channel status @@ -110,10 +116,12 @@ impl KeyKeeper { interval: Duration, shared_state: &SharedState, ) -> Self { + let rules_dir = get_rules_dir_from_key_dir(&key_dir); KeyKeeper { host, port, key_dir, + rules_dir, status_dir, interval, cancellation_token: shared_state.get_cancellation_token(), @@ -132,34 +140,8 @@ impl KeyKeeper { self.update_status_message("poll secure channel status task started.".to_string(), true) .await; - if let Err(e) = misc_helpers::try_create_folder(&self.key_dir) { - logger::write_warning(format!( - "key folder {} created failed with error {}.", - misc_helpers::path_to_string(&self.key_dir), - e - )); - } else { - logger::write(format!( - "key folder {} created if not exists before.", - misc_helpers::path_to_string(&self.key_dir) - )); - } - - match acl::acl_directory(self.key_dir.clone()) { - Ok(()) => { - logger::write(format!( - "Folder {} ACLed if has not before.", - misc_helpers::path_to_string(&self.key_dir) - )); - } - Err(e) => { - logger::write_warning(format!( - "Folder {} ACLed failed with error {}.", - misc_helpers::path_to_string(&self.key_dir), - e - )); - } - } + self.ensure_secure_directory(&self.key_dir, "key"); + self.ensure_secure_directory(&self.rules_dir, "rules"); // acl current executable dir #[cfg(windows)] @@ -194,6 +176,7 @@ impl KeyKeeper { let mut first_iteration: bool = true; let mut started_event_threads: bool = false; let mut provision_timeout: bool = false; + let mut local_rule_state_tracker = LocalRuleStateTracker::default(); let notify = match self.key_keeper_shared_state.get_notify().await { Ok(notify) => notify, Err(e) => { @@ -280,7 +263,8 @@ impl KeyKeeper { previous_key_status_message = Some(key_status_message); } - self.update_access_control_rules(&status).await; + self.update_access_control_rules(&status, &mut local_rule_state_tracker) + .await; let state = status.get_secure_channel_state(); let secure_channel_state_updated = self @@ -483,11 +467,18 @@ impl KeyKeeper { /// Update access control rules from the key status /// Returns true if any rules changed - async fn update_access_control_rules(&self, status: &KeyStatus) -> bool { + async fn update_access_control_rules( + &self, + status: &KeyStatus, + local_rule_state_tracker: &mut LocalRuleStateTracker, + ) -> bool { let mut access_control_rules_changed = false; let wireserver_rule_id = status.get_wireserver_rule_id(); let imds_rule_id = status.get_imds_rule_id(); let hostga_rule_id = status.get_hostga_rule_id(); + let mut wireserver_rule_id_changed = false; + let mut imds_rule_id_changed = false; + let mut hostga_rule_id_changed = false; // Update wireserver rules match self @@ -496,18 +487,11 @@ impl KeyKeeper { .await { Ok((updated, old_wire_server_rule_id)) => { + wireserver_rule_id_changed = updated; if updated { logger::write_warning(format!( "Wireserver rule id changed from '{old_wire_server_rule_id}' to '{wireserver_rule_id}'." )); - if let Err(e) = self - .access_control_shared_state - .set_wireserver_rules(status.get_wireserver_rules()) - .await - { - logger::write_error(format!("Failed to set wireserver rules: {e}")); - } - access_control_rules_changed = true; } } Err(e) => { @@ -522,18 +506,11 @@ impl KeyKeeper { .await { Ok((updated, old_imds_rule_id)) => { + imds_rule_id_changed = updated; if updated { logger::write_warning(format!( "IMDS rule id changed from '{old_imds_rule_id}' to '{imds_rule_id}'." )); - if let Err(e) = self - .access_control_shared_state - .set_imds_rules(status.get_imds_rules()) - .await - { - logger::write_error(format!("Failed to set imds rules: {e}")); - } - access_control_rules_changed = true; } } Err(e) => { @@ -548,18 +525,11 @@ impl KeyKeeper { .await { Ok((updated, old_hostga_rule_id)) => { + hostga_rule_id_changed = updated; if updated { logger::write_warning(format!( "HostGA rule id changed from '{old_hostga_rule_id}' to '{hostga_rule_id}'." )); - if let Err(e) = self - .access_control_shared_state - .set_hostga_rules(status.get_hostga_rules()) - .await - { - logger::write_error(format!("Failed to set HostGA rules: {e}")); - } - access_control_rules_changed = true; } } Err(e) => { @@ -567,9 +537,70 @@ impl KeyKeeper { } } + let (wireserver_rules, wireserver_local_state_changed) = resolve_effective_rules( + &self.rules_dir, + status.get_wireserver_rules(), + LocalRuleTarget::WireServer, + &mut local_rule_state_tracker.wireserver, + wireserver_rule_id_changed, + ) + .await; + let (imds_rules, imds_local_state_changed) = resolve_effective_rules( + &self.rules_dir, + status.get_imds_rules(), + LocalRuleTarget::Imds, + &mut local_rule_state_tracker.imds, + imds_rule_id_changed, + ) + .await; + + if wireserver_rule_id_changed || wireserver_local_state_changed { + if let Err(e) = self + .access_control_shared_state + .set_wireserver_rules(wireserver_rules.clone()) + .await + { + logger::write_error(format!("Failed to set wireserver rules: {e}")); + } + access_control_rules_changed = true; + } + + if imds_rule_id_changed || imds_local_state_changed { + if let Err(e) = self + .access_control_shared_state + .set_imds_rules(imds_rules.clone()) + .await + { + logger::write_error(format!("Failed to set imds rules: {e}")); + } + access_control_rules_changed = true; + } + + // HostGA rules only come from server and do not have local rules, so only update when rule id changed + let hostga_rules = status.get_hostga_rules(); + if hostga_rule_id_changed { + if let Err(e) = self + .access_control_shared_state + .set_hostga_rules(hostga_rules.clone()) + .await + { + logger::write_error(format!("Failed to set HostGA rules: {e}")); + } + access_control_rules_changed = true; + } + // Write authorization rules to file if changed if access_control_rules_changed { - if let (Ok(wireserver_rules), Ok(imds_rules), Ok(hostga_rules)) = ( + let effective_rules = Some(AuthorizationRules { + wireserver: wireserver_rules.clone(), + imds: imds_rules.clone(), + hostga: hostga_rules.clone(), + }); + if let ( + Ok(computed_wireserver_rules), + Ok(computed_imds_rules), + Ok(computed_hostga_rules), + ) = ( self.access_control_shared_state .get_wireserver_rules() .await, @@ -577,11 +608,11 @@ impl KeyKeeper { self.access_control_shared_state.get_hostga_rules().await, ) { let rules = AuthorizationRulesForLogging::new( - status.authorizationRules.clone(), + effective_rules, ComputedAuthorizationRules { - wireserver: wireserver_rules, - imds: imds_rules, - hostga: hostga_rules, + wireserver: computed_wireserver_rules, + imds: computed_imds_rules, + hostga: computed_hostga_rules, }, ); rules.write_all(&self.status_dir, constants::MAX_LOG_FILE_COUNT); @@ -591,6 +622,39 @@ impl KeyKeeper { access_control_rules_changed } + /// Ensure the directory exists and has secure ACLs. + /// If the directory does not exist, it will be created. + fn ensure_secure_directory(&self, dir: &Path, dir_kind: &str) { + if let Err(e) = misc_helpers::try_create_folder(dir) { + logger::write_warning(format!( + "{dir_kind} folder {} created failed with error {}.", + misc_helpers::path_to_string(dir), + e + )); + } else { + logger::write(format!( + "{dir_kind} folder {} created if not exists before.", + misc_helpers::path_to_string(dir) + )); + } + + match acl::acl_directory(dir.to_path_buf()) { + Ok(()) => { + logger::write(format!( + "Folder {} ACLed if has not before.", + misc_helpers::path_to_string(dir) + )); + } + Err(e) => { + logger::write_warning(format!( + "Folder {} ACLed failed with error {}.", + misc_helpers::path_to_string(dir), + e + )); + } + } + } + /// Handle key acquisition from local or server /// Returns true if successful, false if should continue to next iteration async fn handle_key_acquisition(&self, status: &KeyStatus, state: &str) -> bool { @@ -1001,8 +1065,9 @@ impl KeyKeeper { #[cfg(test)] mod tests { use super::key::Key; - use super::KeyKeeper; + use super::local_rules; use crate::key_keeper; + use crate::key_keeper::KeyKeeper; use proxy_agent_shared::misc_helpers; use proxy_agent_shared::server_mock; use std::env; @@ -1084,6 +1149,7 @@ mod tests { host: ip.to_string(), port, key_dir: cloned_keys_dir.clone(), + rules_dir: local_rules::get_rules_dir_from_key_dir(&cloned_keys_dir), status_dir: cloned_keys_dir.clone(), interval: Duration::from_millis(10), cancellation_token: cancellation_token.clone(), diff --git a/proxy_agent/src/key_keeper/local_rules.rs b/proxy_agent/src/key_keeper/local_rules.rs new file mode 100644 index 00000000..40451d94 --- /dev/null +++ b/proxy_agent/src/key_keeper/local_rules.rs @@ -0,0 +1,1328 @@ +// Copyright (c) Microsoft Corporation +// SPDX-License-Identifier: MIT + +use crate::common::error::Error; +use crate::common::logger; +use crate::common::result::Result; +use crate::key_keeper::key::{ + AccessControlRules, AuthorizationItem, Identity, Privilege, Role, RoleAssignment, +}; +use base64::{engine::general_purpose, Engine as _}; +use proxy_agent_shared::logger::LoggerLevel; +use proxy_agent_shared::misc_helpers; +use proxy_agent_shared::telemetry::event_logger; +use serde_derive::Deserialize; +use std::collections::HashSet; +use std::fs; +use std::path::{Path, PathBuf}; +use std::time::{Duration, SystemTime}; + +const LOCAL_RULE_FILE_PARSE_RETRY_COUNT: usize = 3; +const LOCAL_RULE_FILE_PARSE_RETRY_DELAY: Duration = Duration::from_millis(50); + +#[derive(Clone, Copy)] +pub(crate) enum LocalRuleTarget { + WireServer, + Imds, +} + +impl LocalRuleTarget { + pub(crate) fn display_name(self) -> &'static str { + match self { + LocalRuleTarget::WireServer => "WireServer", + LocalRuleTarget::Imds => "IMDS", + } + } + + pub(crate) fn file_name(self) -> &'static str { + match self { + LocalRuleTarget::WireServer => "WireServer_Rules.json", + LocalRuleTarget::Imds => "IMDS_Rules.json", + } + } +} + +#[derive(Clone, Debug, Default, PartialEq)] +pub(crate) enum LocalRuleFileState { + #[default] + Unknown, + Error(String), + Missing, + Present(SystemTime), +} + +#[derive(Clone, Default)] +pub(crate) struct LocalRuleMonitorState { + pub(crate) use_local_file_rules: bool, + pub(crate) file_state: LocalRuleFileState, + pub(crate) parse_failed: bool, + pub(crate) effective_rules: Option, +} + +#[derive(Default)] +pub(crate) struct LocalRuleStateTracker { + pub(crate) wireserver: LocalRuleMonitorState, + pub(crate) imds: LocalRuleMonitorState, +} + +#[derive(Default)] +pub(crate) struct RuleIdDescriptor { + pub(crate) logical_id: String, + pub(crate) use_local_file_rules: bool, +} + +impl RuleIdDescriptor { + pub(crate) fn is_empty(&self) -> bool { + self.logical_id.is_empty() + } + + pub(crate) fn display_id(&self) -> String { + if self.logical_id.is_empty() { + "unknown".to_string() + } else { + format!( + "{}-useLocalFileRules-{}", + self.logical_id, self.use_local_file_rules + ) + } + } +} + +#[derive(Deserialize)] +#[allow(non_snake_case)] +struct EncodedRuleId { + #[serde(default)] + id: String, + #[serde(default)] + useLocalFileRules: bool, +} + +#[derive(Deserialize)] +#[allow(non_snake_case)] +pub(crate) struct LocalAuthorizationRulesFile { + #[serde(default)] + pub(crate) defaultAccess: Option, + #[serde(default)] + pub(crate) id: Option, + #[serde(default)] + pub(crate) rules: Option, +} + +/// As the rules folder is a sibling folder of the key folder, +/// get the rules folder path based on the key folder path. +pub(crate) fn get_rules_dir_from_key_dir(key_dir: &Path) -> PathBuf { + let folder_name = if cfg!(windows) { "Rules" } else { "rules" }; + match key_dir.parent() { + Some(parent) => parent.join(folder_name), + None => key_dir.join(folder_name), + } +} + +/// Get the state of the local rule file - whether it is present or missing, and if present, its last modified time. +fn get_local_rule_file_state(local_rules_file_path: &Path) -> LocalRuleFileState { + match fs::metadata(local_rules_file_path) { + Ok(metadata) => match metadata.modified() { + Ok(modified) => LocalRuleFileState::Present(modified), + Err(_) => LocalRuleFileState::Present(SystemTime::UNIX_EPOCH), // TODO: if we cannot get the modified time, we can treat it as present but with an unknown modified time. Using UNIX_EPOCH as a placeholder here, but it might be better to have a separate variant for this case. + }, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => LocalRuleFileState::Missing, + Err(e) => LocalRuleFileState::Error(format!( + "Unexpected error reading local rules file metadata: {e}" + )), + } +} + +/// Parse the rule ID descriptor from the raw rule ID string. +/// The raw rule ID can be either a plain logical ID or +/// a base64-encoded JSON string containing the logical ID and whether to use local file rules. +fn parse_rule_id_descriptor(raw_rule_id: Option<&str>) -> RuleIdDescriptor { + let raw_rule_id = raw_rule_id.unwrap_or_default().trim(); + if raw_rule_id.is_empty() { + return RuleIdDescriptor::default(); + } + + if let Ok(decoded) = general_purpose::STANDARD.decode(raw_rule_id) { + if let Ok(contract) = serde_json::from_slice::(&decoded) { + return RuleIdDescriptor { + logical_id: contract.id, + use_local_file_rules: contract.useLocalFileRules, + }; + } + } + + // If parsing fails, treat the raw rule ID as the logical ID and do not use local file rules. + RuleIdDescriptor { + logical_id: raw_rule_id.to_string(), + use_local_file_rules: false, + } +} + +fn normalize_authorization_item( + authorization_item: Option, + rule_id_descriptor: &RuleIdDescriptor, +) -> Option { + authorization_item.map(|mut item| { + if !rule_id_descriptor.is_empty() { + item.id = rule_id_descriptor.display_id(); + } + item + }) +} + +fn merge_authorization_item( + remote_rules: Option, + local_rules: LocalAuthorizationRulesFile, + rule_id_descriptor: &RuleIdDescriptor, +) -> Option { + let mut merged_item = remote_rules.unwrap_or(AuthorizationItem { + defaultAccess: "deny".to_string(), + mode: "disabled".to_string(), + id: rule_id_descriptor.display_id(), + rules: None, + }); + + if !rule_id_descriptor.is_empty() { + merged_item.id = rule_id_descriptor.display_id(); + } + + // merge local rule id with remote rule id by appending local id to remote id with an underscore, + // so that we can have both ids in the merged result for better traceability. + // For example, if remote id is "decoded-id" and local id is "local-id", the merged id will be "decoded-id_local-id". + // This also ensures that if local rules are applied, the merged rules will have a different id than the remote rules, + // which can help to avoid confusion and make it clear that the rules have been modified by local rules. + if let Some(local_id) = local_rules.id { + merged_item.id = merged_item.id + "_" + &local_id; + } + + // for defaultAccess, we will let local rules override remote rules if local rules have it defined, + // as defaultAccess is a high-level setting that can significantly change the access control behavior. + if let Some(default_access) = local_rules.defaultAccess { + merged_item.defaultAccess = default_access; + } + + merged_item.rules = merge_access_control_rules(merged_item.rules, local_rules.rules); + Some(merged_item) +} + +const LOCAL_RULE_NAME_PREFIX: &str = "LocalFileRules_"; + +/// Add prefix 'LocalFileRules_' to all name and reference parts in local rules +/// to avoid conflicts with remote rules that may have the same names but different definitions. +fn prefix_local_rule_names(mut rules: AccessControlRules) -> AccessControlRules { + // Prefix privilege names + if let Some(privileges) = &mut rules.privileges { + for privilege in privileges.iter_mut() { + privilege.name = format!("{LOCAL_RULE_NAME_PREFIX}{}", privilege.name); + } + } + + // Prefix role names and their privilege references + if let Some(roles) = &mut rules.roles { + for role in roles.iter_mut() { + role.name = format!("{LOCAL_RULE_NAME_PREFIX}{}", role.name); + for privilege_ref in role.privileges.iter_mut() { + *privilege_ref = format!("{LOCAL_RULE_NAME_PREFIX}{privilege_ref}"); + } + } + } + + // Prefix identity names + if let Some(identities) = &mut rules.identities { + for identity in identities.iter_mut() { + identity.name = format!("{LOCAL_RULE_NAME_PREFIX}{}", identity.name); + } + } + + // Prefix role assignment role and identity references + if let Some(role_assignments) = &mut rules.roleAssignments { + for ra in role_assignments.iter_mut() { + ra.role = format!("{LOCAL_RULE_NAME_PREFIX}{}", ra.role); + for identity_ref in ra.identities.iter_mut() { + *identity_ref = format!("{LOCAL_RULE_NAME_PREFIX}{identity_ref}"); + } + } + } + + rules +} + +fn merge_access_control_rules( + remote_rules: Option, + local_rules: Option, +) -> Option { + match (remote_rules, local_rules) { + (None, None) => None, + (Some(rules), None) | (None, Some(rules)) => Some(rules), + (Some(remote), Some(local)) => { + let prefixed_local = prefix_local_rule_names(local); + Some(AccessControlRules { + privileges: merge_rule_vectors(remote.privileges, prefixed_local.privileges), + roles: merge_rule_vectors(remote.roles, prefixed_local.roles), + identities: merge_rule_vectors(remote.identities, prefixed_local.identities), + roleAssignments: merge_rule_vectors( + remote.roleAssignments, + prefixed_local.roleAssignments, + ), + }) + } + } +} + +fn merge_rule_vectors(remote: Option>, local: Option>) -> Option> { + match (remote, local) { + (None, None) => None, + (Some(values), None) | (None, Some(values)) => Some(values), + (Some(mut remote_values), Some(mut local_values)) => { + remote_values.append(&mut local_values); + Some(remote_values) + } + } +} + +/// SPEC: Failed to parse file after retries - block all the proxied requests to corresponding host service, +/// In case of local rules file parse failure, build fail-closed rules which deny all access. +/// This ensures that if there is an issue with the local rules file, +/// we do not accidentally allow access due to fallback to remote rules which might be more permissive. +/// The fail-closed rules will keep the same logical ID as the remote rules (if any) for better traceability, +/// but will set defaultAccess to "deny" and remove all specific rules. +pub(crate) fn build_fail_closed_rules( + remote_rules: Option, + rule_id_descriptor: &RuleIdDescriptor, +) -> Option { + let mut rules = remote_rules.unwrap_or(AuthorizationItem { + defaultAccess: "deny".to_string(), + mode: "enforce".to_string(), + id: rule_id_descriptor.display_id(), + rules: None, + }); + + if !rule_id_descriptor.is_empty() { + rules.id = rule_id_descriptor.display_id(); + } + + // block all the requests by setting defaultAccess to deny and removing all specific rules, + // regardless of what the remote rules are. + rules.defaultAccess = "deny".to_string(); + rules.rules = None; + + Some(rules) +} + +fn validate_local_rules_file(local_rules: &LocalAuthorizationRulesFile) -> Result<()> { + if let Some(default_access) = &local_rules.defaultAccess { + if !matches!( + default_access.trim().to_lowercase().as_str(), // case-insensitive match + "allow" | "deny" + ) { + return Err(Error::Invalid(format!( + "local rules defaultAccess must be 'allow' or 'deny', got '{}'", + default_access + ))); + } + } + + if let Some(id) = &local_rules.id { + if id.trim().is_empty() { + return Err(Error::Invalid( + "local rules id cannot be empty when provided".to_string(), + )); + } + } + + if let Some(rules) = &local_rules.rules { + validate_access_control_rules(rules)?; + } + + Ok(()) +} + +fn validate_access_control_rules(rules: &AccessControlRules) -> Result<()> { + let privilege_names = validate_privileges(rules.privileges.as_ref())?; + let role_names = validate_roles(rules.roles.as_ref(), &privilege_names)?; + let identity_names = validate_identities(rules.identities.as_ref())?; + validate_role_assignments(rules.roleAssignments.as_ref(), &role_names, &identity_names)?; + Ok(()) +} + +/// Validate privileges and return the set of privilege names. +/// Privilege names must be unique and non-empty. +/// Each privilege must have a non-empty path, and if queryParameters are defined, they cannot contain empty keys or values. +fn validate_privileges(privileges: Option<&Vec>) -> Result> { + let mut names = HashSet::new(); + if let Some(privileges) = privileges { + for privilege in privileges { + if privilege.name.trim().is_empty() { + return Err(Error::Invalid("privilege name cannot be empty".to_string())); + } + if privilege.path.trim().is_empty() { + return Err(Error::Invalid(format!( + "privilege '{}' path cannot be empty", + privilege.name + ))); + } + if !names.insert(privilege.name.clone()) { + return Err(Error::Invalid(format!( + "duplicate privilege name '{}'", + privilege.name + ))); + } + if let Some(query_parameters) = &privilege.queryParameters { + for (key, value) in query_parameters { + if key.trim().is_empty() || value.trim().is_empty() { + return Err(Error::Invalid(format!( + "privilege '{}' queryParameters cannot contain empty keys or values", + privilege.name + ))); + } + } + } + } + } + Ok(names) +} + +/// Validate roles and return the set of role names. +/// Role names must be unique and non-empty. +/// Each role, it must reference at least one privilege, and the referenced privileges cannot be duplicated within the role. +fn validate_roles( + roles: Option<&Vec>, + privilege_names: &HashSet, +) -> Result> { + let mut names = HashSet::new(); + if let Some(roles) = roles { + for role in roles { + if role.name.trim().is_empty() { + return Err(Error::Invalid("role name cannot be empty".to_string())); + } + if role.privileges.is_empty() { + return Err(Error::Invalid(format!( + "role '{}' must reference at least one privilege", + role.name + ))); + } + if !names.insert(role.name.clone()) { + return Err(Error::Invalid(format!( + "duplicate role name '{}'", + role.name + ))); + } + + let mut referenced_privileges = HashSet::new(); + for privilege_name in &role.privileges { + if privilege_name.trim().is_empty() { + return Err(Error::Invalid(format!( + "role '{}' contains an empty privilege reference", + role.name + ))); + } + if !referenced_privileges.insert(privilege_name.clone()) { + return Err(Error::Invalid(format!( + "role '{}' contains duplicate privilege reference '{}'", + role.name, privilege_name + ))); + } + if !privilege_names.contains(privilege_name) { + return Err(Error::Invalid(format!( + "role '{}' references unknown privilege '{}'", + role.name, privilege_name + ))); + } + } + } + } + Ok(names) +} + +/// Validate identities and return the set of identity names. +/// Identity names must be unique and non-empty. +/// Each identity must have at least one selector defined (userName, groupName, exePath or processName). +fn validate_identities(identities: Option<&Vec>) -> Result> { + let mut names = HashSet::new(); + if let Some(identities) = identities { + for identity in identities { + if identity.name.trim().is_empty() { + return Err(Error::Invalid("identity name cannot be empty".to_string())); + } + if !names.insert(identity.name.clone()) { + return Err(Error::Invalid(format!( + "duplicate identity name '{}'", + identity.name + ))); + } + + let selectors = [ + identity.userName.as_deref(), + identity.groupName.as_deref(), + identity.exePath.as_deref(), + identity.processName.as_deref(), + ]; + let has_selector = selectors + .into_iter() + .flatten() + .any(|value| !value.trim().is_empty()); + if !has_selector { + return Err(Error::Invalid(format!( + "identity '{}' must specify at least one selector", + identity.name + ))); + } + } + } + Ok(names) +} + +fn validate_role_assignments( + role_assignments: Option<&Vec>, + role_names: &HashSet, + identity_names: &HashSet, +) -> Result<()> { + if let Some(role_assignments) = role_assignments { + for role_assignment in role_assignments { + if role_assignment.role.trim().is_empty() { + return Err(Error::Invalid( + "roleAssignment role cannot be empty".to_string(), + )); + } + if !role_names.contains(&role_assignment.role) { + return Err(Error::Invalid(format!( + "roleAssignment references unknown role '{}'", + role_assignment.role + ))); + } + if role_assignment.identities.is_empty() { + return Err(Error::Invalid(format!( + "roleAssignment for role '{}' must reference at least one identity", + role_assignment.role + ))); + } + + let mut referenced_identities = HashSet::new(); + for identity_name in &role_assignment.identities { + if identity_name.trim().is_empty() { + return Err(Error::Invalid(format!( + "roleAssignment for role '{}' contains an empty identity reference", + role_assignment.role + ))); + } + if !referenced_identities.insert(identity_name.clone()) { + return Err(Error::Invalid(format!( + "roleAssignment for role '{}' contains duplicate identity reference '{}'", + role_assignment.role, identity_name + ))); + } + if !identity_names.contains(identity_name) { + return Err(Error::Invalid(format!( + "roleAssignment for role '{}' references unknown identity '{}'", + role_assignment.role, identity_name + ))); + } + } + } + } + Ok(()) +} + +pub(crate) async fn read_local_rules_file( + file_path: &Path, + target: LocalRuleTarget, +) -> Result { + let mut last_error = String::new(); + for attempt in 1..=LOCAL_RULE_FILE_PARSE_RETRY_COUNT { + match misc_helpers::json_read_from_file::(file_path) { + Ok(local_rules) => { + validate_local_rules_file(&local_rules)?; + + let message = format!( + "Successfully parsed {} local rules file {} with id '{}' on attempt {}.", + target.display_name(), + file_path.display(), + local_rules.id.as_deref().unwrap_or("unknown"), + attempt + ); + write_local_rules_event( + LoggerLevel::Info, + target, + "LocalRulesFileParseSuccess", + message, + ); + + return Ok(local_rules); + } + Err(e) => { + last_error = e.to_string(); + // write trace level log for each parse failure attempt, + // and return error with last_error if it is the final attempt, + // it is to avoid flooding logs with transient parse failures but still have visibility when it finally fails after retries. + logger::write(format!( + "Failed to parse {} local rules file {} on attempt {}: {}", + target.display_name(), + file_path.display(), + attempt, + last_error + )); + if attempt < LOCAL_RULE_FILE_PARSE_RETRY_COUNT { + tokio::time::sleep(LOCAL_RULE_FILE_PARSE_RETRY_DELAY).await; + } + } + } + } + + Err(Error::Invalid(format!( + "Failed to parse local rules file '{}' after {} attempts: {}", + file_path.display(), + LOCAL_RULE_FILE_PARSE_RETRY_COUNT, + last_error + ))) +} + +/// Resolve the effective authorization rules by considering both remote rules and local file rules based on the descriptor and current state. +/// Return the effective rules and whether the rules have changed compared to the previous effective rules in the tracker. +pub(crate) async fn resolve_effective_rules( + rules_dir: &Path, + remote_rules: Option, + target: LocalRuleTarget, + tracker: &mut LocalRuleMonitorState, + remote_rule_changed: bool, +) -> (Option, bool) { + let descriptor = parse_rule_id_descriptor(remote_rules.as_ref().map(|item| item.id.as_str())); + let normalized_remote_rules = normalize_authorization_item(remote_rules, &descriptor); + let use_local_file_rules_changed = + tracker.use_local_file_rules != descriptor.use_local_file_rules; + let previous_parse_failed = tracker.parse_failed; + + if use_local_file_rules_changed { + let action = if descriptor.use_local_file_rules { + "enabled" + } else { + "disabled" + }; + write_local_rules_event( + LoggerLevel::Warn, + target, + "LocalFileRulesStateChanged", + format!("{} local file rules {action}.", target.display_name()), + ); + } + + tracker.use_local_file_rules = descriptor.use_local_file_rules; + if !descriptor.use_local_file_rules { + // not using local file rules, return normalized remote rules directly. + // also reset the tracker state as we are not monitoring local file changes in this case. + tracker.file_state = LocalRuleFileState::Unknown; + tracker.parse_failed = false; + tracker.effective_rules = normalized_remote_rules.clone(); + return ( + normalized_remote_rules, + use_local_file_rules_changed || previous_parse_failed, + ); + } + + let local_rules_file = rules_dir.join(target.file_name()); + let current_file_state = get_local_rule_file_state(&local_rules_file); + let file_state_changed = tracker.file_state != current_file_state; + + if file_state_changed { + match (&tracker.file_state, ¤t_file_state) { + (_, LocalRuleFileState::Present(_)) + if matches!( + tracker.file_state, + LocalRuleFileState::Unknown | LocalRuleFileState::Missing + ) => + { + write_local_rules_event( + LoggerLevel::Warn, + target, + "LocalFileRulesStateChanged", + format!( + "{} local rules file found at {}.", + target.display_name(), + local_rules_file.display() + ), + ); + } + (LocalRuleFileState::Present(_), LocalRuleFileState::Present(_)) => { + write_local_rules_event( + LoggerLevel::Warn, + target, + "LocalFileRulesStateChanged", + format!( + "{} local rules file changed at {}.", + target.display_name(), + local_rules_file.display() + ), + ); + } + (_, LocalRuleFileState::Missing) => { + write_local_rules_event( + LoggerLevel::Warn, + target, + "LocalFileRulesStateChanged", + format!( + "{} local rules file deleted or not found at {}.", + target.display_name(), + local_rules_file.display() + ), + ); + } + (_, LocalRuleFileState::Error(error)) => { + write_local_rules_event( + LoggerLevel::Error, + target, + "LocalFileRulesStateChanged", + format!( + "{} local rules file metadata read failed at {}: {}", + target.display_name(), + local_rules_file.display(), + error + ), + ); + } + _ => {} + } + } + + tracker.file_state = current_file_state.clone(); + let needs_refresh = remote_rule_changed + || use_local_file_rules_changed + || file_state_changed + || previous_parse_failed; + + if !needs_refresh { + return (tracker.effective_rules.clone(), false); + } + + // SPEC: No corresponding local rules file found - treat it as no local advanced configuration set yet, which means root-only to WS and all to IMDS. + // This is to support new VM just provisioned and the customer payload has not downloaded & applied within the VM yet. + if matches!(current_file_state, LocalRuleFileState::Missing) { + tracker.parse_failed = false; + tracker.effective_rules = normalized_remote_rules.clone(); + return ( + normalized_remote_rules, + use_local_file_rules_changed || file_state_changed || previous_parse_failed, + ); + } + + if let LocalRuleFileState::Error(error) = ¤t_file_state { + let message = format!( + "Failed to read {} local rules file metadata {}: {}. Treat it as parse failure and apply fail-closed rules.", + target.display_name(), + local_rules_file.display(), + error + ); + write_local_rules_event( + LoggerLevel::Error, + target, + "LocalRulesFileMetadataReadFailed", + message, + ); + + let fail_closed_rules = build_fail_closed_rules(normalized_remote_rules, &descriptor); + tracker.parse_failed = true; + tracker.effective_rules = fail_closed_rules.clone(); + return ( + fail_closed_rules, + !previous_parse_failed || file_state_changed, + ); + } + + match read_local_rules_file(&local_rules_file, target).await { + Ok(local_rules) => { + let effective_rules = + merge_authorization_item(normalized_remote_rules, local_rules, &descriptor); + tracker.parse_failed = false; + tracker.effective_rules = effective_rules.clone(); + ( + effective_rules, + use_local_file_rules_changed || file_state_changed || previous_parse_failed, + ) + } + Err(e) => { + let message = format!( + "Failed to parse {} local rules file {}: {}. Apply fail-closed rules.", + target.display_name(), + local_rules_file.display(), + e + ); + write_local_rules_event( + LoggerLevel::Error, + target, + "LocalRulesFileParseFailed", + message, + ); + + let fail_closed_rules = build_fail_closed_rules(normalized_remote_rules, &descriptor); + tracker.parse_failed = true; + tracker.effective_rules = fail_closed_rules.clone(); + ( + fail_closed_rules, + !previous_parse_failed || file_state_changed, + ) + } + } +} + +pub(crate) fn write_local_rules_event( + level: LoggerLevel, + target: LocalRuleTarget, + method_name: &str, + message: String, +) { + event_logger::write_event( + level, + message, + method_name, + target.display_name(), + logger::AGENT_LOGGER_KEY, + ); +} + +#[cfg(test)] +mod tests { + use super::{ + get_rules_dir_from_key_dir, merge_authorization_item, parse_rule_id_descriptor, + prefix_local_rule_names, read_local_rules_file, resolve_effective_rules, + validate_access_control_rules, validate_identities, validate_privileges, + validate_role_assignments, validate_roles, LocalAuthorizationRulesFile, + LocalRuleMonitorState, LocalRuleTarget, RuleIdDescriptor, LOCAL_RULE_NAME_PREFIX, + }; + use crate::key_keeper::key::{ + AccessControlRules, AuthorizationItem, Identity, Privilege, Role, RoleAssignment, + }; + use base64::{engine::general_purpose, Engine as _}; + use proxy_agent_shared::misc_helpers; + use std::collections::HashSet; + use std::env; + use std::fs; + use std::path::{Path, PathBuf}; + + fn create_temp_rules_dir(test_name: &str) -> PathBuf { + let mut dir = env::temp_dir(); + dir.push(format!("local_rules_{test_name}")); + _ = fs::remove_dir_all(&dir); + fs::create_dir_all(&dir).unwrap(); + dir + } + + fn write_file(path: &Path, content: &str) { + fs::write(path, content).unwrap(); + } + + fn encoded_rule_id(logical_id: &str) -> String { + general_purpose::STANDARD.encode(format!( + r#"{{"id":"{logical_id}","useLocalFileRules":true}}"# + )) + } + + fn sample_privilege(name: &str) -> Privilege { + Privilege { + name: name.to_string(), + path: format!("/{name}"), + queryParameters: None, + } + } + + fn sample_role(name: &str, privileges: Vec<&str>) -> Role { + Role { + name: name.to_string(), + privileges: privileges.into_iter().map(str::to_string).collect(), + } + } + + fn sample_identity(name: &str) -> Identity { + Identity { + name: name.to_string(), + userName: Some(name.to_string()), + groupName: None, + exePath: None, + processName: None, + } + } + + fn sample_role_assignment(role: &str, identities: Vec<&str>) -> RoleAssignment { + RoleAssignment { + role: role.to_string(), + identities: identities.into_iter().map(str::to_string).collect(), + } + } + + fn sample_access_control_rules() -> AccessControlRules { + AccessControlRules { + privileges: Some(vec![sample_privilege("p1")]), + roles: Some(vec![sample_role("r1", vec!["p1"])]), + identities: Some(vec![sample_identity("i1")]), + roleAssignments: Some(vec![sample_role_assignment("r1", vec!["i1"])]), + } + } + + fn string_set(values: &[&str]) -> HashSet { + values.iter().map(|value| (*value).to_string()).collect() + } + + fn assert_validation_ok(result: crate::common::result::Result) -> T { + result.unwrap() + } + + fn assert_validation_err(result: crate::common::result::Result) { + assert!(result.is_err()); + } + + fn assert_names_match(actual: HashSet, expected: &[&str]) { + assert_eq!(actual, string_set(expected)); + } + + fn write_wireserver_rules_file(rules_dir: &Path, content: &str) -> PathBuf { + let rules_file = rules_dir.join(LocalRuleTarget::WireServer.file_name()); + write_file(&rules_file, content); + rules_file + } + + async fn run_read_wireserver_rules_file_case( + test_name: &str, + content: &str, + ) -> crate::common::result::Result { + ensure_test_config_in_exe_dir(); + let rules_dir = create_temp_rules_dir(test_name); + let rules_file = write_wireserver_rules_file(&rules_dir, content); + let result = read_local_rules_file(&rules_file, LocalRuleTarget::WireServer).await; + _ = fs::remove_dir_all(&rules_dir); + result + } + + async fn run_resolve_wireserver_case( + test_name: &str, + local_file_content: Option<&str>, + remote_default_access: &str, + remote_rules: Option, + remote_rule_changed: bool, + ) -> (Option, bool, LocalRuleMonitorState) { + ensure_test_config_in_exe_dir(); + let rules_dir = create_temp_rules_dir(test_name); + if let Some(content) = local_file_content { + _ = write_wireserver_rules_file(&rules_dir, content); + } + + let remote_rules = Some(AuthorizationItem { + defaultAccess: remote_default_access.to_string(), + mode: "enforce".to_string(), + id: encoded_rule_id("decoded-id"), + rules: remote_rules, + }); + let mut tracker = LocalRuleMonitorState::default(); + + let result = resolve_effective_rules( + &rules_dir, + remote_rules, + LocalRuleTarget::WireServer, + &mut tracker, + remote_rule_changed, + ) + .await; + + _ = fs::remove_dir_all(&rules_dir); + (result.0, result.1, tracker) + } + + fn ensure_test_config_in_exe_dir() { + let mut config_target = misc_helpers::get_current_exe_dir(); + #[cfg(windows)] + config_target.push("GuestProxyAgent.json"); + #[cfg(not(windows))] + config_target.push("proxy-agent.json"); + + if config_target.exists() { + return; + } + + let mut config_source = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + config_source.push("config"); + #[cfg(windows)] + config_source.push("GuestProxyAgent.windows.json"); + #[cfg(not(windows))] + config_source.push("GuestProxyAgent.linux.json"); + + let config_content = fs::read_to_string(config_source).unwrap(); + fs::write(config_target, config_content).unwrap(); + } + + #[test] + fn parse_rule_id_descriptor_test() { + let legacy = parse_rule_id_descriptor(Some("legacy-id")); + assert_eq!(legacy.logical_id, "legacy-id"); + assert!(!legacy.use_local_file_rules); + + let encoded = general_purpose::STANDARD + .encode(r#"{"id":"sig-resource-id","useLocalFileRules":true}"#); + let descriptor = parse_rule_id_descriptor(Some(&encoded)); + assert_eq!(descriptor.logical_id, "sig-resource-id"); + assert!(descriptor.use_local_file_rules); + } + + #[test] + fn get_rules_dir_from_key_dir_test() { + #[cfg(windows)] + assert_eq!( + get_rules_dir_from_key_dir(Path::new("C:\\WindowsAzure\\ProxyAgent\\Keys")), + PathBuf::from("C:\\WindowsAzure\\ProxyAgent\\Rules") + ); + + #[cfg(not(windows))] + assert_eq!( + get_rules_dir_from_key_dir(Path::new("/var/lib/azure-proxy-agent/keys")), + PathBuf::from("/var/lib/azure-proxy-agent/rules") + ); + } + + #[test] + fn merge_authorization_item_test() { + let remote_rules = AuthorizationItem { + defaultAccess: "deny".to_string(), + mode: "enforce".to_string(), + id: "remote-id".to_string(), + rules: Some(AccessControlRules { + privileges: Some(vec![Privilege { + name: "remote-privilege".to_string(), + path: "/remote".to_string(), + queryParameters: None, + }]), + roles: Some(vec![Role { + name: "remote-role".to_string(), + privileges: vec!["remote-privilege".to_string()], + }]), + identities: Some(vec![Identity { + name: "remote-identity".to_string(), + userName: Some("root".to_string()), + groupName: None, + exePath: None, + processName: None, + }]), + roleAssignments: Some(vec![RoleAssignment { + role: "remote-role".to_string(), + identities: vec!["remote-identity".to_string()], + }]), + }), + }; + let descriptor = RuleIdDescriptor { + logical_id: "decoded-id".to_string(), + use_local_file_rules: true, + }; + let local_rules = LocalAuthorizationRulesFile { + id: Some("local-id".to_string()), + defaultAccess: Some("allow".to_string()), + rules: Some(AccessControlRules { + privileges: Some(vec![Privilege { + name: "local-privilege".to_string(), + path: "/local".to_string(), + queryParameters: None, + }]), + roles: Some(vec![Role { + name: "local-role".to_string(), + privileges: vec!["local-privilege".to_string()], + }]), + identities: Some(vec![Identity { + name: "local-identity".to_string(), + userName: Some("agent".to_string()), + groupName: None, + exePath: None, + processName: None, + }]), + roleAssignments: Some(vec![RoleAssignment { + role: "local-role".to_string(), + identities: vec!["local-identity".to_string()], + }]), + }), + }; + + let merged = + merge_authorization_item(Some(remote_rules), local_rules, &descriptor).unwrap(); + assert_eq!(merged.id, "decoded-id-useLocalFileRules-true_local-id"); + assert_eq!(merged.defaultAccess, "allow"); + let merged_rules = merged.rules.unwrap(); + + let privileges = merged_rules.privileges.unwrap(); + assert_eq!(privileges.len(), 2); + assert_eq!(privileges[0].name, "remote-privilege"); + assert_eq!(privileges[1].name, "LocalFileRules_local-privilege"); + + let roles = merged_rules.roles.unwrap(); + assert_eq!(roles.len(), 2); + assert_eq!(roles[0].name, "remote-role"); + assert_eq!(roles[1].name, "LocalFileRules_local-role"); + assert_eq!(roles[1].privileges[0], "LocalFileRules_local-privilege"); + + let identities = merged_rules.identities.unwrap(); + assert_eq!(identities.len(), 2); + assert_eq!(identities[0].name, "remote-identity"); + assert_eq!(identities[1].name, "LocalFileRules_local-identity"); + + let role_assignments = merged_rules.roleAssignments.unwrap(); + assert_eq!(role_assignments.len(), 2); + assert_eq!(role_assignments[0].role, "remote-role"); + assert_eq!(role_assignments[1].role, "LocalFileRules_local-role"); + assert_eq!( + role_assignments[1].identities[0], + "LocalFileRules_local-identity" + ); + } + + #[test] + fn prefix_local_rule_names_test() { + let rules = sample_access_control_rules(); + let prefixed = prefix_local_rule_names(rules); + + let privileges = prefixed.privileges.unwrap(); + assert_eq!(privileges.len(), 1); + assert_eq!(privileges[0].name, format!("{LOCAL_RULE_NAME_PREFIX}p1")); + assert_eq!(privileges[0].path, "/p1"); // path should not be prefixed + + let roles = prefixed.roles.unwrap(); + assert_eq!(roles[0].name, format!("{LOCAL_RULE_NAME_PREFIX}r1")); + assert_eq!( + roles[0].privileges[0], + format!("{LOCAL_RULE_NAME_PREFIX}p1") + ); + + let identities = prefixed.identities.unwrap(); + assert_eq!(identities[0].name, format!("{LOCAL_RULE_NAME_PREFIX}i1")); + assert_eq!(identities[0].userName.as_deref(), Some("i1")); // selector values should not be prefixed + + let role_assignments = prefixed.roleAssignments.unwrap(); + assert_eq!( + role_assignments[0].role, + format!("{LOCAL_RULE_NAME_PREFIX}r1") + ); + assert_eq!( + role_assignments[0].identities[0], + format!("{LOCAL_RULE_NAME_PREFIX}i1") + ); + } + + #[test] + fn validate_access_control_rules_success_test() { + let rules = sample_access_control_rules(); + assert_validation_ok(validate_access_control_rules(&rules)); + } + + #[test] + fn validate_access_control_rules_invalid_role_assignment_test() { + let mut rules = sample_access_control_rules(); + rules.roleAssignments = Some(vec![sample_role_assignment("missing-role", vec!["i1"])]); + + assert_validation_err(validate_access_control_rules(&rules)); + } + + #[test] + fn validate_privileges_success_test() { + let privileges = vec![sample_privilege("p1"), sample_privilege("p2")]; + let privilege_names = assert_validation_ok(validate_privileges(Some(&privileges))); + assert_names_match(privilege_names, &["p1", "p2"]); + } + + #[test] + fn validate_privileges_duplicate_name_test() { + let privileges = vec![sample_privilege("p1"), sample_privilege("p1")]; + assert_validation_err(validate_privileges(Some(&privileges))); + } + + #[test] + fn validate_roles_success_test() { + let privilege_names = string_set(&["p1", "p2"]); + let roles = vec![sample_role("r1", vec!["p1", "p2"])]; + let role_names = assert_validation_ok(validate_roles(Some(&roles), &privilege_names)); + assert_names_match(role_names, &["r1"]); + } + + #[test] + fn validate_roles_unknown_privilege_test() { + let privilege_names = string_set(&["p1"]); + let roles = vec![sample_role("r1", vec!["missing-privilege"])]; + assert_validation_err(validate_roles(Some(&roles), &privilege_names)); + } + + #[test] + fn validate_identities_success_test() { + let identities = vec![sample_identity("i1"), sample_identity("i2")]; + let identity_names = assert_validation_ok(validate_identities(Some(&identities))); + assert_names_match(identity_names, &["i1", "i2"]); + } + + #[test] + fn validate_identities_missing_selector_test() { + let identities = vec![Identity { + name: "i1".to_string(), + userName: None, + groupName: None, + exePath: None, + processName: None, + }]; + assert_validation_err(validate_identities(Some(&identities))); + } + + #[test] + fn validate_role_assignments_success_test() { + let role_names = string_set(&["r1"]); + let identity_names = string_set(&["i1", "i2"]); + let role_assignments = vec![sample_role_assignment("r1", vec!["i1", "i2"])]; + + assert_validation_ok(validate_role_assignments( + Some(&role_assignments), + &role_names, + &identity_names, + )); + } + + #[test] + fn validate_role_assignments_unknown_identity_test() { + let role_names = string_set(&["r1"]); + let identity_names = string_set(&["i1"]); + let role_assignments = vec![sample_role_assignment("r1", vec!["missing-identity"])]; + + assert_validation_err(validate_role_assignments( + Some(&role_assignments), + &role_names, + &identity_names, + )); + } + + #[tokio::test] + async fn read_local_rules_file_success_test() { + let parsed = run_read_wireserver_rules_file_case( + "read_local_rules_file_success_test", + r#"{ + "id": "local-read-id", + "defaultAccess": "allow", + "rules": { + "privileges": [ + { "name": "p1", "path": "/a" } + ] + } + }"#, + ) + .await + .unwrap(); + assert_eq!(parsed.id.as_deref(), Some("local-read-id")); + assert_eq!(parsed.defaultAccess.as_deref(), Some("allow")); + assert_eq!(parsed.rules.unwrap().privileges.unwrap().len(), 1); + } + + #[tokio::test] + async fn read_local_rules_file_invalid_json_test() { + let result = run_read_wireserver_rules_file_case( + "read_local_rules_file_invalid_json_test", + "{ invalid json ", + ) + .await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn read_local_rules_file_invalid_default_access_test() { + let result = run_read_wireserver_rules_file_case( + "read_local_rules_file_invalid_default_access_test", + r#"{ + "defaultAccess": "maybe", + "rules": { + "privileges": [ + { "name": "p1", "path": "/a" } + ] + } + }"#, + ) + .await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn read_local_rules_file_invalid_role_reference_test() { + let result = run_read_wireserver_rules_file_case( + "read_local_rules_file_invalid_role_reference_test", + r#"{ + "rules": { + "privileges": [ + { "name": "p1", "path": "/a" } + ], + "roles": [ + { "name": "r1", "privileges": ["missing-privilege"] } + ] + } + }"#, + ) + .await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn resolve_effective_rules_local_file_present_merges_test() { + let (effective_rules, changed, tracker) = run_resolve_wireserver_case( + "resolve_effective_rules_local_file_present_merges_test", + Some( + r#"{ + "id": "local-effective-id", + "defaultAccess": "allow", + "rules": { + "privileges": [ + { "name": "local-priv", "path": "/local" } + ] + } + }"#, + ), + "deny", + None, + true, + ) + .await; + + assert!(changed); + let effective_rules = effective_rules.unwrap(); + assert_eq!( + effective_rules.id, + "decoded-id-useLocalFileRules-true_local-effective-id" + ); + assert_eq!(effective_rules.defaultAccess, "allow"); + assert!(effective_rules.rules.is_some()); + assert!(!tracker.parse_failed); + } + + #[tokio::test] + async fn resolve_effective_rules_local_file_missing_returns_none_test() { + let (effective_rules, changed, tracker) = run_resolve_wireserver_case( + "resolve_effective_rules_local_file_missing_returns_none_test", + None, + "deny", + None, + false, + ) + .await; + + assert!(changed); + assert!(effective_rules.is_some()); + assert!(!tracker.parse_failed); + } + + #[tokio::test] + async fn resolve_effective_rules_invalid_local_file_fail_closed_test() { + let (effective_rules, changed, tracker) = run_resolve_wireserver_case( + "resolve_effective_rules_invalid_local_file_fail_closed_test", + Some("{ invalid json "), + "allow", + Some(AccessControlRules { + privileges: Some(vec![Privilege { + name: "remote-privilege".to_string(), + path: "/remote".to_string(), + queryParameters: None, + }]), + roles: None, + identities: None, + roleAssignments: None, + }), + false, + ) + .await; + + assert!(changed); + let effective_rules = effective_rules.unwrap(); + assert_eq!(effective_rules.id, "decoded-id-useLocalFileRules-true"); + assert_eq!(effective_rules.defaultAccess, "deny"); + assert!(effective_rules.rules.is_none()); + assert!(tracker.parse_failed); + } +} From 3d4c0e0150a8bec87c2833a9036bbeed7c10148f Mon Sep 17 00:00:00 2001 From: Zhidong Peng Date: Mon, 27 Apr 2026 14:05:24 -0700 Subject: [PATCH 7/8] cmdline to take the first 4 arguments (#340) * cmdline to take the first 4 arguments * fix in common code path --- .github/actions/spelling/expect.txt | 1 + proxy_agent/src/proxy.rs | 5 +- proxy_agent/src/proxy/windows.rs | 72 ++++++++++++++++++++++++++++- 3 files changed, 76 insertions(+), 2 deletions(-) diff --git a/.github/actions/spelling/expect.txt b/.github/actions/spelling/expect.txt index 6e010cc7..b82055f4 100644 --- a/.github/actions/spelling/expect.txt +++ b/.github/actions/spelling/expect.txt @@ -207,6 +207,7 @@ oneshot opencode opensource parseable +peekable PERCPU pgpkey pgrep diff --git a/proxy_agent/src/proxy.rs b/proxy_agent/src/proxy.rs index 9278e634..b0be815c 100644 --- a/proxy_agent/src/proxy.rs +++ b/proxy_agent/src/proxy.rs @@ -77,6 +77,7 @@ pub struct User { const UNDEFINED: &str = "undefined"; const EMPTY: &str = "empty"; +const MAX_CMD_ARGS: usize = 4; async fn get_user( logon_id: u64, @@ -112,10 +113,12 @@ fn get_process_info(process_id: u32) -> (PathBuf, String) { let cmdline_path = format!("/proc/{}/cmdline", process_id); let process_cmd_line = match std::fs::read(&cmdline_path) { Ok(bytes) => { - // cmdline is null-separated, convert to space-separated string + // cmdline is null-separated; take only the first few arguments + // to avoid capturing credentials in later args bytes .split(|&b| b == 0) .filter(|s| !s.is_empty()) + .take(MAX_CMD_ARGS) .map(|s| String::from_utf8_lossy(s).into_owned()) .collect::>() .join(" ") diff --git a/proxy_agent/src/proxy/windows.rs b/proxy_agent/src/proxy/windows.rs index acc770eb..3667c635 100644 --- a/proxy_agent/src/proxy/windows.rs +++ b/proxy_agent/src/proxy/windows.rs @@ -292,7 +292,46 @@ pub fn get_process_cmd(handler: isize) -> Result { std::slice::from_raw_parts(cmd_buffer.Buffer, (cmd_buffer.Length / 2) as usize) }); - Ok(cmd) + // Only keep the first few arguments to avoid capturing credentials + // that may appear in later command-line arguments + Ok(truncate_cmd_args(&cmd, super::MAX_CMD_ARGS)) +} + +/// Truncate a command line string to at most `max_args` arguments, +/// respecting double-quoted strings that may contain whitespace. +fn truncate_cmd_args(cmd: &str, max_args: usize) -> String { + let mut args = Vec::new(); + let mut chars = cmd.chars().peekable(); + + while args.len() < max_args { + // Skip whitespace between arguments + while chars.peek().is_some_and(|c| c.is_whitespace()) { + chars.next(); + } + if chars.peek().is_none() { + break; + } + + let mut arg = String::new(); + if chars.peek() == Some(&'"') { + // Quoted argument — consume until closing quote + arg.push(chars.next().unwrap()); // opening " + for c in chars.by_ref() { + arg.push(c); + if c == '"' { + break; + } + } + } else { + // Unquoted argument — consume until whitespace + while chars.peek().is_some_and(|c| !c.is_whitespace()) { + arg.push(chars.next().unwrap()); + } + } + args.push(arg); + } + + args.join(" ") } #[allow(dead_code)] @@ -379,4 +418,35 @@ mod tests { ); assert!(!cmd.is_empty(), "process cmd should not be empty"); } + + #[test] + fn truncate_cmd_args_tests() { + // no arguments + let result = super::truncate_cmd_args("", 4); + assert_eq!(result, "", "empty input should return empty string"); + + // fewer than max args + let result = super::truncate_cmd_args("app.exe arg1", 4); + assert_eq!( + result, "app.exe arg1", + "should return all args when fewer than max" + ); + + // exactly max args + let result = super::truncate_cmd_args("app.exe arg1 arg2 arg3 --secret=password", 4); + assert_eq!( + result, "app.exe arg1 arg2 arg3", + "should truncate to first 4 args" + ); + + // more than max args and quoted arg with spaces + let result = super::truncate_cmd_args( + r#""C:\Program Files\app.exe" arg1 "arg with spaces" arg3 --secret=password"#, + 4, + ); + assert_eq!( + result, r#""C:\Program Files\app.exe" arg1 "arg with spaces" arg3"#, + "should treat quoted strings with whitespace as single args" + ); + } } From 579e9e276fa9bc268c7127e72d5ed7074a798793 Mon Sep 17 00:00:00 2001 From: "Zhidong Peng (HE/HIM)" Date: Mon, 27 Apr 2026 14:17:06 -0700 Subject: [PATCH 8/8] Update version to 1.0.43 --- Cargo.lock | 8 ++++---- proxy_agent/Cargo.toml | 2 +- proxy_agent_extension/Cargo.toml | 2 +- proxy_agent_setup/Cargo.toml | 2 +- proxy_agent_shared/Cargo.toml | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 03c0dc09..525ec413 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,7 +4,7 @@ version = 4 [[package]] name = "ProxyAgentExt" -version = "1.0.42" +version = "1.0.43" dependencies = [ "clap", "ctor", @@ -172,7 +172,7 @@ dependencies = [ [[package]] name = "azure-proxy-agent" -version = "1.0.42" +version = "1.0.43" dependencies = [ "aya", "base64", @@ -932,7 +932,7 @@ dependencies = [ [[package]] name = "proxy_agent_setup" -version = "1.0.42" +version = "1.0.43" dependencies = [ "clap", "proxy_agent_shared", @@ -944,7 +944,7 @@ dependencies = [ [[package]] name = "proxy_agent_shared" -version = "1.0.42" +version = "1.0.43" dependencies = [ "chrono", "concurrent-queue", diff --git a/proxy_agent/Cargo.toml b/proxy_agent/Cargo.toml index 90bf53a3..c8e9dc47 100644 --- a/proxy_agent/Cargo.toml +++ b/proxy_agent/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "azure-proxy-agent" -version = "1.0.42" # always 3-number version +version = "1.0.43" # always 3-number version edition = "2021" build = "build.rs" readme = "README.md" diff --git a/proxy_agent_extension/Cargo.toml b/proxy_agent_extension/Cargo.toml index 3bccea47..c9194ac1 100644 --- a/proxy_agent_extension/Cargo.toml +++ b/proxy_agent_extension/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ProxyAgentExt" -version = "1.0.42" # always 3-number version +version = "1.0.43" # always 3-number version edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/proxy_agent_setup/Cargo.toml b/proxy_agent_setup/Cargo.toml index 4b81903a..acc434b0 100644 --- a/proxy_agent_setup/Cargo.toml +++ b/proxy_agent_setup/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "proxy_agent_setup" -version = "1.0.42" +version = "1.0.43" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/proxy_agent_shared/Cargo.toml b/proxy_agent_shared/Cargo.toml index 1b185ffb..28106c9d 100644 --- a/proxy_agent_shared/Cargo.toml +++ b/proxy_agent_shared/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "proxy_agent_shared" -version = "1.0.42" +version = "1.0.43" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html