diff --git a/proxy_agent/src/key_keeper.rs b/proxy_agent/src/key_keeper.rs index 40b92c65..1c090d1d 100644 --- a/proxy_agent/src/key_keeper.rs +++ b/proxy_agent/src/key_keeper.rs @@ -15,11 +15,12 @@ //! use std::time::Duration; //! //! let shared_state = SharedState::start_all(); -//! let base_url = "http://127:0.0.1:8081/"; +//! let host = "127.0.0.1".to_string(); +//! let port = 8081u16; //! let key_dir = PathBuf::from("path"); //! let interval = Duration::from_secs(10); //! let config_start_redirector = false; -//! let key_keeper = key_keeper::KeyKeeper::new(base_url.parse().unwrap(), key_dir, interval, config_start_redirector, &shared_state); +//! let key_keeper = key_keeper::KeyKeeper::new(host, port, key_dir, interval, config_start_redirector, &shared_state); //! tokio::spawn(key_keeper.poll_secure_channel_status()); //! ``` @@ -40,7 +41,6 @@ use crate::shared_state::provision_wrapper::ProvisionSharedState; use crate::shared_state::redirector_wrapper::RedirectorSharedState; use crate::shared_state::{EventThreadsSharedState, SharedState}; use crate::{acl, redirector}; -use hyper::Uri; use proxy_agent_shared::common_state::CommonState; use proxy_agent_shared::logger::LoggerLevel; use proxy_agent_shared::misc_helpers; @@ -64,8 +64,10 @@ const DELAY_START_EVENT_THREADS_IN_MILLISECONDS: u128 = 60000; // 1 minute #[derive(Clone)] pub struct KeyKeeper { - /// base_url: the WireServer endpoint to poll the secure channel status - base_url: Uri, + /// host: the WireServer host to poll the secure channel status + host: String, + /// port: the WireServer port to poll the secure channel status + port: u16, /// key_dir: the folder to save the key details key_dir: PathBuf, /// status_dir: the folder to log the access control rule details @@ -101,14 +103,16 @@ enum WakeReason { impl KeyKeeper { pub fn new( - base_url: Uri, + host: String, + port: u16, key_dir: PathBuf, status_dir: PathBuf, interval: Duration, shared_state: &SharedState, ) -> Self { KeyKeeper { - base_url, + host, + port, key_dir, status_dir, interval, @@ -247,7 +251,7 @@ impl KeyKeeper { .await; started_event_threads = self.handle_event_threads_start(started_event_threads).await; - let status = match key::get_status(&self.base_url).await { + let status = match key::get_status(&self.host, self.port).await { Ok(s) => s, Err(e) => { self.update_status_message(format!("Failed to get key status - {e}"), true) @@ -629,7 +633,7 @@ impl KeyKeeper { /// Acquire key from server, persist it, and attest it /// Returns true if successful, false if should continue to next iteration async fn acquire_key_from_server(&self) -> bool { - let key = match key::acquire_key(&self.base_url).await { + let key = match key::acquire_key(&self.host, self.port).await { Ok(k) => k, Err(e) => { self.update_status_message(format!("Failed to acquire key details: {e:?}"), true) @@ -660,7 +664,7 @@ impl KeyKeeper { } // attest the key - match key::attest_key(&self.base_url, &key).await { + match key::attest_key(&self.host, self.port, &key).await { Ok(()) => { // update in memory if let Err(e) = self.update_key_to_shared_state(key.clone()).await { @@ -1051,7 +1055,8 @@ mod tests { // start poll_secure_channel_status let cloned_keys_dir = keys_dir.to_path_buf(); let key_keeper = KeyKeeper { - base_url: (format!("http://{}:{}/", ip, port)).parse().unwrap(), + host: ip.to_string(), + port, key_dir: cloned_keys_dir.clone(), status_dir: cloned_keys_dir.clone(), interval: Duration::from_millis(10), diff --git a/proxy_agent/src/key_keeper/key.rs b/proxy_agent/src/key_keeper/key.rs index 667b0932..b413e3e2 100644 --- a/proxy_agent/src/key_keeper/key.rs +++ b/proxy_agent/src/key_keeper/key.rs @@ -720,40 +720,22 @@ impl Display for KeyAction { const STATUS_URL: &str = "/secure-channel/status"; const KEY_URL: &str = "/secure-channel/key"; -pub async fn get_status(base_url: &Uri) -> Result { - let (host, port) = hyper_client::host_port_from_uri(base_url)?; - let url = format!("http://{host}:{port}{STATUS_URL}"); - let url: Uri = url.parse().map_err(|e| { - Error::Key(KeyErrorType::ParseKeyUrl( - base_url.to_string(), - STATUS_URL.to_string(), - e, - )) - })?; +pub async fn get_status(host: &str, port: u16) -> Result { + let endpoint = hyper_client::HostEndpoint::new(host, port, STATUS_URL); let mut headers = HashMap::new(); headers.insert( hyper_client::METADATA_HEADER.to_string(), "True ".to_string(), ); let status: KeyStatus = - hyper_client::get(&url, &headers, None, None, logger::write_warning).await?; + hyper_client::get(&endpoint, &headers, None, None, logger::write_warning).await?; status.validate()?; Ok(status) } -pub async fn acquire_key(base_url: &Uri) -> Result { - let (host, port) = hyper_client::host_port_from_uri(base_url)?; - let url = format!("http://{host}:{port}{KEY_URL}"); - let url: Uri = url.parse().map_err(|e| { - Error::Key(KeyErrorType::ParseKeyUrl( - base_url.to_string(), - KEY_URL.to_string(), - e, - )) - })?; - - let (host, port) = hyper_client::host_port_from_uri(&url)?; +pub async fn acquire_key(host: &str, port: u16) -> Result { + let endpoint = hyper_client::HostEndpoint::new(host, port, KEY_URL); let mut headers = HashMap::new(); headers.insert( hyper_client::METADATA_HEADER.to_string(), @@ -763,21 +745,26 @@ pub async fn acquire_key(base_url: &Uri) -> Result { let body = r#"{"authorizationScheme": "Azure-HMAC-SHA256"}"#.to_string(); let request = hyper_client::build_request( hyper::Method::POST, - &url, + &endpoint, &headers, Some(body.as_bytes()), None, None, )?; - let response = hyper_client::send_request(&host, port, request, logger::write_warning) - .await - .map_err(|e| { - Error::Key(KeyErrorType::SendKeyRequest( - format!("{}", KeyAction::Acquire), - e.to_string(), - )) - })?; + let response = hyper_client::send_request( + &endpoint.host, + endpoint.port, + request, + logger::write_warning, + ) + .await + .map_err(|e| { + Error::Key(KeyErrorType::SendKeyRequest( + format!("{}", KeyAction::Acquire), + e.to_string(), + )) + })?; if response.status() != StatusCode::OK { return Err(Error::Key(KeyErrorType::KeyResponse( @@ -790,17 +777,10 @@ pub async fn acquire_key(base_url: &Uri) -> Result { .map_err(Error::ProxyAgentSharedError) } -pub async fn attest_key(base_url: &Uri, key: &Key) -> Result<()> { +pub async fn attest_key(host: &str, port: u16, key: &Key) -> Result<()> { // secure-channel/key/{key_guid}/key-attestation - let (host, port) = hyper_client::host_port_from_uri(base_url)?; - let url = format!( - "http://{}:{}{}/{}/key-attestation", - host, port, KEY_URL, key.guid - ); - let url: Uri = url - .parse() - .map_err(|e| Error::Key(KeyErrorType::ParseKeyUrl(base_url.to_string(), url, e)))?; - + let path = format!("{}/{}/key-attestation", KEY_URL, key.guid); + let endpoint = hyper_client::HostEndpoint::new(host, port, &path); let mut headers = HashMap::new(); headers.insert( hyper_client::METADATA_HEADER.to_string(), @@ -808,21 +788,26 @@ pub async fn attest_key(base_url: &Uri, key: &Key) -> Result<()> { ); let request = hyper_client::build_request( Method::POST, - &url, + &endpoint, &headers, None, Some(key.guid.to_string()), Some(key.key.to_string()), )?; - let response = hyper_client::send_request(&host, port, request, logger::write_warning) - .await - .map_err(|e| { - Error::Key(KeyErrorType::SendKeyRequest( - format!("{}", KeyAction::Attest), - e.to_string(), - )) - })?; + let response = hyper_client::send_request( + &endpoint.host, + endpoint.port, + request, + logger::write_warning, + ) + .await + .map_err(|e| { + Error::Key(KeyErrorType::SendKeyRequest( + format!("{}", KeyAction::Attest), + e.to_string(), + )) + })?; if response.status() != StatusCode::OK { return Err(Error::Key(KeyErrorType::KeyResponse( diff --git a/proxy_agent/src/provision.rs b/proxy_agent/src/provision.rs index 4dc27281..82f5772c 100644 --- a/proxy_agent/src/provision.rs +++ b/proxy_agent/src/provision.rs @@ -661,17 +661,12 @@ pub mod provision_query { // bool - true provision finished; false provision not finished // String - provision error message, empty means provision success or provision failed. async fn get_current_provision_status(&self, notify: bool) -> Result { - let provision_url: String = format!( - "http://{}:{}{}", - Ipv4Addr::LOCALHOST, + let endpoint = hyper_client::HostEndpoint::new( + Ipv4Addr::LOCALHOST.to_string(), self.port, - PROVISION_URL_PATH + PROVISION_URL_PATH, ); - let provision_url: hyper::Uri = provision_url - .parse::() - .map_err(|e| Error::ParseUrl(provision_url, e.to_string()))?; - let mut headers = HashMap::new(); headers.insert( hyper_client::METADATA_HEADER.to_string(), @@ -684,7 +679,7 @@ pub mod provision_query { if notify { headers.insert(constants::NOTIFY_HEADER.to_string(), "true".to_string()); } - hyper_client::get(&provision_url, &headers, None, None, logger::write_warning) + hyper_client::get(&endpoint, &headers, None, None, logger::write_warning) .await .map_err(Error::ProxyAgentSharedError) } diff --git a/proxy_agent/src/proxy/proxy_server.rs b/proxy_agent/src/proxy/proxy_server.rs index 4519c80d..fb2a0abd 100644 --- a/proxy_agent/src/proxy/proxy_server.rs +++ b/proxy_agent/src/proxy/proxy_server.rs @@ -1045,10 +1045,10 @@ mod tests { let sleep_duration = Duration::from_millis(100); tokio::time::sleep(sleep_duration).await; - let url: hyper::Uri = format!("http://{}:{}/", host, port).parse().unwrap(); + let endpoint = hyper_client::HostEndpoint::new(host, port, "/"); let request = hyper_client::build_request( Method::GET, - &url, + &endpoint, &HashMap::new(), None, key_keeper_shared_state @@ -1085,12 +1085,10 @@ mod tests { ); // test with traversal characters - let url: hyper::Uri = format!("http://{}:{}/test/../", host, port) - .parse() - .unwrap(); + let endpoint = hyper_client::HostEndpoint::new(host, port, "/test/../"); let request = hyper_client::build_request( Method::GET, - &url, + &endpoint, &HashMap::new(), None, key_keeper_shared_state @@ -1116,7 +1114,7 @@ mod tests { let body = vec![88u8; super::REQUEST_BODY_LOW_LIMIT_SIZE + 1]; let request = hyper_client::build_request( Method::POST, - &url, + &endpoint, &HashMap::new(), Some(body.as_slice()), key_keeper_shared_state diff --git a/proxy_agent/src/service.rs b/proxy_agent/src/service.rs index db0f2459..c6d74b77 100644 --- a/proxy_agent/src/service.rs +++ b/proxy_agent/src/service.rs @@ -10,6 +10,7 @@ use crate::proxy::proxy_server::ProxyServer; use crate::redirector::{self, Redirector}; use crate::shared_state::SharedState; use proxy_agent_shared::current_info; +use proxy_agent_shared::hyper_client::HostEndpoint; use proxy_agent_shared::logger::rolling_logger::RollingLogger; use proxy_agent_shared::logger::{logger_manager, LoggerLevel}; use proxy_agent_shared::proxy_agent_aggregate_status; @@ -57,9 +58,8 @@ pub async fn start_service(shared_state: SharedState) { tokio::spawn({ let key_keeper = KeyKeeper::new( - (format!("http://{}/", constants::WIRE_SERVER_IP)) - .parse() - .unwrap(), + constants::WIRE_SERVER_IP.to_string(), + HostEndpoint::DEFAULT_HTTP_PORT, config::get_keys_dir(), proxy_agent_aggregate_status::get_proxy_agent_aggregate_status_folder(), config::get_poll_key_status_duration(), diff --git a/proxy_agent_shared/src/host_clients/imds_client.rs b/proxy_agent_shared/src/host_clients/imds_client.rs index 265e3794..cc067545 100644 --- a/proxy_agent_shared/src/host_clients/imds_client.rs +++ b/proxy_agent_shared/src/host_clients/imds_client.rs @@ -6,10 +6,9 @@ //! The GPA service uses the IMDS service to get the instance information of the VM. use super::instance_info::InstanceInfo; -use crate::hyper_client; +use crate::hyper_client::{self, HostEndpoint}; use crate::logger::logger_manager; -use crate::{error::Error, result::Result}; -use hyper::Uri; +use crate::result::Result; use std::collections::HashMap; pub struct ImdsClient { @@ -17,7 +16,7 @@ pub struct ImdsClient { port: u16, } -const IMDS_URI: &str = "metadata/instance?api-version=2018-02-01"; +const IMDS_URI: &str = "/metadata/instance?api-version=2018-02-01"; impl ImdsClient { pub fn new(ip: &str, port: u16) -> Self { @@ -27,19 +26,26 @@ impl ImdsClient { } } + fn endpoint(&self, path: &str) -> HostEndpoint { + HostEndpoint::new(&self.ip, self.port, path) + } + pub async fn get_imds_instance_info( &self, key_guid: Option, key: Option, ) -> Result { - let url: String = format!("http://{}:{}/{}", self.ip, self.port, IMDS_URI); - - let url: Uri = url - .parse::() - .map_err(|e| Error::ParseUrl(url, e.to_string()))?; + let endpoint = self.endpoint(IMDS_URI); let mut headers = HashMap::new(); headers.insert("Metadata".to_string(), "true".to_string()); - hyper_client::get(&url, &headers, key_guid, key, logger_manager::write_warn).await + hyper_client::get( + &endpoint, + &headers, + key_guid, + key, + logger_manager::write_warn, + ) + .await } } diff --git a/proxy_agent_shared/src/host_clients/wire_server_client.rs b/proxy_agent_shared/src/host_clients/wire_server_client.rs index 5d7ada35..e36eece3 100644 --- a/proxy_agent_shared/src/host_clients/wire_server_client.rs +++ b/proxy_agent_shared/src/host_clients/wire_server_client.rs @@ -4,14 +4,13 @@ //! This module contains the logic to interact with the wire server for sending telemetry data and getting goal state. use crate::host_clients::goal_state::{GoalState, SharedConfig}; -use crate::hyper_client; +use crate::hyper_client::{self, HostEndpoint}; use crate::{ error::{Error, WireServerErrorType}, logger::logger_manager, result::Result, }; use http::Method; -use hyper::Uri; use std::collections::HashMap; pub struct WireServerClient { @@ -19,8 +18,8 @@ pub struct WireServerClient { port: u16, } -const TELEMETRY_DATA_URI: &str = "machine/?comp=telemetrydata"; -const GOALSTATE_URI: &str = "machine?comp=goalstate"; +const TELEMETRY_DATA_URI: &str = "/machine/?comp=telemetrydata"; +const GOALSTATE_URI: &str = "/machine?comp=goalstate"; impl WireServerClient { pub fn new(ip: &str, port: u16) -> Self { @@ -30,15 +29,16 @@ impl WireServerClient { } } + fn endpoint(&self, path: &str) -> HostEndpoint { + HostEndpoint::new(&self.ip, self.port, path) + } + pub async fn send_telemetry_data(&self, xml_data: String) -> Result<()> { if xml_data.is_empty() { return Ok(()); } - let url = format!("http://{}:{}/{}", self.ip, self.port, TELEMETRY_DATA_URI); - let url: Uri = url - .parse::() - .map_err(|e| Error::ParseUrl(url, e.to_string()))?; + let endpoint = self.endpoint(TELEMETRY_DATA_URI); let mut headers = HashMap::new(); headers.insert("x-ms-version".to_string(), "2012-11-30".to_string()); headers.insert( @@ -48,15 +48,15 @@ impl WireServerClient { let request = hyper_client::build_request( Method::POST, - &url, + &endpoint, &headers, Some(xml_data.as_bytes()), None, // post telemetry data does not require signing None, )?; let response = match hyper_client::send_request( - &self.ip, - self.port, + &endpoint.host, + endpoint.port, request, logger_manager::write_warn, ) @@ -75,7 +75,7 @@ impl WireServerClient { if !status.is_success() { return Err(Error::WireServer( WireServerErrorType::Telemetry, - format!("Failed to get response from {url}, status code: {status}"), + format!("Failed to get response from {endpoint}, status code: {status}"), )); } @@ -87,16 +87,19 @@ impl WireServerClient { key_guid: Option, key: Option, ) -> Result { - let url = format!("http://{}:{}/{}", self.ip, self.port, GOALSTATE_URI); - let url = url - .parse::() - .map_err(|e| Error::ParseUrl(url, e.to_string()))?; + let endpoint = self.endpoint(GOALSTATE_URI); let mut headers = HashMap::new(); headers.insert("x-ms-version".to_string(), "2012-11-30".to_string()); - hyper_client::get(&url, &headers, key_guid, key, logger_manager::write_warn) - .await - .map_err(|e| Error::WireServer(WireServerErrorType::GoalState, e.to_string())) + hyper_client::get( + &endpoint, + &headers, + key_guid, + key, + logger_manager::write_warn, + ) + .await + .map_err(|e| Error::WireServer(WireServerErrorType::GoalState, e.to_string())) } pub async fn get_shared_config( @@ -106,13 +109,20 @@ impl WireServerClient { key: Option, ) -> Result { let mut headers = HashMap::new(); - let url = url - .parse::() - .map_err(|e| Error::ParseUrl(url, e.to_string()))?; headers.insert("x-ms-version".to_string(), "2012-11-30".to_string()); - hyper_client::get(&url, &headers, key_guid, key, logger_manager::write_warn) - .await - .map_err(|e| Error::WireServer(WireServerErrorType::SharedConfig, e.to_string())) + let uri = url + .parse::() + .map_err(|e| Error::ParseUrl(url.clone(), e.to_string()))?; + let endpoint = HostEndpoint::from_full_uri(uri)?; + hyper_client::get( + &endpoint, + &headers, + key_guid, + key, + logger_manager::write_warn, + ) + .await + .map_err(|e| Error::WireServer(WireServerErrorType::SharedConfig, e.to_string())) } } diff --git a/proxy_agent_shared/src/hyper_client.rs b/proxy_agent_shared/src/hyper_client.rs index 52ed8452..8fe933a6 100644 --- a/proxy_agent_shared/src/hyper_client.rs +++ b/proxy_agent_shared/src/hyper_client.rs @@ -31,8 +31,94 @@ pub const CLAIMS_IS_ROOT: &str = "isRoot"; const LF: &str = "\n"; +/// Pre-parsed HTTP endpoint containing host, port, and path/query. +/// Use this to avoid re-parsing URIs multiple times which is performance-sensitive. +#[derive(Debug, Clone)] +pub struct HostEndpoint { + pub host: String, + pub port: u16, + /// The path and query portion of the URI (e.g., "/api/status?version=1") + pub path_and_query: String, +} + +impl HostEndpoint { + pub const DEFAULT_HTTP_PORT: u16 = 80; + pub const DEFAULT_HTTPS_PORT: u16 = 443; + + /// Create a new HostEndpoint with explicit components + pub fn new(host: impl Into, port: u16, path_and_query: impl Into) -> Self { + Self { + host: host.into(), + port, + path_and_query: path_and_query.into(), + } + } + + /// Create a HostEndpoint from a full URI string (e.g., "http://host:port/path?query") + /// This will parse the URI and extract the host, port, and path/query components. + /// Remark: Do not use this function in performance-sensitive code paths, as URI parsing can be relatively expensive. + /// Instead, use the `new` constructor with pre-parsed components when possible. + /// Remark: This function assumes the URI is well-formed and contains a host. It will return an error if the URI is invalid or missing required components. + pub fn from_full_uri(uri: Uri) -> Result { + let host = match uri.host() { + Some(h) => h.to_string(), + None => { + return Err(Error::Hyper(HyperErrorType::RequestBuilder( + "URI must have a host".to_string(), + ))); + } + }; + let default_port = if uri.scheme_str() == Some("https") { + Self::DEFAULT_HTTPS_PORT + } else { + Self::DEFAULT_HTTP_PORT + }; + let port = uri.port_u16().unwrap_or(default_port); + let path_and_query = match uri.path_and_query() { + Some(pq) => pq.as_str().to_string(), + None => "/".to_string(), // default to root path + }; + + Ok(Self { + host, + port, + path_and_query, + }) + } + + /// Create a HostEndpoint from a URI string (e.g., "http://host:port/path?query") + /// This will parse the URI and extract the host, port, and path/query components. + /// Remark: Do not use this function in performance-sensitive code paths, as URI parsing can be relatively expensive. + /// Instead, use the `new` constructor with pre-parsed components when possible. + pub fn from_uri_str(uri_str: &str) -> Result { + let uri = uri_str.parse::().map_err(|e| { + Error::Hyper(HyperErrorType::RequestBuilder(format!( + "Failed to parse URI string: {uri_str} with error: {e}" + ))) + })?; + Self::from_full_uri(uri) + } + + /// Get the address string for TCP connection (host:port) + #[inline] + pub fn addr(&self) -> String { + format!("{}:{}", self.host, self.port) + } +} + +impl std::fmt::Display for HostEndpoint { + /// Format as full URI string (e.g., "http://host:port/path?query") + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "http://{}:{}{}", + self.host, self.port, self.path_and_query + ) + } +} + pub async fn get( - full_url: &Uri, + endpoint: &HostEndpoint, headers: &HashMap, key_guid: Option, key: Option, @@ -42,14 +128,13 @@ where T: DeserializeOwned, F: Fn(String) + Send + 'static, { - let request = build_request(Method::GET, full_url, headers, None, key_guid, key)?; + let request = build_request(Method::GET, endpoint, headers, None, key_guid, key)?; - let (host, port) = host_port_from_uri(full_url)?; - let response = send_request(&host, port, request, log_fun).await?; + let response = send_request(&endpoint.host, endpoint.port, request, log_fun).await?; let status = response.status(); if !status.is_success() { return Err(Error::Hyper(HyperErrorType::ServerError( - full_url.to_string(), + endpoint.to_string(), status, ))); } @@ -162,22 +247,20 @@ where pub fn build_request( method: http::Method, - full_url: &Uri, + endpoint: &HostEndpoint, headers: &HashMap, body: Option<&[u8]>, key_guid: Option, key: Option, ) -> Result>> { - let (host, _) = host_port_from_uri(full_url)?; - let mut request_builder = Request::builder() .method(method) - .uri(match full_url.path_and_query() { - Some(pq) => pq.as_str(), - None => full_url.path(), - }) + .uri(&endpoint.path_and_query) .header(DATE_HEADER, misc_helpers::get_date_time_rfc1123_string()) - .header(hyper::header::HOST, host) + // The header() method accepts types that implement Into, and &str implements this trait. + // The HeaderValue will internally copy the bytes (which is unavoidable since it needs to own the data), + // So you're not creating any intermediate String allocations. + .header(hyper::header::HOST, &endpoint.host) .header( CLAIMS_HEADER, format!("{{ \"{}\": \"{}\"}}", CLAIMS_IS_ROOT, true,), @@ -278,21 +361,6 @@ where Ok(sender) } -pub fn host_port_from_uri(full_url: &Uri) -> Result<(String, u16)> { - let host = match full_url.host() { - Some(h) => h.to_string(), - None => { - return Err(Error::ParseUrl( - full_url.to_string(), - "Failed to get host from uri".to_string(), - )) - } - }; - let port = full_url.port_u16().unwrap_or(80); - - Ok((host, port)) -} - /* StringToSign = Method + "\n" + HexEncoded(Body) + "\n" +