Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions proxy_agent/src/key_keeper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
//! use std::time::Duration;
//!
//! let shared_state = SharedState::start_all();
//! let base_url = "http://127:0.0.1:8081/";
//! let host = "127.0.0.1".to_string();
//! let port = 8081u16;
//! let key_dir = PathBuf::from("path");
//! let interval = Duration::from_secs(10);
//! let config_start_redirector = false;
//! let key_keeper = key_keeper::KeyKeeper::new(base_url.parse().unwrap(), key_dir, interval, config_start_redirector, &shared_state);
//! let key_keeper = key_keeper::KeyKeeper::new(host, port, key_dir, interval, config_start_redirector, &shared_state);
//! tokio::spawn(key_keeper.poll_secure_channel_status());
//! ```

Expand All @@ -40,7 +41,6 @@ use crate::shared_state::provision_wrapper::ProvisionSharedState;
use crate::shared_state::redirector_wrapper::RedirectorSharedState;
use crate::shared_state::{EventThreadsSharedState, SharedState};
use crate::{acl, redirector};
use hyper::Uri;
use proxy_agent_shared::common_state::CommonState;
use proxy_agent_shared::logger::LoggerLevel;
use proxy_agent_shared::misc_helpers;
Expand All @@ -64,8 +64,10 @@ const DELAY_START_EVENT_THREADS_IN_MILLISECONDS: u128 = 60000; // 1 minute

#[derive(Clone)]
pub struct KeyKeeper {
/// base_url: the WireServer endpoint to poll the secure channel status
base_url: Uri,
/// host: the WireServer host to poll the secure channel status
host: String,
/// port: the WireServer port to poll the secure channel status
port: u16,
/// key_dir: the folder to save the key details
key_dir: PathBuf,
/// status_dir: the folder to log the access control rule details
Expand Down Expand Up @@ -101,14 +103,16 @@ enum WakeReason {

impl KeyKeeper {
pub fn new(
base_url: Uri,
host: String,
port: u16,
key_dir: PathBuf,
status_dir: PathBuf,
interval: Duration,
shared_state: &SharedState,
) -> Self {
KeyKeeper {
base_url,
host,
port,
key_dir,
status_dir,
interval,
Expand Down Expand Up @@ -247,7 +251,7 @@ impl KeyKeeper {
.await;
started_event_threads = self.handle_event_threads_start(started_event_threads).await;

let status = match key::get_status(&self.base_url).await {
let status = match key::get_status(&self.host, self.port).await {
Ok(s) => s,
Err(e) => {
self.update_status_message(format!("Failed to get key status - {e}"), true)
Expand Down Expand Up @@ -629,7 +633,7 @@ impl KeyKeeper {
/// Acquire key from server, persist it, and attest it
/// Returns true if successful, false if should continue to next iteration
async fn acquire_key_from_server(&self) -> bool {
let key = match key::acquire_key(&self.base_url).await {
let key = match key::acquire_key(&self.host, self.port).await {
Ok(k) => k,
Err(e) => {
self.update_status_message(format!("Failed to acquire key details: {e:?}"), true)
Expand Down Expand Up @@ -660,7 +664,7 @@ impl KeyKeeper {
}

// attest the key
match key::attest_key(&self.base_url, &key).await {
match key::attest_key(&self.host, self.port, &key).await {
Ok(()) => {
// update in memory
if let Err(e) = self.update_key_to_shared_state(key.clone()).await {
Expand Down Expand Up @@ -1051,7 +1055,8 @@ mod tests {
// start poll_secure_channel_status
let cloned_keys_dir = keys_dir.to_path_buf();
let key_keeper = KeyKeeper {
base_url: (format!("http://{}:{}/", ip, port)).parse().unwrap(),
host: ip.to_string(),
port,
key_dir: cloned_keys_dir.clone(),
status_dir: cloned_keys_dir.clone(),
interval: Duration::from_millis(10),
Expand Down
87 changes: 36 additions & 51 deletions proxy_agent/src/key_keeper/key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -720,40 +720,22 @@ impl Display for KeyAction {
const STATUS_URL: &str = "/secure-channel/status";
const KEY_URL: &str = "/secure-channel/key";

pub async fn get_status(base_url: &Uri) -> Result<KeyStatus> {
let (host, port) = hyper_client::host_port_from_uri(base_url)?;
let url = format!("http://{host}:{port}{STATUS_URL}");
let url: Uri = url.parse().map_err(|e| {
Error::Key(KeyErrorType::ParseKeyUrl(
base_url.to_string(),
STATUS_URL.to_string(),
e,
))
})?;
pub async fn get_status(host: &str, port: u16) -> Result<KeyStatus> {
let endpoint = hyper_client::HostEndpoint::new(host, port, STATUS_URL);
let mut headers = HashMap::new();
headers.insert(
hyper_client::METADATA_HEADER.to_string(),
"True ".to_string(),
);
let status: KeyStatus =
hyper_client::get(&url, &headers, None, None, logger::write_warning).await?;
hyper_client::get(&endpoint, &headers, None, None, logger::write_warning).await?;
status.validate()?;

Ok(status)
}

pub async fn acquire_key(base_url: &Uri) -> Result<Key> {
let (host, port) = hyper_client::host_port_from_uri(base_url)?;
let url = format!("http://{host}:{port}{KEY_URL}");
let url: Uri = url.parse().map_err(|e| {
Error::Key(KeyErrorType::ParseKeyUrl(
base_url.to_string(),
KEY_URL.to_string(),
e,
))
})?;

let (host, port) = hyper_client::host_port_from_uri(&url)?;
pub async fn acquire_key(host: &str, port: u16) -> Result<Key> {
let endpoint = hyper_client::HostEndpoint::new(host, port, KEY_URL);
let mut headers = HashMap::new();
headers.insert(
hyper_client::METADATA_HEADER.to_string(),
Expand All @@ -763,21 +745,26 @@ pub async fn acquire_key(base_url: &Uri) -> Result<Key> {
let body = r#"{"authorizationScheme": "Azure-HMAC-SHA256"}"#.to_string();
let request = hyper_client::build_request(
hyper::Method::POST,
&url,
&endpoint,
&headers,
Some(body.as_bytes()),
None,
None,
)?;

let response = hyper_client::send_request(&host, port, request, logger::write_warning)
.await
.map_err(|e| {
Error::Key(KeyErrorType::SendKeyRequest(
format!("{}", KeyAction::Acquire),
e.to_string(),
))
})?;
let response = hyper_client::send_request(
&endpoint.host,
endpoint.port,
request,
logger::write_warning,
)
.await
.map_err(|e| {
Error::Key(KeyErrorType::SendKeyRequest(
format!("{}", KeyAction::Acquire),
e.to_string(),
))
})?;

if response.status() != StatusCode::OK {
return Err(Error::Key(KeyErrorType::KeyResponse(
Expand All @@ -790,39 +777,37 @@ pub async fn acquire_key(base_url: &Uri) -> Result<Key> {
.map_err(Error::ProxyAgentSharedError)
}

pub async fn attest_key(base_url: &Uri, key: &Key) -> Result<()> {
pub async fn attest_key(host: &str, port: u16, key: &Key) -> Result<()> {
// secure-channel/key/{key_guid}/key-attestation
let (host, port) = hyper_client::host_port_from_uri(base_url)?;
let url = format!(
"http://{}:{}{}/{}/key-attestation",
host, port, KEY_URL, key.guid
);
let url: Uri = url
.parse()
.map_err(|e| Error::Key(KeyErrorType::ParseKeyUrl(base_url.to_string(), url, e)))?;

let path = format!("{}/{}/key-attestation", KEY_URL, key.guid);
let endpoint = hyper_client::HostEndpoint::new(host, port, &path);
let mut headers = HashMap::new();
headers.insert(
hyper_client::METADATA_HEADER.to_string(),
"True ".to_string(),
);
let request = hyper_client::build_request(
Method::POST,
&url,
&endpoint,
&headers,
None,
Some(key.guid.to_string()),
Some(key.key.to_string()),
)?;

let response = hyper_client::send_request(&host, port, request, logger::write_warning)
.await
.map_err(|e| {
Error::Key(KeyErrorType::SendKeyRequest(
format!("{}", KeyAction::Attest),
e.to_string(),
))
})?;
let response = hyper_client::send_request(
&endpoint.host,
endpoint.port,
request,
logger::write_warning,
)
.await
.map_err(|e| {
Error::Key(KeyErrorType::SendKeyRequest(
format!("{}", KeyAction::Attest),
e.to_string(),
))
})?;

if response.status() != StatusCode::OK {
return Err(Error::Key(KeyErrorType::KeyResponse(
Expand Down
13 changes: 4 additions & 9 deletions proxy_agent/src/provision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -661,17 +661,12 @@ pub mod provision_query {
// bool - true provision finished; false provision not finished
// String - provision error message, empty means provision success or provision failed.
async fn get_current_provision_status(&self, notify: bool) -> Result<ProvisionState> {
let provision_url: String = format!(
"http://{}:{}{}",
Ipv4Addr::LOCALHOST,
let endpoint = hyper_client::HostEndpoint::new(
Ipv4Addr::LOCALHOST.to_string(),
self.port,
PROVISION_URL_PATH
PROVISION_URL_PATH,
Comment thread
ZhidongPeng marked this conversation as resolved.
);

let provision_url: hyper::Uri = provision_url
.parse::<hyper::Uri>()
.map_err(|e| Error::ParseUrl(provision_url, e.to_string()))?;

let mut headers = HashMap::new();
headers.insert(
hyper_client::METADATA_HEADER.to_string(),
Expand All @@ -684,7 +679,7 @@ pub mod provision_query {
if notify {
headers.insert(constants::NOTIFY_HEADER.to_string(), "true".to_string());
}
hyper_client::get(&provision_url, &headers, None, None, logger::write_warning)
hyper_client::get(&endpoint, &headers, None, None, logger::write_warning)
.await
.map_err(Error::ProxyAgentSharedError)
}
Expand Down
12 changes: 5 additions & 7 deletions proxy_agent/src/proxy/proxy_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1045,10 +1045,10 @@ mod tests {
let sleep_duration = Duration::from_millis(100);
tokio::time::sleep(sleep_duration).await;

let url: hyper::Uri = format!("http://{}:{}/", host, port).parse().unwrap();
let endpoint = hyper_client::HostEndpoint::new(host, port, "/");
let request = hyper_client::build_request(
Method::GET,
&url,
&endpoint,
&HashMap::new(),
None,
key_keeper_shared_state
Expand Down Expand Up @@ -1085,12 +1085,10 @@ mod tests {
);

// test with traversal characters
let url: hyper::Uri = format!("http://{}:{}/test/../", host, port)
.parse()
.unwrap();
let endpoint = hyper_client::HostEndpoint::new(host, port, "/test/../");
let request = hyper_client::build_request(
Method::GET,
&url,
&endpoint,
&HashMap::new(),
None,
key_keeper_shared_state
Expand All @@ -1116,7 +1114,7 @@ mod tests {
let body = vec![88u8; super::REQUEST_BODY_LOW_LIMIT_SIZE + 1];
let request = hyper_client::build_request(
Method::POST,
&url,
&endpoint,
&HashMap::new(),
Some(body.as_slice()),
key_keeper_shared_state
Expand Down
6 changes: 3 additions & 3 deletions proxy_agent/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::proxy::proxy_server::ProxyServer;
use crate::redirector::{self, Redirector};
use crate::shared_state::SharedState;
use proxy_agent_shared::current_info;
use proxy_agent_shared::hyper_client::HostEndpoint;
use proxy_agent_shared::logger::rolling_logger::RollingLogger;
use proxy_agent_shared::logger::{logger_manager, LoggerLevel};
use proxy_agent_shared::proxy_agent_aggregate_status;
Expand Down Expand Up @@ -57,9 +58,8 @@ pub async fn start_service(shared_state: SharedState) {

tokio::spawn({
let key_keeper = KeyKeeper::new(
(format!("http://{}/", constants::WIRE_SERVER_IP))
.parse()
.unwrap(),
constants::WIRE_SERVER_IP.to_string(),
Comment thread
ZhidongPeng marked this conversation as resolved.
HostEndpoint::DEFAULT_HTTP_PORT,
config::get_keys_dir(),
proxy_agent_aggregate_status::get_proxy_agent_aggregate_status_folder(),
config::get_poll_key_status_duration(),
Expand Down
26 changes: 16 additions & 10 deletions proxy_agent_shared/src/host_clients/imds_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,17 @@
//! The GPA service uses the IMDS service to get the instance information of the VM.

use super::instance_info::InstanceInfo;
use crate::hyper_client;
use crate::hyper_client::{self, HostEndpoint};
use crate::logger::logger_manager;
use crate::{error::Error, result::Result};
use hyper::Uri;
use crate::result::Result;
use std::collections::HashMap;

pub struct ImdsClient {
ip: String,
port: u16,
}

const IMDS_URI: &str = "metadata/instance?api-version=2018-02-01";
const IMDS_URI: &str = "/metadata/instance?api-version=2018-02-01";
Comment thread
ZhidongPeng marked this conversation as resolved.

impl ImdsClient {
pub fn new(ip: &str, port: u16) -> Self {
Expand All @@ -27,19 +26,26 @@ impl ImdsClient {
}
}

fn endpoint(&self, path: &str) -> HostEndpoint {
HostEndpoint::new(&self.ip, self.port, path)
}

pub async fn get_imds_instance_info(
&self,
key_guid: Option<String>,
key: Option<String>,
) -> Result<InstanceInfo> {
let url: String = format!("http://{}:{}/{}", self.ip, self.port, IMDS_URI);

let url: Uri = url
.parse::<hyper::Uri>()
.map_err(|e| Error::ParseUrl(url, e.to_string()))?;
let endpoint = self.endpoint(IMDS_URI);
let mut headers = HashMap::new();
headers.insert("Metadata".to_string(), "true".to_string());

hyper_client::get(&url, &headers, key_guid, key, logger_manager::write_warn).await
hyper_client::get(
&endpoint,
&headers,
key_guid,
key,
logger_manager::write_warn,
)
.await
}
}
Loading
Loading