diff --git a/crates/aide/src/axum/inputs.rs b/crates/aide/src/axum/inputs.rs index 8da2d4a2..0d725eb0 100644 --- a/crates/aide/src/axum/inputs.rs +++ b/crates/aide/src/axum/inputs.rs @@ -207,14 +207,34 @@ where } } +/// Internal marker carrying scalar/tuple `Path` schemas to the route +/// step, where placeholder names are known. Drained before emit. +pub(crate) const PENDING_PATH_PARAMS_EXT: &str = "x-aide-pending-path-params"; + impl OperationInput for Path where T: JsonSchema, { fn operation_input(ctx: &mut crate::generate::GenContext, operation: &mut Operation) { let schema = ctx.schema.subschema_for::(); - let params = parameters_from_schema(ctx, schema, ParamLocation::Path); - add_parameters(ctx, operation, params); + let resolved = ctx.resolve_schema(&schema); + + // An object schema is the struct form: walk its fields as named params. + if resolved.get("properties").is_some() { + let params = parameters_from_schema(ctx, schema, ParamLocation::Path); + add_parameters(ctx, operation, params); + return; + } + + let pending = match resolved.get("prefixItems") { + Some(serde_json::Value::Array(items)) => items.clone(), + _ => vec![resolved.clone().to_value()], + }; + + operation.extensions.insert( + PENDING_PATH_PARAMS_EXT.to_string(), + serde_json::Value::Array(pending), + ); } } diff --git a/crates/aide/src/axum/mod.rs b/crates/aide/src/axum/mod.rs index 4329e69f..0f0316a9 100644 --- a/crates/aide/src/axum/mod.rs +++ b/crates/aide/src/axum/mod.rs @@ -171,10 +171,14 @@ use std::{convert::Infallible, future::Future, pin::Pin}; use crate::{ - generate::{self, in_context}, - openapi::{OpenApi, PathItem, ReferenceOr, SchemaObject}, + error::Error, + generate::{self, in_context, GenContext}, + openapi::{ + OpenApi, Parameter, ParameterData, ParameterSchemaOrContent, PathItem, PathStyle, + ReferenceOr, SchemaObject, + }, operation::OperationHandler, - util::{merge_paths, path_for_nested_route}, + util::{iter_operations_mut, merge_paths, path_for_nested_route}, OperationInput, OperationOutput, }; #[cfg(feature = "axum-tokio")] @@ -206,6 +210,115 @@ pub mod routing; #[cfg(all(feature = "macros", feature = "axum-extra-typed-routing"))] pub use aide_macros::axum_typed_path as typed_path; +/// Names of `{...}` placeholders in order. Strips leading `*` from +/// wildcard captures like `{*rest}`. +fn extract_path_placeholder_names(path: &str) -> Vec { + let mut names = Vec::new(); + let mut current = String::new(); + let mut in_param = false; + + for ch in path.chars() { + match ch { + '{' if !in_param => { + in_param = true; + current.clear(); + } + '}' if in_param => { + let name = current.trim().trim_start_matches('*'); + if !name.is_empty() { + names.push(name.to_string()); + } + in_param = false; + } + _ if in_param => current.push(ch), + _ => {} + } + } + + names +} + +/// Infer path parameters from stashed `Path` schemas, then add the path +/// item to `paths`, merging into any existing entry. +fn register_path_item( + paths: &mut IndexMap, + ctx: &mut GenContext, + path: &str, + mut path_item: PathItem, +) { + infer_path_parameters(ctx, path, &mut path_item); + if let Some(existing) = paths.get_mut(path) { + merge_paths(ctx, path, existing, path_item); + } else { + paths.insert(path.into(), path_item); + } +} + +/// Turn the schemas stashed by `Path::operation_input` into named +/// `Parameter::Path` entries. A parameter already set under the same name +/// (e.g. by a user transform) wins. +fn infer_path_parameters(ctx: &mut GenContext, path: &str, path_item: &mut PathItem) { + let placeholders = extract_path_placeholder_names(path); + + for (_method, operation) in iter_operations_mut(path_item) { + let Some(pending) = operation + .extensions + .swap_remove(inputs::PENDING_PATH_PARAMS_EXT) + else { + continue; + }; + + let serde_json::Value::Array(schemas) = pending else { + continue; + }; + + if schemas.len() != placeholders.len() { + ctx.error(Error::PathPlaceholderArityMismatch( + path.to_string(), + placeholders.len(), + schemas.len(), + )); + continue; + } + + for (name, schema_value) in placeholders.iter().zip(schemas) { + let already_present = operation.parameters.iter().any(|p| { + matches!( + p, + ReferenceOr::Item(Parameter::Path { parameter_data, .. }) + if parameter_data.name == *name + ) + }); + if already_present { + continue; + } + + let json_schema: schemars::Schema = schema_value.try_into().unwrap_or_default(); + + operation + .parameters + .push(ReferenceOr::Item(Parameter::Path { + parameter_data: ParameterData { + name: name.clone(), + description: None, + required: true, + format: ParameterSchemaOrContent::Schema(SchemaObject { + json_schema, + example: None, + external_docs: None, + }), + extensions: Default::default(), + deprecated: None, + example: None, + examples: Default::default(), + explode: None, + }, + style: PathStyle::Simple, + })); + } + } +} + /// A wrapper over [`axum::Router`] that adds /// API documentation-specific features. #[must_use] @@ -300,13 +413,7 @@ where #[tracing::instrument(skip_all, fields(path = path))] pub fn api_route(mut self, path: &str, mut method_router: ApiMethodRouter) -> Self { in_context(|ctx| { - let new_path_item = method_router.take_path_item(); - - if let Some(path_item) = self.paths.get_mut(path) { - merge_paths(ctx, path, path_item, new_path_item); - } else { - self.paths.insert(path.into(), new_path_item); - } + register_path_item(&mut self.paths, ctx, path, method_router.take_path_item()); }); self.router = self.router.route(path, method_router.router); @@ -323,13 +430,7 @@ where #[tracing::instrument(skip_all, fields(path = path))] pub fn api_route_with_tsr(mut self, path: &str, mut method_router: ApiMethodRouter) -> Self { in_context(|ctx| { - let new_path_item = method_router.take_path_item(); - - if let Some(path_item) = self.paths.get_mut(path) { - merge_paths(ctx, path, path_item, new_path_item); - } else { - self.paths.insert(path.into(), new_path_item); - } + register_path_item(&mut self.paths, ctx, path, method_router.take_path_item()); }); self.router = self.router.route_with_tsr(path, method_router.router); @@ -352,14 +453,9 @@ where ) -> Self { in_context(|ctx| { let mut p = method_router.take_path_item(); - let t = transform(TransformPathItem::new(&mut p)); - - if !t.hidden { - if let Some(path_item) = self.paths.get_mut(path) { - merge_paths(ctx, path, path_item, p); - } else { - self.paths.insert(path.into(), p); - } + let hidden = transform(TransformPathItem::new(&mut p)).hidden; + if !hidden { + register_path_item(&mut self.paths, ctx, path, p); } }); @@ -384,14 +480,9 @@ where ) -> Self { in_context(|ctx| { let mut p = method_router.take_path_item(); - let t = transform(TransformPathItem::new(&mut p)); - - if !t.hidden { - if let Some(path_item) = self.paths.get_mut(path) { - merge_paths(ctx, path, path_item, p); - } else { - self.paths.insert(path.into(), p); - } + let hidden = transform(TransformPathItem::new(&mut p)).hidden; + if !hidden { + register_path_item(&mut self.paths, ctx, path, p); } }); @@ -1010,4 +1101,216 @@ mod tests { routing::get(test_handler3.layer(tower_layer::Identity::new())), ); } + + #[test] + fn extract_path_placeholder_names_basic() { + use super::extract_path_placeholder_names; + + assert_eq!(extract_path_placeholder_names("/"), Vec::::new()); + assert_eq!( + extract_path_placeholder_names("/users"), + Vec::::new() + ); + assert_eq!( + extract_path_placeholder_names("/users/{id}"), + vec!["id".to_string()] + ); + assert_eq!( + extract_path_placeholder_names("/users/{user_id}/posts/{post_id}"), + vec!["user_id".to_string(), "post_id".to_string()], + ); + assert_eq!( + extract_path_placeholder_names("/assets/{*rest}"), + vec!["rest".to_string()] + ); + } + + #[cfg(feature = "axum")] + mod path_params { + use super::super::{infer_path_parameters, inputs::PENDING_PATH_PARAMS_EXT}; + use crate::{ + axum::{routing, ApiRouter}, + openapi::{Parameter, ParameterSchemaOrContent, ReferenceOr}, + }; + use axum::extract::Path; + use schemars::JsonSchema; + + fn path_param<'a>(app: &'a ApiRouter, path: &str, name: &str) -> &'a Parameter { + let op = app + .paths + .get(path) + .and_then(|p| p.get.as_ref()) + .unwrap_or_else(|| panic!("missing GET for {path}")); + op.parameters + .iter() + .find_map(|p| match p { + ReferenceOr::Item(param @ Parameter::Path { parameter_data, .. }) + if parameter_data.name == name => + { + Some(param) + } + _ => None, + }) + .unwrap_or_else(|| panic!("missing path param {name:?} for {path}")) + } + + fn param_schema(param: &Parameter) -> &schemars::Schema { + let ParameterSchemaOrContent::Schema(s) = ¶m.parameter_data_ref().format else { + panic!("expected schema-form parameter"); + }; + &s.json_schema + } + + #[test] + fn scalar_path_generates_named_parameter() { + async fn h(Path(_): Path) {} + + let app: ApiRouter = ApiRouter::new().api_route("/users/{user_id}", routing::get(h)); + + let param = path_param(&app, "/users/{user_id}", "user_id"); + assert!(param.parameter_data_ref().required); + assert_eq!(param_schema(param).get("type"), Some(&"integer".into())); + } + + #[test] + fn tuple_path_generates_parameters_in_order() { + async fn h(Path(_): Path<(u64, String)>) {} + + let app: ApiRouter = + ApiRouter::new().api_route("/users/{user_id}/posts/{post_id}", routing::get(h)); + + let user_id = path_param(&app, "/users/{user_id}/posts/{post_id}", "user_id"); + let post_id = path_param(&app, "/users/{user_id}/posts/{post_id}", "post_id"); + assert_eq!(param_schema(user_id).get("type"), Some(&"integer".into())); + assert_eq!(param_schema(post_id).get("type"), Some(&"string".into())); + } + + #[test] + fn struct_path_regression() { + #[derive(JsonSchema, serde::Deserialize)] + #[allow(dead_code)] + struct Params { + user_id: u64, + post_id: String, + } + async fn h(Path(_): Path) {} + + let app: ApiRouter = + ApiRouter::new().api_route("/users/{user_id}/posts/{post_id}", routing::get(h)); + + let _ = path_param(&app, "/users/{user_id}/posts/{post_id}", "user_id"); + let _ = path_param(&app, "/users/{user_id}/posts/{post_id}", "post_id"); + } + + #[test] + fn pending_extension_does_not_leak() { + async fn h(Path(_): Path) {} + + let app: ApiRouter = ApiRouter::new().api_route("/users/{id}", routing::get(h)); + + let op = app + .paths + .get("/users/{id}") + .and_then(|p| p.get.as_ref()) + .unwrap(); + assert!(!op.extensions.contains_key(PENDING_PATH_PARAMS_EXT)); + } + + #[test] + fn transform_override_is_respected() { + use crate::openapi::{ + ParameterData, ParameterSchemaOrContent, PathStyle, SchemaObject, + }; + use schemars::json_schema; + + async fn h(Path(_): Path) {} + + let app: ApiRouter = + ApiRouter::new().api_route_with("/users/{id}", routing::get(h), |mut t| { + t.inner_mut() + .get + .as_mut() + .unwrap() + .parameters + .push(ReferenceOr::Item(Parameter::Path { + parameter_data: ParameterData { + name: "id".into(), + description: Some("user-provided".into()), + required: true, + format: ParameterSchemaOrContent::Schema(SchemaObject { + json_schema: json_schema!({ + "type": "string", + "format": "uuid", + }), + example: None, + external_docs: None, + }), + extensions: Default::default(), + deprecated: None, + example: None, + examples: Default::default(), + explode: None, + }, + style: PathStyle::Simple, + })); + t + }); + + let op = app + .paths + .get("/users/{id}") + .and_then(|p| p.get.as_ref()) + .unwrap(); + let id_params: Vec<_> = op + .parameters + .iter() + .filter(|p| { + matches!(p, ReferenceOr::Item(Parameter::Path { parameter_data, .. }) + if parameter_data.name == "id") + }) + .collect(); + assert_eq!(id_params.len(), 1, "should not duplicate overridden param"); + let ReferenceOr::Item(Parameter::Path { parameter_data, .. }) = id_params[0] else { + unreachable!() + }; + assert_eq!(parameter_data.description.as_deref(), Some("user-provided")); + } + + #[test] + fn arity_mismatch_is_reported() { + use crate::{generate, openapi::PathItem}; + use std::sync::{Arc, Mutex}; + + let errors: Arc>> = Arc::new(Mutex::new(Vec::new())); + let captured = errors.clone(); + generate::on_error(move |e| captured.lock().unwrap().push(e.to_string())); + + let mut path_item = PathItem::default(); + let mut op = crate::openapi::Operation::default(); + op.extensions.insert( + PENDING_PATH_PARAMS_EXT.to_string(), + serde_json::json!([{"type": "integer"}, {"type": "integer"}]), + ); + path_item.get = Some(op); + + generate::in_context(|ctx| { + infer_path_parameters(ctx, "/a/{x}/b/{y}/c/{z}", &mut path_item); + }); + + let logged = errors.lock().unwrap(); + assert!( + logged + .iter() + .any(|m| m.contains("placeholder") && m.contains('3') && m.contains('2')), + "expected arity mismatch error, got: {logged:?}" + ); + assert!( + path_item.get.as_ref().unwrap().parameters.is_empty(), + "no parameters should be synthesized on arity mismatch" + ); + + // Clear the error handler so it doesn't leak to other tests on this thread. + generate::reset_context(); + } + } } diff --git a/crates/aide/src/error.rs b/crates/aide/src/error.rs index 101bfb85..fe1c71cb 100644 --- a/crates/aide/src/error.rs +++ b/crates/aide/src/error.rs @@ -27,6 +27,8 @@ pub enum Error { DuplicateRequestBody, #[error(r#"duplicate parameter "{0}" for the operation"#)] DuplicateParameter(String), + #[error(r#"path "{0}" has {1} placeholder(s), but the `Path` extractor has arity {2}"#)] + PathPlaceholderArityMismatch(String, usize, usize), #[error(r#"transformations do not support references"#)] UnexpectedReference, #[error("did not apply inferred response because a response for status {0} already exists")]