diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 7dce6f8f..22dfbf71 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -75,7 +75,6 @@ jobs: build-essential \ linux-tools-$(uname -r) \ linux-tools-common \ - linux-tools-generic \ rpm \ musl-tools \ diff --git a/.github/workflows/reusable-build.yml b/.github/workflows/reusable-build.yml index 750b320d..f5fe69b1 100644 --- a/.github/workflows/reusable-build.yml +++ b/.github/workflows/reusable-build.yml @@ -76,7 +76,7 @@ jobs: - name: Run Code Coverage for proxy_agent_shared run: | - cargo llvm-cov --target x86_64-pc-windows-msvc --manifest-path ./proxy_agent_shared/Cargo.toml --output-path ./out/proxy_agent_shared_codeCov.txt --release + cargo llvm-cov --target x86_64-pc-windows-msvc --manifest-path ./proxy_agent_shared/Cargo.toml --output-path ./out/proxy_agent_shared_codeCov.txt --release -- --test-threads=1 type ./out/proxy_agent_shared_codeCov.txt - name: Parse Code Coverage for proxy_agent_shared @@ -292,7 +292,6 @@ jobs: build-essential \ linux-tools-$(uname -r) \ linux-tools-common \ - linux-tools-generic \ rpm \ musl-tools \ libssl-dev \ @@ -325,7 +324,7 @@ jobs: - name: Run Code Coverage for proxy_agent_shared run: | - cargo llvm-cov --target x86_64-unknown-linux-musl --manifest-path ./proxy_agent_shared/Cargo.toml --output-path ./out/proxy_agent_shared_codeCov.txt --release + cargo llvm-cov --target x86_64-unknown-linux-musl --manifest-path ./proxy_agent_shared/Cargo.toml --output-path ./out/proxy_agent_shared_codeCov.txt --release -- --test-threads=1 cat ./out/proxy_agent_shared_codeCov.txt - name: Parse Code Coverage for proxy_agent_shared @@ -464,7 +463,6 @@ jobs: build-essential \ linux-tools-$(uname -r) \ linux-tools-common \ - linux-tools-generic \ rpm \ musl-tools \ gcc-aarch64-linux-gnu \ diff --git a/.vscode/settings.json b/.vscode/settings.json index 9bf0838b..3e71cf20 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -32,5 +32,8 @@ ], "vscode-nmake-tools.workspaceBuildDirectories": [ "." - ] + ], + "chat.tools.terminal.autoApprove": { + "./build-linux.sh": true + } } \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 1b252ea4..39438e11 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,7 +4,7 @@ version = 4 [[package]] name = "ProxyAgentExt" -version = "1.0.39" +version = "1.0.40" dependencies = [ "clap", "ctor", @@ -172,12 +172,11 @@ dependencies = [ [[package]] name = "azure-proxy-agent" -version = "1.0.39" +version = "1.0.40" dependencies = [ "aya", "bitflags", "clap", - "ctor", "http", "http-body-util", "hyper", @@ -192,6 +191,7 @@ dependencies = [ "serde-xml-rs", "serde_derive", "serde_json", + "socket2", "static_vcruntime", "sysinfo", "thiserror", @@ -243,9 +243,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.1" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" [[package]] name = "cc" @@ -403,9 +403,9 @@ checksum = "c426d2ba3e525b39c1f0a9ba41b9fe61878dee11fa4e4a76b6ab440f46c5db5d" [[package]] name = "deranged" -version = "0.3.11" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587" dependencies = [ "powerfmt", ] @@ -791,9 +791,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" [[package]] name = "num-traits" @@ -925,7 +925,7 @@ dependencies = [ [[package]] name = "proxy_agent_setup" -version = "1.0.39" +version = "1.0.40" dependencies = [ "clap", "proxy_agent_shared", @@ -937,17 +937,17 @@ dependencies = [ [[package]] name = "proxy_agent_shared" -version = "1.0.39" +version = "1.0.40" dependencies = [ "chrono", "concurrent-queue", - "ctor", "hex", "http", "http-body-util", "hyper", "hyper-util", "itertools", + "libc", "log", "once_cell", "openssl", @@ -1092,10 +1092,11 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" [[package]] name = "serde" -version = "1.0.210" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ + "serde_core", "serde_derive", ] @@ -1111,11 +1112,20 @@ dependencies = [ "xml-rs", ] +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -1241,30 +1251,30 @@ dependencies = [ [[package]] name = "time" -version = "0.3.36" +version = "0.3.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" dependencies = [ "deranged", "itoa", "num-conv", "powerfmt", - "serde", + "serde_core", "time-core", "time-macros", ] [[package]] name = "time-core" -version = "0.1.2" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" [[package]] name = "time-macros" -version = "0.2.18" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" +checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" dependencies = [ "num-conv", "time-core", diff --git a/build.cmd b/build.cmd index 57e2a790..72d0640b 100644 --- a/build.cmd +++ b/build.cmd @@ -134,12 +134,10 @@ if "%Target%"=="arm64" ( xcopy /Y /S /C /Q %out_dir% %root_path%proxy_agent_shared\target\%Configuration%\ echo ======= run rust proxy_agent_shared tests - echo call cargo test --all-features %release_flag% --manifest-path %cargo_toml% --target-dir %out_path% --target %build_target% -- --test-threads=1 --nocapture - call cargo test --all-features %release_flag% --manifest-path %cargo_toml% --target-dir %out_path% --target %build_target% -- --test-threads=1 --nocapture - if %ERRORLEVEL% NEQ 0 ( - echo call cargo test proxy_agent_shared with exit-code: %errorlevel% - exit /b %errorlevel% - ) + REM %ERRORLEVEL% inside a (...) block is expanded when the entire block is parsed, not when each line run + REM use exit /b 1 to propagate error codes inside a (...) block + echo call cargo test --all-features %release_flag% --manifest-path %cargo_toml% --target-dir %out_path% --target %build_target% -- --test-threads=1 --nocapture ^|^| exit /b 1 + call cargo test --all-features %release_flag% --manifest-path %cargo_toml% --target-dir %out_path% --target %build_target% -- --test-threads=1 --nocapture || exit /b 1 ) echo ======= copy config file for windows platform @@ -167,12 +165,9 @@ if "%Target%"=="arm64" ( xcopy /Y /S /C /Q %out_dir% %root_path%proxy_agent\target\%Configuration%\ echo ======= run rust proxy_agent tests - echo call cargo test --all-features %release_flag% --manifest-path %cargo_toml% --target-dir %out_path% --target %build_target% -- --test-threads=1 --nocapture - call cargo test --all-features %release_flag% --manifest-path %cargo_toml% --target-dir %out_path% --target %build_target% -- --test-threads=1 --nocapture - if %ERRORLEVEL% NEQ 0 ( - echo call cargo test proxy_agent with exit-code: %errorlevel% - exit /b %errorlevel% - ) + REM use exit /b 1 to propagate error codes inside a (...) block + echo call cargo test --all-features %release_flag% --manifest-path %cargo_toml% --target-dir %out_path% --target %build_target% -- --test-threads=1 --nocapture ^|^| exit /b 1 + call cargo test --all-features %release_flag% --manifest-path %cargo_toml% --target-dir %out_path% --target %build_target% -- --test-threads=1 --nocapture || exit /b 1 ) echo ======= build proxy_agent_extension @@ -196,12 +191,9 @@ if "%Target%"=="arm64" ( xcopy /Y /S /C /Q %out_dir% %root_path%proxy_agent_extension\target\%Configuration%\ echo ======= run rust proxy_agent_extension tests - echo call cargo test --all-features %release_flag% --manifest-path %cargo_toml% --target-dir %out_path% --target %build_target% -- --test-threads=1 --nocapture - call cargo test --all-features %release_flag% --manifest-path %cargo_toml% --target-dir %out_path% --target %build_target% -- --test-threads=1 --nocapture - if %ERRORLEVEL% NEQ 0 ( - echo call cargo test proxy_agent_extension with exit-code: %errorlevel% - exit /b %errorlevel% - ) + REM use exit /b 1 to propagate error codes inside a (...) block + echo call cargo test --all-features %release_flag% --manifest-path %cargo_toml% --target-dir %out_path% --target %build_target% -- --test-threads=1 --nocapture ^|^| exit /b 1 + call cargo test --all-features %release_flag% --manifest-path %cargo_toml% --target-dir %out_path% --target %build_target% -- --test-threads=1 --nocapture || exit /b 1 ) echo ======= build proxy_agent_setup diff --git a/proxy_agent/Cargo.toml b/proxy_agent/Cargo.toml index 8f279dda..8da0e07b 100644 --- a/proxy_agent/Cargo.toml +++ b/proxy_agent/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "azure-proxy-agent" -version = "1.0.39" # always 3-number version +version = "1.0.40" # always 3-number version edition = "2021" build = "build.rs" readme = "README.md" @@ -27,7 +27,8 @@ tower = { version = "0.5.2", features = ["full"] } tower-http = { version = "0.6.2", features = ["limit"] } clap = { version = "4.5.17", features =["derive"] } # Command Line Argument Parser thiserror = "1.0.64" -ctor = "0.3.6" # used for test setup and clean up +libc = "0.2.147" +socket2 = "0.5" # Set socket options without tokio/std conversion [dependencies.uuid] version = "1.3.0" @@ -40,7 +41,6 @@ features = [ sysinfo = "0.30.13" # read process information for Linux aya = "0.13.1" # linux ebpf program loader uzers = "0.12.1" # get user name -libc = "0.2.147" # linux call [target.'cfg(not(windows))'.dependencies.nix] version = "0.29.0" diff --git a/proxy_agent/src/common/config.rs b/proxy_agent/src/common/config.rs index a9dae135..021ef065 100644 --- a/proxy_agent/src/common/config.rs +++ b/proxy_agent/src/common/config.rs @@ -122,39 +122,44 @@ impl Default for Config { config_file_full_path = misc_helpers::get_current_exe_dir(); config_file_full_path.push(CONFIG_FILE_NAME); } + Config::from_json_file(config_file_full_path) } } impl Config { + /// Load config from a specific JSON file path pub fn from_json_file(file_path: PathBuf) -> Self { - misc_helpers::json_read_from_file::(&file_path).unwrap_or_else(|_| { - panic!( - "Error in reading Config from Json file: {}", - misc_helpers::path_to_string(&file_path) - ) - }) + let mut config = + misc_helpers::json_read_from_file::(&file_path).unwrap_or_else(|_| { + panic!( + "Error in reading Config from Json file: {}", + misc_helpers::path_to_string(&file_path) + ) + }); + config.resolve_env_variables(); + config + } + + /// Resolve environment variables in path fields once during construction + /// This allows us to keep the rest of the code simple without worrying about env vars, + /// and also ensures that we only pay the cost of resolving env vars once at startup rather than on every access. + fn resolve_env_variables(&mut self) { + self.logFolder = misc_helpers::resolve_env_variables(&self.logFolder); + self.eventFolder = misc_helpers::resolve_env_variables(&self.eventFolder); + self.latchKeyFolder = misc_helpers::resolve_env_variables(&self.latchKeyFolder); } pub fn get_log_folder(&self) -> String { - match misc_helpers::resolve_env_variables(&self.logFolder) { - Ok(val) => val, - Err(_) => self.logFolder.clone(), - } + self.logFolder.clone() } pub fn get_event_folder(&self) -> String { - match misc_helpers::resolve_env_variables(&self.eventFolder) { - Ok(val) => val, - Err(_) => self.eventFolder.clone(), - } + self.eventFolder.clone() } pub fn get_latch_key_folder(&self) -> String { - match misc_helpers::resolve_env_variables(&self.latchKeyFolder) { - Ok(val) => val, - Err(_) => self.latchKeyFolder.clone(), - } + self.latchKeyFolder.clone() } pub fn get_monitor_interval(&self) -> u64 { @@ -231,19 +236,19 @@ mod tests { let config = create_config_file(config_file_path); assert_eq!( - r#"C:\logFolderName"#.to_string(), + r#"C:\logFolderName"#, config.get_log_folder(), "Log Folder mismatch" ); assert_eq!( - r#"C:\eventFolderName"#.to_string(), + r#"C:\eventFolderName"#, config.get_event_folder(), "Event Folder mismatch" ); assert_eq!( - r#"C:\latchKeyFolderName"#.to_string(), + r#"C:\latchKeyFolderName"#, config.get_latch_key_folder(), "Latch Key Folder mismatch" ); @@ -267,7 +272,7 @@ mod tests { ); assert_eq!( - "ebpfProgramName".to_string(), + "ebpfProgramName", config.get_ebpf_program_name(), "get_ebpf_program_name mismatch" ); diff --git a/proxy_agent/src/common/logger.rs b/proxy_agent/src/common/logger.rs index 437a3f69..a83c390b 100644 --- a/proxy_agent/src/common/logger.rs +++ b/proxy_agent/src/common/logger.rs @@ -43,14 +43,14 @@ fn log(log_level: LoggerLevel, message: String) { #[cfg(not(windows))] pub fn write_serial_console_log(message: String) { - use proxy_agent_shared::misc_helpers; + use proxy_agent_shared::{current_info, misc_helpers}; use std::io::Write; let message = format!( "{} {}_{}({}) - {}\n", misc_helpers::get_date_time_string_with_milliseconds(), env!("CARGO_PKG_NAME"), - misc_helpers::get_current_version(), + current_info::get_current_exe_version(), std::process::id(), message ); diff --git a/proxy_agent/src/key_keeper.rs b/proxy_agent/src/key_keeper.rs index fb527e9b..1c090d1d 100644 --- a/proxy_agent/src/key_keeper.rs +++ b/proxy_agent/src/key_keeper.rs @@ -15,11 +15,12 @@ //! use std::time::Duration; //! //! let shared_state = SharedState::start_all(); -//! let base_url = "http://127:0.0.1:8081/"; +//! let host = "127.0.0.1".to_string(); +//! let port = 8081u16; //! let key_dir = PathBuf::from("path"); //! let interval = Duration::from_secs(10); //! let config_start_redirector = false; -//! let key_keeper = key_keeper::KeyKeeper::new(base_url.parse().unwrap(), key_dir, interval, config_start_redirector, &shared_state); +//! let key_keeper = key_keeper::KeyKeeper::new(host, port, key_dir, interval, config_start_redirector, &shared_state); //! tokio::spawn(key_keeper.poll_secure_channel_status()); //! ``` @@ -40,7 +41,6 @@ use crate::shared_state::provision_wrapper::ProvisionSharedState; use crate::shared_state::redirector_wrapper::RedirectorSharedState; use crate::shared_state::{EventThreadsSharedState, SharedState}; use crate::{acl, redirector}; -use hyper::Uri; use proxy_agent_shared::common_state::CommonState; use proxy_agent_shared::logger::LoggerLevel; use proxy_agent_shared::misc_helpers; @@ -64,8 +64,10 @@ const DELAY_START_EVENT_THREADS_IN_MILLISECONDS: u128 = 60000; // 1 minute #[derive(Clone)] pub struct KeyKeeper { - /// base_url: the WireServer endpoint to poll the secure channel status - base_url: Uri, + /// host: the WireServer host to poll the secure channel status + host: String, + /// port: the WireServer port to poll the secure channel status + port: u16, /// key_dir: the folder to save the key details key_dir: PathBuf, /// status_dir: the folder to log the access control rule details @@ -101,14 +103,16 @@ enum WakeReason { impl KeyKeeper { pub fn new( - base_url: Uri, + host: String, + port: u16, key_dir: PathBuf, status_dir: PathBuf, interval: Duration, shared_state: &SharedState, ) -> Self { KeyKeeper { - base_url, + host, + port, key_dir, status_dir, interval, @@ -247,7 +251,7 @@ impl KeyKeeper { .await; started_event_threads = self.handle_event_threads_start(started_event_threads).await; - let status = match key::get_status(&self.base_url).await { + let status = match key::get_status(&self.host, self.port).await { Ok(s) => s, Err(e) => { self.update_status_message(format!("Failed to get key status - {e}"), true) @@ -629,7 +633,7 @@ impl KeyKeeper { /// Acquire key from server, persist it, and attest it /// Returns true if successful, false if should continue to next iteration async fn acquire_key_from_server(&self) -> bool { - let key = match key::acquire_key(&self.base_url).await { + let key = match key::acquire_key(&self.host, self.port).await { Ok(k) => k, Err(e) => { self.update_status_message(format!("Failed to acquire key details: {e:?}"), true) @@ -660,7 +664,7 @@ impl KeyKeeper { } // attest the key - match key::attest_key(&self.base_url, &key).await { + match key::attest_key(&self.host, self.port, &key).await { Ok(()) => { // update in memory if let Err(e) = self.update_key_to_shared_state(key.clone()).await { @@ -1032,7 +1036,7 @@ mod tests { match fs::remove_dir_all(&temp_test_path) { Ok(_) => {} Err(e) => { - print!("Failed to remove_dir_all with error {}.", e); + eprintln!("Failed to remove_dir_all with error {}.", e); } } @@ -1040,11 +1044,9 @@ mod tests { // start wire_server listener let ip = "127.0.0.1"; let port = 8081u16; - tokio::spawn(server_mock::start( - ip.to_string(), - port, - cancellation_token.clone(), - )); + let port = server_mock::start(ip.to_string(), port, cancellation_token.clone()) + .await + .unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; // start with disabled secure channel state @@ -1053,13 +1055,14 @@ mod tests { // start poll_secure_channel_status let cloned_keys_dir = keys_dir.to_path_buf(); let key_keeper = KeyKeeper { - base_url: (format!("http://{}:{}/", ip, port)).parse().unwrap(), + host: ip.to_string(), + port, key_dir: cloned_keys_dir.clone(), status_dir: cloned_keys_dir.clone(), interval: Duration::from_millis(10), cancellation_token: cancellation_token.clone(), key_keeper_shared_state: key_keeper::KeyKeeperSharedState::start_new(), - common_state: key_keeper::CommonState::start_new(), + common_state: key_keeper::CommonState::start_new(cancellation_token.clone()), redirector_shared_state: key_keeper::RedirectorSharedState::start_new(), provision_shared_state: key_keeper::ProvisionSharedState::start_new(), agent_status_shared_state: key_keeper::AgentStatusSharedState::start_new(), diff --git a/proxy_agent/src/key_keeper/key.rs b/proxy_agent/src/key_keeper/key.rs index 667b0932..e8c2b61d 100644 --- a/proxy_agent/src/key_keeper/key.rs +++ b/proxy_agent/src/key_keeper/key.rs @@ -213,12 +213,20 @@ impl Clone for Privilege { } impl Privilege { - pub fn is_match(&self, logger: &mut ConnectionLogger, request_url: &Uri) -> bool { + /// Note: `self.path` and `self.queryParameters` keys/values are expected to be + /// pre-lowercased (done in `ComputedAuthorizationItem::from_authorization_item`). + /// `lowered_request_path` should be `request_url.path().to_lowercase()`, hoisted by the caller. + pub fn is_match( + &self, + logger: &mut ConnectionLogger, + request_url: &Uri, + lowered_request_path: &str, + ) -> bool { logger.write( LoggerLevel::Trace, format!("Start to match privilege '{}'", self.name), ); - if request_url.path().to_lowercase().starts_with(&self.path) { + if lowered_request_path.starts_with(&self.path) { logger.write( LoggerLevel::Trace, format!("Matched privilege path '{}'", self.path), @@ -234,12 +242,14 @@ impl Privilege { ); for (key, value) in query_parameters { + // We may need to optimize this like `lowered_request_path` if there are too many query parameters in the future, + // but currently we expect only a few query parameters at most, so the performance impact should be minimal. match hyper_client::query_pairs(request_url) .into_iter() - .find(|(k, _)| k.to_lowercase() == key.to_lowercase()) + .find(|(k, _)| k.to_lowercase() == *key) { Some((_, v)) => { - if v.to_lowercase() == value.to_lowercase() { + if v.to_lowercase() == *value { logger.write( LoggerLevel::Trace, format!( @@ -720,40 +730,22 @@ impl Display for KeyAction { const STATUS_URL: &str = "/secure-channel/status"; const KEY_URL: &str = "/secure-channel/key"; -pub async fn get_status(base_url: &Uri) -> Result { - let (host, port) = hyper_client::host_port_from_uri(base_url)?; - let url = format!("http://{host}:{port}{STATUS_URL}"); - let url: Uri = url.parse().map_err(|e| { - Error::Key(KeyErrorType::ParseKeyUrl( - base_url.to_string(), - STATUS_URL.to_string(), - e, - )) - })?; +pub async fn get_status(host: &str, port: u16) -> Result { + let endpoint = hyper_client::HostEndpoint::new(host, port, STATUS_URL); let mut headers = HashMap::new(); headers.insert( hyper_client::METADATA_HEADER.to_string(), "True ".to_string(), ); let status: KeyStatus = - hyper_client::get(&url, &headers, None, None, logger::write_warning).await?; + hyper_client::get(&endpoint, &headers, None, None, logger::write_warning).await?; status.validate()?; Ok(status) } -pub async fn acquire_key(base_url: &Uri) -> Result { - let (host, port) = hyper_client::host_port_from_uri(base_url)?; - let url = format!("http://{host}:{port}{KEY_URL}"); - let url: Uri = url.parse().map_err(|e| { - Error::Key(KeyErrorType::ParseKeyUrl( - base_url.to_string(), - KEY_URL.to_string(), - e, - )) - })?; - - let (host, port) = hyper_client::host_port_from_uri(&url)?; +pub async fn acquire_key(host: &str, port: u16) -> Result { + let endpoint = hyper_client::HostEndpoint::new(host, port, KEY_URL); let mut headers = HashMap::new(); headers.insert( hyper_client::METADATA_HEADER.to_string(), @@ -763,21 +755,26 @@ pub async fn acquire_key(base_url: &Uri) -> Result { let body = r#"{"authorizationScheme": "Azure-HMAC-SHA256"}"#.to_string(); let request = hyper_client::build_request( hyper::Method::POST, - &url, + &endpoint, &headers, Some(body.as_bytes()), None, None, )?; - let response = hyper_client::send_request(&host, port, request, logger::write_warning) - .await - .map_err(|e| { - Error::Key(KeyErrorType::SendKeyRequest( - format!("{}", KeyAction::Acquire), - e.to_string(), - )) - })?; + let response = hyper_client::send_request( + &endpoint.host, + endpoint.port, + request, + logger::write_warning, + ) + .await + .map_err(|e| { + Error::Key(KeyErrorType::SendKeyRequest( + format!("{}", KeyAction::Acquire), + e.to_string(), + )) + })?; if response.status() != StatusCode::OK { return Err(Error::Key(KeyErrorType::KeyResponse( @@ -790,17 +787,10 @@ pub async fn acquire_key(base_url: &Uri) -> Result { .map_err(Error::ProxyAgentSharedError) } -pub async fn attest_key(base_url: &Uri, key: &Key) -> Result<()> { +pub async fn attest_key(host: &str, port: u16, key: &Key) -> Result<()> { // secure-channel/key/{key_guid}/key-attestation - let (host, port) = hyper_client::host_port_from_uri(base_url)?; - let url = format!( - "http://{}:{}{}/{}/key-attestation", - host, port, KEY_URL, key.guid - ); - let url: Uri = url - .parse() - .map_err(|e| Error::Key(KeyErrorType::ParseKeyUrl(base_url.to_string(), url, e)))?; - + let path = format!("{}/{}/key-attestation", KEY_URL, key.guid); + let endpoint = hyper_client::HostEndpoint::new(host, port, &path); let mut headers = HashMap::new(); headers.insert( hyper_client::METADATA_HEADER.to_string(), @@ -808,21 +798,26 @@ pub async fn attest_key(base_url: &Uri, key: &Key) -> Result<()> { ); let request = hyper_client::build_request( Method::POST, - &url, + &endpoint, &headers, None, Some(key.guid.to_string()), Some(key.key.to_string()), )?; - let response = hyper_client::send_request(&host, port, request, logger::write_warning) - .await - .map_err(|e| { - Error::Key(KeyErrorType::SendKeyRequest( - format!("{}", KeyAction::Attest), - e.to_string(), - )) - })?; + let response = hyper_client::send_request( + &endpoint.host, + endpoint.port, + request, + logger::write_warning, + ) + .await + .map_err(|e| { + Error::Key(KeyErrorType::SendKeyRequest( + format!("{}", KeyAction::Attest), + e.to_string(), + )) + })?; if response.status() != StatusCode::OK { return Err(Error::Key(KeyErrorType::KeyResponse( @@ -1415,7 +1410,7 @@ mod tests { .parse() .unwrap(); assert!( - privilege.is_match(&mut logger, &url), + privilege.is_match(&mut logger, &url, &url.path().to_lowercase()), "privilege should be matched" ); @@ -1423,13 +1418,13 @@ mod tests { .parse() .unwrap(); assert!( - !privilege.is_match(&mut logger, &url), + !privilege.is_match(&mut logger, &url, &url.path().to_lowercase()), "privilege should not be matched" ); let url = "http://localhost/test?key1=value1".parse().unwrap(); assert!( - !privilege.is_match(&mut logger, &url), + !privilege.is_match(&mut logger, &url, &url.path().to_lowercase()), "privilege should not be matched" ); @@ -1442,7 +1437,7 @@ mod tests { .parse() .unwrap(); assert!( - privilege1.is_match(&mut logger, &url), + privilege1.is_match(&mut logger, &url, &url.path().to_lowercase()), "privilege should be matched" ); @@ -1459,7 +1454,7 @@ mod tests { .parse() .unwrap(); assert!( - !privilege2.is_match(&mut logger, &url), + !privilege2.is_match(&mut logger, &url, &url.path().to_lowercase()), "privilege should not be matched" ); } diff --git a/proxy_agent/src/main.rs b/proxy_agent/src/main.rs index 7a73720c..5f090fd4 100644 --- a/proxy_agent/src/main.rs +++ b/proxy_agent/src/main.rs @@ -15,7 +15,7 @@ use common::cli::{Commands, CLI}; use common::constants; use common::helpers; use provision::provision_query::ProvisionQuery; -use proxy_agent_shared::misc_helpers; +use proxy_agent_shared::current_info; use shared_state::SharedState; use std::{process, time::Duration}; @@ -48,7 +48,7 @@ async fn main() { let _time = helpers::get_elapsed_time_in_millisec(); if CLI.version { - println!("{}", misc_helpers::get_current_version()); + println!("{}", current_info::get_current_exe_version()); return; } diff --git a/proxy_agent/src/provision.rs b/proxy_agent/src/provision.rs index fd11fbea..82f5772c 100644 --- a/proxy_agent/src/provision.rs +++ b/proxy_agent/src/provision.rs @@ -19,6 +19,7 @@ use crate::{proxy_agent_status, redirector}; use proxy_agent_shared::logger::LoggerLevel; use proxy_agent_shared::telemetry::event_logger; use proxy_agent_shared::telemetry::event_reader::EventReader; +use proxy_agent_shared::telemetry::event_sender::EventSender; use proxy_agent_shared::{misc_helpers, proxy_agent_aggregate_status}; use std::path::PathBuf; use std::time::Duration; @@ -370,18 +371,36 @@ pub async fn start_event_threads(event_threads_shared_state: EventThreadsSharedS tokio::spawn({ let event_reader = EventReader::new( config::get_events_dir(), - true, - event_threads_shared_state.cancellation_token.clone(), event_threads_shared_state.common_state.clone(), "ProxyAgent".to_string(), "MicrosoftAzureGuestProxyAgent".to_string(), ); async move { event_reader - .start(Some(Duration::from_secs(300)), None, None) + .start(true, Some(Duration::from_secs(300))) .await; } }); + tokio::spawn({ + let event_reader = EventReader::new( + config::get_events_dir(), + event_threads_shared_state.common_state.clone(), + "ProxyAgent".to_string(), + "MicrosoftAzureGuestProxyAgent".to_string(), + ); + async move { + event_reader + .start_extension_status_event_processor(true, Some(Duration::from_secs(60))) + .await; + } + }); + tokio::spawn({ + let event_sender = EventSender::new(event_threads_shared_state.common_state.clone()); + async move { + event_sender.start(None, None).await; + } + }); + if let Err(e) = event_threads_shared_state .provision_shared_state .set_event_log_threads_initialized() @@ -642,17 +661,12 @@ pub mod provision_query { // bool - true provision finished; false provision not finished // String - provision error message, empty means provision success or provision failed. async fn get_current_provision_status(&self, notify: bool) -> Result { - let provision_url: String = format!( - "http://{}:{}{}", - Ipv4Addr::LOCALHOST, + let endpoint = hyper_client::HostEndpoint::new( + Ipv4Addr::LOCALHOST.to_string(), self.port, - PROVISION_URL_PATH + PROVISION_URL_PATH, ); - let provision_url: hyper::Uri = provision_url - .parse::() - .map_err(|e| Error::ParseUrl(provision_url, e.to_string()))?; - let mut headers = HashMap::new(); headers.insert( hyper_client::METADATA_HEADER.to_string(), @@ -665,7 +679,7 @@ pub mod provision_query { if notify { headers.insert(constants::NOTIFY_HEADER.to_string(), "true".to_string()); } - hyper_client::get(&provision_url, &headers, None, None, logger::write_warning) + hyper_client::get(&endpoint, &headers, None, None, logger::write_warning) .await .map_err(Error::ProxyAgentSharedError) } diff --git a/proxy_agent/src/proxy.rs b/proxy_agent/src/proxy.rs index dcc50b1a..9278e634 100644 --- a/proxy_agent/src/proxy.rs +++ b/proxy_agent/src/proxy.rs @@ -46,9 +46,6 @@ use crate::shared_state::proxy_server_wrapper::ProxyServerSharedState; use serde_derive::{Deserialize, Serialize}; use std::{ffi::OsString, net::IpAddr, path::PathBuf}; -#[cfg(not(windows))] -use sysinfo::{Pid, ProcessRefreshKind, RefreshKind, System, UpdateKind}; - #[derive(Serialize, Deserialize, Clone)] #[allow(non_snake_case)] pub struct Claims { @@ -90,33 +87,41 @@ async fn get_user( Ok(user) } else { let user = User::from_logon_id(logon_id)?; - if let Err(e) = proxy_server_shared_state.add_user(user.clone()).await { - println!("Failed to add user: {e} to cache"); + if let Err(_e) = proxy_server_shared_state.add_user(user.clone()).await { + #[cfg(test)] + eprintln!("Failed to add user: {_e} to cache"); } Ok(user) } } +/// Get process information (executable path and command line) for the given process ID. +/// Reads directly from the /proc filesystem on Linux for better performance. +/// Returns (executable path, command line). If the process information cannot be retrieved, returns (empty path, "undefined" command line). +/// Remarks: both /proc/{pid}/exe and /proc/{pid}/cmdline are universally supported across all Linux distributions since kernel 1.0 +/// Remarks: Do not use sysinfo::System::refresh_process_specifics(pid, refresh_kind) to get this information, +/// as it reads all files from /proc/{pid}/* and Create Process struct with all fields, +/// which is very inefficient when we only need the executable path and command line. #[cfg(not(windows))] fn get_process_info(process_id: u32) -> (PathBuf, String) { - let mut process_name = PathBuf::default(); - let mut process_cmd_line = UNDEFINED.to_string(); + // Get executable path from /proc/{pid}/exe symlink + let exe_path = format!("/proc/{}/exe", process_id); + let process_name = std::fs::read_link(&exe_path).unwrap_or_default(); - let pid = Pid::from_u32(process_id); - let sys = System::new_with_specifics( - RefreshKind::new().with_processes( - ProcessRefreshKind::new() - .with_cmd(UpdateKind::Always) - .with_exe(UpdateKind::Always), - ), - ); - if let Some(p) = sys.process(pid) { - process_name = match p.exe() { - Some(path) => path.to_path_buf(), - None => PathBuf::default(), - }; - process_cmd_line = p.cmd().join(" "); - } + // Get command line from /proc/{pid}/cmdline (null-separated arguments) + let cmdline_path = format!("/proc/{}/cmdline", process_id); + let process_cmd_line = match std::fs::read(&cmdline_path) { + Ok(bytes) => { + // cmdline is null-separated, convert to space-separated string + bytes + .split(|&b| b == 0) + .filter(|s| !s.is_empty()) + .map(|s| String::from_utf8_lossy(s).into_owned()) + .collect::>() + .join(" ") + } + Err(_) => UNDEFINED.to_string(), + }; (process_name, process_cmd_line) } @@ -147,12 +152,12 @@ impl Claims { let u = get_user(entry.logon_id, proxy_server_shared_state).await?; Ok(Claims { userId: entry.logon_id, - userName: u.user_name.to_string(), + userName: u.user_name.clone(), userGroups: u.user_groups.clone(), processId: p.pid, processName: p.name, processFullPath: p.exe_full_name, - processCmdLine: p.command_line.to_string(), + processCmdLine: p.command_line.clone(), runAsElevated: entry.is_admin == 1, clientIp: client_ip.to_string(), clientPort: client_port, @@ -170,26 +175,26 @@ impl Process { }; let options = PROCESS_QUERY_INFORMATION | PROCESS_VM_READ; - let handler = proxy_agent_shared::windows::get_process_handler(pid, options) - .unwrap_or_else(|e| { - println!("Failed to get process handler: {e}"); - 0 - }); - let base_info = windows::query_basic_process_info(handler); - match base_info { - Ok(_) => { - process_full_path = windows::get_process_full_name(handler).unwrap_or_default(); - cmd = windows::get_process_cmd(handler).unwrap_or(UNDEFINED.to_string()); - } - Err(e) => { - process_full_path = PathBuf::default(); - cmd = UNDEFINED.to_string(); - println!("Failed to query basic process info: {e}"); + let handler = + proxy_agent_shared::windows::get_process_handler(pid, options).unwrap_or(0); + if handler != 0 { + // Get process info directly - if either fails, the process may have exited + process_full_path = windows::get_process_full_name(handler).unwrap_or_default(); + cmd = windows::get_process_cmd(handler).unwrap_or(UNDEFINED.to_string()); + + // close the handle + if let Err(_e) = proxy_agent_shared::windows::close_handler(handler) { + #[cfg(test)] + println!("Failed to close process handler: {_e}"); } - } - // close the handle - if let Err(e) = proxy_agent_shared::windows::close_handler(handler) { - println!("Failed to close process handler: {e}"); + } else { + process_full_path = PathBuf::default(); + cmd = UNDEFINED.to_string(); + #[cfg(test)] + eprintln!( + "Failed to get_process_handler: {}", + std::io::Error::last_os_error() + ); } } #[cfg(not(windows))] @@ -200,7 +205,7 @@ impl Process { } // redact the secrets in the command line - let cmd = proxy_agent_shared::secrets_redactor::redact_secrets(cmd); + let cmd = proxy_agent_shared::secrets_redactor::redact_secrets_string(cmd); let process_name = process_full_path .file_name() @@ -249,8 +254,8 @@ impl User { Ok(User { logon_id, - user_name: user_name.to_string(), - user_groups: user_groups.clone(), + user_name, + user_groups, }) } } diff --git a/proxy_agent/src/proxy/authorization_rules.rs b/proxy_agent/src/proxy/authorization_rules.rs index 27ef84fa..3c706083 100644 --- a/proxy_agent/src/proxy/authorization_rules.rs +++ b/proxy_agent/src/proxy/authorization_rules.rs @@ -111,7 +111,21 @@ impl ComputedAuthorizationItem { .collect::>(); privilege_dict = privileges .into_iter() - .map(|privilege| (privilege.name.clone(), privilege)) + .map(|privilege| { + // case insensitive for path and query parameters key/values, + // to make it easier for users to write the rules without worrying about the case sensitivity. + // The name of the privilege is case sensitive, as it is used as the key in the privilege_dict and privilege_assignments. + let normalized = Privilege { + name: privilege.name, + path: privilege.path.to_lowercase(), + queryParameters: privilege.queryParameters.map(|qp| { + qp.into_iter() + .map(|(k, v)| (k.to_lowercase(), v.to_lowercase())) + .collect() + }), + }; + (normalized.name.clone(), normalized) + }) .collect::>(); for role_assignment in role_assignments { @@ -177,10 +191,11 @@ impl ComputedAuthorizationItem { return true; } + let lowered_request_path = request_url.path().to_lowercase(); let mut any_privilege_matched = false; for privilege in self.privileges.values() { let privilege_name = &privilege.name; - if privilege.is_match(logger, &request_url) { + if privilege.is_match(logger, &request_url, &lowered_request_path) { any_privilege_matched = true; logger.write( LoggerLevel::Trace, @@ -263,6 +278,10 @@ pub struct AuthorizationRulesForLogging { pub computedRules: ComputedAuthorizationRules, } +/// Remark: Regex::new is performance-sensitive, so we use LazyLock to compile it only once and reuse it for subsequent calls +static AUTHORIZATION_RULES_FILE_SEARCH_REGEX: std::sync::LazyLock = + std::sync::LazyLock::new(|| regex::Regex::new(r"^AuthorizationRules_.*\.json$").unwrap()); + impl AuthorizationRulesForLogging { pub fn new( input_rules: Option, @@ -280,7 +299,10 @@ impl AuthorizationRulesForLogging { /// The file is written to the path_dir specified by the input parameter pub fn write_all(&self, path_dir: &Path, max_file_count: usize) { // remove the old files - let files = match misc_helpers::search_files(path_dir, r"^AuthorizationRules_.*\.json$") { + let files = match misc_helpers::search_files( + path_dir, + &AUTHORIZATION_RULES_FILE_SEARCH_REGEX, + ) { Ok(files) => files, Err(e) => { // This should not happen, log the error and skip write the file @@ -343,7 +365,9 @@ mod tests { AccessControlRules, AuthorizationItem, AuthorizationRules, Identity, Privilege, Role, RoleAssignment, }; - use crate::proxy::authorization_rules::{AuthorizationMode, ComputedAuthorizationItem}; + use crate::proxy::authorization_rules::{ + AuthorizationMode, ComputedAuthorizationItem, AUTHORIZATION_RULES_FILE_SEARCH_REGEX, + }; use crate::proxy::{proxy_connection::ConnectionLogger, Claims}; use proxy_agent_shared::misc_helpers; use std::ffi::OsString; @@ -365,7 +389,7 @@ mod tests { }]), privileges: Some(vec![Privilege { name: "test".to_string(), - path: "/test".to_string(), + path: "/TEST".to_string(), // test the case insensitivity of the path queryParameters: None, }]), identities: Some(vec![Identity { @@ -407,8 +431,12 @@ mod tests { runAsElevated: true, }; // assert the claim is allowed given the rules above - let url = hyper::Uri::from_str("http://localhost/test/test").unwrap(); + + // test the case insensitivity of the path + let url = hyper::Uri::from_str("http://localhost/tESt/test").unwrap(); assert!(rules.is_allowed(&mut test_logger, url, claims.clone())); + + // test the case insensitivity of the path and the relative url let relative_url = hyper::Uri::from_str("/test/test").unwrap(); assert!(rules.is_allowed(&mut test_logger, relative_url.clone(), claims.clone())); claims.userName = "test1".to_string(); @@ -544,7 +572,7 @@ mod tests { match std::fs::remove_dir_all(&temp_test_path) { Ok(_) => {} Err(e) => { - print!("Failed to remove_dir_all with error {}.", e); + eprintln!("Failed to remove_dir_all with error {}.", e); } } misc_helpers::try_create_folder(&temp_test_path).unwrap(); @@ -600,7 +628,8 @@ mod tests { } let files = - misc_helpers::search_files(&temp_test_path, r"^AuthorizationRules_.*\.json$").unwrap(); + misc_helpers::search_files(&temp_test_path, &AUTHORIZATION_RULES_FILE_SEARCH_REGEX) + .unwrap(); assert_eq!(files.len(), max_file_count); // clean up and ignore the clean up errors diff --git a/proxy_agent/src/proxy/proxy_connection.rs b/proxy_agent/src/proxy/proxy_connection.rs index e20956a2..a810c77a 100644 --- a/proxy_agent/src/proxy/proxy_connection.rs +++ b/proxy_agent/src/proxy/proxy_connection.rs @@ -339,14 +339,14 @@ impl ConnectionLogger { // write to system log for connection logger explicitly, // as the connection logger only writes to file when the connection is dropped and, // connection logger file log does not write to system log implicitly. - logger_manager::write_system_log(logger_level, message.to_string()); + logger_manager::write_system_log(logger_level, message.clone()); if let Some(log_for_event) = crate::common::config::get_file_log_level_for_events() { if log_for_event >= logger_level { // write to event proxy_agent_shared::telemetry::event_logger::write_event_only( logger_level, - message.to_string(), + message.clone(), "ConnectionLogger", "ProxyAgent", ); @@ -360,11 +360,9 @@ impl ConnectionLogger { return; } - self.queue.push(format!( - "{}{}", - logger::get_log_header(logger_level), - message - )); + let mut msg = logger::get_log_header(logger_level); + msg.push_str(&message); + self.queue.push(msg); } } diff --git a/proxy_agent/src/proxy/proxy_server.rs b/proxy_agent/src/proxy/proxy_server.rs index cdea7b37..fb2a0abd 100644 --- a/proxy_agent/src/proxy/proxy_server.rs +++ b/proxy_agent/src/proxy/proxy_server.rs @@ -246,31 +246,27 @@ impl ProxyServer { }; let mut tcp_connection_logger = ConnectionLogger::new(tcp_connection_id, 0); tcp_connection_logger.write( - LoggerLevel::Info, + LoggerLevel::Trace, format!("Accepted new tcp connection [{tcp_connection_id}]."), ); tokio::spawn({ let cloned_proxy_server = self.clone(); async move { - let (stream, _cloned_std_stream) = - match Self::set_stream_read_time_out(stream, &mut tcp_connection_logger) { - Ok((stream, cloned_std_stream)) => (stream, cloned_std_stream), - Err(e) => { - tcp_connection_logger.write( - LoggerLevel::Error, - format!("Failed to set stream read timeout: {e}"), - ); - return; - } - }; + // Get raw socket ID before any conversion (Windows only) + #[cfg(windows)] + let raw_socket_id = Self::get_stream_raw_socket_id(&stream); + + // Set read timeout directly on the socket without conversion + Self::set_stream_read_time_out(&stream, &mut tcp_connection_logger); + let tcp_connection_context = TcpConnectionContext::new( tcp_connection_id, client_addr, cloned_proxy_server.redirector_shared_state.clone(), cloned_proxy_server.proxy_server_shared_state.clone(), #[cfg(windows)] - ProxyServer::get_stream_rocket_id(&_cloned_std_stream), + raw_socket_id, ) .await; @@ -324,46 +320,24 @@ impl ProxyServer { } #[cfg(windows)] - fn get_stream_rocket_id(stream: &std::net::TcpStream) -> usize { + fn get_stream_raw_socket_id(stream: &TcpStream) -> usize { use std::os::windows::io::AsRawSocket; stream.as_raw_socket() as usize } // Set the read timeout for the stream - fn set_stream_read_time_out( - stream: TcpStream, - connection_logger: &mut ConnectionLogger, - ) -> Result<(TcpStream, std::net::TcpStream)> { - // Convert the stream to a std stream - let std_stream = stream.into_std().map_err(|e| { - Error::Io( - "Failed to convert Tokio stream into std equivalent".to_string(), - e, - ) - })?; + // Uses socket2::SockRef to set socket options directly on the tokio stream + // socket2 crate already used by tokio internally, so it won't cause extra dependency + fn set_stream_read_time_out(stream: &TcpStream, connection_logger: &mut ConnectionLogger) { + use socket2::SockRef; - // Set the read timeout - if let Err(e) = std_stream.set_read_timeout(Some(std::time::Duration::from_secs(10))) { + let sock_ref = SockRef::from(stream); + if let Err(e) = sock_ref.set_read_timeout(Some(std::time::Duration::from_secs(10))) { connection_logger.write( LoggerLevel::Warn, format!("Failed to set read timeout: {e}"), ); } - - // Clone the stream for the service_fn - let cloned_std_stream = std_stream - .try_clone() - .map_err(|e| Error::Io("Failed to clone TCP stream".to_string(), e))?; - - // Convert the std stream back - let tokio_tcp_stream = TcpStream::from_std(std_stream).map_err(|e| { - Error::Io( - "Failed to convert std stream into Tokio equivalent".to_string(), - e, - ) - })?; - - Ok((tokio_tcp_stream, cloned_std_stream)) } async fn handle_new_http_request( @@ -393,7 +367,7 @@ impl ProxyServer { tcp_connection_context.clone(), ); http_connection_context.log( - LoggerLevel::Info, + LoggerLevel::Trace, format!( "Got request from {} for {} {}", tcp_connection_context.client_addr, @@ -471,7 +445,7 @@ impl ProxyServer { return Ok(Self::closed_response(StatusCode::MISDIRECTED_REQUEST)); } }; - http_connection_context.log(LoggerLevel::Info, claim_details.to_string()); + http_connection_context.log(LoggerLevel::Trace, claim_details.to_string()); // authenticate the connection let access_control_rules = match proxy_authorizer::get_access_control_rules( @@ -564,7 +538,7 @@ impl ProxyServer { if http_connection_context.should_skip_sig() { http_connection_context.log( - LoggerLevel::Info, + LoggerLevel::Trace, format!( "Skip compute signature for the request for {} {}", http_connection_context.method, http_connection_context.url @@ -798,12 +772,12 @@ impl ProxyServer { let summary = ProxySummary { id: http_connection_context.id, userId: claims.userId, - userName: claims.userName.to_string(), + userName: claims.userName.clone(), userGroups: claims.userGroups.clone(), - clientIp: claims.clientIp.to_string(), + clientIp: claims.clientIp.clone(), clientPort: claims.clientPort, processFullPath: claims.processFullPath, - processCmdLine: claims.processCmdLine.to_string(), + processCmdLine: claims.processCmdLine.clone(), runAsElevated: claims.runAsElevated, method: http_connection_context.method.to_string(), url: http_connection_context.url.to_string(), @@ -817,41 +791,67 @@ impl ProxyServer { elapsedTime: elapsed_time.as_millis(), errorDetails: error_details, }; - if let Ok(json) = serde_json::to_string(&summary) { - event_logger::write_event( - LoggerLevel::Info, - json, - "log_connection_summary", - "proxy_server", - ConnectionLogger::CONNECTION_LOGGER_KEY, - ); - }; http_connection_context.log( LoggerLevel::Trace, "Starting add connection summary for status reporting.".to_string(), ); if log_authorize_failed { - if let Err(e) = self + match self .connection_summary_shared_state - .add_one_failed_connection_summary(summary) + .add_one_failed_connection_summary(summary.clone()) .await { - http_connection_context.log( - LoggerLevel::Warn, - format!("Failed to add failed connection summary: {e}"), - ); + Ok(is_new_bucket) => { + if is_new_bucket { + // if it's a new bucket, we don't need to add to failed connection summary again + if let Ok(json) = serde_json::to_string(&summary) { + event_logger::write_event( + LoggerLevel::Info, + json, + "log_connection_summary", + "proxy_server", + ConnectionLogger::CONNECTION_LOGGER_KEY, + ); + }; + } + } + Err(e) => { + http_connection_context.log( + LoggerLevel::Warn, + format!("Failed to add failed connection summary: {e}"), + ); + } + } + } else { + match self + .connection_summary_shared_state + .add_one_connection_summary(summary.clone()) + .await + { + Ok(is_new_bucket) => { + if is_new_bucket { + // if it's a new bucket, we log it to event logger + if let Ok(json) = serde_json::to_string(&summary) { + event_logger::write_event( + LoggerLevel::Info, + json, + "log_connection_summary", + "proxy_server", + ConnectionLogger::CONNECTION_LOGGER_KEY, + ); + }; + } + } + Err(e) => { + http_connection_context.log( + LoggerLevel::Warn, + format!("Failed to add connection summary: {e}"), + ); + } } - } else if let Err(e) = self - .connection_summary_shared_state - .add_one_connection_summary(summary) - .await - { - http_connection_context.log( - LoggerLevel::Warn, - format!("Failed to add connection summary: {e}"), - ); } + http_connection_context.log( LoggerLevel::Trace, "Finished log_connection_summary.".to_string(), @@ -1045,10 +1045,10 @@ mod tests { let sleep_duration = Duration::from_millis(100); tokio::time::sleep(sleep_duration).await; - let url: hyper::Uri = format!("http://{}:{}/", host, port).parse().unwrap(); + let endpoint = hyper_client::HostEndpoint::new(host, port, "/"); let request = hyper_client::build_request( Method::GET, - &url, + &endpoint, &HashMap::new(), None, key_keeper_shared_state @@ -1085,12 +1085,10 @@ mod tests { ); // test with traversal characters - let url: hyper::Uri = format!("http://{}:{}/test/../", host, port) - .parse() - .unwrap(); + let endpoint = hyper_client::HostEndpoint::new(host, port, "/test/../"); let request = hyper_client::build_request( Method::GET, - &url, + &endpoint, &HashMap::new(), None, key_keeper_shared_state @@ -1116,7 +1114,7 @@ mod tests { let body = vec![88u8; super::REQUEST_BODY_LOW_LIMIT_SIZE + 1]; let request = hyper_client::build_request( Method::POST, - &url, + &endpoint, &HashMap::new(), Some(body.as_slice()), key_keeper_shared_state diff --git a/proxy_agent/src/proxy/proxy_summary.rs b/proxy_agent/src/proxy/proxy_summary.rs index 4de51bf3..ff823d7d 100644 --- a/proxy_agent/src/proxy/proxy_summary.rs +++ b/proxy_agent/src/proxy/proxy_summary.rs @@ -48,13 +48,13 @@ impl ProxySummary { impl From for ProxyConnectionSummary { fn from(proxy_summary: ProxySummary) -> ProxyConnectionSummary { ProxyConnectionSummary { - userName: proxy_summary.userName.to_string(), - userGroups: Some(proxy_summary.userGroups.clone()), - ip: proxy_summary.ip.to_string(), + userName: proxy_summary.userName, + userGroups: Some(proxy_summary.userGroups), + ip: proxy_summary.ip, port: proxy_summary.port, processFullPath: Some(proxy_summary.processFullPath.to_string_lossy().to_string()), - processCmdLine: proxy_summary.processCmdLine.to_string(), - responseStatus: proxy_summary.responseStatus.to_string(), + processCmdLine: proxy_summary.processCmdLine, + responseStatus: proxy_summary.responseStatus, count: 1, } } diff --git a/proxy_agent/src/proxy/windows.rs b/proxy_agent/src/proxy/windows.rs index 9b4ac2f7..acc770eb 100644 --- a/proxy_agent/src/proxy/windows.rs +++ b/proxy_agent/src/proxy/windows.rs @@ -24,7 +24,6 @@ use windows_sys::Win32::System::ProcessStatus::{ K32GetModuleBaseNameW, // kernel32.dll K32GetModuleFileNameExW, // kernel32.dll }; -use windows_sys::Win32::System::Threading::PROCESS_BASIC_INFORMATION; const LG_INCLUDE_INDIRECT: u32 = 1u32; const MAX_PREFERRED_LENGTH: u32 = 4294967295u32; @@ -239,80 +238,61 @@ const MAX_PATH: usize = 260; const STATUS_BUFFER_OVERFLOW: NTSTATUS = -2147483643; const STATUS_BUFFER_TOO_SMALL: NTSTATUS = -1073741789; const STATUS_INFO_LENGTH_MISMATCH: NTSTATUS = -1073741820; -const PROCESS_BASIC_INFORMATION_CLASS: PROCESSINFOCLASS = 0; const PROCESS_COMMAND_LINE_INFORMATION_CLASS: PROCESSINFOCLASS = 60; -pub fn query_basic_process_info(handler: isize) -> Result { - unsafe { - let mut process_basic_information = std::mem::zeroed::(); - let mut return_length = 0; - let status: NTSTATUS = NtQueryInformationProcess( - handler, - PROCESS_BASIC_INFORMATION_CLASS, - &mut process_basic_information as *mut _ as *mut _, - std::mem::size_of::() as u32, - &mut return_length, - ); - - if status != 0 { - return Err(Error::WindowsApi(WindowsApiErrorType::WindowsOsError( - std::io::Error::from_raw_os_error(status), - ))); - } - Ok(process_basic_information) - } -} - pub fn get_process_cmd(handler: isize) -> Result { - unsafe { - let mut return_length = 0; - let status: NTSTATUS = NtQueryInformationProcess( + let mut return_length = 0; + let status: NTSTATUS = unsafe { + NtQueryInformationProcess( handler, PROCESS_COMMAND_LINE_INFORMATION_CLASS, null_mut(), 0, &mut return_length as *mut _, - ); + ) + }; - if status != STATUS_BUFFER_OVERFLOW - && status != STATUS_BUFFER_TOO_SMALL - && status != STATUS_INFO_LENGTH_MISMATCH - { - return Err(Error::WindowsApi(WindowsApiErrorType::WindowsOsError( - std::io::Error::from_raw_os_error(status), - ))); - } - println!("return_length: {return_length}"); + if status != STATUS_BUFFER_OVERFLOW + && status != STATUS_BUFFER_TOO_SMALL + && status != STATUS_INFO_LENGTH_MISMATCH + { + return Err(Error::WindowsApi(WindowsApiErrorType::WindowsOsError( + std::io::Error::from_raw_os_error(status), + ))); + } + #[cfg(test)] + println!("return_length: {return_length}"); - let buf_len = (return_length as usize) / 2; - let mut buffer: Vec = vec![0; buf_len + 1]; - buffer.resize(buf_len + 1, 0); // set everything to 0 + let buf_len = (return_length as usize) / 2; + let mut buffer: Vec = vec![0; buf_len + 1]; + buffer.resize(buf_len + 1, 0); // set everything to 0 - let status: NTSTATUS = NtQueryInformationProcess( + let status: NTSTATUS = unsafe { + NtQueryInformationProcess( handler, PROCESS_COMMAND_LINE_INFORMATION_CLASS, buffer.as_mut_ptr() as *mut _, return_length, &mut return_length as *mut _, - ); - if status < 0 { - eprintln!("NtQueryInformationProcess failed with status: {status}"); - return Err(Error::WindowsApi(WindowsApiErrorType::WindowsOsError( - std::io::Error::from_raw_os_error(status), - ))); - } - buffer.set_len(buf_len); - buffer.push(0); + ) + }; + if status < 0 { + #[cfg(test)] + eprintln!("NtQueryInformationProcess failed with status: {status}"); + return Err(Error::WindowsApi(WindowsApiErrorType::WindowsOsError( + std::io::Error::from_raw_os_error(status), + ))); + } + unsafe { buffer.set_len(buf_len) }; + buffer.push(0); - let cmd_buffer = *(buffer.as_ptr() as *const UNICODE_STRING); + let cmd_buffer = unsafe { *(buffer.as_ptr() as *const UNICODE_STRING) }; - let cmd = String::from_utf16_lossy(std::slice::from_raw_parts( - cmd_buffer.Buffer, - (cmd_buffer.Length / 2) as usize, - )); + let cmd = String::from_utf16_lossy(unsafe { + std::slice::from_raw_parts(cmd_buffer.Buffer, (cmd_buffer.Length / 2) as usize) + }); - Ok(cmd) - } + Ok(cmd) } #[allow(dead_code)] @@ -389,9 +369,6 @@ mod tests { let full_name = super::get_process_full_name(handler).unwrap(); let cmd = super::get_process_cmd(handler).unwrap(); - let base_info = super::query_basic_process_info(handler); - assert!(base_info.is_ok(), "base_info must be ok"); - assert!( !name.as_os_str().is_empty(), "process name should not be empty" diff --git a/proxy_agent/src/proxy_agent_status.rs b/proxy_agent/src/proxy_agent_status.rs index 6a248474..aec85804 100644 --- a/proxy_agent/src/proxy_agent_status.rs +++ b/proxy_agent/src/proxy_agent_status.rs @@ -28,18 +28,17 @@ //! tokio::spawn(proxy_agent_status_task.start()); //! ``` -use crate::common::logger; +use crate::common::{constants, logger}; use crate::key_keeper::UNKNOWN_STATE; use crate::shared_state::agent_status_wrapper::{AgentStatusModule, AgentStatusSharedState}; use crate::shared_state::connection_summary_wrapper::ConnectionSummarySharedState; use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState; -use proxy_agent_shared::logger::LoggerLevel; -use proxy_agent_shared::misc_helpers; use proxy_agent_shared::proxy_agent_aggregate_status::{ GuestProxyAgentAggregateStatus, ModuleState, OverallState, ProxyAgentDetailStatus, ProxyAgentStatus, }; -use proxy_agent_shared::telemetry::event_logger; +use proxy_agent_shared::telemetry::{event_logger, Extension, OperationStatus}; +use proxy_agent_shared::{current_info, misc_helpers}; use std::collections::HashMap; use std::path::PathBuf; use std::time::{Duration, Instant}; @@ -116,8 +115,6 @@ impl ProxyAgentStatusTask { } async fn loop_status(&self) { - let map_clear_duration = Duration::from_secs(60 * 60 * 24); - let mut start_time = Instant::now(); let status_report_duration = Duration::from_secs(60 * 15); let mut status_report_time = Instant::now(); @@ -128,7 +125,12 @@ impl ProxyAgentStatusTask { e )); } - + let agent_status = Extension { + name: constants::PROXY_AGENT_SERVICE_NAME.to_string(), + version: current_info::get_current_exe_version(), + is_internal: true, + extension_type: "Monitoring".to_string(), + }; loop { #[cfg(not(windows))] { @@ -142,34 +144,23 @@ impl ProxyAgentStatusTask { Ok(status) => status, Err(e) => format!("Error serializing proxy agent status: {e}"), }; - event_logger::write_event( - LoggerLevel::Info, - status, - "loop_status", - "proxy_agent_status", - logger::AGENT_LOGGER_KEY, - ); + event_logger::report_extension_status_event( + agent_status.clone(), + OperationStatus { + operation_success: aggregate_status.proxyAgentStatus.status + == OverallState::SUCCESS, + task_name: "loop_status".to_string(), + operation: "report_proxy_agent_status".to_string(), + message: status, + duration: status_report_time.elapsed().as_millis() as i64, + }, + ) + .await; status_report_time = Instant::now(); } // write the aggregate status to status.json file self.write_aggregate_status_to_file(aggregate_status).await; - //Clear the connection map and reset start_time after 24 hours - if start_time.elapsed() >= map_clear_duration { - logger::write_information( - "Clearing the connection summary map and failed authenticate summary map." - .to_string(), - ); - if let Err(e) = self - .connection_summary_shared_state - .clear_all_summary() - .await - { - logger::write_error(format!("Error clearing the connection summary map and failed authenticate summary map: {e}")); - } - start_time = Instant::now(); - } - tokio::time::sleep(self.interval).await; } } @@ -243,7 +234,7 @@ impl ProxyAgentStatusTask { }; ProxyAgentStatus { - version: misc_helpers::get_current_version(), + version: current_info::get_current_exe_version(), status, // monitorStatus is proxy_agent_status itself status monitorStatus: ProxyAgentDetailStatus { @@ -300,7 +291,7 @@ impl ProxyAgentStatusTask { async fn write_aggregate_status_to_file(&self, status: GuestProxyAgentAggregateStatus) { let full_file_path = self.status_dir.join("status.json"); - if let Err(e) = misc_helpers::json_write_to_file(&status, &full_file_path) { + if let Err(e) = misc_helpers::json_write_to_file_async(&status, &full_file_path).await { self.update_agent_status_message(format!( "Error writing aggregate status to status file: {e}" )) diff --git a/proxy_agent/src/service.rs b/proxy_agent/src/service.rs index 23e8011e..c6d74b77 100644 --- a/proxy_agent/src/service.rs +++ b/proxy_agent/src/service.rs @@ -10,6 +10,7 @@ use crate::proxy::proxy_server::ProxyServer; use crate::redirector::{self, Redirector}; use crate::shared_state::SharedState; use proxy_agent_shared::current_info; +use proxy_agent_shared::hyper_client::HostEndpoint; use proxy_agent_shared::logger::rolling_logger::RollingLogger; use proxy_agent_shared::logger::{logger_manager, LoggerLevel}; use proxy_agent_shared::proxy_agent_aggregate_status; @@ -46,7 +47,7 @@ pub async fn start_service(shared_state: SharedState) { let start_message = format!( "============== GuestProxyAgent ({}) is starting on {}({}), elapsed: {}", - proxy_agent_shared::misc_helpers::get_current_version(), + current_info::get_current_exe_version(), current_info::get_long_os_version(), current_info::get_cpu_arch(), helpers::get_elapsed_time_in_millisec() @@ -57,9 +58,8 @@ pub async fn start_service(shared_state: SharedState) { tokio::spawn({ let key_keeper = KeyKeeper::new( - (format!("http://{}/", constants::WIRE_SERVER_IP)) - .parse() - .unwrap(), + constants::WIRE_SERVER_IP.to_string(), + HostEndpoint::DEFAULT_HTTP_PORT, config::get_keys_dir(), proxy_agent_aggregate_status::get_proxy_agent_aggregate_status_folder(), config::get_poll_key_status_duration(), @@ -158,31 +158,3 @@ pub fn stop_service(shared_state: SharedState) { event_logger::stop(); } - -#[cfg(test)] -mod tests { - use ctor::{ctor, dtor}; - use proxy_agent_shared::logger::LoggerLevel; - use std::env; - use std::fs; - - const TEST_LOGGER_KEY: &str = "proxy_agent_test"; - - fn get_temp_test_dir() -> std::path::PathBuf { - let mut temp_test_path = env::temp_dir(); - temp_test_path.push(TEST_LOGGER_KEY); - temp_test_path - } - - #[ctor] - fn setup() { - // Setup logger_manager for unit tests - super::setup_loggers(get_temp_test_dir(), LoggerLevel::Trace); - } - - #[dtor] - fn cleanup() { - // clean up and ignore the clean up errors - _ = fs::remove_dir_all(&get_temp_test_dir()); - } -} diff --git a/proxy_agent/src/shared_state.rs b/proxy_agent/src/shared_state.rs index 58553849..47d8c0e2 100644 --- a/proxy_agent/src/shared_state.rs +++ b/proxy_agent/src/shared_state.rs @@ -56,10 +56,12 @@ pub struct SharedState { impl SharedState { pub fn start_all() -> Self { + let cancellation_token = CancellationToken::new(); + SharedState { - cancellation_token: CancellationToken::new(), + cancellation_token: cancellation_token.clone(), key_keeper_shared_state: key_keeper_wrapper::KeyKeeperSharedState::start_new(), - common_state: CommonState::start_new(), + common_state: CommonState::start_new(cancellation_token.clone()), provision_shared_state: provision_wrapper::ProvisionSharedState::start_new(), agent_status_shared_state: agent_status_wrapper::AgentStatusSharedState::start_new(), redirector_shared_state: redirector_wrapper::RedirectorSharedState::start_new(), diff --git a/proxy_agent/src/shared_state/connection_summary_wrapper.rs b/proxy_agent/src/shared_state/connection_summary_wrapper.rs index dd472c78..4d94a57a 100644 --- a/proxy_agent/src/shared_state/connection_summary_wrapper.rs +++ b/proxy_agent/src/shared_state/connection_summary_wrapper.rs @@ -9,17 +9,21 @@ use crate::common::logger; use crate::common::result::Result; use crate::{common::error::Error, proxy::proxy_summary::ProxySummary}; use proxy_agent_shared::proxy_agent_aggregate_status::ProxyConnectionSummary; +use proxy_agent_shared::time_buckets::TimeBucketedItem; use std::collections::{hash_map, HashMap}; use tokio::sync::{mpsc, oneshot}; +const BUCKET_DURATION_SECS: u64 = 900; // 15-minute buckets +const MAX_AGE_SECS: u64 = 4 * 3600; // 4 hours + enum ConnectionSummaryAction { AddOneConnection { summary: ProxySummary, - response: oneshot::Sender<()>, + response: oneshot::Sender, }, AddOneFailedConnection { summary: ProxySummary, - response: oneshot::Sender<()>, + response: oneshot::Sender, }, GetAllConnection { response: oneshot::Sender>, @@ -39,47 +43,60 @@ impl ConnectionSummarySharedState { pub fn start_new() -> Self { let (tx, mut rx) = mpsc::channel(100); tokio::spawn(async move { - // The proxy connection summary from the proxy - let mut proxy_summary: HashMap = HashMap::new(); - // The failed authenticate summary from the proxy - let mut failed_authenticate_summary: HashMap = + // The proxy connection summary from the proxy (using time-bucketed items) + let mut proxy_summary: HashMap> = HashMap::new(); + // The failed authenticate summary from the proxy (using time-bucketed items) + let mut failed_authenticate_summary: HashMap< + String, + TimeBucketedItem, + > = HashMap::new(); + let max_age_duration = std::time::Duration::from_secs(MAX_AGE_SECS); + let bucket_duration = std::time::Duration::from_secs(BUCKET_DURATION_SECS); while let Some(action) = rx.recv().await { match action { ConnectionSummaryAction::AddOneConnection { summary, response } => { + let mut is_new_bucket = true; let key = summary.to_key_string(); if let hash_map::Entry::Vacant(e) = proxy_summary.entry(key.clone()) { - e.insert(summary.into()); + e.insert(TimeBucketedItem::new( + summary.into(), + bucket_duration, + max_age_duration, + )); } else if let Some(connection_summary) = proxy_summary.get_mut(&key) { - //increase_count(connection_summary); - connection_summary.count += 1; + is_new_bucket = connection_summary.add_one(); } - if response.send(()).is_err() { + if response.send(is_new_bucket).is_err() { logger::write_warning("Failed to send response to ConnectionSummaryAction::AddOneConnection".to_string()); } } ConnectionSummaryAction::AddOneFailedConnection { summary, response } => { + let mut is_new_bucket = true; let key = summary.to_key_string(); if let hash_map::Entry::Vacant(e) = failed_authenticate_summary.entry(key.clone()) { - e.insert(summary.into()); + e.insert(TimeBucketedItem::new( + summary.into(), + bucket_duration, + max_age_duration, + )); } else if let Some(connection_summary) = failed_authenticate_summary.get_mut(&key) { - //increase_count(connection_summary); - connection_summary.count += 1; + is_new_bucket = connection_summary.add_one(); } - if response.send(()).is_err() { + if response.send(is_new_bucket).is_err() { logger::write_warning("Failed to send response to ConnectionSummaryAction::AddOneFailedConnection".to_string()); } } ConnectionSummaryAction::GetAllConnection { response } => { - let mut copy_summary: Vec = Vec::new(); - for (_, connection_summary) in proxy_summary.iter() { - copy_summary.push(connection_summary.clone()); - } + // Remove entries with no recent connections and collect summaries + proxy_summary.retain(|_, v| !v.is_empty()); + let copy_summary: Vec = + proxy_summary.values_mut().map(|v| v.to_item()).collect(); if let Err(summary) = response.send(copy_summary) { logger::write_warning(format!( "Failed to send response to ConnectionSummaryAction::GetAllConnection with summary count '{:?}'", @@ -88,10 +105,12 @@ impl ConnectionSummarySharedState { } } ConnectionSummaryAction::GetAllFailedConnection { response } => { - let mut copy_summary: Vec = Vec::new(); - for (_, connection_summary) in failed_authenticate_summary.iter() { - copy_summary.push(connection_summary.clone()); - } + // Remove entries with no recent failed connections and collect summaries + failed_authenticate_summary.retain(|_, v| !v.is_empty()); + let copy_summary: Vec = failed_authenticate_summary + .values_mut() + .map(|v| v.to_item()) + .collect(); if let Err(summary) = response.send(copy_summary) { logger::write_warning(format!( "Failed to send response to ConnectionSummaryAction::GetAllFailedConnection with summary count '{:?}'", @@ -100,6 +119,7 @@ impl ConnectionSummarySharedState { } } ConnectionSummaryAction::ClearAll { response } => { + // force clear all summaries proxy_summary.clear(); failed_authenticate_summary.clear(); if response.send(()).is_err() { @@ -116,7 +136,10 @@ impl ConnectionSummarySharedState { ConnectionSummarySharedState(tx) } - pub async fn add_one_connection_summary(&self, summary: ProxySummary) -> Result<()> { + /// Add one connection summary + /// Returns true if a new time-bucketed item was created. + /// It does implicitly removes expired time-bucketed items + pub async fn add_one_connection_summary(&self, summary: ProxySummary) -> Result { let (response_tx, response_rx) = oneshot::channel(); self.0 .send(ConnectionSummaryAction::AddOneConnection { @@ -135,7 +158,10 @@ impl ConnectionSummarySharedState { }) } - pub async fn add_one_failed_connection_summary(&self, summary: ProxySummary) -> Result<()> { + /// Add one failed connection summary + /// Returns true if a new time bucket is created for this summary, false otherwise + /// It does implicitly removes expired time-bucketed items + pub async fn add_one_failed_connection_summary(&self, summary: ProxySummary) -> Result { let (response_tx, response_rx) = oneshot::channel(); self.0 .send(ConnectionSummaryAction::AddOneFailedConnection { @@ -157,6 +183,7 @@ impl ConnectionSummarySharedState { }) } + /// Clear both connection summaries explicitly pub async fn clear_all_summary(&self) -> Result<()> { let (response_tx, response_rx) = oneshot::channel(); self.0 @@ -176,6 +203,8 @@ impl ConnectionSummarySharedState { Ok(()) } + /// Get success connection summaries + /// Returns a vector of ProxyConnectionSummary, implicitly removed expired time-bucketed items pub async fn get_all_connection_summary(&self) -> Result> { let (response_tx, response_rx) = oneshot::channel(); self.0 @@ -194,6 +223,8 @@ impl ConnectionSummarySharedState { }) } + /// Get failed connection summaries + /// Returns a vector of ProxyConnectionSummary, implicitly removed expired time-bucketed items pub async fn get_all_failed_connection_summary(&self) -> Result> { let (response_tx, response_rx) = oneshot::channel(); self.0 diff --git a/proxy_agent_extension/Cargo.toml b/proxy_agent_extension/Cargo.toml index 10e66563..c36ec78d 100644 --- a/proxy_agent_extension/Cargo.toml +++ b/proxy_agent_extension/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ProxyAgentExt" -version = "1.0.39" # always 3-number version +version = "1.0.40" # always 3-number version edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/proxy_agent_extension/src/handler_main.rs b/proxy_agent_extension/src/handler_main.rs index 41920f4d..94220412 100644 --- a/proxy_agent_extension/src/handler_main.rs +++ b/proxy_agent_extension/src/handler_main.rs @@ -6,6 +6,7 @@ use crate::logger; use crate::structs; use crate::ExtensionCommand; use once_cell::sync::Lazy; +use proxy_agent_shared::current_info; use proxy_agent_shared::misc_helpers; use proxy_agent_shared::version::Version; use std::fs::{self}; @@ -41,7 +42,7 @@ pub async fn program_start(command: ExtensionCommand, config_seq_no: String) { logger::write(format!( "GuestProxyAgentExtension Version: {}, OS Arch: {}, OS Version: {}", - misc_helpers::get_current_version(), + current_info::get_current_exe_version(), misc_helpers::get_processor_arch(), misc_helpers::get_long_os_version() )); diff --git a/proxy_agent_extension/src/service_main.rs b/proxy_agent_extension/src/service_main.rs index aeeac69a..3739a9ba 100644 --- a/proxy_agent_extension/src/service_main.rs +++ b/proxy_agent_extension/src/service_main.rs @@ -4,6 +4,7 @@ use crate::common; use crate::constants; use crate::logger; use crate::structs::*; +use proxy_agent_shared::current_info; use proxy_agent_shared::logger::LoggerLevel; use proxy_agent_shared::proxy_agent_aggregate_status::{ self, GuestProxyAgentAggregateStatus, ProxyConnectionSummary, @@ -28,7 +29,7 @@ const MAX_STATE_COUNT: u32 = 120; pub fn run() { let message = format!( "============== GuestProxyAgentExtension Enabling Agent, Version: {}, OS Arch: {}, OS Version: {}", - misc_helpers::get_current_version(), + current_info::get_current_exe_version(), misc_helpers::get_processor_arch(), misc_helpers::get_long_os_version() ); diff --git a/proxy_agent_setup/Cargo.toml b/proxy_agent_setup/Cargo.toml index ccc017b0..7a1e3a33 100644 --- a/proxy_agent_setup/Cargo.toml +++ b/proxy_agent_setup/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "proxy_agent_setup" -version = "1.0.39" +version = "1.0.40" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/proxy_agent_setup/src/main.rs b/proxy_agent_setup/src/main.rs index 1e14ba0b..4a90862e 100644 --- a/proxy_agent_setup/src/main.rs +++ b/proxy_agent_setup/src/main.rs @@ -13,6 +13,7 @@ pub mod setup; mod linux; use clap::Parser; +use proxy_agent_shared::current_info; use proxy_agent_shared::misc_helpers; use proxy_agent_shared::service; use std::process; @@ -32,7 +33,7 @@ async fn main() { let cli = args::Cli::parse(); logger::write(format!( "\r\n\r\n============== ProxyAgent Setup Tool ({}) is starting with args: {} ==============", - misc_helpers::get_current_version(), + current_info::get_current_exe_version(), cli )); diff --git a/proxy_agent_setup/src/running.rs b/proxy_agent_setup/src/running.rs index 93905cb8..f6048e88 100644 --- a/proxy_agent_setup/src/running.rs +++ b/proxy_agent_setup/src/running.rs @@ -26,8 +26,7 @@ pub fn proxy_agent_running_folder(_service_name: &str) -> PathBuf { pub fn proxy_agent_parent_folder() -> PathBuf { #[cfg(windows)] { - let path = misc_helpers::resolve_env_variables("%SYSTEMDRIVE%\\WindowsAzure\\ProxyAgent") - .unwrap_or("C:\\WindowsAzure\\ProxyAgent".to_string()); + let path = misc_helpers::resolve_env_variables("%SYSTEMDRIVE%\\WindowsAzure\\ProxyAgent"); PathBuf::from(path) } #[cfg(not(windows))] diff --git a/proxy_agent_shared/Cargo.toml b/proxy_agent_shared/Cargo.toml index d8137078..d105b15e 100644 --- a/proxy_agent_shared/Cargo.toml +++ b/proxy_agent_shared/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "proxy_agent_shared" -version = "1.0.39" +version = "1.0.40" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -8,7 +8,7 @@ edition = "2021" [dependencies] concurrent-queue = "2.1.0" # for event queue once_cell = "1.17.0" # use Lazy -time = { version = "0.3.30", features = ["formatting", "parsing"] } +time = { version = "0.3.47", features = ["formatting", "parsing"] } thread-id = "4.0.0" serde = "1.0.152" serde_derive = "1.0.152" @@ -16,10 +16,10 @@ serde_json = "1.0.91" # json Deserializer serde-xml-rs = "0.8.1" # xml Deserializer with xml attribute regex = "1.11" # match file name thiserror = "1.0.64" -tokio = { version = "1", features = ["rt", "macros", "sync", "time"] } +tokio = { version = "1", features = ["fs", "rt", "macros", "net", "sync", "time"] } tokio-util = "0.7.11" +libc = "0.2.147" log = { version = "0.4.26", features = ["std"] } -ctor = "0.3.6" # used for test setup and clean up hex = "0.4.3" # hex encode itertools = "0.10.5" # use to sort iterator elements into a new iterator in ascending order http = "1.1.0" diff --git a/proxy_agent_shared/src/common_state.rs b/proxy_agent_shared/src/common_state.rs index e5ab29a5..4a4ce0c9 100644 --- a/proxy_agent_shared/src/common_state.rs +++ b/proxy_agent_shared/src/common_state.rs @@ -4,8 +4,10 @@ //! This module contains the logic to get and update common states. use crate::result::Result; -use crate::{error::Error, logger::logger_manager, telemetry::event_reader::VmMetaData}; -use tokio::sync::{mpsc, oneshot}; +use crate::{error::Error, logger::logger_manager, telemetry::telemetry_event::VmMetaData}; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot, Notify}; +use tokio_util::sync::CancellationToken; pub const SECURE_KEY_GUID: &str = "key_guid"; pub const SECURE_KEY_VALUE: &str = "key_value"; @@ -27,18 +29,27 @@ enum CommonStateAction { key: String, response: oneshot::Sender>, }, + GetTelemetryEventNotify { + response: oneshot::Sender>, + }, } #[derive(Clone, Debug)] -pub struct CommonState(mpsc::Sender); +pub struct CommonState { + /// The cancellation token is used to cancel the agent when the agent is stopped + cancellation_token: CancellationToken, + sender: mpsc::Sender, +} impl CommonState { - pub fn start_new() -> Self { + pub fn start_new(cancellation_token: CancellationToken) -> Self { let (sender, mut receiver) = mpsc::channel(100); tokio::spawn(async move { let mut vm_meta_data: Option = None; let mut states: std::collections::HashMap = std::collections::HashMap::new(); + let telemetry_event_notify = Arc::new(Notify::new()); + loop { match receiver.recv().await { Some(CommonStateAction::SetVmMetaData { @@ -79,6 +90,13 @@ impl CommonState { )); } } + Some(CommonStateAction::GetTelemetryEventNotify { response }) => { + if let Err(notify) = response.send(telemetry_event_notify.clone()) { + logger_manager::write_warn(format!( + "Failed to send response to CommonStateAction::GetTelemetryEventNotify '{notify:?}'" + )); + } + } None => { break; } @@ -86,12 +104,15 @@ impl CommonState { } }); - Self(sender) + Self { + cancellation_token, + sender, + } } pub async fn set_vm_meta_data(&self, vm_meta_data: Option) -> Result<()> { let (response, receiver) = oneshot::channel(); - self.0 + self.sender .send(CommonStateAction::SetVmMetaData { vm_meta_data, response, @@ -110,7 +131,7 @@ impl CommonState { pub async fn get_vm_meta_data(&self) -> Result> { let (response, receiver) = oneshot::channel(); - self.0 + self.sender .send(CommonStateAction::GetVmMetaData { response }) .await .map_err(|e| { @@ -126,7 +147,7 @@ impl CommonState { pub async fn set_state(&self, key: String, value: String) -> Result<()> { let (response, receiver) = oneshot::channel(); - self.0 + self.sender .send(CommonStateAction::SetState { key, value, @@ -143,7 +164,7 @@ impl CommonState { pub async fn get_state(&self, key: String) -> Result> { let (response, receiver) = oneshot::channel(); - self.0 + self.sender .send(CommonStateAction::GetState { key, response }) .await .map_err(|e| { @@ -153,4 +174,121 @@ impl CommonState { .await .map_err(|e| Error::RecvError("CommonStateAction::GetState".to_string(), e)) } + + pub async fn get_telemetry_event_notify(&self) -> Result> { + let (response, receiver) = oneshot::channel(); + self.sender + .send(CommonStateAction::GetTelemetryEventNotify { response }) + .await + .map_err(|e| { + Error::SendError( + "CommonStateAction::GetTelemetryEventNotify".to_string(), + e.to_string(), + ) + })?; + receiver.await.map_err(|e| { + Error::RecvError("CommonStateAction::GetTelemetryEventNotify".to_string(), e) + }) + } + + pub async fn notify_telemetry_event(&self) -> Result<()> { + let notify = self.get_telemetry_event_notify().await?; + notify.notify_one(); + Ok(()) + } + + pub fn get_cancellation_token(&self) -> CancellationToken { + self.cancellation_token.clone() + } + + pub fn cancel_cancellation_token(&self) { + self.cancellation_token.cancel(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_common_state_key_value_operations() { + let cancellation_token = CancellationToken::new(); + let common_state = CommonState::start_new(cancellation_token); + + // Get non-existent key should return None + let value = common_state + .get_state("non_existent_key".to_string()) + .await + .unwrap(); + assert!(value.is_none(), "Non-existent key should return None"); + + // Set and get a key-value pair + common_state + .set_state(SECURE_KEY_GUID.to_string(), "test-guid-value".to_string()) + .await + .unwrap(); + let value = common_state + .get_state(SECURE_KEY_GUID.to_string()) + .await + .unwrap(); + assert_eq!(value, Some("test-guid-value".to_string())); + + // Set and get another key-value pair + common_state + .set_state(SECURE_KEY_VALUE.to_string(), "test-key-value".to_string()) + .await + .unwrap(); + let value = common_state + .get_state(SECURE_KEY_VALUE.to_string()) + .await + .unwrap(); + assert_eq!(value, Some("test-key-value".to_string())); + + // Update existing key + common_state + .set_state( + SECURE_KEY_GUID.to_string(), + "updated-guid-value".to_string(), + ) + .await + .unwrap(); + let value = common_state + .get_state(SECURE_KEY_GUID.to_string()) + .await + .unwrap(); + assert_eq!(value, Some("updated-guid-value".to_string())); + + // First key should still have its value + let value = common_state + .get_state(SECURE_KEY_VALUE.to_string()) + .await + .unwrap(); + assert_eq!(value, Some("test-key-value".to_string())); + } + + #[tokio::test] + async fn test_common_state_multiple_operations() { + let cancellation_token = CancellationToken::new(); + let common_state = CommonState::start_new(cancellation_token); + + // Perform multiple operations in sequence + for i in 0..10 { + let key = format!("key_{}", i); + let value = format!("value_{}", i); + common_state + .set_state(key.clone(), value.clone()) + .await + .unwrap(); + let retrieved = common_state.get_state(key).await.unwrap(); + assert_eq!(retrieved, Some(value)); + } + + // Verify all values are still accessible + for i in 0..10 { + let key = format!("key_{}", i); + let expected_value = format!("value_{}", i); + let retrieved = common_state.get_state(key).await.unwrap(); + assert_eq!(retrieved, Some(expected_value)); + } + } } diff --git a/proxy_agent_shared/src/current_info.rs b/proxy_agent_shared/src/current_info.rs index e9e66544..c2070e8c 100644 --- a/proxy_agent_shared/src/current_info.rs +++ b/proxy_agent_shared/src/current_info.rs @@ -45,6 +45,8 @@ static CURRENT_OS_INFO: Lazy<(String, String)> = Lazy::new(|| { (arch, os) }); +static CURRENT_EXE_VERSION: Lazy = Lazy::new(misc_helpers::get_current_exe_version); + pub fn get_ram_in_mb() -> u64 { CURRENT_SYS_INFO.0 } @@ -61,6 +63,10 @@ pub fn get_long_os_version() -> String { CURRENT_OS_INFO.1.to_string() } +pub fn get_current_exe_version() -> String { + CURRENT_EXE_VERSION.clone() +} + #[cfg(test)] mod tests { #[test] diff --git a/proxy_agent_shared/src/etw/etw_reader.rs b/proxy_agent_shared/src/etw/etw_reader.rs index 5c06065b..bbc6ec44 100644 --- a/proxy_agent_shared/src/etw/etw_reader.rs +++ b/proxy_agent_shared/src/etw/etw_reader.rs @@ -32,21 +32,25 @@ pub struct System { pub provider: Provider, #[serde(rename = "EventID")] pub event_id: u32, - #[serde(rename = "Version")] + /// Version is only present in ETW events, not classic event log entries + #[serde(rename = "Version", default)] pub version: u8, #[serde(rename = "Level")] pub level: u8, - #[serde(rename = "Task")] + /// Task is only present in ETW events, not classic event log entries + #[serde(rename = "Task", default)] pub task: u8, - #[serde(rename = "Opcode")] + /// Opcode is only present in ETW events, not classic event log entries + #[serde(rename = "Opcode", default)] pub opcode: u8, - #[serde(rename = "Keywords")] + #[serde(rename = "Keywords", default)] pub keywords: String, #[serde(rename = "TimeCreated")] pub time_created: TimeCreated, #[serde(rename = "EventRecordID")] pub event_record_id: u64, - #[serde(rename = "Execution")] + /// Execution may not be present in some classic event log entries + #[serde(rename = "Execution", default)] pub execution: Execution, #[serde(rename = "Channel")] pub channel: String, @@ -68,11 +72,11 @@ pub struct TimeCreated { pub system_time: Option, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Default, Deserialize, Serialize)] pub struct Execution { - #[serde(rename = "@ProcessID")] + #[serde(rename = "@ProcessID", default)] pub process_id: u32, - #[serde(rename = "@ThreadID")] + #[serde(rename = "@ThreadID", default)] pub thread_id: u32, } diff --git a/proxy_agent_shared/src/etw/etw_writter.rs b/proxy_agent_shared/src/etw/etw_writter.rs index bc71d821..c2ff3798 100644 --- a/proxy_agent_shared/src/etw/etw_writter.rs +++ b/proxy_agent_shared/src/etw/etw_writter.rs @@ -45,7 +45,7 @@ impl WindowsEventWritter { ); let value = crate::misc_helpers::resolve_env_variables( r"%SystemRoot%\Microsoft.NET\Framework64\v4.0.30319\EventLogMessages.dll", - )?; + ); crate::windows::set_reg_string(&key_name, "EventMessageFile", value)?; let source_name_wide = super::to_wide(source_name); diff --git a/proxy_agent_shared/src/host_clients/imds_client.rs b/proxy_agent_shared/src/host_clients/imds_client.rs index 265e3794..cc067545 100644 --- a/proxy_agent_shared/src/host_clients/imds_client.rs +++ b/proxy_agent_shared/src/host_clients/imds_client.rs @@ -6,10 +6,9 @@ //! The GPA service uses the IMDS service to get the instance information of the VM. use super::instance_info::InstanceInfo; -use crate::hyper_client; +use crate::hyper_client::{self, HostEndpoint}; use crate::logger::logger_manager; -use crate::{error::Error, result::Result}; -use hyper::Uri; +use crate::result::Result; use std::collections::HashMap; pub struct ImdsClient { @@ -17,7 +16,7 @@ pub struct ImdsClient { port: u16, } -const IMDS_URI: &str = "metadata/instance?api-version=2018-02-01"; +const IMDS_URI: &str = "/metadata/instance?api-version=2018-02-01"; impl ImdsClient { pub fn new(ip: &str, port: u16) -> Self { @@ -27,19 +26,26 @@ impl ImdsClient { } } + fn endpoint(&self, path: &str) -> HostEndpoint { + HostEndpoint::new(&self.ip, self.port, path) + } + pub async fn get_imds_instance_info( &self, key_guid: Option, key: Option, ) -> Result { - let url: String = format!("http://{}:{}/{}", self.ip, self.port, IMDS_URI); - - let url: Uri = url - .parse::() - .map_err(|e| Error::ParseUrl(url, e.to_string()))?; + let endpoint = self.endpoint(IMDS_URI); let mut headers = HashMap::new(); headers.insert("Metadata".to_string(), "true".to_string()); - hyper_client::get(&url, &headers, key_guid, key, logger_manager::write_warn).await + hyper_client::get( + &endpoint, + &headers, + key_guid, + key, + logger_manager::write_warn, + ) + .await } } diff --git a/proxy_agent_shared/src/host_clients/wire_server_client.rs b/proxy_agent_shared/src/host_clients/wire_server_client.rs index 5d7ada35..e36eece3 100644 --- a/proxy_agent_shared/src/host_clients/wire_server_client.rs +++ b/proxy_agent_shared/src/host_clients/wire_server_client.rs @@ -4,14 +4,13 @@ //! This module contains the logic to interact with the wire server for sending telemetry data and getting goal state. use crate::host_clients::goal_state::{GoalState, SharedConfig}; -use crate::hyper_client; +use crate::hyper_client::{self, HostEndpoint}; use crate::{ error::{Error, WireServerErrorType}, logger::logger_manager, result::Result, }; use http::Method; -use hyper::Uri; use std::collections::HashMap; pub struct WireServerClient { @@ -19,8 +18,8 @@ pub struct WireServerClient { port: u16, } -const TELEMETRY_DATA_URI: &str = "machine/?comp=telemetrydata"; -const GOALSTATE_URI: &str = "machine?comp=goalstate"; +const TELEMETRY_DATA_URI: &str = "/machine/?comp=telemetrydata"; +const GOALSTATE_URI: &str = "/machine?comp=goalstate"; impl WireServerClient { pub fn new(ip: &str, port: u16) -> Self { @@ -30,15 +29,16 @@ impl WireServerClient { } } + fn endpoint(&self, path: &str) -> HostEndpoint { + HostEndpoint::new(&self.ip, self.port, path) + } + pub async fn send_telemetry_data(&self, xml_data: String) -> Result<()> { if xml_data.is_empty() { return Ok(()); } - let url = format!("http://{}:{}/{}", self.ip, self.port, TELEMETRY_DATA_URI); - let url: Uri = url - .parse::() - .map_err(|e| Error::ParseUrl(url, e.to_string()))?; + let endpoint = self.endpoint(TELEMETRY_DATA_URI); let mut headers = HashMap::new(); headers.insert("x-ms-version".to_string(), "2012-11-30".to_string()); headers.insert( @@ -48,15 +48,15 @@ impl WireServerClient { let request = hyper_client::build_request( Method::POST, - &url, + &endpoint, &headers, Some(xml_data.as_bytes()), None, // post telemetry data does not require signing None, )?; let response = match hyper_client::send_request( - &self.ip, - self.port, + &endpoint.host, + endpoint.port, request, logger_manager::write_warn, ) @@ -75,7 +75,7 @@ impl WireServerClient { if !status.is_success() { return Err(Error::WireServer( WireServerErrorType::Telemetry, - format!("Failed to get response from {url}, status code: {status}"), + format!("Failed to get response from {endpoint}, status code: {status}"), )); } @@ -87,16 +87,19 @@ impl WireServerClient { key_guid: Option, key: Option, ) -> Result { - let url = format!("http://{}:{}/{}", self.ip, self.port, GOALSTATE_URI); - let url = url - .parse::() - .map_err(|e| Error::ParseUrl(url, e.to_string()))?; + let endpoint = self.endpoint(GOALSTATE_URI); let mut headers = HashMap::new(); headers.insert("x-ms-version".to_string(), "2012-11-30".to_string()); - hyper_client::get(&url, &headers, key_guid, key, logger_manager::write_warn) - .await - .map_err(|e| Error::WireServer(WireServerErrorType::GoalState, e.to_string())) + hyper_client::get( + &endpoint, + &headers, + key_guid, + key, + logger_manager::write_warn, + ) + .await + .map_err(|e| Error::WireServer(WireServerErrorType::GoalState, e.to_string())) } pub async fn get_shared_config( @@ -106,13 +109,20 @@ impl WireServerClient { key: Option, ) -> Result { let mut headers = HashMap::new(); - let url = url - .parse::() - .map_err(|e| Error::ParseUrl(url, e.to_string()))?; headers.insert("x-ms-version".to_string(), "2012-11-30".to_string()); - hyper_client::get(&url, &headers, key_guid, key, logger_manager::write_warn) - .await - .map_err(|e| Error::WireServer(WireServerErrorType::SharedConfig, e.to_string())) + let uri = url + .parse::() + .map_err(|e| Error::ParseUrl(url.clone(), e.to_string()))?; + let endpoint = HostEndpoint::from_full_uri(uri)?; + hyper_client::get( + &endpoint, + &headers, + key_guid, + key, + logger_manager::write_warn, + ) + .await + .map_err(|e| Error::WireServer(WireServerErrorType::SharedConfig, e.to_string())) } } diff --git a/proxy_agent_shared/src/hyper_client.rs b/proxy_agent_shared/src/hyper_client.rs index d4e0d89e..8fe933a6 100644 --- a/proxy_agent_shared/src/hyper_client.rs +++ b/proxy_agent_shared/src/hyper_client.rs @@ -31,8 +31,94 @@ pub const CLAIMS_IS_ROOT: &str = "isRoot"; const LF: &str = "\n"; +/// Pre-parsed HTTP endpoint containing host, port, and path/query. +/// Use this to avoid re-parsing URIs multiple times which is performance-sensitive. +#[derive(Debug, Clone)] +pub struct HostEndpoint { + pub host: String, + pub port: u16, + /// The path and query portion of the URI (e.g., "/api/status?version=1") + pub path_and_query: String, +} + +impl HostEndpoint { + pub const DEFAULT_HTTP_PORT: u16 = 80; + pub const DEFAULT_HTTPS_PORT: u16 = 443; + + /// Create a new HostEndpoint with explicit components + pub fn new(host: impl Into, port: u16, path_and_query: impl Into) -> Self { + Self { + host: host.into(), + port, + path_and_query: path_and_query.into(), + } + } + + /// Create a HostEndpoint from a full URI string (e.g., "http://host:port/path?query") + /// This will parse the URI and extract the host, port, and path/query components. + /// Remark: Do not use this function in performance-sensitive code paths, as URI parsing can be relatively expensive. + /// Instead, use the `new` constructor with pre-parsed components when possible. + /// Remark: This function assumes the URI is well-formed and contains a host. It will return an error if the URI is invalid or missing required components. + pub fn from_full_uri(uri: Uri) -> Result { + let host = match uri.host() { + Some(h) => h.to_string(), + None => { + return Err(Error::Hyper(HyperErrorType::RequestBuilder( + "URI must have a host".to_string(), + ))); + } + }; + let default_port = if uri.scheme_str() == Some("https") { + Self::DEFAULT_HTTPS_PORT + } else { + Self::DEFAULT_HTTP_PORT + }; + let port = uri.port_u16().unwrap_or(default_port); + let path_and_query = match uri.path_and_query() { + Some(pq) => pq.as_str().to_string(), + None => "/".to_string(), // default to root path + }; + + Ok(Self { + host, + port, + path_and_query, + }) + } + + /// Create a HostEndpoint from a URI string (e.g., "http://host:port/path?query") + /// This will parse the URI and extract the host, port, and path/query components. + /// Remark: Do not use this function in performance-sensitive code paths, as URI parsing can be relatively expensive. + /// Instead, use the `new` constructor with pre-parsed components when possible. + pub fn from_uri_str(uri_str: &str) -> Result { + let uri = uri_str.parse::().map_err(|e| { + Error::Hyper(HyperErrorType::RequestBuilder(format!( + "Failed to parse URI string: {uri_str} with error: {e}" + ))) + })?; + Self::from_full_uri(uri) + } + + /// Get the address string for TCP connection (host:port) + #[inline] + pub fn addr(&self) -> String { + format!("{}:{}", self.host, self.port) + } +} + +impl std::fmt::Display for HostEndpoint { + /// Format as full URI string (e.g., "http://host:port/path?query") + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "http://{}:{}{}", + self.host, self.port, self.path_and_query + ) + } +} + pub async fn get( - full_url: &Uri, + endpoint: &HostEndpoint, headers: &HashMap, key_guid: Option, key: Option, @@ -42,14 +128,13 @@ where T: DeserializeOwned, F: Fn(String) + Send + 'static, { - let request = build_request(Method::GET, full_url, headers, None, key_guid, key)?; + let request = build_request(Method::GET, endpoint, headers, None, key_guid, key)?; - let (host, port) = host_port_from_uri(full_url)?; - let response = send_request(&host, port, request, log_fun).await?; + let response = send_request(&endpoint.host, endpoint.port, request, log_fun).await?; let status = response.status(); if !status.is_success() { return Err(Error::Hyper(HyperErrorType::ServerError( - full_url.to_string(), + endpoint.to_string(), status, ))); } @@ -162,22 +247,20 @@ where pub fn build_request( method: http::Method, - full_url: &Uri, + endpoint: &HostEndpoint, headers: &HashMap, body: Option<&[u8]>, key_guid: Option, key: Option, ) -> Result>> { - let (host, _) = host_port_from_uri(full_url)?; - let mut request_builder = Request::builder() .method(method) - .uri(match full_url.path_and_query() { - Some(pq) => pq.as_str(), - None => full_url.path(), - }) + .uri(&endpoint.path_and_query) .header(DATE_HEADER, misc_helpers::get_date_time_rfc1123_string()) - .header(hyper::header::HOST, host) + // The header() method accepts types that implement Into, and &str implements this trait. + // The HeaderValue will internally copy the bytes (which is unavoidable since it needs to own the data), + // So you're not creating any intermediate String allocations. + .header(hyper::header::HOST, &endpoint.host) .header( CLAIMS_HEADER, format!("{{ \"{}\": \"{}\"}}", CLAIMS_IS_ROOT, true,), @@ -278,21 +361,6 @@ where Ok(sender) } -pub fn host_port_from_uri(full_url: &Uri) -> Result<(String, u16)> { - let host = match full_url.host() { - Some(h) => h.to_string(), - None => { - return Err(Error::ParseUrl( - full_url.to_string(), - "Failed to get host from uri".to_string(), - )) - } - }; - let port = full_url.port_u16().unwrap_or(80); - - Ok((host, port)) -} - /* StringToSign = Method + "\n" + HexEncoded(Body) + "\n" + @@ -471,6 +539,7 @@ mod tests { use crate::{ host_clients::{imds_client::ImdsClient, wire_server_client::WireServerClient}, logger::logger_manager, + server_mock, }; use tokio_util::sync::CancellationToken; @@ -510,13 +579,11 @@ mod tests { async fn http_request_tests() { // start mock server let ip = "127.0.0.1"; - let port = 7072u16; + let port = 9072u16; let cancellation_token = CancellationToken::new(); - tokio::spawn(crate::server_mock::start( - ip.to_string(), - port, - cancellation_token.clone(), - )); + let port = server_mock::start(ip.to_string(), port, cancellation_token.clone()) + .await + .unwrap(); tokio::time::sleep(std::time::Duration::from_millis(100)).await; logger_manager::write_info("server_mock started.".to_string()); diff --git a/proxy_agent_shared/src/lib.rs b/proxy_agent_shared/src/lib.rs index 3961bb6a..6549eb75 100644 --- a/proxy_agent_shared/src/lib.rs +++ b/proxy_agent_shared/src/lib.rs @@ -15,6 +15,7 @@ pub mod result; pub mod secrets_redactor; pub mod service; pub mod telemetry; +pub mod time_buckets; pub mod version; #[cfg(windows)] diff --git a/proxy_agent_shared/src/logger/logger_manager.rs b/proxy_agent_shared/src/logger/logger_manager.rs index 4ab046b8..4d3bf188 100644 --- a/proxy_agent_shared/src/logger/logger_manager.rs +++ b/proxy_agent_shared/src/logger/logger_manager.rs @@ -195,11 +195,12 @@ fn get_max_system_logger_level() -> LoggerLevel { mod tests { use crate::logger::LoggerLevel; use crate::misc_helpers; - use ctor::{ctor, dtor}; use std::env; use std::fs; + use std::sync::Once; const TEST_LOGGER_KEY: &str = "logger_manager_test"; + static TEST_INIT: Once = Once::new(); fn get_temp_test_dir() -> std::path::PathBuf { let mut temp_test_path = env::temp_dir(); @@ -207,32 +208,38 @@ mod tests { temp_test_path } - #[ctor] fn setup() { - // Setup logger_manager for unit tests - let logger = crate::logger::rolling_logger::RollingLogger::create_new( - get_temp_test_dir(), - "test.log".to_string(), - 200, - 6, - ); - let mut loggers = std::collections::HashMap::new(); - loggers.insert(TEST_LOGGER_KEY.to_string(), logger); - crate::logger::logger_manager::set_loggers( - loggers, - TEST_LOGGER_KEY.to_string(), - LoggerLevel::Trace, - ); + TEST_INIT.call_once(|| { + // Setup logger_manager for unit tests + let logger = crate::logger::rolling_logger::RollingLogger::create_new( + get_temp_test_dir(), + "test.log".to_string(), + 200, + 6, + ); + let mut loggers = std::collections::HashMap::new(); + loggers.insert(TEST_LOGGER_KEY.to_string(), logger); + crate::logger::logger_manager::set_loggers( + loggers, + TEST_LOGGER_KEY.to_string(), + LoggerLevel::Trace, + ); + + unsafe { + libc::atexit(cleanup); + } + }); } - #[dtor] - fn cleanup() { + extern "C" fn cleanup() { // clean up and ignore the clean up errors _ = fs::remove_dir_all(&get_temp_test_dir()); } #[test] fn logger_manager_test() { + setup(); + for _ in [0; 20] { super::write_log( LoggerLevel::Trace, diff --git a/proxy_agent_shared/src/misc_helpers.rs b/proxy_agent_shared/src/misc_helpers.rs index b6ee12b2..9378c1b0 100644 --- a/proxy_agent_shared/src/misc_helpers.rs +++ b/proxy_agent_shared/src/misc_helpers.rs @@ -25,32 +25,47 @@ pub fn get_thread_identity() -> String { format!("{:0>8}", thread_id::get()) } -pub fn get_date_time_string_with_milliseconds() -> String { - let date_format = - format_description::parse("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond]") - .unwrap(); +// Static format descriptors parsed once and reused for all calls +static ISO8601_MILLIS_FORMAT: std::sync::LazyLock< + Vec>, +> = std::sync::LazyLock::new(|| { + format_description::parse("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond]") + .expect("Invalid ISO8601 millis date format") +}); + +static ISO8601_FORMAT: std::sync::LazyLock>> = + std::sync::LazyLock::new(|| { + format_description::parse("[year]-[month]-[day]T[hour]:[minute]:[second]Z") + .expect("Invalid ISO8601 date format") + }); + +// This format is also the preferred HTTP date format. https://httpwg.org/specs/rfc9110.html#http.date +static RFC1123_FORMAT: std::sync::LazyLock>> = + std::sync::LazyLock::new(|| { + format_description::parse( + "[weekday repr:short], [day] [month repr:short] [year] [hour]:[minute]:[second] GMT", + ) + .expect("Invalid RFC1123 date format") + }); - let time_str = OffsetDateTime::now_utc().format(&date_format).unwrap(); +pub fn get_date_time_string_with_milliseconds() -> String { + let time_str = OffsetDateTime::now_utc() + .format(&*ISO8601_MILLIS_FORMAT) + .expect("Failed to format ISO8601 millis date"); + // Truncate to 23 chars: "YYYY-MM-DDTHH:MM:SS.mmm" time_str.chars().take(23).collect() } pub fn get_date_time_string() -> String { - let date_format = - format_description::parse("[year]-[month]-[day]T[hour]:[minute]:[second]Z").unwrap(); - - let time_str = OffsetDateTime::now_utc().format(&date_format).unwrap(); - time_str.chars().collect() + OffsetDateTime::now_utc() + .format(&*ISO8601_FORMAT) + .expect("Failed to format ISO8601 date") } -// This format is also the preferred HTTP date format. https://httpwg.org/specs/rfc9110.html#http.date pub fn get_date_time_rfc1123_string() -> String { - let date_format = format_description::parse( - "[weekday repr:short], [day] [month repr:short] [year] [hour]:[minute]:[second] GMT", - ) - .unwrap(); - - let time_str = OffsetDateTime::now_utc().format(&date_format).unwrap(); - time_str.chars().collect() + OffsetDateTime::now_utc() + .format(&*RFC1123_FORMAT) + .expect("Failed to format RFC1123 date") } pub fn get_date_time_unix_nano() -> i128 { @@ -126,19 +141,45 @@ pub fn try_create_folder(dir: &Path) -> Result<()> { Ok(()) } +/// Writes a serializable object to a file in JSON format. +/// It first writes to a temporary file and then renames it to the target file to avoid leaving a corrupted file if the write operation fails. +/// Remark: it uses BufWriter to reduce system calls and improve performance. +/// Remark: Called from sync code, infrequent writes and small objects pub fn json_write_to_file(obj: &T, file_path: &Path) -> Result<()> where T: ?Sized + Serialize, { + use std::io::BufWriter; + // write to a temp file and rename to avoid corrupted file let temp_file_path = file_path.with_extension("tmp"); let file = File::create(&temp_file_path)?; - serde_json::to_writer_pretty(file, obj)?; + let writer = BufWriter::new(file); // Reduces system calls + serde_json::to_writer_pretty(writer, obj)?; std::fs::rename(temp_file_path, file_path)?; Ok(()) } +/// Async version of json_write_to_file using tokio::fs +/// Serializes to memory first (CPU work), then writes asynchronously (IO work) +/// This avoids blocking the async runtime during serialization +/// Remark: Called from async context, writing while handing concurrent requests, and potentially larger objects +pub async fn json_write_to_file_async(obj: &T, file_path: &Path) -> Result<()> +where + T: ?Sized + Serialize, +{ + // Serialize to memory first (CPU work - fast) + let json_bytes = serde_json::to_vec_pretty(obj)?; + + // Write asynchronously (IO work) + let temp_file_path = file_path.with_extension("tmp"); + tokio::fs::write(&temp_file_path, json_bytes).await?; + tokio::fs::rename(&temp_file_path, file_path).await?; + + Ok(()) +} + pub fn json_read_from_file(file_path: &Path) -> Result where T: DeserializeOwned, @@ -214,7 +255,7 @@ pub fn get_current_version() -> String { /// otherwise fallback to Cargo.toml version. /// # Returns /// A string representing the current executable version -pub fn get_current_exe_version() -> String { +pub(crate) fn get_current_exe_version() -> String { #[cfg(windows)] { match try_get_current_exe_version() { @@ -261,28 +302,18 @@ pub fn empty_path() -> PathBuf { PathBuf::new() } -/// Search files in a directory with a regex pattern +/// Search files in a directory with a regex /// # Arguments /// * `dir` - The directory to search -/// * `search_regex_pattern` - The regex pattern to search +/// * `search_regex` - The regex to search /// # Returns /// A vector of PathBufs that match the search pattern in ascending order /// # Errors /// Returns an error if the regex pattern is invalid or if there is an IO error -/// # Example -/// ```rust -/// use std::path::PathBuf; -/// use proxy_agent_shared::misc_helpers; -/// let dir = PathBuf::from("."); -/// let search_regex_pattern = r"^(.*\.log)$"; // search for files with .log extension -/// let files = misc_helpers::search_files(&dir, search_regex_pattern).unwrap(); -/// -/// let search_regex_pattern = r"^MyFile.*\.json$"; // Regex pattern to match "MyFile*.json" -/// let files = misc_helpers::search_files(&dir, search_regex_pattern).unwrap(); -/// ``` -pub fn search_files(dir: &Path, search_regex_pattern: &str) -> Result> { +/// Remarks: The Regex::new is expensive, so the caller should cache the regex if it is used frequently, +/// for example, by using once_cell::sync::Lazy or std::sync::LazyLock to create a static regex instance. +pub fn search_files(dir: &Path, search_regex: &Regex) -> Result> { let mut files = Vec::new(); - let regex = Regex::new(search_regex_pattern)?; for entry in fs::read_dir(dir)? { let entry = entry?; @@ -292,7 +323,7 @@ pub fn search_files(dir: &Path, search_regex_pattern: &str) -> Result Result { } } +/// Static regex for matching environment variables like %VAR% +/// expect() only panics on failure to compile the regex, which should not happen since the pattern is constant and valid +/// This is the idiomatic Rust pattern for creating a static regex that is compiled once and reused, ensuring thread safety and performance +/// Remark: Regex::new is performance-sensitive, so we use LazyLock to compile it only once and reuse it for subsequent calls to resolve_env_variables, which can be called frequently. +/// This avoids the overhead of compiling the regex on every call, improving performance while ensuring thread safety. +static ENV_VAR_REGEX: std::sync::LazyLock = + std::sync::LazyLock::new(|| Regex::new(r"%(\w+)%").expect("Invalid env var regex pattern")); + /// This function replaces all occurrences of %VAR% in the input string with the value of the environment variable VAR /// If the environment variable is not set, it returns the original string with VAR unchanged. /// # Arguments /// * `input` - The input string to resolve environment variables in /// # Returns /// A Result containing the resolved string or an error if the regex pattern is invalid -pub fn resolve_env_variables(input: &str) -> Result { - let re = Regex::new(r"%(\w+)%")?; - let ret = re +/// The resolved string with environment variables expanded +pub fn resolve_env_variables(input: &str) -> String { + if input.is_empty() || !input.contains('%') { + // If the input string is empty or does not contain '%', return the original string + return input.to_string(); + } + + ENV_VAR_REGEX .replace_all(input, |caps: ®ex::Captures| { std::env::var(&caps[1]).unwrap_or_else(|_| caps[1].to_string()) }) - .to_string(); - - Ok(ret) + .to_string() } /// Compute HMAC-SHA256 signature for the input using the provided hex-encoded key @@ -424,6 +466,7 @@ pub fn xml_escape(s: String) -> String { #[cfg(test)] mod tests { + use regex::Regex; use serde_derive::{Deserialize, Serialize}; use std::env; use std::fs; @@ -578,14 +621,16 @@ mod tests { let json_file = json_file.join("test_1.json"); super::json_write_to_file(&test, &json_file).unwrap(); - let files = super::search_files(&temp_test_path, "test.json").unwrap(); + let regex = Regex::new(r"test.json").unwrap(); + let files = super::search_files(&temp_test_path, ®ex).unwrap(); assert_eq!( 1, files.len(), "file count mismatch with 'test.json' search" ); - let files = super::search_files(&temp_test_path, r"^test.*\.json$").unwrap(); + let regex = Regex::new(r"^test.*\.json$").unwrap(); + let files = super::search_files(&temp_test_path, ®ex).unwrap(); assert_eq!( 2, files.len(), @@ -647,12 +692,12 @@ mod tests { "{}\\WindowsAzure\\ProxyAgent\\Package_1.0.0", env::var("SYSTEMDRIVE").unwrap_or("SYSTEMDRIVE".to_string()) ); - let resolved = super::resolve_env_variables(input).unwrap(); + let resolved = super::resolve_env_variables(input); assert_eq!(expected, resolved, "resolved string mismatch"); let input = "/var/log/azure-proxy-agent/"; let expected = "/var/log/azure-proxy-agent/".to_string(); - let resolved = super::resolve_env_variables(input).unwrap(); + let resolved = super::resolve_env_variables(input); assert_eq!(expected, resolved, "resolved string mismatch"); } diff --git a/proxy_agent_shared/src/proxy_agent_aggregate_status.rs b/proxy_agent_shared/src/proxy_agent_aggregate_status.rs index 6a6a2971..9d862db7 100644 --- a/proxy_agent_shared/src/proxy_agent_aggregate_status.rs +++ b/proxy_agent_shared/src/proxy_agent_aggregate_status.rs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT -use crate::misc_helpers; +use crate::{misc_helpers, time_buckets::Countable}; use serde_derive::{Deserialize, Serialize}; use std::{collections::HashMap, path::PathBuf}; use time::OffsetDateTime; @@ -12,8 +12,7 @@ const PROXY_AGENT_AGGREGATE_STATUS_FOLDER: &str = "/var/log/azure-proxy-agent/"; pub const PROXY_AGENT_AGGREGATE_STATUS_FILE_NAME: &str = "status.json"; pub fn get_proxy_agent_aggregate_status_folder() -> std::path::PathBuf { - let path = misc_helpers::resolve_env_variables(PROXY_AGENT_AGGREGATE_STATUS_FOLDER) - .unwrap_or(PROXY_AGENT_AGGREGATE_STATUS_FOLDER.to_string()); + let path = misc_helpers::resolve_env_variables(PROXY_AGENT_AGGREGATE_STATUS_FOLDER); PathBuf::from(path) } @@ -81,6 +80,12 @@ impl Clone for ProxyConnectionSummary { } } +impl Countable for ProxyConnectionSummary { + fn set_count(&mut self, count: u64) { + self.count = count; + } +} + #[derive(Serialize, Deserialize)] #[allow(non_snake_case)] pub struct GuestProxyAgentAggregateStatus { diff --git a/proxy_agent_shared/src/secrets_redactor.rs b/proxy_agent_shared/src/secrets_redactor.rs index bf983664..7cb070f6 100644 --- a/proxy_agent_shared/src/secrets_redactor.rs +++ b/proxy_agent_shared/src/secrets_redactor.rs @@ -1,7 +1,31 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT +use std::borrow::Cow; + const REDACTED_TEXT: &str = "[REDACTED]"; +/// Common substrings that indicate a secret might be present - for quick pre-filtering +/// These are not regex patterns, just simple substrings to check for before running the more expensive regexes. +const SECRET_INDICATORS: [&str; 15] = [ + "pwd=", + "password=", + "AccountKey=", + "PrimaryKey=", + "SecondaryKey=", + "sig=", + "AzCa", + "PRIVATE KEY", + "token", + "ado", + "vsts", + "key", + "secret", + "authorization", + "eyJ", +]; +/// Regular expression patterns to identify secrets. These are more expensive to run, so we first check for indicators. +/// Remarks: when add more patterns, please also add corresponding indicators in SECRET_INDICATORS for better performance. +/// And try to make the pattern as specific as possible to avoid false positives and unnecessary redaction. const CRED_PATTERNS: [&str; 17] = [ // SQL Connection String Password "pwd=[^;]*", @@ -45,20 +69,46 @@ fn init_regex_patterns() -> Vec { patterns } -pub fn redact_secrets(text: String) -> String { - if text.is_empty() { - return text; +/// Quick check if text might contain secrets (case-insensitive for most indicators) +#[inline] +fn might_contain_secrets(text: &str) -> bool { + let lower = text.to_ascii_lowercase(); + SECRET_INDICATORS.iter().any(|indicator| { + if *indicator == "AzCa" || *indicator == "PRIVATE KEY" || *indicator == "eyJ" { + // Case-sensitive check for these + text.contains(indicator) + } else { + lower.contains(&indicator.to_ascii_lowercase()) + } + }) +} + +/// Redacts secrets from text. Returns the original text unchanged if no secrets found. +/// Takes `&str` to avoid unnecessary ownership transfer. +fn redact_secrets(text: &str) -> Cow<'_, str> { + if text.is_empty() || !might_contain_secrets(text) { + return Cow::Borrowed(text); } - let mut redacted_text = text.clone(); + let mut redacted_text = Cow::Borrowed(text); for pattern in REGEX_PATTERNS.iter() { - redacted_text = pattern - .replace_all(&redacted_text, REDACTED_TEXT) - .to_string(); + if let Cow::Owned(s) = pattern.replace_all(&redacted_text, REDACTED_TEXT) { + redacted_text = Cow::Owned(s); + } } redacted_text } +/// Convenience function that takes ownership and returns String +/// Use this when you already have a String and need a String back +#[inline] +pub fn redact_secrets_string(text: String) -> String { + match redact_secrets(&text) { + Cow::Borrowed(_) => text, // No changes, return original + Cow::Owned(s) => s, // Changed, return new string + } +} + #[cfg(test)] mod tests { use super::*; @@ -121,7 +171,23 @@ authorization: aws4-hmac-sha256"#, ), ]; for (input, expected) in test_strings { - assert_eq!(redact_secrets(input.to_string()), expected.to_string()); + assert_eq!(redact_secrets(input), expected); } } + + #[test] + fn test_no_secrets_no_allocation() { + let text = "This is a normal log message without any secrets"; + let result = redact_secrets(text); + // Should return Borrowed (no allocation) when no secrets found + assert!(matches!(result, std::borrow::Cow::Borrowed(_))); + assert_eq!(result, text); + } + + #[test] + fn test_redact_secrets_string() { + let text = "pwd=secret123;".to_string(); + let result = redact_secrets_string(text); + assert_eq!(result, "[REDACTED];"); + } } diff --git a/proxy_agent_shared/src/server_mock.rs b/proxy_agent_shared/src/server_mock.rs index 757c19a1..0472cf93 100644 --- a/proxy_agent_shared/src/server_mock.rs +++ b/proxy_agent_shared/src/server_mock.rs @@ -21,38 +21,61 @@ static EMPTY_GUID: Lazy = Lazy::new(|| "00000000-0000-0000-0000-00000000 static GUID: Lazy = Lazy::new(|| Uuid::new_v4().to_string()); static mut CURRENT_STATE: Lazy = Lazy::new(|| String::from("wireserver")); -pub async fn start(ip: String, port: u16, cancellation_token: CancellationToken) { +/// A mock server to simulate the behavior of the Azure WireServer for testing purposes. +/// It listens for incoming HTTP requests and responds with predefined responses based on the request path and method. +/// The server can be started on a specified IP and port, and it can be gracefully shut down using a cancellation token. +/// If the port is already in use, it will automatically bind to an ephemeral port and log a warning message. +/// Returns the port number that the server is listening on, or an error if the server fails to start. +pub async fn start(ip: String, port: u16, cancellation_token: CancellationToken) -> Result { logger_manager::write_info("Mock Server starting...".to_string()); let addr = format!("{ip}:{port}"); - let listener = TcpListener::bind(&addr).await.unwrap(); - println!("Listening on http://{addr}"); - - loop { - tokio::select! { - _ = cancellation_token.cancelled() => { - logger_manager::write_warn("cancellation token signal received, stop the listener.".to_string()); - return; + let listener = match TcpListener::bind(&addr).await { + Ok(l) => l, + Err(e) => { + // if the specified port is already in use, bind to an ephemeral port instead + if e.kind() == std::io::ErrorKind::AddrInUse { + logger_manager::write_warn(format!( + "Port {port} is already in use, trying to bind to an ephemeral port." + )); + TcpListener::bind(format!("{ip}:0")).await? + } else { + return Err(e.into()); } - result = listener.accept() => { - match result { - Ok((stream, _)) =>{ - let ip = ip.to_string(); - tokio::spawn(async move { - let io = TokioIo::new(stream); + } + }; + let local_addr = listener.local_addr().unwrap(); + // Update the port if we had to bind to an ephemeral port + let port = local_addr.port(); + println!("Mock Server Listening on http://{local_addr}"); + tokio::spawn(async move { + loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + logger_manager::write_warn("cancellation token signal received, stop the listener.".to_string()); + return; + } + result = listener.accept() => { + match result { + Ok((stream, _)) =>{ let ip = ip.to_string(); - let service = service_fn(move |req| handle_request(ip.to_string(), port, req)); - if let Err(err) = http1::Builder::new().serve_connection(io, service).await { - println!("Error serving connection: {err:?}"); - } - }); - }, - Err(e) => { - logger_manager::write_err(format!("Failed to accept connection: {e}")); + tokio::spawn(async move { + let io = TokioIo::new(stream); + let ip = ip.to_string(); + let service = service_fn(move |req| handle_request(ip.to_string(), port, req)); + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { + println!("Error serving connection: {err:?}"); + } + }); + }, + Err(e) => { + logger_manager::write_err(format!("Failed to accept connection: {e}")); + } } } } } - } + }); + Ok(port) } async fn handle_request( diff --git a/proxy_agent_shared/src/telemetry.rs b/proxy_agent_shared/src/telemetry.rs index 77be8ee0..c831acc4 100644 --- a/proxy_agent_shared/src/telemetry.rs +++ b/proxy_agent_shared/src/telemetry.rs @@ -2,12 +2,28 @@ // SPDX-License-Identifier: MIT pub mod event_logger; pub mod event_reader; +pub mod event_sender; pub mod span; pub mod telemetry_event; -use crate::misc_helpers; +use crate::{current_info, misc_helpers}; use serde_derive::{Deserialize, Serialize}; +pub const GENERIC_EVENT_FILE_SEARCH_PATTERN: &str = r"^[0-9]+\.json$"; +pub static GENERIC_EVENT_FILE_SEARCH_REGEX: std::sync::LazyLock = + std::sync::LazyLock::new(|| regex::Regex::new(GENERIC_EVENT_FILE_SEARCH_PATTERN).unwrap()); +pub fn new_generic_event_file_name() -> String { + format!("{}.json", misc_helpers::get_date_time_unix_nano()) +} +pub const EXTENSION_EVENT_FILE_SEARCH_PATTERN: &str = r"^extension_[0-9]+\.json$"; +pub static EXTENSION_EVENT_FILE_SEARCH_REGEX: std::sync::LazyLock = + std::sync::LazyLock::new(|| regex::Regex::new(EXTENSION_EVENT_FILE_SEARCH_PATTERN).unwrap()); + +pub fn new_extension_event_file_name() -> String { + format!("extension_{}.json", misc_helpers::get_date_time_unix_nano()) +} + +/// Represents a telemetry event for TelemetryGenericLogsEvent #[derive(Serialize, Deserialize)] #[allow(non_snake_case)] pub struct Event { @@ -26,7 +42,7 @@ impl Event { Event { EventLevel: level, Message: message, - Version: misc_helpers::get_current_exe_version(), + Version: current_info::get_current_exe_version(), TaskName: task_name, EventPid: std::process::id().to_string(), EventTid: misc_helpers::get_thread_identity(), @@ -36,6 +52,53 @@ impl Event { } } +#[derive(Serialize, Deserialize, Clone)] +pub struct Extension { + pub name: String, + pub version: String, + pub is_internal: bool, + pub extension_type: String, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct OperationStatus { + pub operation_success: bool, + pub operation: String, + pub task_name: String, + pub message: String, + pub duration: i64, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct ExtensionStatusEvent { + pub extension: Extension, + pub operation_status: OperationStatus, + + pub event_pid: String, + pub event_tid: String, + pub time_stamp: String, +} + +impl ExtensionStatusEvent { + /// Create a new ExtensionStatusEvent + /// Rust does not recommend using too many arguments in a function, + /// so we use structs to group related arguments together. + /// # Arguments + /// * `extension` - The extension information + /// * `operation_status` - The operation status information + /// # Returns + /// A new instance of `ExtensionStatusEvent` + pub fn new(extension: Extension, operation_status: OperationStatus) -> Self { + ExtensionStatusEvent { + extension, + operation_status, + event_pid: std::process::id().to_string(), + event_tid: misc_helpers::get_thread_identity(), + time_stamp: misc_helpers::get_date_time_string_with_milliseconds(), + } + } +} + #[cfg(test)] mod tests { #[test] @@ -51,4 +114,33 @@ mod tests { assert_eq!(event.TaskName, "test task name".to_string()); assert_eq!(event.OperationId, "test operation id".to_string()); } + + #[test] + fn test_extension_status_event_new() { + let extension = super::Extension { + name: "test extension".to_string(), + version: "1.0.0".to_string(), + is_internal: true, + extension_type: "test type".to_string(), + }; + let operation_status = super::OperationStatus { + operation_success: true, + task_name: "test task".to_string(), + operation: "test operation".to_string(), + message: "test message".to_string(), + duration: 100, + }; + let event = super::ExtensionStatusEvent::new(extension.clone(), operation_status.clone()); + assert_eq!(event.extension.name, extension.name); + assert_eq!(event.extension.version, extension.version); + assert_eq!(event.extension.is_internal, extension.is_internal); + assert_eq!(event.extension.extension_type, extension.extension_type); + assert_eq!( + event.operation_status.operation_success, + operation_status.operation_success + ); + assert_eq!(event.operation_status.operation, operation_status.operation); + assert_eq!(event.operation_status.message, operation_status.message); + assert_eq!(event.operation_status.duration, operation_status.duration); + } } diff --git a/proxy_agent_shared/src/telemetry/event_logger.rs b/proxy_agent_shared/src/telemetry/event_logger.rs index 6d7ea3a9..e7b343df 100644 --- a/proxy_agent_shared/src/telemetry/event_logger.rs +++ b/proxy_agent_shared/src/telemetry/event_logger.rs @@ -12,11 +12,13 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::Duration; -pub const MAX_MESSAGE_LENGTH: usize = 1024 * 4; // 4KB - +const MAX_MESSAGE_LENGTH: usize = 1024 * 4; // 4KB static EVENT_QUEUE: Lazy> = Lazy::new(|| ConcurrentQueue::::bounded(1000)); static SHUT_DOWN: Lazy> = Lazy::new(|| Arc::new(AtomicBool::new(false))); +/// Store the event directory path, so that other modules can access it if needed. +static EVENTS_DIR: tokio::sync::OnceCell = tokio::sync::OnceCell::const_new(); +const MAX_EXTENSION_EVENT_FILE_COUNT: usize = 1000; pub async fn start( event_dir: PathBuf, @@ -37,6 +39,12 @@ pub async fn start( set_status_fn(message.to_string()); } + if EVENTS_DIR.set(event_dir.clone()).is_err() { + let message = "Event directory is already set, cannot set it again."; + set_status_fn(message.to_string()); + logger_manager::write_log(Level::Warn, message.to_string()); + } + let shutdown = SHUT_DOWN.clone(); if interval == Duration::default() { interval = Duration::from_secs(60); @@ -72,7 +80,10 @@ pub async fn start( // Check the event file counts, // if it exceeds the max file number, drop the new events - match misc_helpers::get_files(&event_dir) { + match misc_helpers::search_files( + &event_dir, + &crate::telemetry::GENERIC_EVENT_FILE_SEARCH_REGEX, + ) { Ok(files) => { if files.len() >= max_event_file_count { logger_manager::write_log(Level::Warn, format!( @@ -90,9 +101,8 @@ pub async fn start( } let mut file_path = event_dir.to_path_buf(); - - file_path.push(format!("{}.json", misc_helpers::get_date_time_unix_nano())); - match misc_helpers::json_write_to_file(&events, &file_path) { + file_path.push(crate::telemetry::new_generic_event_file_name()); + match misc_helpers::json_write_to_file_async(&events, &file_path).await { Ok(()) => { logger_manager::write_log( Level::Trace, @@ -120,6 +130,8 @@ pub fn stop() { SHUT_DOWN.store(true, Ordering::Relaxed); } +/// Write event and log to file +/// This event will send out as `TelemetryGenericLogsEvent` pub fn write_event( level: Level, message: String, @@ -133,6 +145,8 @@ pub fn write_event( logger_manager::log(logger_key.to_string(), level, message); } +/// Write event only without logging to file +/// This event will send out as `TelemetryGenericLogsEvent` pub fn write_event_only(level: Level, message: String, method_name: &str, module_name: &str) { let event_message = if message.len() > MAX_MESSAGE_LENGTH { message[..MAX_MESSAGE_LENGTH].to_string() @@ -155,6 +169,59 @@ pub fn write_event_only(level: Level, message: String, method_name: &str, module }; } +pub async fn report_extension_status_event( + extension: crate::telemetry::Extension, + operation_status: crate::telemetry::OperationStatus, +) { + let event_dir = match EVENTS_DIR.get() { + Some(dir) => dir.clone(), + None => { + logger_manager::write_log( + Level::Warn, + "Event directory is not set, cannot report extension status event.".to_string(), + ); + return; + } + }; + + // Check the event file counts, + // if it exceeds the max file number, drop the new events + match misc_helpers::search_files( + &event_dir, + &crate::telemetry::EXTENSION_EVENT_FILE_SEARCH_REGEX, + ) { + Ok(files) => { + if files.len() >= MAX_EXTENSION_EVENT_FILE_COUNT { + logger_manager::write_log(Level::Warn, format!( + "Event files exceed the max file count {}, drop and skip the write to disk.", + MAX_EXTENSION_EVENT_FILE_COUNT + )); + return; + } + } + Err(e) => { + logger_manager::write_log( + Level::Warn, + format!("Failed to get event files with error: {e}"), + ); + } + } + + let event = crate::telemetry::ExtensionStatusEvent::new(extension, operation_status); + let mut file_path = event_dir.to_path_buf(); + file_path.push(crate::telemetry::new_extension_event_file_name()); + if let Err(e) = misc_helpers::json_write_to_file_async(&event, &file_path).await { + logger_manager::write_log( + Level::Warn, + format!( + "Failed to write extension status event to the file {} with error: {}", + file_path.display(), + e + ), + ); + } +} + #[cfg(test)] mod tests { use crate::misc_helpers; @@ -162,16 +229,41 @@ mod tests { use std::fs; use std::time::Duration; + const TEST_EVENTS_DIR: &str = "test_events_dir"; + const TEST_LOGGER_KEY: &str = "test_logger_key"; + #[tokio::test] async fn event_logger_test() { let mut temp_test_path = env::temp_dir(); - let logger_key = "event_logger_test"; - temp_test_path.push(logger_key); + temp_test_path.push(TEST_EVENTS_DIR); // clean up and ignore the clean up errors _ = fs::remove_dir_all(&temp_test_path); let mut events_dir: std::path::PathBuf = temp_test_path.to_path_buf(); events_dir.push("Events"); + // When EVENTS_DIR is not set, report_extension_status_event should return early + // This test verifies the function handles the case gracefully + // Note: Since EVENTS_DIR is a static OnceCell, if other tests set it first, + // this test will still pass but will write to that directory instead + + let extension = crate::telemetry::Extension { + name: "test_extension".to_string(), + version: "1.0.0".to_string(), + is_internal: false, + extension_type: "test_type".to_string(), + }; + let operation_status = crate::telemetry::OperationStatus { + operation_success: false, + operation: "test_operation".to_string(), + task_name: "test_task".to_string(), + message: "error message".to_string(), + duration: 50, + }; + + // This should not panic even if EVENTS_DIR is not set + super::report_extension_status_event(extension, operation_status).await; + + // Start the event logger loop and set the EVENTS_DIR let cloned_events_dir = events_dir.to_path_buf(); tokio::spawn(async { super::start(cloned_events_dir, Duration::from_millis(100), 3, |_| { @@ -183,7 +275,7 @@ mod tests { }); // write some events to the queue and flush to disk - write_events(logger_key).await; + write_events(TEST_LOGGER_KEY).await; let files = misc_helpers::get_files(&events_dir).unwrap(); let file_count = files.len(); @@ -194,7 +286,7 @@ mod tests { // write some events to the queue and flush to disk 3 times for _ in [0; 3] { - write_events(logger_key).await; + write_events(TEST_LOGGER_KEY).await; } let files = misc_helpers::get_files(&events_dir).unwrap(); @@ -209,7 +301,7 @@ mod tests { // wait for stop signal responded tokio::time::sleep(Duration::from_millis(500)).await; - write_events(logger_key).await; + write_events(TEST_LOGGER_KEY).await; let files = misc_helpers::get_files(&events_dir).unwrap(); assert_eq!( @@ -218,6 +310,54 @@ mod tests { "No more files could write to event folder after stop()" ); + // Create test extension and operation status + let extension = crate::telemetry::Extension { + name: "test_extension".to_string(), + version: "1.0.0".to_string(), + is_internal: true, + extension_type: "test_type".to_string(), + }; + let operation_status = crate::telemetry::OperationStatus { + operation_success: true, + operation: "test_operation".to_string(), + task_name: "test_task".to_string(), + message: "test_message".to_string(), + duration: 100, + }; + + // Call report_extension_status_event + super::report_extension_status_event(extension.clone(), operation_status.clone()).await; + + // Wait for the file to be written + tokio::time::sleep(Duration::from_millis(100)).await; + + // Verify extension event file was created + let files = misc_helpers::search_files( + &events_dir, + &crate::telemetry::EXTENSION_EVENT_FILE_SEARCH_REGEX, + ) + .unwrap(); + assert!( + !files.is_empty(), + "Extension status event file should be created" + ); + + // Read and verify the event content + let event: crate::telemetry::ExtensionStatusEvent = + misc_helpers::json_read_from_file(&files[0]).unwrap(); + assert_eq!(event.extension.name, extension.name); + assert_eq!(event.extension.version, extension.version); + assert_eq!(event.extension.is_internal, extension.is_internal); + assert_eq!(event.extension.extension_type, extension.extension_type); + assert_eq!( + event.operation_status.operation_success, + operation_status.operation_success + ); + assert_eq!(event.operation_status.operation, operation_status.operation); + assert_eq!(event.operation_status.task_name, operation_status.task_name); + assert_eq!(event.operation_status.message, operation_status.message); + assert_eq!(event.operation_status.duration, operation_status.duration); + _ = fs::remove_dir_all(&temp_test_path); } diff --git a/proxy_agent_shared/src/telemetry/event_reader.rs b/proxy_agent_shared/src/telemetry/event_reader.rs index ac3d526f..ef0c7d79 100644 --- a/proxy_agent_shared/src/telemetry/event_reader.rs +++ b/proxy_agent_shared/src/telemetry/event_reader.rs @@ -1,173 +1,133 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT -//! This module contains the logic to read the telemetry event files and send them to the wire server. -//! The telemetry event files are written by the event_logger module. +//! This module contains the logic to read the telemetry event files. +//! //! The telemetry event files are written by the event_logger module. -use crate::common_state; use crate::common_state::CommonState; -use crate::host_clients::imds_client::ImdsClient; -use crate::host_clients::wire_server_client::WireServerClient; +use crate::current_info; use crate::logger::logger_manager; use crate::misc_helpers; -use crate::result::Result; -use crate::telemetry::telemetry_event::TelemetryData; +use crate::telemetry::event_sender; use crate::telemetry::telemetry_event::TelemetryEvent; +use crate::telemetry::telemetry_event::TelemetryExtensionEventsEvent; +use crate::telemetry::telemetry_event::TelemetryGenericLogsEvent; use crate::telemetry::Event; use std::fs::remove_file; use std::path::PathBuf; use std::time::Duration; -use tokio_util::sync::CancellationToken; -#[cfg(test)] -const EMPTY_GUID: &str = "00000000-0000-0000-0000-000000000000"; - -const WIRE_SERVER_IP: &str = "168.63.129.16"; -const WIRE_SERVER_PORT: u16 = 80u16; -const IMDS_IP: &str = "169.254.169.254"; -const IMDS_PORT: u16 = 80u16; - -/// VmMetaData contains the metadata of the VM. -/// The metadata is used to identify the VM and the image origin. -/// It will be part of the telemetry data send to the wire server. -/// The metadata is updated by the wire server and the IMDS client. -#[derive(Clone, Debug)] -pub struct VmMetaData { - pub container_id: String, - pub tenant_name: String, - pub role_name: String, - pub role_instance_name: String, - pub subscription_id: String, - pub resource_group_name: String, - pub vm_id: String, - pub image_origin: u64, +/// Configuration for limiting EventReader behavior +#[derive(Default, Clone)] +pub struct EventReaderLimits { + pub max_events_per_round: Option, + pub max_event_file_size_bytes: Option, + pub version: Option, } -impl VmMetaData { - #[cfg(test)] - pub fn empty() -> Self { - VmMetaData { - container_id: EMPTY_GUID.to_string(), - tenant_name: EMPTY_GUID.to_string(), - role_name: EMPTY_GUID.to_string(), - role_instance_name: EMPTY_GUID.to_string(), - subscription_id: EMPTY_GUID.to_string(), - resource_group_name: EMPTY_GUID.to_string(), - vm_id: EMPTY_GUID.to_string(), - image_origin: 3, // unknown - } +impl EventReaderLimits { + pub fn new() -> Self { + EventReaderLimits::default() + } + + pub fn with_max_events_per_round(mut self, max: usize) -> Self { + self.max_events_per_round = Some(max); + self + } + + pub fn with_max_event_file_size_bytes(mut self, max: u64) -> Self { + self.max_event_file_size_bytes = Some(max); + self + } + + pub fn with_version(mut self, version: String) -> Self { + self.version = Some(version); + self } } pub struct EventReader { dir_path: PathBuf, - delay_start: bool, - cancellation_token: CancellationToken, common_state: CommonState, execution_mode: String, event_name: String, + limits: EventReaderLimits, } impl EventReader { + /// Create a new EventReader without limits on event file size and max events per round. + /// The event reader will read the event files from the specified directory. + /// If delay_start is true, the event reader will delay start for 60 seconds. + /// The common_state is used to store the vm metadata. + /// The execution_mode is used to indicate the mode of the agent. + /// The event_name is used to indicate the name of the event reader. pub fn new( dir_path: PathBuf, - delay_start: bool, - cancellation_token: CancellationToken, common_state: CommonState, execution_mode: String, event_name: String, ) -> EventReader { EventReader { dir_path, - delay_start, - cancellation_token, common_state, execution_mode, event_name, + limits: EventReaderLimits::default(), } } - pub async fn start( - &self, - interval: Option, - server_ip: Option<&str>, - server_port: Option, - ) { - logger_manager::write_info("telemetry event reader task started.".to_string()); + /// Create a new EventReader with limits configuration. + pub fn new_with_limits( + dir_path: PathBuf, + common_state: CommonState, + execution_mode: String, + event_name: String, + limits: EventReaderLimits, + ) -> EventReader { + EventReader { + dir_path, + common_state, + execution_mode, + event_name, + limits, + } + } - let wire_server_client = WireServerClient::new( - server_ip.unwrap_or(WIRE_SERVER_IP), - server_port.unwrap_or(WIRE_SERVER_PORT), - ); - let imds_client = ImdsClient::new( - server_ip.unwrap_or(IMDS_IP), - server_port.unwrap_or(IMDS_PORT), - ); + pub async fn start(&self, delay_start: bool, interval: Option) { + if delay_start { + // delay start the event_reader task to give additional CPU cycles to more important threads + tokio::time::sleep(Duration::from_secs(60)).await; + } + logger_manager::write_info("telemetry event reader task started.".to_string()); let interval = interval.unwrap_or(Duration::from_secs(300)); + let cancellation_token = self.common_state.get_cancellation_token(); tokio::select! { - _ = self.loop_reader(interval, wire_server_client, imds_client ) => {} - _ = self.cancellation_token.cancelled() => { + _ = self.loop_reader(interval) => {} + _ = cancellation_token.cancelled() => { logger_manager::write_warn("cancellation token signal received, stop the telemetry event reader task.".to_string()); } } } - async fn loop_reader( - &self, - interval: Duration, - wire_server_client: WireServerClient, - imds_client: ImdsClient, - ) { - let mut first = true; - + async fn loop_reader(&self, interval: Duration) { loop { - if first { - if self.delay_start { - // delay start the event_reader task to give additional CPU cycles to more important threads - tokio::time::sleep(Duration::from_secs(60)).await; - } - first = false; - } - - // refresh vm metadata - match self - .update_vm_meta_data(&wire_server_client, &imds_client) - .await - { - Ok(()) => { - logger_manager::write_info("success updated the vm metadata.".to_string()); - } - Err(e) => { - logger_manager::write_warn(format!( - "Failed to read vm metadata with error {e}." - )); - } - } - - if let Ok(Some(vm_meta_data)) = self.common_state.get_vm_meta_data().await { - let _processed = self - .process_events(&wire_server_client, &vm_meta_data) - .await; - } - + self.process_once().await; tokio::time::sleep(interval).await; } } - async fn process_events( - &self, - wire_server_client: &WireServerClient, - vm_meta_data: &VmMetaData, - ) -> usize { + /// Process the event files from the directory once. + pub async fn process_once(&self) -> usize { let event_count: usize; - // get all .json event files in the directory - match misc_helpers::search_files(&self.dir_path, r"^(.*\.json)$") { + // get all [0-9]+.json event filenames with numbers in the directory + match misc_helpers::search_files( + &self.dir_path, + &crate::telemetry::GENERIC_EVENT_FILE_SEARCH_REGEX, + ) { Ok(files) => { let file_count = files.len(); - event_count = self - .process_events_and_clean(files, wire_server_client, vm_meta_data) - .await; + event_count = self.process_events_and_clean(files).await; let message = format!( "Telemetry event reader sent {event_count} events from {file_count} files" ); @@ -185,159 +145,94 @@ impl EventReader { event_count } - async fn update_vm_meta_data( - &self, - wire_server_client: &WireServerClient, - imds_client: &ImdsClient, - ) -> Result<()> { - let guid = self - .common_state - .get_state(common_state::SECURE_KEY_GUID.to_string()) - .await - .unwrap_or(None); - let key = self - .common_state - .get_state(common_state::SECURE_KEY_VALUE.to_string()) - .await - .unwrap_or(None); - let goal_state = wire_server_client - .get_goalstate(guid.clone(), key.clone()) - .await?; - let shared_config = wire_server_client - .get_shared_config( - goal_state.get_shared_config_uri(), - guid.clone(), - key.clone(), - ) - .await?; - - let instance_info = imds_client - .get_imds_instance_info(guid.clone(), key.clone()) - .await?; - let vm_meta_data = VmMetaData { - container_id: goal_state.get_container_id(), - role_name: shared_config.get_role_name(), - role_instance_name: shared_config.get_role_instance_name(), - tenant_name: shared_config.get_deployment_name(), - subscription_id: instance_info.get_subscription_id(), - resource_group_name: instance_info.get_resource_group_name(), - vm_id: instance_info.get_vm_id(), - image_origin: instance_info.get_image_origin(), - }; - - self.common_state - .set_vm_meta_data(Some(vm_meta_data)) - .await?; - - Ok(()) - } - - async fn process_events_and_clean( - &self, - files: Vec, - wire_server_client: &WireServerClient, - vm_meta_data: &VmMetaData, - ) -> usize { + async fn process_events_and_clean(&self, files: Vec) -> usize { let mut num_events_logged = 0; for file in files { + if let Some(max_events) = self.limits.max_events_per_round { + if num_events_logged >= max_events { + logger_manager::write_warn(format!( + "EventReader:: Reached the max number of events to be read per round: {}. Stop processing file {} this round.", + max_events, + file.display() + )); + // do not delete this event json file, will try process it at next round + break; + } + } + + match file.metadata() { + Err(e) => { + logger_manager::write_warn(format!( + "EventReader:: Failed to get metadata for file {}: {}", + file.display(), + e + )); + continue; + } + Ok(metadata) => { + if let Some(max_size) = self.limits.max_event_file_size_bytes { + if metadata.len() > max_size { + logger_manager::write_warn(format!( + "EventReader:: File {} exceeds the size limit of {} bytes, skip it.", + file.display(), + max_size + )); + // clean up the file to avoid blocking further processing + Self::clean_file(file); + continue; + } + } + } + } match misc_helpers::json_read_from_file::>(&file) { Ok(events) => { num_events_logged += events.len(); - self.send_events(events, wire_server_client, vm_meta_data) - .await; + self.handle_events(events).await; } Err(e) => { logger_manager::write_warn(format!( - "Failed to read events from file {}: {}", + "EventReader:: Failed to read events from file {}: {}", file.display(), e )); } } - Self::clean_files(file); + Self::clean_file(file); } num_events_logged } - const MAX_MESSAGE_SIZE: usize = 1024 * 64; - async fn send_events( - &self, - mut events: Vec, - wire_server_client: &WireServerClient, - vm_meta_data: &VmMetaData, - ) { + async fn handle_events(&self, mut events: Vec) { + let mut queued_event = false; while !events.is_empty() { - let mut telemetry_data = TelemetryData::new(); - let mut add_more_events = true; - while !events.is_empty() && add_more_events { - match events.pop() { - Some(event) => { - telemetry_data.add_event(TelemetryEvent::from_event_log( - &event, - vm_meta_data.clone(), - self.execution_mode.clone(), - self.event_name.clone(), - )); - - if telemetry_data.get_size() >= Self::MAX_MESSAGE_SIZE { - telemetry_data.remove_last_event(); - if telemetry_data.event_count() == 0 { - match serde_json::to_string(&event) { - Ok(json) => { - logger_manager::write_warn(format!( - "Event data too large. Not sending to wire-server. Event: {json}.", - )); - } - Err(_) => { - logger_manager::write_warn( - "Event data too large. Not sending to wire-server. Event cannot be displayed.".to_string() - ); - } - } - } else { - events.push(event); - } - add_more_events = false; - } - } - None => { - break; - } + match events.pop() { + Some(event) => { + let telemetry_event = TelemetryGenericLogsEvent::from_event_log( + &event, + self.execution_mode.clone(), + self.event_name.clone(), + self.limits.version.clone(), + ); + let telemetry_event = TelemetryEvent::GenericLogsEvent(telemetry_event); + event_sender::enqueue_event(telemetry_event); + queued_event = true; + } + None => { + break; } } - - Self::send_data_to_wire_server(telemetry_data, wire_server_client).await; - } - } - - async fn send_data_to_wire_server( - telemetry_data: TelemetryData, - wire_server_client: &WireServerClient, - ) { - if telemetry_data.event_count() == 0 { - return; } - for _ in [0; 5] { - match wire_server_client - .send_telemetry_data(telemetry_data.to_xml()) - .await - { - Ok(()) => { - break; - } - Err(e) => { - logger_manager::write_warn(format!( - "Failed to send telemetry data to host with error: {e}" - )); - // wait 15 seconds and retry - tokio::time::sleep(Duration::from_secs(15)).await; - } + if queued_event { + if let Err(e) = self.common_state.notify_telemetry_event().await { + logger_manager::write_warn(format!( + "Failed to notify telemetry event with error: {e}" + )); } } } - fn clean_files(file: PathBuf) { + fn clean_file(file: PathBuf) { match remove_file(&file) { Ok(_) => { logger_manager::write_info(format!("Removed File: {}", file.display())); @@ -352,13 +247,93 @@ impl EventReader { } } - #[cfg(test)] - async fn get_vm_meta_data(&self) -> VmMetaData { - if let Ok(Some(vm_meta_data)) = self.common_state.get_vm_meta_data().await { - vm_meta_data - } else { - VmMetaData::empty() + pub async fn start_extension_status_event_processor( + &self, + delay_start: bool, + interval: Option, + ) { + if delay_start { + // delay start the event_reader task to give additional CPU cycles to more important threads + tokio::time::sleep(Duration::from_secs(60)).await; + } + + logger_manager::write_info( + "telemetry extension status event reader task started.".to_string(), + ); + let interval = interval.unwrap_or(Duration::from_secs(60)); + let cancellation_token = self.common_state.get_cancellation_token(); + tokio::select! { + _ = self.loop_extension_status_event_processor(interval ) => {} + _ = cancellation_token.cancelled() => { + logger_manager::write_warn("cancellation token signal received, stop the telemetry extension status event reader task.".to_string()); + } + } + } + + async fn loop_extension_status_event_processor(&self, interval: Duration) { + loop { + self.process_extension_status_events().await; + tokio::time::sleep(interval).await; + } + } + + async fn process_extension_status_events(&self) -> usize { + let mut event_count: usize = 0; + // get all extension status event filenames in the directory + match misc_helpers::search_files( + &self.dir_path, + &crate::telemetry::EXTENSION_EVENT_FILE_SEARCH_REGEX, + ) { + Ok(files) => { + let file_count = files.len(); + for file in files { + event_count += self.process_one_extension_status_event_file(file).await; + } + logger_manager::write_info( format!( + "Telemetry event reader sent {event_count} extension status events from {file_count} files" + )); + } + Err(e) => { + logger_manager::write_warn(format!( + "Extension Status Event Files not found in directory {}: {}", + self.dir_path.display(), + e + )); + } + } + event_count + } + + async fn process_one_extension_status_event_file(&self, file: PathBuf) -> usize { + let mut num_events_logged = 0; + + match misc_helpers::json_read_from_file::(&file) { + Ok(event) => { + num_events_logged += 1; + let telemetry_event = TelemetryExtensionEventsEvent::from_extension_status_event( + &event, + self.execution_mode.clone(), + current_info::get_current_exe_version(), + ); + let telemetry_event = TelemetryEvent::ExtensionEvent(telemetry_event); + event_sender::enqueue_event(telemetry_event); + if let Err(e) = self.common_state.notify_telemetry_event().await { + logger_manager::write_warn(format!( + "Failed to notify telemetry event with error: {e}" + )); + } + } + Err(e) => { + logger_manager::write_warn(format!( + "EventReader:: Failed to read extension status event from file {}: {}", + file.display(), + e + )); + } } + + Self::clean_file(file); + num_events_logged } } @@ -366,8 +341,8 @@ impl EventReader { mod tests { use super::*; use crate::misc_helpers; - use crate::server_mock; use std::{env, fs}; + use tokio_util::sync::CancellationToken; #[tokio::test] async fn test_event_reader_thread() { @@ -378,42 +353,15 @@ mod tests { let mut events_dir = temp_dir.to_path_buf(); events_dir.push("Events"); - // start wire_server listener - let ip = "127.0.0.1"; - let port = 7071u16; - let cancellation_token = CancellationToken::new(); - let common_state = CommonState::start_new(); - let event_reader = EventReader { - dir_path: events_dir.clone(), - delay_start: false, - cancellation_token: cancellation_token.clone(), - common_state: common_state.clone(), - execution_mode: "Test".to_string(), - event_name: "test_event_reader_thread".to_string(), - }; - let wire_server_client = WireServerClient::new(ip, port); - let imds_client = ImdsClient::new(ip, port); - tokio::spawn(server_mock::start( - ip.to_string(), - port, - cancellation_token.clone(), - )); - tokio::time::sleep(Duration::from_millis(100)).await; - logger_manager::write_info("server_mock started.".to_string()); - - match event_reader - .update_vm_meta_data(&wire_server_client, &imds_client) - .await - { - Ok(()) => { - logger_manager::write_info("success updated the vm metadata.".to_string()); - } - Err(e) => { - logger_manager::write_warn(format!("Failed to read vm metadata with error {}.", e)); - } - } + let common_state = CommonState::start_new(CancellationToken::new()); + let event_reader = EventReader::new( + events_dir.clone(), + common_state.clone(), + "Test".to_string(), + "test_event_reader_thread".to_string(), + ); - // Write 10 events to events dir + // Write events to events dir let message = r#"{\"method\":\"GET\",\"url\":\"/machine/37569ad2-69a3-44fd-b653-813e62a177cf/68938c06%2D5233%2D4ff9%2Da173%2D0ac0a2754f8a.%5FWS2022?comp=config&type=hostingEnvironmentConfig&incarnation=2\",\"ip\":\"168.63.129.16\",\"port\":80,\"userId\":999,\"userName\":\"WS2022$\",\"processName\":\"C:\\\\WindowsAzure\\\\GuestAgent_2.7.41491.1071_2023-03-02_185502\\\\WindowsAzureGuestAgent.exe\",\"runAsElevated\":true,\"responseStatus\":\"200 OK\",\"elapsedTime\":8}"#; let mut events: Vec = Vec::new(); for _ in [0; 10] { @@ -428,29 +376,81 @@ mod tests { misc_helpers::try_create_folder(&events_dir).unwrap(); let mut file_path = events_dir.to_path_buf(); file_path.push(format!("{}.json", misc_helpers::get_date_time_unix_nano())); - misc_helpers::json_write_to_file(&events, &file_path).unwrap(); - + misc_helpers::json_write_to_file_async(&events, &file_path) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(1)).await; + let mut file_path = events_dir.to_path_buf(); + file_path.push(format!("{}.json", misc_helpers::get_date_time_unix_nano())); + misc_helpers::json_write_to_file_async(&events, &file_path) + .await + .unwrap(); + // test EventReader with limits + let event_reader_limits = EventReaderLimits::new() + .with_max_event_file_size_bytes(1024 * 10) + .with_max_events_per_round(10) + .with_version("test_version".to_string()); + let event_reader_with_limits = EventReader::new_with_limits( + events_dir.clone(), + common_state.clone(), + "Test".to_string(), + "test_event_reader_thread".to_string(), + event_reader_limits.clone(), + ); // Check the events processed - let vm_meta_data = event_reader.get_vm_meta_data().await; - let events_processed = event_reader - .process_events(&wire_server_client, &vm_meta_data) - .await; + let events_processed = event_reader_with_limits.process_once().await; logger_manager::write_info(format!("Send {} events from event files", events_processed)); - //Should be 10 events written and read into events Vector + //Should be 10 events processed and read into events Vector assert_eq!(events_processed, 10, "Events processed should be 10"); let files = misc_helpers::get_files(&events_dir).unwrap(); - assert!(files.is_empty(), "Events files not cleaned up."); + assert_eq!(1, files.len(), "Must still have 1 event file."); + // test EventReader with limits - second round + let events_processed = event_reader_with_limits.process_once().await; + logger_manager::write_info(format!("Send {} events from event files", events_processed)); + //Should be 10 events processed and read into events Vector + assert_eq!(events_processed, 10, "Events processed should be 10"); + let files = misc_helpers::get_files(&events_dir).unwrap(); + assert!(files.is_empty(), "Must have no event files."); - // Test not processing the non-json files + // Write 2 event files again for next test + tokio::time::sleep(Duration::from_millis(1)).await; + let mut file_path = events_dir.to_path_buf(); + file_path.push(format!("{}.json", misc_helpers::get_date_time_unix_nano())); + misc_helpers::json_write_to_file_async(&events, &file_path) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(1)).await; + let mut file_path = events_dir.to_path_buf(); + file_path.push(format!("{}.json", misc_helpers::get_date_time_unix_nano())); + misc_helpers::json_write_to_file_async(&events, &file_path) + .await + .unwrap(); + let files = misc_helpers::get_files(&events_dir).unwrap(); + assert_eq!(2, files.len(), "Must have 2 event files."); + + // test EventReader without limits + let events_processed = event_reader.process_once().await; + logger_manager::write_info(format!("Send {} events from event files", events_processed)); + //Should be 20 events processed and read into events Vector + assert_eq!(events_processed, 20, "Events processed should be 20"); + let files = misc_helpers::get_files(&events_dir).unwrap(); + assert!(files.is_empty(), "Must have no event files."); + + // Test not processing the non-json files, nor the file name containing non-numeric characters let mut file_path = events_dir.to_path_buf(); file_path.push(format!( "{}.notjson", misc_helpers::get_date_time_unix_nano() )); - misc_helpers::json_write_to_file(&events, &file_path).unwrap(); - let events_processed = event_reader - .process_events(&wire_server_client, &vm_meta_data) - .await; + misc_helpers::json_write_to_file_async(&events, &file_path) + .await + .unwrap(); + let mut file_path = events_dir.to_path_buf(); + file_path.push(format!("a{}.json", misc_helpers::get_date_time_unix_nano())); + misc_helpers::json_write_to_file_async(&events, &file_path) + .await + .unwrap(); + let events_processed = event_reader.process_once().await; assert_eq!(0, events_processed, "events_processed must be 0."); let files = misc_helpers::get_files(&events_dir).unwrap(); assert!( @@ -458,7 +458,342 @@ mod tests { ".notjson files should not been cleaned up." ); - cancellation_token.cancel(); + common_state.cancel_cancellation_token(); + _ = fs::remove_dir_all(&temp_dir); + } + + #[tokio::test] + async fn test_extension_status_event_processor() { + let mut temp_dir = env::temp_dir(); + temp_dir.push("test_extension_status_event_processor"); + + _ = fs::remove_dir_all(&temp_dir); + let mut events_dir = temp_dir.to_path_buf(); + events_dir.push("Events"); + misc_helpers::try_create_folder(&events_dir).unwrap(); + + let cancellation_token = CancellationToken::new(); + let common_state = CommonState::start_new(cancellation_token.clone()); + let event_reader = EventReader::new( + events_dir.clone(), + common_state.clone(), + "Test".to_string(), + "test_extension_status_event_processor".to_string(), + ); + + // Create test extension status event files + let extension = crate::telemetry::Extension { + name: "test_extension".to_string(), + version: "1.0.0".to_string(), + is_internal: true, + extension_type: "test_type".to_string(), + }; + let operation_status = crate::telemetry::OperationStatus { + operation_success: true, + operation: "test_operation".to_string(), + task_name: "test_task".to_string(), + message: "test_message".to_string(), + duration: 100, + }; + let event = crate::telemetry::ExtensionStatusEvent::new( + extension.clone(), + operation_status.clone(), + ); + + // Write extension event files with proper naming pattern + let mut file_path = events_dir.to_path_buf(); + file_path.push(crate::telemetry::new_extension_event_file_name()); + misc_helpers::json_write_to_file_async(&event, &file_path) + .await + .unwrap(); + + tokio::time::sleep(Duration::from_millis(1)).await; + + let mut file_path2 = events_dir.to_path_buf(); + file_path2.push(crate::telemetry::new_extension_event_file_name()); + misc_helpers::json_write_to_file_async(&event, &file_path2) + .await + .unwrap(); + + // Verify files were created + let files = misc_helpers::search_files( + &events_dir, + &crate::telemetry::EXTENSION_EVENT_FILE_SEARCH_REGEX, + ) + .unwrap(); + assert_eq!(2, files.len(), "Should have 2 extension event files"); + + // Process extension status events directly (without starting the loop) + let events_processed = event_reader.process_extension_status_events().await; + assert_eq!( + 2, events_processed, + "Should have processed 2 extension status events" + ); + + // Verify files were cleaned up after processing + let files = misc_helpers::search_files( + &events_dir, + &crate::telemetry::EXTENSION_EVENT_FILE_SEARCH_REGEX, + ) + .unwrap(); + assert!( + files.is_empty(), + "Extension event files should be cleaned up after processing" + ); + + // Test with non-matching file names (should not be processed) + let mut non_matching_file = events_dir.to_path_buf(); + non_matching_file.push("not_extension_event.json"); + misc_helpers::json_write_to_file_async(&event, &non_matching_file) + .await + .unwrap(); + + let events_processed = event_reader.process_extension_status_events().await; + assert_eq!( + 0, events_processed, + "Should not process files with non-matching names" + ); + + // Non-matching file should still exist + assert!( + non_matching_file.exists(), + "Non-matching file should not be cleaned up" + ); + + // Test start_extension_status_event_processor with cancellation + // Write another event file + let mut file_path3 = events_dir.to_path_buf(); + file_path3.push(crate::telemetry::new_extension_event_file_name()); + misc_helpers::json_write_to_file_async(&event, &file_path3) + .await + .unwrap(); + + // Start the processor in a separate task + let event_reader_for_task = EventReader::new( + events_dir.clone(), + common_state.clone(), + "Test".to_string(), + "test_extension_status_event_processor".to_string(), + ); + let handle = tokio::spawn(async move { + event_reader_for_task + .start_extension_status_event_processor(false, Some(Duration::from_millis(50))) + .await; + }); + + // Wait for processing + tokio::time::sleep(Duration::from_millis(100)).await; + + // Cancel the token to stop the processor + common_state.cancel_cancellation_token(); + + // Wait for the task to complete + let result = tokio::time::timeout(Duration::from_secs(2), handle).await; + assert!( + result.is_ok(), + "Extension status event processor should stop when cancelled" + ); + + // Verify the file was processed + let files = misc_helpers::search_files( + &events_dir, + &crate::telemetry::EXTENSION_EVENT_FILE_SEARCH_REGEX, + ) + .unwrap(); + assert!( + files.is_empty(), + "Extension event file should be processed before cancellation" + ); + + _ = fs::remove_dir_all(&temp_dir); + } + + #[tokio::test] + async fn test_mixed_event_files() { + let mut temp_dir = env::temp_dir(); + temp_dir.push("test_mixed_event_files"); + + _ = fs::remove_dir_all(&temp_dir); + let mut events_dir = temp_dir.to_path_buf(); + events_dir.push("Events"); + misc_helpers::try_create_folder(&events_dir).unwrap(); + + let cancellation_token = CancellationToken::new(); + let common_state = CommonState::start_new(cancellation_token.clone()); + let event_reader = EventReader::new( + events_dir.clone(), + common_state.clone(), + "Test".to_string(), + "test_mixed_event_files".to_string(), + ); + + // Create generic event files (numeric names like 1234567890.json) + let message = "Test message for mixed events"; + let mut generic_events: Vec = Vec::new(); + for _ in 0..5 { + generic_events.push(Event::new( + "Informational".to_string(), + message.to_string(), + "test_mixed_event_files".to_string(), + "test_mixed_event_files".to_string(), + )); + } + + // Write 2 generic event files + let mut generic_file1 = events_dir.to_path_buf(); + generic_file1.push(crate::telemetry::new_generic_event_file_name()); + misc_helpers::json_write_to_file_async(&generic_events, &generic_file1) + .await + .unwrap(); + + tokio::time::sleep(Duration::from_millis(1)).await; + + let mut generic_file2 = events_dir.to_path_buf(); + generic_file2.push(crate::telemetry::new_generic_event_file_name()); + misc_helpers::json_write_to_file_async(&generic_events, &generic_file2) + .await + .unwrap(); + + // Create extension status event files + let extension = crate::telemetry::Extension { + name: "test_extension".to_string(), + version: "1.0.0".to_string(), + is_internal: true, + extension_type: "test_type".to_string(), + }; + let operation_status = crate::telemetry::OperationStatus { + operation_success: true, + operation: "test_operation".to_string(), + task_name: "test_task".to_string(), + message: "test_message".to_string(), + duration: 100, + }; + let extension_event = crate::telemetry::ExtensionStatusEvent::new( + extension.clone(), + operation_status.clone(), + ); + + // Write 3 extension event files + for _ in 0..3 { + let mut ext_file = events_dir.to_path_buf(); + ext_file.push(crate::telemetry::new_extension_event_file_name()); + misc_helpers::json_write_to_file_async(&extension_event, &ext_file) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(1)).await; + } + + // Verify all files were created + let generic_files = misc_helpers::search_files( + &events_dir, + &crate::telemetry::GENERIC_EVENT_FILE_SEARCH_REGEX, + ) + .unwrap(); + assert_eq!(2, generic_files.len(), "Should have 2 generic event files"); + + let extension_files = misc_helpers::search_files( + &events_dir, + &crate::telemetry::EXTENSION_EVENT_FILE_SEARCH_REGEX, + ) + .unwrap(); + assert_eq!( + 3, + extension_files.len(), + "Should have 3 extension event files" + ); + + // Process generic events using process_once + let generic_events_processed = event_reader.process_once().await; + assert_eq!( + 10, generic_events_processed, + "Should have processed 10 generic events (5 events x 2 files)" + ); + + // Verify only generic files were cleaned up + let generic_files = misc_helpers::search_files( + &events_dir, + &crate::telemetry::GENERIC_EVENT_FILE_SEARCH_REGEX, + ) + .unwrap(); + assert!( + generic_files.is_empty(), + "Generic event files should be cleaned up" + ); + + let extension_files = misc_helpers::search_files( + &events_dir, + &crate::telemetry::EXTENSION_EVENT_FILE_SEARCH_REGEX, + ) + .unwrap(); + assert_eq!( + 3, + extension_files.len(), + "Extension event files should still exist" + ); + + // Process extension events using process_extension_status_events + let extension_events_processed = event_reader.process_extension_status_events().await; + assert_eq!( + 3, extension_events_processed, + "Should have processed 3 extension status events" + ); + + // Verify extension files were cleaned up + let extension_files = misc_helpers::search_files( + &events_dir, + &crate::telemetry::EXTENSION_EVENT_FILE_SEARCH_REGEX, + ) + .unwrap(); + assert!( + extension_files.is_empty(), + "Extension event files should be cleaned up" + ); + + // Test that both processors ignore each other's files + // Write one of each type again + let mut generic_file = events_dir.to_path_buf(); + generic_file.push(crate::telemetry::new_generic_event_file_name()); + misc_helpers::json_write_to_file_async(&generic_events, &generic_file) + .await + .unwrap(); + + let mut ext_file = events_dir.to_path_buf(); + ext_file.push(crate::telemetry::new_extension_event_file_name()); + misc_helpers::json_write_to_file_async(&extension_event, &ext_file) + .await + .unwrap(); + + // Process extension events - should only process extension file + let extension_events_processed = event_reader.process_extension_status_events().await; + assert_eq!( + 1, extension_events_processed, + "Should only process extension event file" + ); + + // Generic file should still exist + let generic_files = misc_helpers::search_files( + &events_dir, + &crate::telemetry::GENERIC_EVENT_FILE_SEARCH_REGEX, + ) + .unwrap(); + assert_eq!( + 1, + generic_files.len(), + "Generic event file should still exist after extension processing" + ); + + // Process generic events - should only process generic file + let generic_events_processed = event_reader.process_once().await; + assert_eq!( + 5, generic_events_processed, + "Should only process generic event file" + ); + + // All files should be cleaned up now + let all_files = misc_helpers::get_files(&events_dir).unwrap(); + assert!(all_files.is_empty(), "All event files should be cleaned up"); + + common_state.cancel_cancellation_token(); _ = fs::remove_dir_all(&temp_dir); } } diff --git a/proxy_agent_shared/src/telemetry/event_sender.rs b/proxy_agent_shared/src/telemetry/event_sender.rs new file mode 100644 index 00000000..3acba8ff --- /dev/null +++ b/proxy_agent_shared/src/telemetry/event_sender.rs @@ -0,0 +1,668 @@ +// Copyright (c) Microsoft Corporation +// SPDX-License-Identifier: MIT + +//! This module contains the logic to send the telemetry event to the wire server. +use std::time::Duration; + +use crate::common_state::{self, CommonState}; +use crate::host_clients::imds_client::ImdsClient; +use crate::host_clients::wire_server_client::WireServerClient; +use crate::logger::{logger_manager, LoggerLevel}; +use crate::result::Result; +use crate::telemetry::telemetry_event::{ + TelemetryData, TelemetryEvent, TelemetryEventVMData, VmMetaData, +}; +use concurrent_queue::ConcurrentQueue; +use once_cell::sync::Lazy; + +static TELEMETRY_EVENT_QUEUE: Lazy> = + Lazy::new(|| ConcurrentQueue::::bounded(1000)); + +const MAX_MESSAGE_SIZE: usize = 1024 * 64; +const WIRE_SERVER_IP: &str = "168.63.129.16"; +const WIRE_SERVER_PORT: u16 = 80u16; +const IMDS_IP: &str = "169.254.169.254"; +const IMDS_PORT: u16 = 80u16; + +pub struct EventSender { + common_state: CommonState, +} + +impl EventSender { + pub fn new(common_state: CommonState) -> Self { + EventSender { common_state } + } + + pub async fn start(&self, server_ip: Option<&str>, server_port: Option) { + logger_manager::write_info("telemetry event sender task started.".to_string()); + let notify = match self.common_state.get_telemetry_event_notify().await { + Ok(notify) => notify, + Err(e) => { + logger_manager::write_err(format!("Failed to get notify: {e}")); + return; + } + }; + let cancellation_token = self.common_state.get_cancellation_token(); + + loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + logger_manager::write_info("telemetry event sender task cancelled.".to_string()); + // Close the event queue to stop accepting new events + TELEMETRY_EVENT_QUEUE.close(); + break; + } + _ = notify.notified() => { + self.process_event_queue(server_ip, server_port).await; + } + } + } + } + + async fn process_event_queue(&self, server_ip: Option<&str>, server_port: Option) { + if TELEMETRY_EVENT_QUEUE.is_empty() { + return; + } + + let wire_server_client = WireServerClient::new( + server_ip.unwrap_or(WIRE_SERVER_IP), + server_port.unwrap_or(WIRE_SERVER_PORT), + ); + let imds_client = ImdsClient::new( + server_ip.unwrap_or(IMDS_IP), + server_port.unwrap_or(IMDS_PORT), + ); + // refresh vm metadata + match self + .update_vm_meta_data(&wire_server_client, &imds_client) + .await + { + Ok(()) => { + logger_manager::write_info("success updated the vm metadata.".to_string()); + } + Err(e) => { + logger_manager::write_warn(format!("Failed to update vm metadata with error {e}.")); + } + } + + if let Ok(Some(vm_meta_data)) = self.common_state.get_vm_meta_data().await { + let vm_data = TelemetryEventVMData::new_from_vm_meta_data(&vm_meta_data); + self.send_events(&wire_server_client, &vm_data).await + } else { + logger_manager::write_warn( + "VmMetaData is not available. Skipping sending telemetry events.".to_string(), + ); + } + } + + pub async fn update_vm_meta_data( + &self, + wire_server_client: &WireServerClient, + imds_client: &ImdsClient, + ) -> Result<()> { + let guid = self + .common_state + .get_state(common_state::SECURE_KEY_GUID.to_string()) + .await + .unwrap_or(None); + let key = self + .common_state + .get_state(common_state::SECURE_KEY_VALUE.to_string()) + .await + .unwrap_or(None); + let goal_state = wire_server_client + .get_goalstate(guid.clone(), key.clone()) + .await?; + let shared_config = wire_server_client + .get_shared_config( + goal_state.get_shared_config_uri(), + guid.clone(), + key.clone(), + ) + .await?; + + let instance_info = imds_client + .get_imds_instance_info(guid.clone(), key.clone()) + .await?; + let vm_meta_data = VmMetaData { + container_id: goal_state.get_container_id(), + role_name: shared_config.get_role_name(), + role_instance_name: shared_config.get_role_instance_name(), + tenant_name: shared_config.get_deployment_name(), + subscription_id: instance_info.get_subscription_id(), + resource_group_name: instance_info.get_resource_group_name(), + vm_id: instance_info.get_vm_id(), + image_origin: instance_info.get_image_origin(), + }; + + self.common_state + .set_vm_meta_data(Some(vm_meta_data)) + .await?; + + Ok(()) + } + + async fn send_events( + &self, + wire_server_client: &WireServerClient, + vm_data: &TelemetryEventVMData, + ) { + while !TELEMETRY_EVENT_QUEUE.is_closed() && !TELEMETRY_EVENT_QUEUE.is_empty() { + let mut telemetry_data = TelemetryData::new_with_vm_data(vm_data.clone()); + let mut add_more_events = true; + while !TELEMETRY_EVENT_QUEUE.is_empty() && add_more_events { + match TELEMETRY_EVENT_QUEUE.pop() { + Ok(event) => { + telemetry_data.add_event(event.clone()); + + if telemetry_data.get_size() >= MAX_MESSAGE_SIZE { + _ = telemetry_data.remove_last_event(event.clone()); + if telemetry_data.event_count() == 0 { + logger_manager::write_warn(format!( + "Event data too large. Not sending to wire-server. Event: {}.", + event.to_xml_event(vm_data), + )); + } else if let Err(e) = TELEMETRY_EVENT_QUEUE.push(event) { + logger_manager::write_warn(format!( + "Failed to re-enqueue telemetry event with error: {e}" + )); + } + add_more_events = false; + } + } + Err(err) => { + logger_manager::write_warn(format!( + "Failed to pop telemetry event from queue with error: {err}" + )); + break; + } + } + } + + Self::send_data_to_wire_server(telemetry_data, wire_server_client).await; + } + } + + async fn send_data_to_wire_server( + telemetry_data: TelemetryData, + wire_server_client: &WireServerClient, + ) { + if telemetry_data.event_count() == 0 { + return; + } + + let event_count = telemetry_data.event_count(); + for _ in [0; 5] { + match wire_server_client + .send_telemetry_data(telemetry_data.to_xml()) + .await + { + Ok(()) => { + logger_manager::write_log( + LoggerLevel::Trace, + format!("Successfully sent {event_count} telemetry events to wire server."), + ); + break; + } + Err(e) => { + logger_manager::write_warn(format!( + "Failed to send telemetry data to host with error: {e}" + )); + // wait 15 seconds and retry + tokio::time::sleep(Duration::from_secs(15)).await; + } + } + } + } +} + +pub(crate) fn enqueue_event(event: TelemetryEvent) { + if let Err(e) = TELEMETRY_EVENT_QUEUE.push(event) { + logger_manager::write_warn(format!("Failed to enqueue telemetry event with error: {e}")); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::host_clients::wire_server_client::WireServerClient; + use crate::server_mock; + use crate::telemetry::telemetry_event::{ + TelemetryExtensionEventsEvent, TelemetryGenericLogsEvent, VmMetaData, + }; + use crate::telemetry::{Event, ExtensionStatusEvent}; + use tokio_util::sync::CancellationToken; + + fn create_test_vm_meta_data() -> VmMetaData { + VmMetaData { + container_id: "test-container-id".to_string(), + tenant_name: "test-tenant".to_string(), + role_name: "test-role".to_string(), + role_instance_name: "test-role-instance".to_string(), + subscription_id: "test-subscription-id".to_string(), + resource_group_name: "test-resource-group".to_string(), + vm_id: "test-vm-id".to_string(), + image_origin: 1, + } + } + + fn create_test_event(message: &str) -> TelemetryEvent { + let event_log = Event::new( + "Informational".to_string(), + message.to_string(), + "test_task".to_string(), + "test_module".to_string(), + ); + TelemetryEvent::GenericLogsEvent(TelemetryGenericLogsEvent::from_event_log( + &event_log, + "test_execution_mode".to_string(), + "test_event_name".to_string(), + Some("1.0.0".to_string()), + )) + } + + fn create_test_extension_event() -> TelemetryEvent { + let extension = crate::telemetry::Extension { + name: "test_extension".to_string(), + version: "1.0.0".to_string(), + is_internal: true, + extension_type: "test_type".to_string(), + }; + let operation_status = crate::telemetry::OperationStatus { + operation_success: true, + operation: "install".to_string(), + task_name: "test_task".to_string(), + message: "Installation successful".to_string(), + duration: 500, + }; + let extension_status_event = ExtensionStatusEvent::new(extension, operation_status); + let telemetry_event = TelemetryExtensionEventsEvent::from_extension_status_event( + &extension_status_event, + "production".to_string(), + "1.0.0".to_string(), + ); + TelemetryEvent::ExtensionEvent(telemetry_event) + } + + #[tokio::test] + async fn test_event_sender_new() { + let cancellation_token = CancellationToken::new(); + let common_state = CommonState::start_new(cancellation_token); + let event_sender = EventSender::new(common_state); + + // Verify EventSender was created (common_state is private, so we just check it doesn't panic) + assert!(std::mem::size_of_val(&event_sender) > 0); + } + + #[tokio::test] + async fn test_common_state_vm_meta_data() { + let cancellation_token = CancellationToken::new(); + let common_state = CommonState::start_new(cancellation_token); + + // Initially should be None + let vm_meta_data = common_state.get_vm_meta_data().await.unwrap(); + assert!(vm_meta_data.is_none()); + + // Set vm_meta_data + let test_meta_data = create_test_vm_meta_data(); + common_state + .set_vm_meta_data(Some(test_meta_data)) + .await + .unwrap(); + + // Verify it was set and TelemetryEventVMData conversion works + let retrieved = common_state.get_vm_meta_data().await.unwrap().unwrap(); + let vm_data = TelemetryEventVMData::new_from_vm_meta_data(&retrieved); + + assert_eq!(vm_data.container_id, "test-container-id"); + assert_eq!(vm_data.tenant_name, "test-tenant"); + assert_eq!(vm_data.role_name, "test-role"); + assert_eq!(vm_data.role_instance_name, "test-role-instance"); + assert_eq!(vm_data.subscription_id, "test-subscription-id"); + assert_eq!(vm_data.resource_group_name, "test-resource-group"); + assert_eq!(vm_data.vm_id, "test-vm-id"); + assert_eq!(vm_data.image_origin, 1); + + // Test notify functionality + let notify_result = common_state.get_telemetry_event_notify().await; + assert!(notify_result.is_ok()); + assert!(common_state.notify_telemetry_event().await.is_ok()); + } + + #[test] + fn test_queue_bounded_capacity() { + // Create a local bounded queue for testing capacity behavior + let test_queue: ConcurrentQueue = ConcurrentQueue::bounded(10); + + // Fill the queue + for i in 0..10 { + let event = create_test_event(&format!("Test message {}", i)); + assert!( + test_queue.push(event).is_ok(), + "Should be able to push event {}", + i + ); + } + + // Queue should be full now + assert!(test_queue.is_full(), "Queue should be full after 10 pushes"); + + // Try to push one more - should fail + let extra_event = create_test_event("Extra event"); + assert!( + test_queue.push(extra_event).is_err(), + "Push should fail when queue is full" + ); + } + + #[test] + fn test_telemetry_event_xml_format() { + let vm_meta_data = create_test_vm_meta_data(); + let vm_data = TelemetryEventVMData::new_from_vm_meta_data(&vm_meta_data); + + // Test single event XML + let event = create_test_event("Test XML message"); + let event_xml = event.to_xml_event(&vm_data); + assert!(event_xml.contains("")); + assert!(event_xml.contains("")); + assert!(event_xml.contains("TenantName")); + assert!(event_xml.contains("test-tenant")); + + // Test provider ID + assert_eq!( + event.get_provider_id(), + "FFF0196F-EE4C-4EAF-9AA5-776F622DEB4F" + ); + + // Test full TelemetryData XML structure + let mut telemetry_data = TelemetryData::new_with_vm_data(vm_data); + telemetry_data.add_event(event); + let xml = telemetry_data.to_xml(); + + assert!(xml.starts_with("")); + assert!(xml.contains("")); + assert!(xml.contains("")); + assert!(xml.contains("")); + assert!(xml.contains("")); + } + + #[test] + fn test_extension_event_xml_format() { + let vm_meta_data = create_test_vm_meta_data(); + let vm_data = TelemetryEventVMData::new_from_vm_meta_data(&vm_meta_data); + + // Test extension event XML + let event = create_test_extension_event(); + let event_xml = event.to_xml_event(&vm_data); + assert!(event_xml.contains("")); + assert!(event_xml.contains("")); + assert!(event_xml.contains("ExtensionType")); + assert!(event_xml.contains("test_type")); + assert!(event_xml.contains("Name")); + assert!(event_xml.contains("test_extension")); + + // Test provider ID for extension events + assert_eq!( + event.get_provider_id(), + "69B669B9-4AF8-4C50-BDC4-6006FA76E975" + ); + + // Test TelemetryData with extension event + let mut telemetry_data = TelemetryData::new_with_vm_data(vm_data); + telemetry_data.add_event(event); + let xml = telemetry_data.to_xml(); + assert!(xml.contains("")); + } + + #[test] + fn test_mixed_events_xml_format() { + let vm_meta_data = create_test_vm_meta_data(); + let vm_data = TelemetryEventVMData::new_from_vm_meta_data(&vm_meta_data); + + let mut telemetry_data = TelemetryData::new_with_vm_data(vm_data); + + // Add generic logs event + let generic_event = create_test_event("Test generic message"); + telemetry_data.add_event(generic_event); + + // Add extension event + let extension_event = create_test_extension_event(); + telemetry_data.add_event(extension_event); + + assert_eq!(telemetry_data.event_count(), 2); + + let xml = telemetry_data.to_xml(); + + // Verify both providers are present + assert!(xml.contains("")); + assert!(xml.contains("")); + assert!(xml.contains("")); // Generic logs event + assert!(xml.contains("")); // Extension event + } + + #[test] + fn test_queue_with_extension_events() { + // Create a local bounded queue for testing + let test_queue: ConcurrentQueue = ConcurrentQueue::bounded(10); + + // Add generic and extension events + let generic_event = create_test_event("Generic message"); + let extension_event = create_test_extension_event(); + + assert!(test_queue.push(generic_event.clone()).is_ok()); + assert!(test_queue.push(extension_event.clone()).is_ok()); + + assert_eq!(test_queue.len(), 2); + + // Verify FIFO order and event types + let popped1 = test_queue.pop(); + assert!(popped1.is_ok()); + assert_eq!( + popped1.unwrap().get_provider_id(), + "FFF0196F-EE4C-4EAF-9AA5-776F622DEB4F" + ); + + let popped2 = test_queue.pop(); + assert!(popped2.is_ok()); + assert_eq!( + popped2.unwrap().get_provider_id(), + "69B669B9-4AF8-4C50-BDC4-6006FA76E975" + ); + + assert!(test_queue.is_empty()); + } + + #[tokio::test] + async fn test_update_vm_meta_data_with_mock_server() { + let ip = "127.0.0.1"; + let port = 9073u16; + + let cancellation_token = CancellationToken::new(); + let common_state = CommonState::start_new(cancellation_token.clone()); + let event_sender = EventSender::new(common_state.clone()); + + let port = server_mock::start(ip.to_string(), port, cancellation_token.clone()) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + + let wire_server_client = WireServerClient::new(ip, port); + let imds_client = ImdsClient::new(ip, port); + + // Initially vm_meta_data should be None + let vm_meta_data = common_state.get_vm_meta_data().await.unwrap(); + assert!(vm_meta_data.is_none()); + + // Update vm_meta_data + let result = event_sender + .update_vm_meta_data(&wire_server_client, &imds_client) + .await; + assert!(result.is_ok(), "update_vm_meta_data should succeed"); + + // Verify vm_meta_data was set + let vm_meta_data = common_state.get_vm_meta_data().await.unwrap(); + assert!(vm_meta_data.is_some(), "vm_meta_data should be set"); + + let vm_data = vm_meta_data.unwrap(); + // Values come from mock server responses + assert!(!vm_data.container_id.is_empty()); + assert!(!vm_data.role_name.is_empty()); + + cancellation_token.cancel(); + } + + /// Consolidated test for all TELEMETRY_EVENT_QUEUE and wire server operations. + /// This test must run in a single test function because the global static queue + /// cannot be reopened once closed. The test covers: + /// 1. Enqueue events and verify FIFO order + /// 2. Process empty queue + /// 3. Send data to wire server (empty and with events) + /// 4. Enqueue and process events with mock server + /// 5. EventSender lifecycle (cancellation) - must be last as it closes the queue + #[tokio::test] + async fn test_telemetry_event_queue_operations() { + // ===== Part 1: Test enqueue and FIFO order ===== + // Clear the queue first + while TELEMETRY_EVENT_QUEUE.pop().is_ok() {} + + // Mock server details + let ip = "127.0.0.1"; + let port = 9071u16; + + // Create EventSender + let cancellation_token = CancellationToken::new(); + let process_common_state = CommonState::start_new(cancellation_token.clone()); + let event_sender = EventSender::new(process_common_state.clone()); + + // Enqueue events + let event1 = create_test_event("Test message 1"); + let event2 = create_test_event("Test message 2"); + let event3 = create_test_event("Test message 3"); + + enqueue_event(event1.clone()); + assert!( + !TELEMETRY_EVENT_QUEUE.is_empty(), + "Queue should not be empty after enqueue" + ); + + enqueue_event(event2.clone()); + enqueue_event(event3.clone()); + assert_eq!(TELEMETRY_EVENT_QUEUE.len(), 3, "Queue should have 3 events"); + + // Verify FIFO order + assert!(TELEMETRY_EVENT_QUEUE.pop().unwrap() == event1); + assert!(TELEMETRY_EVENT_QUEUE.pop().unwrap() == event2); + assert!(TELEMETRY_EVENT_QUEUE.pop().unwrap() == event3); + assert!(TELEMETRY_EVENT_QUEUE.is_empty()); + + // ===== Part 2: Test process empty queue - should return without error ===== + event_sender.process_event_queue(None, None).await; + assert!(TELEMETRY_EVENT_QUEUE.is_empty()); + + // ===== Part 3: Test enqueue mixed events (generic and extension) ===== + let generic_event = create_test_event("Generic event for queue"); + let extension_event = create_test_extension_event(); + + enqueue_event(generic_event); + enqueue_event(extension_event); + assert_eq!( + TELEMETRY_EVENT_QUEUE.len(), + 2, + "Queue should have 2 mixed events" + ); + + // Clear for next test + while TELEMETRY_EVENT_QUEUE.pop().is_ok() {} + + // ===== Part 4: Test enqueue and process with mock server ===== + // Start mock server FIRST to respond to goalstate and shared config requests + let port = server_mock::start(ip.to_string(), port, cancellation_token.clone()) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + + // Enqueue events + let event_a = create_test_event("Test event A for processing"); + let event_b = create_test_event("Test event B for processing"); + let event_c = create_test_event("Test event C for processing"); + + enqueue_event(event_a); + enqueue_event(event_b); + enqueue_event(event_c); + assert_eq!( + TELEMETRY_EVENT_QUEUE.len(), + 3, + "Queue should have 3 events after enqueue" + ); + + // Start the event sender in a separate task + let handle = tokio::spawn(async move { + event_sender.start(Some(ip), Some(port)).await; + }); + + // Give it a moment to start event sender task + tokio::time::sleep(Duration::from_millis(50)).await; + + // Notify to process events now that VM data can be retrieved + process_common_state.notify_telemetry_event().await.unwrap(); + + // Give it a moment to process the events (needs enough time for multiple HTTP requests) + tokio::time::sleep(Duration::from_millis(500)).await; + + // Verify queue is empty after processing + assert_eq!( + TELEMETRY_EVENT_QUEUE.len(), + 0, + "Queue should be empty after processing" + ); + + // Verify queue is NOT closed after processing + assert!( + !TELEMETRY_EVENT_QUEUE.is_closed(), + "Queue should not be closed after processing" + ); + + // ===== Part 5: Test send_data_to_wire_server ===== + let wire_server_client = WireServerClient::new(ip, port); + let vm_meta_data = create_test_vm_meta_data(); + let vm_data = TelemetryEventVMData::new_from_vm_meta_data(&vm_meta_data); + + // Test sending empty data - should return early without error + let empty_data = TelemetryData::new_with_vm_data(vm_data.clone()); + assert_eq!(empty_data.event_count(), 0); + EventSender::send_data_to_wire_server(empty_data, &wire_server_client).await; + + // Test sending data with events + let mut telemetry_data = TelemetryData::new_with_vm_data(vm_data.clone()); + telemetry_data.add_event(create_test_event("Test event 1")); + telemetry_data.add_event(create_test_event("Test event 2")); + assert_eq!(telemetry_data.event_count(), 2); + println!("{}", telemetry_data.to_xml()); + EventSender::send_data_to_wire_server(telemetry_data, &wire_server_client).await; + + // Test sending data with mixed events + let mut mixed_data = TelemetryData::new_with_vm_data(vm_data); + mixed_data.add_event(create_test_event("Generic event")); + mixed_data.add_event(create_test_extension_event()); + assert_eq!(mixed_data.event_count(), 2); + EventSender::send_data_to_wire_server(mixed_data, &wire_server_client).await; + + // ===== Part 6: Test EventSender lifecycle (cancellation) ===== + // This MUST be last as it closes the queue permanently + + // Cancel the token - this will close the queue, stop the event sender task and stop mock server + process_common_state.cancel_cancellation_token(); + + // Wait for the task to complete + let result = tokio::time::timeout(Duration::from_secs(2), handle).await; + assert!(result.is_ok(), "Event sender should stop when cancelled"); + + // Verify queue is now closed + assert!( + TELEMETRY_EVENT_QUEUE.is_closed(), + "Queue should be closed after cancellation" + ); + } +} diff --git a/proxy_agent_shared/src/telemetry/telemetry_event.rs b/proxy_agent_shared/src/telemetry/telemetry_event.rs index 79149d70..0b25b484 100644 --- a/proxy_agent_shared/src/telemetry/telemetry_event.rs +++ b/proxy_agent_shared/src/telemetry/telemetry_event.rs @@ -3,26 +3,186 @@ //! This module contains the logic to generate the telemetry data to be send to wire server. -use super::event_reader::VmMetaData; -use crate::telemetry::Event; +use crate::telemetry::{Event, ExtensionStatusEvent}; use crate::{current_info, misc_helpers}; use once_cell::sync::Lazy; use serde_derive::{Deserialize, Serialize}; -/// TelemetryData struct to hold the telemetry events send to wire server. -pub struct TelemetryData { +const METRICS_PROVIDER_ID: &str = "FFF0196F-EE4C-4EAF-9AA5-776F622DEB4F"; +const STATUS_PROVIDER_ID: &str = "69B669B9-4AF8-4C50-BDC4-6006FA76E975"; + +/// VmMetaData contains the metadata of the VM. +/// The metadata is used to identify the VM and the image origin. +/// It will be part of the telemetry data send to the wire server. +/// The metadata is updated by the wire server and the IMDS client. +#[derive(Clone, Debug)] +pub struct VmMetaData { + pub container_id: String, + pub tenant_name: String, + pub role_name: String, + pub role_instance_name: String, + pub subscription_id: String, + pub resource_group_name: String, + pub vm_id: String, + pub image_origin: u64, +} + +/// Base struct containing common fields shared between telemetry event types. +#[derive(PartialEq, Eq, Hash, Clone)] +pub struct TelemetryEventVMData { + pub container_id: String, + pub keyword_name: String, + pub os_version: String, + pub ram: u64, + pub processors: u64, + pub tenant_name: String, + pub role_name: String, + pub role_instance_name: String, + pub subscription_id: String, + pub resource_group_name: String, + pub vm_id: String, + pub image_origin: u64, +} + +impl TelemetryEventVMData { + pub fn new_from_vm_meta_data(vm_meta_data: &VmMetaData) -> Self { + TelemetryEventVMData { + keyword_name: CURRENT_KEYWORD_NAME.to_string(), + os_version: current_info::get_long_os_version(), + ram: current_info::get_ram_in_mb(), + processors: current_info::get_cpu_count() as u64, + container_id: vm_meta_data.container_id.clone(), + tenant_name: vm_meta_data.tenant_name.clone(), + role_name: vm_meta_data.role_name.clone(), + role_instance_name: vm_meta_data.role_instance_name.clone(), + subscription_id: vm_meta_data.subscription_id.clone(), + resource_group_name: vm_meta_data.resource_group_name.clone(), + vm_id: vm_meta_data.vm_id.clone(), + image_origin: vm_meta_data.image_origin, + } + } + + /// Convert the base fields to XML format. + pub fn to_xml_params(&self) -> String { + let mut xml = String::new(); + xml.push_str(&format!( + "", + misc_helpers::xml_escape(self.keyword_name.clone()) + )); + xml.push_str(&format!( + "", + misc_helpers::xml_escape(self.tenant_name.clone()) + )); + xml.push_str(&format!( + "", + misc_helpers::xml_escape(self.role_name.clone()) + )); + xml.push_str(&format!( + "", + misc_helpers::xml_escape(self.role_instance_name.clone()) + )); + xml.push_str(&format!( + "", + misc_helpers::xml_escape(self.container_id.clone()) + )); + xml.push_str(&format!( + "", + misc_helpers::xml_escape(self.resource_group_name.clone()) + )); + xml.push_str(&format!( + "", + misc_helpers::xml_escape(self.subscription_id.clone()) + )); + xml.push_str(&format!( + "", + misc_helpers::xml_escape(self.vm_id.clone()) + )); + xml.push_str(&format!( + "", + self.image_origin + )); + xml.push_str(&format!( + "", + misc_helpers::xml_escape(self.os_version.clone()) + )); + xml.push_str(&format!( + "", + self.ram + )); + xml.push_str(&format!( + "", + self.processors + )); + xml + } +} + +/// TelemetryProvider struct to hold the telemetry events for a specific provider. +pub struct TelemetryProvider { + pub id: String, events: Vec, } -impl Default for TelemetryData { - fn default() -> Self { - Self::new() +impl TelemetryProvider { + pub fn new(id: String) -> Self { + TelemetryProvider { + id, + events: Vec::new(), + } + } + + pub fn add_event(&mut self, event: TelemetryEvent) { + self.events.push(event); + } + + pub fn event_count(&self) -> usize { + self.events.len() + } + + pub fn remove_event(&mut self, event: TelemetryEvent) -> Option { + if let Some(pos) = self.events.iter().position(|x| *x == event) { + Some(self.events.remove(pos)) + } else { + None + } + } + + pub fn to_xml(&self, vm_data: &TelemetryEventVMData) -> String { + let mut xml: String = String::new(); + xml.push_str(&format!( + "", + misc_helpers::xml_escape(self.id.clone()) + )); + + for e in &self.events { + match e { + TelemetryEvent::GenericLogsEvent(event) => { + xml.push_str(&event.to_xml_event(vm_data)); + } + TelemetryEvent::ExtensionEvent(event) => { + xml.push_str(&event.to_xml_event(vm_data)); + } + } + } + + xml.push_str(""); + xml } } +/// TelemetryData struct to hold the telemetry events send to wire server. +pub struct TelemetryData { + providers: Vec, + vm_data: TelemetryEventVMData, +} + impl TelemetryData { - pub fn new() -> Self { - TelemetryData { events: Vec::new() } + /// Create a new TelemetryData instance with VM data. + pub fn new_with_vm_data(vm_data: TelemetryEventVMData) -> Self { + TelemetryData { + providers: Vec::new(), + vm_data, + } } /// Convert the telemetry data to xml format. @@ -30,13 +190,13 @@ impl TelemetryData { pub fn to_xml(&self) -> String { let mut xml: String = String::new(); - xml.push_str(""); + xml.push_str(""); - for e in &self.events { - xml.push_str(&e.to_xml_event()); + for provider in &self.providers { + xml.push_str(&provider.to_xml(&self.vm_data)); } - xml.push_str(""); + xml.push_str(""); xml } @@ -45,38 +205,91 @@ impl TelemetryData { self.to_xml().len() } + /// Add a telemetry event to the telemetry data. + /// It will be added to the corresponding provider. pub fn add_event(&mut self, event: TelemetryEvent) { - self.events.push(event); + for provider in &mut self.providers { + match &event { + TelemetryEvent::GenericLogsEvent(_) => { + if provider.id == METRICS_PROVIDER_ID { + provider.add_event(event); + return; + } + } + TelemetryEvent::ExtensionEvent(_) => { + if provider.id == STATUS_PROVIDER_ID { + provider.add_event(event); + return; + } + } + } + } + let mut p = TelemetryProvider::new(match &event { + TelemetryEvent::GenericLogsEvent(_) => METRICS_PROVIDER_ID.to_string(), + TelemetryEvent::ExtensionEvent(_) => STATUS_PROVIDER_ID.to_string(), + }); + p.add_event(event); + self.providers.push(p); } - pub fn remove_last_event(&mut self) -> Option { - self.events.pop() + /// Remove the last added telemetry event from the telemetry data. + /// This is used when the telemetry data size exceeds the maximum allowed size. + pub fn remove_last_event(&mut self, last_event: TelemetryEvent) -> Option { + for provider in &mut self.providers { + match &last_event { + TelemetryEvent::GenericLogsEvent(_) => { + if provider.id == METRICS_PROVIDER_ID { + return provider.remove_event(last_event); + } + } + TelemetryEvent::ExtensionEvent(_) => { + if provider.id == STATUS_PROVIDER_ID { + return provider.remove_event(last_event); + } + } + } + } + None } + /// Get the total number of events in the telemetry data. + /// It adds up the event counts from all providers. pub fn event_count(&self) -> usize { - self.events.len() + self.providers.iter().map(|p| p.event_count()).sum() } } -pub struct TelemetryEvent { +#[derive(PartialEq, Eq, Hash, Clone)] +pub enum TelemetryEvent { + GenericLogsEvent(TelemetryGenericLogsEvent), + ExtensionEvent(TelemetryExtensionEventsEvent), +} + +impl TelemetryEvent { + pub fn get_provider_id(&self) -> String { + match self { + TelemetryEvent::GenericLogsEvent(_) => TelemetryGenericLogsEvent::get_provider_id(), + TelemetryEvent::ExtensionEvent(_) => TelemetryExtensionEventsEvent::get_provider_id(), + } + } + + pub fn to_xml_event(&self, vm_data: &TelemetryEventVMData) -> String { + match self { + TelemetryEvent::GenericLogsEvent(event) => event.to_xml_event(vm_data), + TelemetryEvent::ExtensionEvent(event) => event.to_xml_event(vm_data), + } + } +} + +/// Struct to hold Generic Logs telemetry event data without VM metadata. +#[derive(PartialEq, Eq, Hash, Clone)] +pub struct TelemetryGenericLogsEvent { event_pid: u64, event_tid: u64, ga_version: String, - container_id: String, task_name: String, opcode_name: String, - keyword_name: String, - os_version: String, execution_mode: String, - ram: u64, - processors: u64, - tenant_name: String, - role_name: String, - role_instance_name: String, - subscription_id: String, - resource_group_name: String, - vm_id: String, - image_origin: u64, event_name: String, capability_used: String, @@ -85,86 +298,156 @@ pub struct TelemetryEvent { context3: String, } -impl TelemetryEvent { +impl TelemetryGenericLogsEvent { pub fn from_event_log( event_log: &Event, - vm_meta_data: VmMetaData, execution_mode: String, event_name: String, + ga_version: Option, ) -> Self { - TelemetryEvent { + // if ga_version is provided, append event_log.version to event_name + // if ga_version is None, use event_log.Version as ga_version and keep event_name unchanged + let (ga_version, event_name) = match ga_version { + Some(version) => (version, format!("{}-{}", event_name, event_log.Version)), + None => (event_log.Version.clone(), event_name), + }; + // redact secrets in the message before sending to telemetry + let message = event_log.Message.clone(); + let message = crate::secrets_redactor::redact_secrets_string(message); + TelemetryGenericLogsEvent { + event_name, + ga_version, + execution_mode, event_pid: event_log.EventPid.parse::().unwrap_or(0), event_tid: event_log.EventTid.parse::().unwrap_or(0), - ga_version: event_log.Version.to_string(), - task_name: event_log.TaskName.to_string(), - opcode_name: event_log.TimeStamp.to_string(), - capability_used: event_log.EventLevel.to_string(), - context1: event_log.Message.to_string(), - context2: event_log.TimeStamp.to_string(), - context3: event_log.OperationId.to_string(), - - execution_mode, - event_name, - os_version: current_info::get_long_os_version(), - keyword_name: CURRENT_KEYWORD_NAME.to_string(), - ram: current_info::get_ram_in_mb(), - processors: current_info::get_cpu_count() as u64, - - container_id: vm_meta_data.container_id, - tenant_name: vm_meta_data.tenant_name, - role_name: vm_meta_data.role_name, - role_instance_name: vm_meta_data.role_instance_name, - subscription_id: vm_meta_data.subscription_id, - resource_group_name: vm_meta_data.resource_group_name, - vm_id: vm_meta_data.vm_id, - image_origin: vm_meta_data.image_origin, + task_name: event_log.TaskName.clone(), + opcode_name: event_log.TimeStamp.clone(), + capability_used: event_log.EventLevel.clone(), + context1: message, + context2: event_log.TimeStamp.clone(), + context3: event_log.OperationId.clone(), } } - fn to_xml_event(&self) -> String { + pub fn get_provider_id() -> String { + METRICS_PROVIDER_ID.to_string() + } + + fn to_xml_event(&self, vm_data: &TelemetryEventVMData) -> String { let mut xml: String = String::new(); + // Event ID 7 is for Generic Logs Events xml.push_str("", - misc_helpers::xml_escape(self.opcode_name.to_string()) + "", + self.event_pid )); xml.push_str(&format!( - "", - misc_helpers::xml_escape(self.keyword_name.to_string()) + "", + self.event_tid )); xml.push_str(&format!( - "", - misc_helpers::xml_escape(self.task_name.to_string()) + "", + misc_helpers::xml_escape(self.ga_version.clone()) )); xml.push_str(&format!( - "", - misc_helpers::xml_escape(self.tenant_name.to_string()) + "", + misc_helpers::xml_escape(self.execution_mode.clone()) )); xml.push_str(&format!( - "", - misc_helpers::xml_escape(self.role_name.to_string()) + "", + misc_helpers::xml_escape(self.task_name.clone()) )); xml.push_str(&format!( - "", - misc_helpers::xml_escape(self.role_instance_name.to_string()) + "", + misc_helpers::xml_escape(self.opcode_name.clone()) )); + xml.push_str(&format!( - "", - misc_helpers::xml_escape(self.container_id.to_string()) + "", + misc_helpers::xml_escape(self.event_name.clone()) )); xml.push_str(&format!( - "", - misc_helpers::xml_escape(self.resource_group_name.to_string()) + "", + misc_helpers::xml_escape(self.capability_used.clone()) )); xml.push_str(&format!( - "", - misc_helpers::xml_escape(self.subscription_id.to_string()) + "", + misc_helpers::xml_escape(self.context1.clone()) )); xml.push_str(&format!( - "", - misc_helpers::xml_escape(self.vm_id.to_string()) + "", + misc_helpers::xml_escape(self.context2.clone()) )); + xml.push_str(&format!( + "", + misc_helpers::xml_escape(self.context3.clone()) + )); + + xml.push_str("]]>"); + xml + } +} + +#[derive(PartialEq, Eq, Hash, Clone)] +pub struct TelemetryExtensionEventsEvent { + event_pid: u64, + event_tid: u64, + ga_version: String, + task_name: String, + opcode_name: String, + execution_mode: String, + + extension_type: String, + is_internal: bool, + name: String, + version: String, + operation: String, + operation_success: bool, + message: String, + duration: u64, +} + +impl TelemetryExtensionEventsEvent { + pub fn from_extension_status_event( + event: &ExtensionStatusEvent, + execution_mode: String, + ga_version: String, + ) -> Self { + // redact secrets in the message before sending to telemetry + let message = event.operation_status.message.clone(); + let message = crate::secrets_redactor::redact_secrets_string(message); + TelemetryExtensionEventsEvent { + ga_version, + execution_mode, + event_pid: event.event_pid.parse::().unwrap_or(0), + event_tid: event.event_tid.parse::().unwrap_or(0), + opcode_name: event.time_stamp.clone(), + extension_type: event.extension.extension_type.clone(), + is_internal: event.extension.is_internal, + name: event.extension.name.clone(), + version: event.extension.version.clone(), + operation: event.operation_status.operation.clone(), + task_name: event.operation_status.task_name.clone(), + operation_success: event.operation_status.operation_success, + message, + duration: event.operation_status.duration as u64, + } + } + + pub fn get_provider_id() -> String { + STATUS_PROVIDER_ID.to_string() + } + + fn to_xml_event(&self, vm_data: &TelemetryEventVMData) -> String { + let mut xml: String = String::new(); + // Event ID 1 is for Extension Events + xml.push_str("", self.event_pid @@ -174,50 +457,56 @@ impl TelemetryEvent { self.event_tid )); xml.push_str(&format!( - "", - self.image_origin + "", + misc_helpers::xml_escape(self.ga_version.clone()) )); - xml.push_str(&format!( "", - misc_helpers::xml_escape(self.execution_mode.to_string()) + misc_helpers::xml_escape(self.execution_mode.clone()) )); xml.push_str(&format!( - "", - misc_helpers::xml_escape(self.os_version.to_string()) + "", + misc_helpers::xml_escape(self.task_name.clone()) )); xml.push_str(&format!( - "", - misc_helpers::xml_escape(self.ga_version.to_string()) + "", + misc_helpers::xml_escape(self.opcode_name.clone()) )); xml.push_str(&format!( - "", - self.ram + "", + misc_helpers::xml_escape(self.extension_type.clone()) )); xml.push_str(&format!( - "", - self.processors + "", + if self.is_internal { "True" } else { "False" } )); - xml.push_str(&format!( - "", - misc_helpers::xml_escape(self.event_name.to_string()) + "", + misc_helpers::xml_escape(self.name.clone()) )); xml.push_str(&format!( - "", - misc_helpers::xml_escape(self.capability_used.to_string()) + "", + misc_helpers::xml_escape(self.version.clone()) )); xml.push_str(&format!( - "", - misc_helpers::xml_escape(self.context1.to_string()) + "", + misc_helpers::xml_escape(self.operation.clone()) )); xml.push_str(&format!( - "", - misc_helpers::xml_escape(self.context2.to_string()) + "", + if self.operation_success { + "True" + } else { + "False" + } )); xml.push_str(&format!( - "", - misc_helpers::xml_escape(self.context3.to_string()) + "", + misc_helpers::xml_escape(self.message.clone()) + )); + xml.push_str(&format!( + "", + self.duration )); xml.push_str("]]>"); @@ -245,3 +534,437 @@ impl KeywordName { serde_json::to_string(self).unwrap_or_else(|_| "".to_owned()) } } + +#[cfg(test)] +mod tests { + use super::*; + + fn create_test_vm_meta_data() -> VmMetaData { + VmMetaData { + container_id: "test-container-id".to_string(), + tenant_name: "test-tenant".to_string(), + role_name: "test-role".to_string(), + role_instance_name: "test-role-instance".to_string(), + subscription_id: "test-subscription-id".to_string(), + resource_group_name: "test-resource-group".to_string(), + vm_id: "test-vm-id".to_string(), + image_origin: 1, + } + } + + fn create_test_vm_data() -> TelemetryEventVMData { + TelemetryEventVMData::new_from_vm_meta_data(&create_test_vm_meta_data()) + } + + fn create_test_event(message: &str) -> Event { + Event::new( + "Informational".to_string(), + message.to_string(), + "test_task".to_string(), + "test_module".to_string(), + ) + } + + fn create_test_telemetry_event(message: &str) -> TelemetryEvent { + let event_log = create_test_event(message); + TelemetryEvent::GenericLogsEvent(TelemetryGenericLogsEvent::from_event_log( + &event_log, + "test_execution_mode".to_string(), + "test_event_name".to_string(), + Some("1.0.0".to_string()), + )) + } + + /// Tests VmMetaData, TelemetryEventVMData creation and XML params generation + #[test] + fn test_vm_meta_data_and_vm_data() { + // Test VmMetaData clone + let meta_data = create_test_vm_meta_data(); + let cloned = meta_data.clone(); + assert_eq!(cloned.container_id, "test-container-id"); + assert_eq!(cloned.tenant_name, "test-tenant"); + assert_eq!(cloned.vm_id, "test-vm-id"); + assert_eq!(cloned.image_origin, 1); + + // Test TelemetryEventVMData creation from VmMetaData + let vm_data = TelemetryEventVMData::new_from_vm_meta_data(&meta_data); + assert_eq!(vm_data.container_id, "test-container-id"); + assert_eq!(vm_data.tenant_name, "test-tenant"); + assert_eq!(vm_data.role_name, "test-role"); + assert_eq!(vm_data.role_instance_name, "test-role-instance"); + assert_eq!(vm_data.subscription_id, "test-subscription-id"); + assert_eq!(vm_data.resource_group_name, "test-resource-group"); + assert_eq!(vm_data.vm_id, "test-vm-id"); + assert_eq!(vm_data.image_origin, 1); + // These are populated from current_info + assert!(!vm_data.keyword_name.is_empty()); + assert!(!vm_data.os_version.is_empty()); + assert!(vm_data.ram > 0); + assert!(vm_data.processors > 0); + + // Test XML params generation + let xml = vm_data.to_xml_params(); + assert!(xml.contains("KeywordName")); + assert!(xml.contains("TenantName")); + assert!(xml.contains("test-tenant")); + assert!(xml.contains("RoleName")); + assert!(xml.contains("ContainerId")); + assert!(xml.contains("ResourceGroupName")); + assert!(xml.contains("SubscriptionId")); + assert!(xml.contains("VMId")); + assert!(xml.contains("ImageOrigin")); + assert!(xml.contains("OSVersion")); + assert!(xml.contains("RAM")); + assert!(xml.contains("Processors")); + } + + /// Tests TelemetryProvider operations: add, remove, count, and XML generation + #[test] + fn test_telemetry_provider() { + let mut provider = TelemetryProvider::new(METRICS_PROVIDER_ID.to_string()); + assert_eq!(provider.id, METRICS_PROVIDER_ID); + assert_eq!(provider.event_count(), 0); + + // Add events + let event1 = create_test_telemetry_event("Test message 1"); + let event2 = create_test_telemetry_event("Test message 2"); + + provider.add_event(event1.clone()); + assert_eq!(provider.event_count(), 1); + + provider.add_event(event2); + assert_eq!(provider.event_count(), 2); + + // Remove event + let removed = provider.remove_event(event1.clone()); + assert!(removed.is_some()); + assert_eq!(provider.event_count(), 1); + + // Remove non-existent event returns None + let removed = provider.remove_event(event1); + assert!(removed.is_none()); + + // Test XML generation + let vm_data = create_test_vm_data(); + let xml = provider.to_xml(&vm_data); + assert!(xml.starts_with(&format!("", METRICS_PROVIDER_ID))); + assert!(xml.ends_with("")); + assert!(xml.contains("")); + } + + /// Tests TelemetryData operations: add, remove, count, size, and XML generation + #[test] + fn test_telemetry_data() { + let vm_data = create_test_vm_data(); + let mut telemetry_data = TelemetryData::new_with_vm_data(vm_data); + assert_eq!(telemetry_data.event_count(), 0); + + // Test empty XML + let empty_xml = telemetry_data.to_xml(); + assert!(empty_xml.starts_with("")); + assert!(empty_xml.contains("")); + assert!(empty_xml.contains("")); + assert!(!empty_xml.contains(" 0); + + // Add events + let event1 = create_test_telemetry_event("Test message 1"); + let event2 = create_test_telemetry_event("Test message 2"); + let event3 = create_test_telemetry_event("Test message 3"); + + telemetry_data.add_event(event1); + assert_eq!(telemetry_data.event_count(), 1); + + telemetry_data.add_event(event2); + telemetry_data.add_event(event3.clone()); + assert_eq!(telemetry_data.event_count(), 3); + + // Size should increase after adding events + let new_size = telemetry_data.get_size(); + assert!(new_size > initial_size); + + // Remove last event + let removed = telemetry_data.remove_last_event(event3); + assert!(removed.is_some()); + assert_eq!(telemetry_data.event_count(), 2); + + // Test XML with events + let xml = telemetry_data.to_xml(); + assert!(xml.starts_with("")); + assert!(xml.contains("")); + assert!(xml.contains(&format!("", METRICS_PROVIDER_ID))); + assert!(xml.contains("")); + } + + /// Tests TelemetryEvent and TelemetryGenericLogsEvent + #[test] + fn test_telemetry_event() { + // Test provider ID + let event = create_test_telemetry_event("Test message"); + assert_eq!(event.get_provider_id(), METRICS_PROVIDER_ID); + assert_eq!( + TelemetryGenericLogsEvent::get_provider_id(), + METRICS_PROVIDER_ID + ); + + // Test XML event generation + let vm_data = create_test_vm_data(); + let xml = event.to_xml_event(&vm_data); + assert!(xml.contains("")); + assert!(xml.contains("")); + assert!(xml.contains("EventName")); + assert!(xml.contains("CapabilityUsed")); + assert!(xml.contains("Context1")); + assert!(xml.contains("Context2")); + assert!(xml.contains("Context3")); + + // Test that different messages produce different events + let event2 = create_test_telemetry_event("Different message"); + assert!(event != event2); // Different messages create different events + } + + /// Tests TelemetryGenericLogsEvent from_event_log with and without ga_version + #[test] + fn test_telemetry_generic_logs_event_creation() { + let event_log = create_test_event("Test message"); + + // With ga_version provided + let event_with_version = TelemetryGenericLogsEvent::from_event_log( + &event_log, + "execution_mode".to_string(), + "event_name".to_string(), + Some("1.0.0".to_string()), + ); + assert_eq!(event_with_version.ga_version, "1.0.0"); + assert!(event_with_version.event_name.starts_with("event_name-")); + assert_eq!(event_with_version.execution_mode, "execution_mode"); + + // Without ga_version (None) + let event_without_version = TelemetryGenericLogsEvent::from_event_log( + &event_log, + "execution_mode".to_string(), + "event_name".to_string(), + None, + ); + assert_eq!(event_without_version.ga_version, event_log.Version); + assert_eq!(event_without_version.event_name, "event_name"); + } + + /// Tests KeywordName JSON serialization + #[test] + fn test_keyword_name() { + let keyword = KeywordName::new("x86_64".to_string()); + let json = keyword.to_json(); + + assert!(json.contains("CpuArchitecture")); + assert!(json.contains("x86_64")); + } + + fn create_test_extension_status_event() -> ExtensionStatusEvent { + let extension = crate::telemetry::Extension { + name: "test_extension".to_string(), + version: "2.0.0".to_string(), + is_internal: true, + extension_type: "test_type".to_string(), + }; + let operation_status = crate::telemetry::OperationStatus { + operation_success: true, + operation: "install".to_string(), + task_name: "test_task".to_string(), + message: "Installation successful".to_string(), + duration: 500, + }; + ExtensionStatusEvent::new(extension, operation_status) + } + + /// Tests TelemetryExtensionEventsEvent creation and XML generation + #[test] + fn test_telemetry_extension_events_event() { + let extension_status_event = create_test_extension_status_event(); + let telemetry_event = TelemetryExtensionEventsEvent::from_extension_status_event( + &extension_status_event, + "production".to_string(), + "1.0.0".to_string(), + ); + + // Verify field mappings + assert_eq!(telemetry_event.ga_version, "1.0.0"); + assert_eq!(telemetry_event.execution_mode, "production"); + assert_eq!(telemetry_event.extension_type, "test_type"); + assert!(telemetry_event.is_internal); + assert_eq!(telemetry_event.name, "test_extension"); + assert_eq!(telemetry_event.version, "2.0.0"); + assert_eq!(telemetry_event.operation, "install"); + assert_eq!(telemetry_event.task_name, "test_task"); + assert!(telemetry_event.operation_success); + assert_eq!(telemetry_event.message, "Installation successful"); + assert_eq!(telemetry_event.duration, 500); + + // Verify provider ID + assert_eq!( + TelemetryExtensionEventsEvent::get_provider_id(), + STATUS_PROVIDER_ID + ); + + // Test XML generation + let vm_data = create_test_vm_data(); + let xml = telemetry_event.to_xml_event(&vm_data); + assert!(xml.contains("")); + assert!(xml.contains("")); + assert!(xml.contains("ExtensionType")); + assert!(xml.contains("test_type")); + assert!(xml.contains("IsInternal")); + assert!(xml.contains("True")); // is_internal = true + assert!(xml.contains("Name")); + assert!(xml.contains("test_extension")); + assert!(xml.contains("Version")); + assert!(xml.contains("2.0.0")); + assert!(xml.contains("Operation")); + assert!(xml.contains("install")); + assert!(xml.contains("OperationSuccess")); + assert!(xml.contains("Message")); + assert!(xml.contains("Installation successful")); + assert!(xml.contains("Duration")); + assert!(xml.contains("500")); + } + + /// Tests TelemetryExtensionEventsEvent with operation failure + #[test] + fn test_telemetry_extension_events_event_failure() { + let extension = crate::telemetry::Extension { + name: "failed_extension".to_string(), + version: "1.0.0".to_string(), + is_internal: false, + extension_type: "external_type".to_string(), + }; + let operation_status = crate::telemetry::OperationStatus { + operation_success: false, + operation: "enable".to_string(), + task_name: "enable_task".to_string(), + message: "Enable failed with error".to_string(), + duration: 100, + }; + let extension_status_event = ExtensionStatusEvent::new(extension, operation_status); + + let telemetry_event = TelemetryExtensionEventsEvent::from_extension_status_event( + &extension_status_event, + "test".to_string(), + "2.0.0".to_string(), + ); + + assert!(!telemetry_event.is_internal); + assert!(!telemetry_event.operation_success); + assert_eq!(telemetry_event.name, "failed_extension"); + assert_eq!(telemetry_event.operation, "enable"); + + // Test XML with False values + let vm_data = create_test_vm_data(); + let xml = telemetry_event.to_xml_event(&vm_data); + assert!(xml.contains("IsInternal")); + assert!(xml.contains("\"False\"")); // is_internal = false + assert!(xml.contains("OperationSuccess")); + } + + /// Tests TelemetryEvent enum with ExtensionEvent variant + #[test] + fn test_telemetry_event_extension_variant() { + let extension_status_event = create_test_extension_status_event(); + let telemetry_event = TelemetryExtensionEventsEvent::from_extension_status_event( + &extension_status_event, + "production".to_string(), + "1.0.0".to_string(), + ); + let event = TelemetryEvent::ExtensionEvent(telemetry_event); + + // Test provider ID through enum + assert_eq!(event.get_provider_id(), STATUS_PROVIDER_ID); + + // Test XML generation through enum + let vm_data = create_test_vm_data(); + let xml = event.to_xml_event(&vm_data); + assert!(xml.contains("")); + assert!(xml.contains("ExtensionType")); + } + + /// Tests TelemetryData with mixed event types (GenericLogs and Extension events) + #[test] + fn test_telemetry_data_mixed_events() { + let vm_data = create_test_vm_data(); + let mut telemetry_data = TelemetryData::new_with_vm_data(vm_data); + + // Add generic logs event + let generic_event = create_test_telemetry_event("Generic log message"); + telemetry_data.add_event(generic_event); + assert_eq!(telemetry_data.event_count(), 1); + + // Add extension event + let extension_status_event = create_test_extension_status_event(); + let extension_telemetry_event = TelemetryExtensionEventsEvent::from_extension_status_event( + &extension_status_event, + "production".to_string(), + "1.0.0".to_string(), + ); + let extension_event = TelemetryEvent::ExtensionEvent(extension_telemetry_event); + telemetry_data.add_event(extension_event.clone()); + assert_eq!(telemetry_data.event_count(), 2); + + // Add another extension event + let extension_status_event2 = create_test_extension_status_event(); + let extension_telemetry_event2 = TelemetryExtensionEventsEvent::from_extension_status_event( + &extension_status_event2, + "production".to_string(), + "1.0.0".to_string(), + ); + let extension_event2 = TelemetryEvent::ExtensionEvent(extension_telemetry_event2); + telemetry_data.add_event(extension_event2); + assert_eq!(telemetry_data.event_count(), 3); + + // Verify XML contains both provider types + let xml = telemetry_data.to_xml(); + assert!(xml.contains(&format!("", METRICS_PROVIDER_ID))); + assert!(xml.contains(&format!("", STATUS_PROVIDER_ID))); + assert!(xml.contains("")); // Generic logs event + assert!(xml.contains("")); // Extension event + println!("{xml}"); + + // Remove extension event + let removed = telemetry_data.remove_last_event(extension_event); + assert!(removed.is_some()); + assert_eq!(telemetry_data.event_count(), 2); + } + + /// Tests TelemetryProvider with extension events + #[test] + fn test_telemetry_provider_with_extension_events() { + let mut provider = TelemetryProvider::new(STATUS_PROVIDER_ID.to_string()); + assert_eq!(provider.id, STATUS_PROVIDER_ID); + assert_eq!(provider.event_count(), 0); + + // Add extension events + let extension_status_event = create_test_extension_status_event(); + let telemetry_event = TelemetryExtensionEventsEvent::from_extension_status_event( + &extension_status_event, + "production".to_string(), + "1.0.0".to_string(), + ); + let event = TelemetryEvent::ExtensionEvent(telemetry_event); + provider.add_event(event.clone()); + assert_eq!(provider.event_count(), 1); + + // Test XML generation + let vm_data = create_test_vm_data(); + let xml = provider.to_xml(&vm_data); + assert!(xml.starts_with(&format!("", STATUS_PROVIDER_ID))); + assert!(xml.ends_with("")); + assert!(xml.contains("")); + + // Remove event + let removed = provider.remove_event(event); + assert!(removed.is_some()); + assert_eq!(provider.event_count(), 0); + } +} diff --git a/proxy_agent_shared/src/time_buckets.rs b/proxy_agent_shared/src/time_buckets.rs new file mode 100644 index 00000000..8c586394 --- /dev/null +++ b/proxy_agent_shared/src/time_buckets.rs @@ -0,0 +1,235 @@ +// Copyright (c) Microsoft Corporation +// SPDX-License-Identifier: MIT + +use std::collections::VecDeque; +use std::time::{Duration, SystemTime}; + +/// Trait for items that have a count field. +pub trait Countable { + fn set_count(&mut self, count: u64); +} + +/// A generic container that buckets counts over time. +/// This helps in aging out old items. +pub struct TimeBucketedItem { + item: T, // base info (count will be computed) + buckets: VecDeque<(SystemTime, u64)>, // (bucket_start, count_in_bucket) + bucket_duration: Duration, + max_age: Duration, +} + +impl TimeBucketedItem { + pub fn new(item: T, bucket_duration: Duration, max_age: Duration) -> Self { + let now = SystemTime::now(); + let mut buckets = VecDeque::new(); + buckets.push_back((now, 1)); + Self { + item, + buckets, + bucket_duration, + max_age, + } + } + + /// Adds one to the count. + /// Returns true if a new bucket was created. + pub fn add_one(&mut self) -> bool { + let now = SystemTime::now(); + self.prune_old_buckets(now); + + // Check if we can add to current bucket + if let Some((bucket_time, count)) = self.buckets.back_mut() { + if now.duration_since(*bucket_time).unwrap_or_default() < self.bucket_duration { + *count += 1; + return false; + } + } + // Create new bucket + self.buckets.push_back((now, 1)); + true + } + + /// Prunes buckets older than MAX_AGE_SECS. + fn prune_old_buckets(&mut self, now: SystemTime) { + let max_age = self.max_age; + while let Some((bucket_time, _)) = self.buckets.front() { + if now.duration_since(*bucket_time).unwrap_or(max_age) >= max_age { + self.buckets.pop_front(); + } else { + break; + } + } + } + + /// Gets the total count across all buckets. + fn get_count(&mut self) -> u64 { + self.prune_old_buckets(SystemTime::now()); + self.buckets.iter().map(|(_, c)| c).sum() + } + + /// Checks if there are no buckets left. + pub fn is_empty(&mut self) -> bool { + self.prune_old_buckets(SystemTime::now()); + self.buckets.is_empty() + } +} + +impl TimeBucketedItem { + /// Converts to the item type with updated count. + pub fn to_item(&mut self) -> T { + let mut result = self.item.clone(); + result.set_count(self.get_count()); + result + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::thread::sleep; + + #[derive(Clone, Debug, PartialEq)] + struct TestItem { + name: String, + count: u64, + } + + impl Countable for TestItem { + fn set_count(&mut self, count: u64) { + self.count = count; + } + } + + #[test] + fn test_new_creates_single_bucket() { + let item = TestItem { + name: "test".to_string(), + count: 0, + }; + let bucket_duration = Duration::from_secs(60); + let max_age = Duration::from_secs(300); + + // test_new_creates_single_bucket_with_count_one + let mut bucketed = TimeBucketedItem::new(item, bucket_duration, max_age); + assert_eq!(bucketed.get_count(), 1); + assert!(!bucketed.is_empty()); + + // test_add_one_increments_count_in_same_bucket() + let new_bucket = bucketed.add_one(); + + assert!(!new_bucket); // Should not create new bucket + assert_eq!(bucketed.get_count(), 2); + } + + #[test] + fn test_add_one_creates_new_bucket_after_duration() { + let item = TestItem { + name: "test".to_string(), + count: 0, + }; + // Use very short bucket duration for testing + let bucket_duration = Duration::from_millis(10); + let max_age = Duration::from_secs(300); + + let mut bucketed = TimeBucketedItem::new(item, bucket_duration, max_age); + + // Wait for bucket duration to pass + sleep(Duration::from_millis(15)); + + let new_bucket = bucketed.add_one(); + + assert!(new_bucket); // Should create new bucket + assert_eq!(bucketed.get_count(), 2); // Both buckets should count + } + + #[test] + fn test_prune_old_buckets_removes_expired_buckets() { + let item = TestItem { + name: "test".to_string(), + count: 0, + }; + // Use very short durations for testing + let bucket_duration = Duration::from_millis(5); + let max_age = Duration::from_millis(20); + + let mut bucketed = TimeBucketedItem::new(item, bucket_duration, max_age); + + // Add counts over time to create multiple buckets + sleep(Duration::from_millis(10)); + bucketed.add_one(); + + // Wait for max_age to pass for the first bucket + sleep(Duration::from_millis(25)); + + // This should prune the old bucket and create a new one + bucketed.add_one(); + + // The initial bucket should be pruned, only newer counts should remain + let count = bucketed.get_count(); + assert!(count <= 2); // Should have pruned at least the first bucket + } + + #[test] + fn test_is_empty_after_all_buckets_expire() { + let item = TestItem { + name: "test".to_string(), + count: 0, + }; + // Use very short max_age for testing + let bucket_duration = Duration::from_millis(5); + let max_age = Duration::from_millis(10); + + let mut bucketed = TimeBucketedItem::new(item, bucket_duration, max_age); + assert!(!bucketed.is_empty()); + + // Wait for all buckets to expire + sleep(Duration::from_millis(20)); + + assert!(bucketed.is_empty()); + } + + #[test] + fn test_to_item_returns_cloned_item_with_count() { + let item = TestItem { + name: "test".to_string(), + count: 0, + }; + let bucket_duration = Duration::from_secs(60); + let max_age = Duration::from_secs(300); + + let mut bucketed = TimeBucketedItem::new(item, bucket_duration, max_age); + bucketed.add_one(); + bucketed.add_one(); + + assert_eq!(bucketed.get_count(), 3); // 1 from new + 2 from add_one + let result = bucketed.to_item(); + + assert_eq!(result.name, "test"); + assert_eq!(result.count, 3); // 1 from new + 2 from add_one + } + + #[test] + fn test_bucket_count_accumulates_across_buckets() { + let item = TestItem { + name: "test".to_string(), + count: 0, + }; + // Use short bucket duration to force new buckets + let bucket_duration = Duration::from_millis(5); + let max_age = Duration::from_secs(60); + + let mut bucketed = TimeBucketedItem::new(item, bucket_duration, max_age); + + // Add some counts + bucketed.add_one(); + bucketed.add_one(); + + // Wait to create new bucket + sleep(Duration::from_millis(10)); + bucketed.add_one(); + bucketed.add_one(); + + // Total should be 5 (1 from new + 4 from add_one) + assert_eq!(bucketed.get_count(), 5); + } +}