Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions proxy_agent/src/proxy/proxy_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ impl ProxyServer {
LoggerLevel::Error,
format!("Failed to increase connection count: {e}"),
);
return Ok(Self::empty_response(StatusCode::INTERNAL_SERVER_ERROR));
return Ok(Self::closed_response(StatusCode::INTERNAL_SERVER_ERROR));
}
};

Expand Down Expand Up @@ -405,7 +405,7 @@ impl ProxyServer {
"Traversal characters found in the request, return NOT_FOUND!".to_string(),
)
.await;
return Ok(Self::empty_response(StatusCode::NOT_FOUND));
return Ok(Self::closed_response(StatusCode::NOT_FOUND));
}

if http_connection_context.url == provision::provision_query::PROVISION_URL_PATH {
Expand All @@ -427,7 +427,7 @@ impl ProxyServer {
"No remote destination_ip found in the request, return!".to_string(),
)
.await;
return Ok(Self::empty_response(StatusCode::MISDIRECTED_REQUEST));
return Ok(Self::closed_response(StatusCode::MISDIRECTED_REQUEST));
}
};
let port = tcp_connection_context.destination_port;
Expand All @@ -441,7 +441,7 @@ impl ProxyServer {
"No claims found in the request, return!".to_string(),
)
.await;
return Ok(Self::empty_response(StatusCode::MISDIRECTED_REQUEST));
return Ok(Self::closed_response(StatusCode::MISDIRECTED_REQUEST));
}
};
http_connection_context.log(LoggerLevel::Trace, format!("Use lookup value:{ip}:{port}."));
Expand All @@ -455,7 +455,7 @@ impl ProxyServer {
format!("Failed to get claims json string: {e}"),
)
.await;
return Ok(Self::empty_response(StatusCode::MISDIRECTED_REQUEST));
return Ok(Self::closed_response(StatusCode::MISDIRECTED_REQUEST));
}
};
http_connection_context.log(LoggerLevel::Trace, claim_details.to_string());
Expand All @@ -477,7 +477,7 @@ impl ProxyServer {
format!("Failed to get access control rules: {e}"),
)
.await;
return Ok(Self::empty_response(StatusCode::INTERNAL_SERVER_ERROR));
return Ok(Self::closed_response(StatusCode::INTERNAL_SERVER_ERROR));
}
};
let result = proxy_authorizer::authorize(
Expand Down Expand Up @@ -505,7 +505,7 @@ impl ProxyServer {
format!("Block unauthorized request: {claim_details}"),
)
.await;
return Ok(Self::empty_response(StatusCode::FORBIDDEN));
return Ok(Self::closed_response(StatusCode::FORBIDDEN));
}
}

Expand All @@ -527,7 +527,7 @@ impl ProxyServer {
LoggerLevel::Error,
format!("Failed to add claims header: {host_claims} with error: {e}"),
);
return Ok(Self::empty_response(StatusCode::BAD_GATEWAY));
return Ok(Self::closed_response(StatusCode::BAD_GATEWAY));
}
},
);
Expand All @@ -540,7 +540,7 @@ impl ProxyServer {
LoggerLevel::Error,
format!("Failed to add date header with error: {e}"),
);
return Ok(Self::empty_response(StatusCode::BAD_GATEWAY));
return Ok(Self::closed_response(StatusCode::BAD_GATEWAY));
}
},
);
Expand All @@ -567,7 +567,7 @@ impl ProxyServer {
LoggerLevel::Error,
format!("Failed to convert request: {e}"),
);
return Ok(Self::empty_response(StatusCode::BAD_REQUEST));
return Ok(Self::closed_response(StatusCode::BAD_REQUEST));
}
};
let proxy_response = http_connection_context.send_request(request).await;
Expand Down Expand Up @@ -600,7 +600,7 @@ impl ProxyServer {
LoggerLevel::Warn,
"No MetaData header found in the request.".to_string(),
);
return Ok(Self::empty_response(StatusCode::BAD_REQUEST));
return Ok(Self::closed_response(StatusCode::BAD_REQUEST));
}
// Get the query time_tick
let query_time_tick = match request.headers().get(constants::TIME_TICK_HEADER) {
Expand Down Expand Up @@ -695,7 +695,7 @@ impl ProxyServer {
format!("Failed to send request to host: {e}"),
)
.await;
return Ok(Self::empty_response(http_status_code));
return Ok(Self::closed_response(http_status_code));
}
};

Expand Down Expand Up @@ -828,6 +828,17 @@ impl ProxyServer {
response
}

fn closed_response(status_code: StatusCode) -> Response<BoxBody<Bytes, hyper::Error>> {
let mut response = Self::empty_response(status_code);

// Add the Connection: close header to close the tcp connection
response
.headers_mut()
.insert(hyper::header::CONNECTION, HeaderValue::from_static("close"));

response
}

async fn handle_request_with_signature(
&self,
mut http_connection_context: HttpConnectionContext,
Expand All @@ -841,7 +852,7 @@ impl ProxyServer {
LoggerLevel::Error,
format!("Failed to receive the request body: {e}"),
);
return Ok(Self::empty_response(StatusCode::BAD_REQUEST));
return Ok(Self::closed_response(StatusCode::BAD_REQUEST));
}
};

Expand Down Expand Up @@ -887,7 +898,7 @@ impl ProxyServer {
"Failed to add authorization header: {authorization_value} with error: {e}"
),
);
return Ok(Self::empty_response(StatusCode::BAD_GATEWAY));
return Ok(Self::closed_response(StatusCode::BAD_GATEWAY));
}
},
);
Expand Down Expand Up @@ -965,15 +976,29 @@ mod tests {
.unwrap_or(None),
)
.unwrap();
let response = hyper_client::send_request(host, port, request, logger::write_warning)
let mut sender = hyper_client::build_http_sender(host, port, logger::write_warning)
.await
.unwrap();
let response = sender.send_request(request).await.unwrap();
assert_eq!(
http::StatusCode::MISDIRECTED_REQUEST,
response.status(),
"response.status must be MISDIRECTED_REQUEST."
);

// verify the connection is closed
response.headers().get("connection").map(|v| {
assert_eq!(
v.to_str().unwrap(),
"close",
"response.headers.connection must be close."
);
});
assert!(
sender.is_closed(),
"sender must be closed after the request."
);

// test with traversal characters
let url: hyper::Uri = format!("http://{}:{}/test/../", host, port)
.parse()
Expand Down
Loading