diff --git a/src/data/cel.rs b/src/data/cel.rs index 5880cc01..35add1ad 100644 --- a/src/data/cel.rs +++ b/src/data/cel.rs @@ -534,16 +534,7 @@ impl Predicate { }) } - pub fn test(&self, req_ctx: &ReqRespCtx) -> PredicateResult { - let mut cel_ctx = Context::default(); - self.test_with_ctx(req_ctx, &mut cel_ctx) - } - - pub fn test_with_ctx( - &self, - req_ctx: &ReqRespCtx, - cel_ctx: &mut Context<'_>, - ) -> PredicateResult { + pub fn test(&self, req_ctx: &ReqRespCtx, cel_ctx: &mut Context<'_>) -> PredicateResult { match self.expression.eval(req_ctx, cel_ctx) { Ok(AttributeState::Pending) => Ok(AttributeState::Pending), Ok(AttributeState::Available(value)) => match value { @@ -592,8 +583,9 @@ impl PredicateVec for Vec { .collect(); req_ctx.ensure_attributes(&paths); + let mut cel_ctx = Context::default(); for predicate in self.iter() { - match predicate.test(req_ctx)? { + match predicate.test(req_ctx, &mut cel_ctx)? { AttributeState::Pending => { return Ok(AttributeState::Pending); } @@ -1091,7 +1083,7 @@ mod tests { use crate::kuadrant::MockWasmHost; use crate::kuadrant::ReqRespCtx; use cel::objects::ValueType; - use cel::Value; + use cel::{Context, Value}; use std::collections::HashMap; use std::sync::Arc; @@ -1102,7 +1094,9 @@ mod tests { let ctx = ReqRespCtx::new(Arc::new(mock_host)); let predicate = Predicate::new("source.port == 65432").expect("This is valid CEL!"); assert_eq!( - predicate.test(&ctx).expect("This must evaluate properly!"), + predicate + .test(&ctx, &mut Context::default()) + .expect("This must evaluate properly!"), AttributeState::Available(true) ); } @@ -1239,7 +1233,9 @@ mod tests { ) .expect("This is valid!"); assert_eq!( - predicate.test(&ctx).expect("This must evaluate properly!"), + predicate + .test(&ctx, &mut Context::default()) + .expect("This must evaluate properly!"), AttributeState::Available(true) ); @@ -1253,7 +1249,9 @@ mod tests { ) .expect("This is valid!"); assert_eq!( - predicate.test(&ctx).expect("This must evaluate properly!"), + predicate + .test(&ctx, &mut Context::default()) + .expect("This must evaluate properly!"), AttributeState::Available(true) ); @@ -1263,7 +1261,9 @@ mod tests { let predicate = Predicate::route_rule("queryMap(request.query) == {'👾': ''}").expect("This is valid!"); assert_eq!( - predicate.test(&ctx).expect("This must evaluate properly!"), + predicate + .test(&ctx, &mut Context::default()) + .expect("This must evaluate properly!"), AttributeState::Available(true) ); } @@ -1394,7 +1394,10 @@ mod tests { "'👾' in queryMap(request.query) ? queryMap(request.query)['👾'] == '123' : false", ) .expect("This is valid!"); - assert_eq!(predicate.test(&ctx), Ok(AttributeState::Available(true))); + assert_eq!( + predicate.test(&ctx, &mut Context::default()), + Ok(AttributeState::Available(true)) + ); let headers = vec![ ("X-Auth".to_string(), "kuadrant".to_string()), @@ -1404,7 +1407,10 @@ mod tests { let ctx = ReqRespCtx::new(Arc::new(mock_host)); let predicate = Predicate::route_rule("request.headers.exists(h, h.lowerAscii() == 'x-auth' && request.headers[h] == 'kuadrant')").expect("This is valid!"); - assert_eq!(predicate.test(&ctx), Ok(AttributeState::Available(true))); + assert_eq!( + predicate.test(&ctx, &mut Context::default()), + Ok(AttributeState::Available(true)) + ); } #[test] @@ -1474,7 +1480,9 @@ mod tests { let ctx = ReqRespCtx::new(Arc::new(mock_host)); let predicate = Predicate::new("source.port == 65432").expect("This is valid CEL!"); - let result = predicate.test(&ctx).expect("Test should succeed"); + let result = predicate + .test(&ctx, &mut Context::default()) + .expect("Test should succeed"); assert_eq!(result, AttributeState::Pending); } @@ -1603,7 +1611,9 @@ mod tests { let ctx = ReqRespCtx::new(Arc::new(mock_host)); let predicate = Predicate::new("request.grpc.service == 'UserService'").expect("valid CEL"); assert_eq!( - predicate.test(&ctx).expect("must evaluate"), + predicate + .test(&ctx, &mut Context::default()) + .expect("must evaluate"), AttributeState::Available(true) ); } @@ -1623,7 +1633,9 @@ mod tests { let ctx = ReqRespCtx::new(Arc::new(mock_host)); let predicate = Predicate::new("request.grpc.method == 'GetUser'").expect("valid CEL"); assert_eq!( - predicate.test(&ctx).expect("must evaluate"), + predicate + .test(&ctx, &mut Context::default()) + .expect("must evaluate"), AttributeState::Available(true) ); } @@ -1640,7 +1652,9 @@ mod tests { let ctx = ReqRespCtx::new(Arc::new(mock_host)); let predicate = Predicate::new("has(request.grpc)").expect("valid CEL"); assert_eq!( - predicate.test(&ctx).expect("must evaluate"), + predicate + .test(&ctx, &mut Context::default()) + .expect("must evaluate"), AttributeState::Available(true) ); } @@ -1654,7 +1668,9 @@ mod tests { let ctx = ReqRespCtx::new(Arc::new(mock_host)); let predicate = Predicate::new("has(request.grpc)").expect("valid CEL"); assert_eq!( - predicate.test(&ctx).expect("must evaluate"), + predicate + .test(&ctx, &mut Context::default()) + .expect("must evaluate"), AttributeState::Available(false) ); } @@ -1668,7 +1684,9 @@ mod tests { let ctx = ReqRespCtx::new(Arc::new(mock_host)); let predicate = Predicate::new("has(request.grpc)").expect("valid CEL"); assert_eq!( - predicate.test(&ctx).expect("must evaluate"), + predicate + .test(&ctx, &mut Context::default()) + .expect("must evaluate"), AttributeState::Available(false) ); } @@ -1687,7 +1705,9 @@ mod tests { Predicate::new("has(request.grpc) && request.grpc.service == 'UserService'") .expect("valid CEL"); assert_eq!( - predicate.test(&ctx).expect("must evaluate"), + predicate + .test(&ctx, &mut Context::default()) + .expect("must evaluate"), AttributeState::Available(true) ); } @@ -1703,7 +1723,9 @@ mod tests { Predicate::new("has(request.grpc) && request.grpc.service == 'UserService'") .expect("valid CEL"); assert_eq!( - predicate.test(&ctx).expect("must evaluate"), + predicate + .test(&ctx, &mut Context::default()) + .expect("must evaluate"), AttributeState::Available(false) ); } diff --git a/src/kuadrant/context.rs b/src/kuadrant/context.rs index c864ba2c..95d04403 100644 --- a/src/kuadrant/context.rs +++ b/src/kuadrant/context.rs @@ -1,12 +1,13 @@ -use cel::Value; +use cel::{Context, Env, Value}; use std::cell::OnceCell; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::sync::Arc; -use tracing::{debug, warn}; +use tracing::{debug, error, warn}; use crate::data::attribute::{wasm_prop, AttributeError, AttributeState, AttributeValue, Path}; use crate::data::{Expression, Headers}; use crate::kuadrant::cache::{AttributeCache, CachedValue}; +use crate::kuadrant::pipeline::tasks::Task; use crate::kuadrant::resolver::{AttributeResolver, ProxyWasmHost}; use crate::services::ServiceError; use tracing_opentelemetry::OpenTelemetrySpanExt; @@ -28,6 +29,7 @@ pub struct ReqRespCtx { tracker: Tracker, stored_values: BTreeMap, pub barrier: Barrier, + pub cel: CelScope, } impl Default for ReqRespCtx { @@ -49,6 +51,7 @@ impl ReqRespCtx { tracker: Tracker::default(), stored_values: BTreeMap::new(), barrier: Barrier::default(), + cel: CelScope::default(), } } @@ -467,7 +470,7 @@ impl Barrier { match self.count.checked_sub(1) { Some(new_value) => self.count = new_value, None => { - tracing::error!( + error!( "Attempted to lower upstream barrier when count is already 0 - mismatched raise/lower pairs" ); } @@ -514,6 +517,64 @@ impl BodyContext { } } +pub struct CelScope { + env: Arc, + registered: HashSet, + bindings: BTreeMap>, +} + +impl Default for CelScope { + fn default() -> Self { + Self { + env: Arc::new(Env::stdlib()), + registered: HashSet::new(), + bindings: BTreeMap::new(), + } + } +} + +impl CelScope { + pub fn new_ctx(&mut self, task: &dyn Task) -> Context<'static> { + let task_id = task.id(); + + if !self.registered.contains(task_id) { + let types = task.cel_types(); + if !types.is_empty() { + match Arc::get_mut(&mut self.env) { + Some(env) => { + for type_def in types { + env.add_struct(type_def); + } + } + None => error!("Failed to add CEL types: Arc refcount > 1"), + } + } + self.registered.insert(task_id.to_string()); + } + + let mut ctx = Context::with_env(Arc::clone(&self.env)); + for (scope_id, scope_bindings) in &self.bindings { + if is_ancestor(scope_id, task_id) { + for (name, value) in scope_bindings { + ctx.add_variable_from_value(name, value.clone()); + } + } + } + ctx + } + + pub fn add_scoped_binding(&mut self, task_id: &str, name: String, val: Value) { + self.bindings + .entry(task_id.to_string()) + .or_default() + .push((name, val)); + } +} + +fn is_ancestor(scope_id: &str, task_id: &str) -> bool { + task_id == scope_id || task_id.starts_with(&format!("{}.", scope_id)) +} + #[cfg(test)] mod tests { use super::*; @@ -754,4 +815,66 @@ mod tests { Ok(AttributeState::Available(Some(ref s))) if s == "external-user-id" )); } + + #[test] + fn test_cel_scope_hierarchical_bindings() { + use crate::kuadrant::pipeline::tasks::{Task, TaskOutcome}; + + struct MockTask { + id: String, + } + impl Task for MockTask { + fn id(&self) -> &str { + &self.id + } + fn apply(self: Box, _ctx: &mut ReqRespCtx) -> TaskOutcome { + TaskOutcome::Done + } + } + + let mut scope = CelScope::default(); + + // Task "0" gets a binding + scope.add_scoped_binding( + "0", + "my_response".to_string(), + Value::String(Arc::new("user123".to_string())), + ); + + // Task "0.0" should see "my_response" from parent "0" + let task_0_0 = MockTask { + id: "0.0".to_string(), + }; + let ctx_0_0 = scope.new_ctx(&task_0_0); + assert!(ctx_0_0.get_variable("my_response").is_some()); + + // Task "1" should NOT see "my_response" (different branch) + let task_1 = MockTask { + id: "1".to_string(), + }; + let ctx_1 = scope.new_ctx(&task_1); + assert!(ctx_1.get_variable("my_response").is_none()); + } + + #[test] + fn test_is_ancestor() { + // Same task + assert!(is_ancestor("0", "0")); + + // Direct child + assert!(is_ancestor("0", "0.0")); + + // Nested child + assert!(is_ancestor("0", "0.0.1")); + + // Different branch + assert!(!is_ancestor("0", "1")); + assert!(!is_ancestor("0", "1.0")); + + // Sibling + assert!(!is_ancestor("0.0", "0.1")); + + // Parent relationship is not symmetric + assert!(!is_ancestor("0.0", "0")); + } } diff --git a/src/kuadrant/pipeline/blueprint.rs b/src/kuadrant/pipeline/blueprint.rs index 291fe9e3..6c13815b 100644 --- a/src/kuadrant/pipeline/blueprint.rs +++ b/src/kuadrant/pipeline/blueprint.rs @@ -5,8 +5,8 @@ use crate::configuration::{ }; use crate::data::{cel::Predicate, Expression}; use crate::kuadrant::pipeline::tasks::{ - DynamicTask, ExportTracesTask, FailureModeTask, HeaderOperation, HeadersType, - ModifyHeadersTask, Task, TeardownAction, TokenUsageTask, TracingDecoratorTask, + DynamicTask, ExportTracesTask, FailureModeTask, HeadersType, ModifyHeadersTask, Task, + TeardownAction, TokenUsageTask, TracingDecoratorTask, }; use crate::kuadrant::ReqRespCtx; use crate::services::ServiceInstance; @@ -60,6 +60,112 @@ pub(crate) enum Operation { } impl Action { + fn to_core_task(&self, ctx: &mut ReqRespCtx) -> Option> { + match &self.operation { + Operation::Grpc { + service, + var, + message_builder, + on_reply, + } => { + let children: Vec> = on_reply + .iter() + .filter_map(|a| a.to_core_task(ctx)) + .collect(); + + match service { + ServiceInstance::Dynamic(dynamic_service) + | ServiceInstance::Auth(dynamic_service) + | ServiceInstance::RateLimit(dynamic_service) + | ServiceInstance::RateLimitCheck(dynamic_service) + | ServiceInstance::RateLimitReport(dynamic_service) => { + Some(Box::new(DynamicTask::new_with_attributes( + ctx, + self.id.clone(), + Rc::clone(dynamic_service), + var.clone(), + message_builder.clone(), + children, + self.predicate.clone(), + self.dependencies.clone(), + self.is_guard, + ))) + } + ServiceInstance::Tracing(_) => { + ctx.set_public_tracker_id(var.clone()); + #[allow(clippy::expect_used)] + let predicate = Predicate::new("true").expect("Needs to be valid!"); + #[allow(clippy::expect_used)] + let headers_expr = + Expression::new(&format!("[['{var}', '{}']]", ctx.request_id())) + .expect("Needs to be valid CEL!"); + Some(Box::new(ModifyHeadersTask::new( + self.id.clone(), + predicate, + headers_expr, + HeadersType::HttpResponseHeaders, + false, + ))) + } + } + } + Operation::Deny { deny_with } => { + use crate::kuadrant::pipeline::tasks::SendReplyTask; + Some(Box::new(SendReplyTask::new( + self.id.clone(), + self.predicate.clone(), + deny_with.clone(), + self.terminal, + ))) + } + Operation::Headers { + target, + headers: headers_expr, + } => Some(Box::new(ModifyHeadersTask::new( + self.id.clone(), + self.predicate.clone(), + headers_expr.clone(), + target.clone(), + self.terminal, + ))), + Operation::Store { + path, + expression, + export_to_host, + } => { + use crate::kuadrant::pipeline::tasks::StoreTask; + match StoreTask::new( + self.id.clone(), + self.predicate.clone(), + expression.clone(), + path.clone(), + *export_to_host, + self.terminal, + ) { + Ok(task) => Some(Box::new(task)), + Err(e) => { + tracing::error!( + "Failed to create StoreTask for path '{}': {}. Action {} will be skipped.", + path, + e, + self.id + ); + None + } + } + } + Operation::Fail { log_message } => { + use crate::kuadrant::pipeline::tasks::FailTask; + Some(Box::new(FailTask::new( + self.id.clone(), + self.predicate.clone(), + log_message.clone(), + self.terminal, + ))) + } + } + } + pub fn collect_body_values(&self, request_data: &[RequestData]) -> Vec { use std::collections::HashSet; @@ -185,13 +291,13 @@ impl Blueprint { .enumerate() .map(|(i, action_config)| { let id = i.to_string(); - let dependencies = if i > 0 { - vec![(i - 1).to_string()] - } else { - vec![] - }; match action_config { configuration::ActionConfig::Legacy(action) => { + let dependencies = if i > 0 { + vec![(i - 1).to_string()] + } else { + vec![] + }; let legacy_request_data: Vec<((String, String), String)> = request_data .iter() .map(|(key, expr)| (key.clone(), expr.source().to_string())) @@ -199,7 +305,7 @@ impl Blueprint { Action::compile(action, services, id, dependencies, &legacy_request_data) } configuration::ActionConfig::Typed(typed) => { - Action::compile_typed(typed, services, id, dependencies) + Action::compile_typed(typed, services, id, vec![]) } } }) @@ -234,53 +340,33 @@ impl Blueprint { for action in &self.actions { match &action.operation { - Operation::Grpc { - service, - var, - message_builder, - on_reply, - } => { - let abort_on_failure = - service.failure_mode() == configuration::FailureMode::Deny; - - match service { - ServiceInstance::Tracing(tracing_service) => { - ctx.set_public_tracker_id(var.clone()); - tasks.push(Box::new(ModifyHeadersTask::new( - HeaderOperation::Append( - vec![(var.clone(), ctx.request_id().to_string())].into(), - ), - HeadersType::HttpResponseHeaders, + Operation::Grpc { service, .. } => match service { + ServiceInstance::Tracing(tracing_service) => { + if let Some(task) = action.to_core_task(ctx) { + tasks.push(task); + } + if let Some(service) = tracing_service { + teardown_tasks + .push(Box::new(ExportTracesTask::new(ctx, service.clone()))); + } + } + ServiceInstance::Dynamic(_) + | ServiceInstance::Auth(_) + | ServiceInstance::RateLimit(_) + | ServiceInstance::RateLimitCheck(_) + | ServiceInstance::RateLimitReport(_) => { + let body_values = action.collect_body_values(request_data); + if !body_values.is_empty() { + tasks.push(Box::new(TokenUsageTask::with_expected_response_fields( + body_values, ))); - if let Some(service) = tracing_service { - teardown_tasks - .push(Box::new(ExportTracesTask::new(ctx, service.clone()))); - } } - ServiceInstance::Dynamic(dynamic_service) - | ServiceInstance::Auth(dynamic_service) - | ServiceInstance::RateLimit(dynamic_service) - | ServiceInstance::RateLimitCheck(dynamic_service) - | ServiceInstance::RateLimitReport(dynamic_service) => { - let body_values = action.collect_body_values(request_data); - if !body_values.is_empty() { - tasks.push(Box::new( - TokenUsageTask::with_expected_response_fields(body_values), - )); - } - let task: Box = Box::new(DynamicTask::new_with_attributes( - ctx, - action.id.clone(), - Rc::clone(dynamic_service), - var.clone(), - message_builder.clone(), - on_reply.clone(), - action.predicate.clone(), - action.dependencies.clone(), - action.is_guard, - )); - let task = Box::new(FailureModeTask::new(task, abort_on_failure)); + if let Some(mut task) = action.to_core_task(ctx) { + let abort_on_failure = + service.failure_mode() == configuration::FailureMode::Deny; + task = Box::new(FailureModeTask::new(task, abort_on_failure)); + if tracing_enabled { let span_label = match service { ServiceInstance::Auth(_) => "auth", @@ -289,69 +375,22 @@ impl Blueprint { ServiceInstance::RateLimitReport(_) => "ratelimit_report", _ => "dynamic", }; - tasks.push(Box::new(TracingDecoratorTask::new( + task = Box::new(TracingDecoratorTask::new( span_label, task, action.sources.clone(), - ))); - } else { - tasks.push(task); + )); } + + tasks.push(task); } } - } - Operation::Deny { deny_with } => { - use crate::kuadrant::pipeline::tasks::SendReplyTask; - let task = SendReplyTask::new_deferred( - action.predicate.clone(), - deny_with.clone(), - action.terminal, - ); - tasks.push(Box::new(task)); - } - Operation::Headers { - target, - headers: headers_expr, - } => { - let task = ModifyHeadersTask::new_deferred( - action.predicate.clone(), - headers_expr.clone(), - target.clone(), - action.terminal, - ); - tasks.push(Box::new(task)); - } - Operation::Store { - path, - expression, - export_to_host, - } => { - use crate::kuadrant::pipeline::tasks::StoreTask; - match StoreTask::new( - action.predicate.clone(), - expression.clone(), - path.clone(), - *export_to_host, - action.terminal, - ) { - Ok(task) => tasks.push(Box::new(task)), - Err(e) => { - tracing::error!( - "Failed to create StoreTask for path '{}': {}. Action {} will be skipped.", - path, - e, - action.id - ); - } + }, + _ => { + if let Some(task) = action.to_core_task(ctx) { + tasks.push(task); } } - Operation::Fail { log_message } => { - tracing::error!( - "Top-level Fail operation is currently unsupported. Action {}: {}", - action.id, - log_message - ); - } } } @@ -435,12 +474,7 @@ impl Action { .enumerate() .map(|(idx, typed_action)| { let reply_id = format!("{}.{}", id, idx); - let reply_deps = if idx > 0 { - vec![format!("{}.{}", id, idx - 1)] - } else { - vec![] - }; - Action::compile_typed(typed_action, services, reply_id, reply_deps) + Action::compile_typed(typed_action, services, reply_id, vec![]) }) .collect::>()?; @@ -1027,6 +1061,7 @@ mod tests { assert!(matches!(service, ServiceInstance::Dynamic(_))); assert_eq!(on_reply.len(), 1); } - assert_eq!(blueprint.actions[1].dependencies, vec!["0"]); + + assert!(blueprint.actions[1].dependencies.is_empty()); } } diff --git a/src/kuadrant/pipeline/executor.rs b/src/kuadrant/pipeline/executor.rs index e6492813..0ceb24ee 100644 --- a/src/kuadrant/pipeline/executor.rs +++ b/src/kuadrant/pipeline/executor.rs @@ -70,7 +70,7 @@ impl Pipeline { let is_guard = task.is_guard(); // Create a new PendingTask with no-op processor let pending = Box::new(PendingTask::new( - task.id().unwrap_or_default(), + task.id().to_string(), Box::new(noop_response_processor(token_id, is_guard)), is_guard, )) as Box; @@ -113,12 +113,10 @@ impl Pipeline { continue; } - let task_id = task.id(); + let task_id = task.id().to_string(); match task.apply(&mut self.ctx) { TaskOutcome::Done => { - if let Some(id) = task_id { - self.completed_tasks.insert(id); - } + self.completed_tasks.insert(task_id); } TaskOutcome::Deferred { token_id, pending } => { if self.deferred_tasks.insert(token_id, pending).is_some() { @@ -162,17 +160,13 @@ impl Pipeline { Ok(_) => {} Err(err) => error!("Failed to set gRPC response data: {}", err), }; - let task_id = pending.id(); + let task_id = pending.id().to_string(); match pending.apply(&mut self.ctx) { TaskOutcome::Done => { - if let Some(id) = task_id { - self.completed_tasks.insert(id); - } + self.completed_tasks.insert(task_id); } TaskOutcome::Requeued(tasks) => { - if let Some(id) = task_id { - self.completed_tasks.insert(id); - } + self.completed_tasks.insert(task_id); for task in tasks.into_iter().rev() { self.task_queue.insert(0, task); } @@ -271,8 +265,8 @@ mod tests { } } - fn id(&self) -> Option { - Some(self.id.clone()) + fn id(&self) -> &str { + &self.id } fn dependencies(&self) -> &[String] { diff --git a/src/kuadrant/pipeline/mod.rs b/src/kuadrant/pipeline/mod.rs index 5fd928bb..e28bd48e 100644 --- a/src/kuadrant/pipeline/mod.rs +++ b/src/kuadrant/pipeline/mod.rs @@ -1,7 +1,7 @@ mod blueprint; mod executor; mod factory; -mod tasks; +pub(crate) mod tasks; pub(crate) use executor::{Pipeline, PipelineState}; pub(crate) use factory::PipelineFactory; diff --git a/src/kuadrant/pipeline/tasks/dynamic.rs b/src/kuadrant/pipeline/tasks/dynamic.rs index e7e6d2f9..0b6e6694 100644 --- a/src/kuadrant/pipeline/tasks/dynamic.rs +++ b/src/kuadrant/pipeline/tasks/dynamic.rs @@ -1,25 +1,21 @@ use std::rc::Rc; -use cel::Value; use tracing::{debug, error}; use crate::data::attribute::AttributeState; use crate::data::cel::Predicate; use crate::data::Expression; -use crate::kuadrant::pipeline::blueprint::{Action, Operation}; -use crate::kuadrant::pipeline::tasks::{ - HeaderOperation, ModifyHeadersTask, PendingTask, SendReplyTask, Task, TaskOutcome, -}; +use crate::kuadrant::pipeline::tasks::{PendingTask, Task, TaskOutcome}; use crate::kuadrant::ReqRespCtx; use crate::record_error; -use crate::services::{cel_value_to_header_pairs, DynamicService, MessageConverter}; +use crate::services::{DescriptorConverter, DynamicService}; pub struct DynamicTask { task_id: String, service: Rc, name: String, message_builder: Expression, - on_reply: Vec, + on_reply: Vec>, predicate: Predicate, dependencies: Vec, is_guard: bool, @@ -28,50 +24,17 @@ pub struct DynamicTask { impl DynamicTask { #[allow(clippy::too_many_arguments)] pub fn new_with_attributes( - ctx: &ReqRespCtx, + ctx: &mut ReqRespCtx, task_id: String, service: Rc, name: String, message_builder: Expression, - on_reply: Vec, + on_reply: Vec>, predicate: Predicate, dependencies: Vec, is_guard: bool, ) -> Self { - // Warm up the cache - let _ = predicate.test(ctx); - if let Ok(env) = service.cel_env() { - let mut cel_ctx = cel::Context::with_env(env); - let _ = message_builder.eval(ctx, &mut cel_ctx); - - for action in &on_reply { - let _ = action.predicate.test_with_ctx(ctx, &mut cel_ctx); - match &action.operation { - Operation::Grpc { - message_builder, - on_reply: nested_on_reply, - .. - } => { - let _ = message_builder.eval(ctx, &mut cel_ctx); - for nested_action in nested_on_reply { - let _ = nested_action.predicate.test_with_ctx(ctx, &mut cel_ctx); - } - } - Operation::Deny { deny_with } => { - let _ = deny_with.eval(ctx, &mut cel_ctx); - } - Operation::Headers { headers, .. } => { - let _ = headers.eval(ctx, &mut cel_ctx); - } - Operation::Store { expression, .. } => { - let _ = expression.eval(ctx, &mut cel_ctx); - } - Operation::Fail { .. } => {} - } - } - } - - Self { + let task = Self { task_id, service, name, @@ -80,13 +43,22 @@ impl DynamicTask { predicate, dependencies, is_guard, - } + }; + + task.warm(ctx); + task + } + + fn warm(&self, ctx: &mut ReqRespCtx) { + let mut cel_ctx = ctx.cel.new_ctx(self); + let _ = self.predicate.test(ctx, &mut cel_ctx); + let _ = self.message_builder.eval(ctx, &mut cel_ctx); } } impl Task for DynamicTask { - fn id(&self) -> Option { - Some(self.task_id.clone()) + fn id(&self) -> &str { + &self.task_id } fn dependencies(&self) -> &[String] { @@ -97,8 +69,23 @@ impl Task for DynamicTask { self.is_guard } + fn cel_types(&self) -> Vec { + (|| -> Result, Box> { + let input_desc = self.service.input_descriptor()?; + let output_desc = self.service.output_descriptor()?; + let mut types = DescriptorConverter::collect_struct_defs(&input_desc)?; + types.extend(DescriptorConverter::collect_struct_defs(&output_desc)?); + Ok(types) + })() + .unwrap_or_else(|e| { + error!("Failed to collect CEL types: {}", e); + vec![] + }) + } + fn apply(self: Box, ctx: &mut ReqRespCtx) -> TaskOutcome { - match self.predicate.test(ctx) { + let mut cel_ctx = ctx.cel.new_ctx(&*self); + match self.predicate.test(ctx, &mut cel_ctx) { Ok(AttributeState::Pending) => { return if ctx.response_body.is_end_of_stream() { TaskOutcome::Failed @@ -119,15 +106,7 @@ impl Task for DynamicTask { tracing::debug_span!("dynamic_request", task_id = self.task_id, name = self.name) .entered(); - let env = match self.service.cel_env() { - Ok(env) => env, - Err(e) => { - error!("Failed to get CEL environment: {e}"); - return TaskOutcome::Failed; - } - }; - - let mut cel_ctx = cel::Context::with_env(env); + let mut cel_ctx = ctx.cel.new_ctx(&*self); let cel_value = match self.message_builder.eval(ctx, &mut cel_ctx) { Ok(AttributeState::Pending) => { return if ctx.response_body.is_end_of_stream() { @@ -155,7 +134,7 @@ impl Task for DynamicTask { let service = self.service.clone(); let task_id = self.task_id.clone(); let name = self.name.clone(); - let on_reply = self.on_reply.clone(); + let on_reply = self.on_reply; let is_guard = self.is_guard; if is_guard { @@ -165,10 +144,10 @@ impl Task for DynamicTask { TaskOutcome::Deferred { token_id, pending: Box::new(PendingTask::new( - self.task_id, + task_id.clone(), Box::new(move |ctx| { let outcome = process_dynamic_response( - ctx, &service, &task_id, token_id, &name, &on_reply, + ctx, &service, &task_id, token_id, &name, on_reply, ); if is_guard { ctx.barrier.lower(); @@ -187,7 +166,7 @@ fn process_dynamic_response( task_id: &str, token_id: u32, name: &str, - on_reply: &[Action], + on_reply: Vec>, ) -> TaskOutcome { let span = tracing::debug_span!( "dynamic_response", @@ -214,155 +193,20 @@ fn process_dynamic_response( } if on_reply.is_empty() { - debug!("No onReply actions, completing"); + debug!("No onReply tasks, completing"); return TaskOutcome::Done; } - let mut cel_ctx = match service.response_cel_context(ctx, response_size, name) { - Ok(c) => c, + let cel_value = match service.get_response_cel_value(ctx, response_size) { + Ok(val) => val, Err(e) => { - record_error!("Failed to build response context: {e:?}"); + error!("Failed to get response CEL value: {e}"); return TaskOutcome::Failed; } }; - let mut tasks: Vec> = Vec::new(); - - for action in on_reply { - match action.predicate.test_with_ctx(ctx, &mut cel_ctx) { - Ok(AttributeState::Available(true)) => {} - Ok(AttributeState::Available(false)) => continue, - Ok(AttributeState::Pending) => { - //todo(@adam-cattermole): if we requeue here, we lose predicates as headers/store/sendreply are not modelled with predicates - } - Err(e) => { - error!("Failed to apply predicates: {e:?}"); - return TaskOutcome::Failed; - } - } - - match &action.operation { - Operation::Deny { deny_with } => match deny_with.eval(ctx, &mut cel_ctx) { - Ok(AttributeState::Pending) => { - error!("Unexpected pending state in onReply deny"); - return TaskOutcome::Failed; - } - Ok(AttributeState::Available(val @ Value::Struct(_))) => { - match SendReplyTask::try_from(val) { - Ok(task) => { - if action.terminal { - return TaskOutcome::Terminate(Box::new(task)); - } - tasks.push(Box::new(task)); - } - Err(e) => { - error!("Invalid DenyResponse: {e}"); - return TaskOutcome::Failed; - } - } - } - Ok(AttributeState::Available(other)) => { - error!("denyWith must return DenyResponse, got: {other:?}"); - return TaskOutcome::Failed; - } - Err(e) => { - error!("Failed to evaluate denyWith expression: {e}"); - return TaskOutcome::Failed; - } - }, - Operation::Headers { target, headers } => match headers.eval(ctx, &mut cel_ctx) { - Ok(AttributeState::Available(ref val)) => { - let pairs = cel_value_to_header_pairs(val); - if !pairs.is_empty() { - tasks.push(Box::new(ModifyHeadersTask::new( - HeaderOperation::Set(pairs.into()), - target.clone(), - ))); - } - } - Ok(AttributeState::Pending) => { - error!("Unexpected pending state in onReply headers"); - return TaskOutcome::Failed; - } - Err(e) => { - error!("Failed to evaluate headers expression: {e}"); - return TaskOutcome::Failed; - } - }, - Operation::Store { - path, - expression, - export_to_host, - } => match expression.eval(ctx, &mut cel_ctx) { - // todo(@adam-cattermole): this should be delegated to the StoreTask - Ok(AttributeState::Available(val)) => { - if *export_to_host { - match MessageConverter::cel_value_to_bytes(&val) { - Ok(bytes) => { - if let Err(e) = ctx.set_attribute(path, &bytes) { - error!("Failed to store attribute {path}: {e:?}"); - return TaskOutcome::Failed; - } - } - Err(e) => { - error!("Failed to convert value to bytes for '{path}': {e}"); - return TaskOutcome::Failed; - } - } - } - ctx.store_value(path.clone(), val); - } - Ok(AttributeState::Pending) => { - error!("Unexpected pending state in onReply store for '{path}'"); - return TaskOutcome::Failed; - } - Err(e) => { - error!("Failed to evaluate store expression for '{path}': {e}"); - return TaskOutcome::Failed; - } - }, - Operation::Fail { log_message } => { - error!("Action failure: {log_message}"); - return TaskOutcome::Failed; - } - Operation::Grpc { - service, - var, - message_builder, - on_reply: nested_on_reply, - } => match service { - crate::services::ServiceInstance::Dynamic(dynamic_service) - | crate::services::ServiceInstance::Auth(dynamic_service) - | crate::services::ServiceInstance::RateLimit(dynamic_service) - | crate::services::ServiceInstance::RateLimitCheck(dynamic_service) - | crate::services::ServiceInstance::RateLimitReport(dynamic_service) => { - let task = Box::new(DynamicTask::new_with_attributes( - ctx, - action.id.clone(), - Rc::clone(dynamic_service), - var.clone(), - message_builder.clone(), - nested_on_reply.clone(), - action.predicate.clone(), - action.dependencies.clone(), - action.is_guard, - )); - if action.terminal { - return TaskOutcome::Terminate(task); - } - tasks.push(task); - } - _ => { - error!("Unsupported service type for nested gRPC operation"); - return TaskOutcome::Failed; - } - }, - } - } + ctx.cel + .add_scoped_binding(task_id, name.to_string(), cel_value); - if tasks.is_empty() { - TaskOutcome::Done - } else { - TaskOutcome::Requeued(tasks) - } + TaskOutcome::Requeued(on_reply) } diff --git a/src/kuadrant/pipeline/tasks/fail.rs b/src/kuadrant/pipeline/tasks/fail.rs new file mode 100644 index 00000000..2870f294 --- /dev/null +++ b/src/kuadrant/pipeline/tasks/fail.rs @@ -0,0 +1,58 @@ +use tracing::error; + +use crate::data::attribute::AttributeState; +use crate::data::cel::Predicate; +use crate::kuadrant::pipeline::tasks::{SendReplyTask, Task, TaskOutcome}; +use crate::kuadrant::ReqRespCtx; +use crate::metrics::METRICS; + +pub struct FailTask { + task_id: String, + predicate: Predicate, + log_message: String, + terminal: bool, +} + +impl FailTask { + pub fn new(task_id: String, predicate: Predicate, log_message: String, terminal: bool) -> Self { + Self { + task_id, + predicate, + log_message, + terminal, + } + } +} + +impl Task for FailTask { + fn id(&self) -> &str { + &self.task_id + } + + fn apply(self: Box, ctx: &mut ReqRespCtx) -> TaskOutcome { + let mut cel_ctx = ctx.cel.new_ctx(&*self); + match self.predicate.test(ctx, &mut cel_ctx) { + Ok(AttributeState::Available(true)) => { + error!("Action failure: {}", self.log_message); + if self.terminal { + METRICS.errors().increment(); + TaskOutcome::Terminate(Box::new(SendReplyTask::default())) + } else { + TaskOutcome::Done + } + } + Ok(AttributeState::Available(false)) => TaskOutcome::Done, + Ok(AttributeState::Pending) => { + if ctx.response_body.is_end_of_stream() { + TaskOutcome::Failed + } else { + TaskOutcome::Requeued(vec![self]) + } + } + Err(e) => { + error!("Failed to evaluate log task predicate: {e:?}"); + TaskOutcome::Failed + } + } + } +} diff --git a/src/kuadrant/pipeline/tasks/failure_mode.rs b/src/kuadrant/pipeline/tasks/failure_mode.rs index 97760a3a..4ea6472b 100644 --- a/src/kuadrant/pipeline/tasks/failure_mode.rs +++ b/src/kuadrant/pipeline/tasks/failure_mode.rs @@ -39,7 +39,7 @@ impl Task for FailureModeTask { } } - fn id(&self) -> Option { + fn id(&self) -> &str { self.task.id() } diff --git a/src/kuadrant/pipeline/tasks/headers.rs b/src/kuadrant/pipeline/tasks/headers.rs index d9330dae..bf9f13b6 100644 --- a/src/kuadrant/pipeline/tasks/headers.rs +++ b/src/kuadrant/pipeline/tasks/headers.rs @@ -29,51 +29,27 @@ impl From<&HeadersType> for Path { } } -enum HeadersMode { - Concrete { operation: HeaderOperation }, - Deferred { headers_expr: Expression }, -} - #[derive(Clone)] pub struct ModifyHeadersTask { - predicate: Option, - mode: HeadersMode, + task_id: String, + predicate: Predicate, + headers_expr: Expression, target: HeadersType, terminal: bool, } -impl Clone for HeadersMode { - fn clone(&self) -> Self { - match self { - HeadersMode::Concrete { operation } => HeadersMode::Concrete { - operation: operation.clone(), - }, - HeadersMode::Deferred { headers_expr } => HeadersMode::Deferred { - headers_expr: headers_expr.clone(), - }, - } - } -} - impl ModifyHeadersTask { - pub fn new(operation: HeaderOperation, target: HeadersType) -> ModifyHeadersTask { - ModifyHeadersTask { - predicate: None, - mode: HeadersMode::Concrete { operation }, - target, - terminal: false, - } - } - - pub fn new_deferred( + pub fn new( + task_id: String, predicate: Predicate, headers_expr: Expression, target: HeadersType, terminal: bool, ) -> Self { Self { - predicate: Some(predicate), - mode: HeadersMode::Deferred { headers_expr }, + task_id, + predicate, + headers_expr, target, terminal, } @@ -81,42 +57,39 @@ impl ModifyHeadersTask { } impl Task for ModifyHeadersTask { + fn id(&self) -> &str { + &self.task_id + } + fn apply(self: Box, ctx: &mut ReqRespCtx) -> TaskOutcome { - if let Some(ref predicate) = self.predicate { - match predicate.test(ctx) { - Ok(AttributeState::Available(true)) => {} - Ok(AttributeState::Available(false)) => return TaskOutcome::Done, - Ok(AttributeState::Pending) => { - return TaskOutcome::Requeued(vec![self]); - } - Err(e) => { - error!("Failed to evaluate predicate: {e:?}"); - return TaskOutcome::Failed; - } + let mut cel_ctx = ctx.cel.new_ctx(&*self); + match self.predicate.test(ctx, &mut cel_ctx) { + Ok(AttributeState::Available(true)) => {} + Ok(AttributeState::Available(false)) => return TaskOutcome::Done, + Ok(AttributeState::Pending) => { + return TaskOutcome::Requeued(vec![self]); + } + Err(e) => { + error!("Failed to evaluate predicate: {e:?}"); + return TaskOutcome::Failed; } } - let operation = match &self.mode { - HeadersMode::Concrete { operation } => operation.clone(), - HeadersMode::Deferred { headers_expr } => { - let mut cel_ctx = cel::Context::default(); - match headers_expr.eval(ctx, &mut cel_ctx) { - Ok(AttributeState::Pending) => { - error!("Unexpected pending state in headers expression"); - return TaskOutcome::Failed; - } - Ok(AttributeState::Available(ref val)) => { - let pairs = cel_value_to_header_pairs(val); - if pairs.is_empty() { - return TaskOutcome::Done; - } - HeaderOperation::Set(pairs.into()) - } - Err(e) => { - error!("Failed to evaluate headers expression: {e}"); - return TaskOutcome::Failed; - } + let operation = match self.headers_expr.eval(ctx, &mut cel_ctx) { + Ok(AttributeState::Pending) => { + error!("Unexpected pending state in headers expression"); + return TaskOutcome::Failed; + } + Ok(AttributeState::Available(ref val)) => { + let pairs = cel_value_to_header_pairs(val); + if pairs.is_empty() { + return TaskOutcome::Done; } + HeaderOperation::Set(pairs.into()) + } + Err(e) => { + error!("Failed to evaluate headers expression: {e}"); + return TaskOutcome::Failed; } }; @@ -175,6 +148,9 @@ impl Task for ModifyHeadersTask { #[cfg(test)] mod tests { use super::*; + use crate::data::attribute::Path; + use crate::data::cel::Predicate; + use crate::data::Expression; use crate::kuadrant::MockWasmHost; use std::sync::Arc; @@ -186,11 +162,15 @@ mod tests { let backend = Arc::new(mock_host); let mut ctx = ReqRespCtx::new(backend); - let new_headers: Headers = vec![("New-Key".to_string(), "New-Value".to_string())].into(); + let predicate = Predicate::new("true").unwrap(); + let headers_expr = Expression::new("[['New-Key', 'New-Value']]").unwrap(); let task = Box::new(ModifyHeadersTask::new( - HeaderOperation::Append(new_headers), + "0".to_string(), + predicate, + headers_expr, HeadersType::HttpRequestHeaders, + false, )); let outcome = task.apply(&mut ctx); @@ -218,12 +198,15 @@ mod tests { let backend = Arc::new(mock_host); let mut ctx = ReqRespCtx::new(backend); - let new_headers: Headers = - vec![("Content-Type".to_string(), "application/json".to_string())].into(); + let predicate = Predicate::new("true").unwrap(); + let headers_expr = Expression::new("[['Content-Type', 'application/json']]").unwrap(); let task = Box::new(ModifyHeadersTask::new( - HeaderOperation::Set(new_headers), + "0".to_string(), + predicate, + headers_expr, HeadersType::HttpRequestHeaders, + false, )); let outcome = task.apply(&mut ctx); @@ -241,7 +224,7 @@ mod tests { } #[test] - fn remove_headers_task() { + fn empty_headers_expr_returns_done() { let existing_headers = vec![ ("API-Key-To-Remove".to_string(), "API-Value".to_string()), ("X-Origin".to_string(), "Kuadrant".to_string()), @@ -251,23 +234,18 @@ mod tests { let backend = Arc::new(mock_host); let mut ctx = ReqRespCtx::new(backend); - let keys_to_remove = vec!["API-Key-To-Remove".to_string()]; + let predicate = Predicate::new("true").unwrap(); + let headers_expr = Expression::new("[]").unwrap(); let task = Box::new(ModifyHeadersTask::new( - HeaderOperation::Remove(keys_to_remove), + "0".to_string(), + predicate, + headers_expr, HeadersType::HttpResponseHeaders, + false, )); let outcome = task.apply(&mut ctx); assert!(matches!(outcome, TaskOutcome::Done)); - - let result: Result>, _> = - ctx.get_attribute_ref(&Path::from(&HeadersType::HttpResponseHeaders)); - - assert!(matches!(result, Ok(AttributeState::Available(Some(_))))); - if let Ok(AttributeState::Available(Some(headers))) = result { - assert_eq!(headers.len(), 1); - assert_eq!(headers.get("X-Origin"), Some("Kuadrant")); - } } } diff --git a/src/kuadrant/pipeline/tasks/mod.rs b/src/kuadrant/pipeline/tasks/mod.rs index 74e60f90..4a734712 100644 --- a/src/kuadrant/pipeline/tasks/mod.rs +++ b/src/kuadrant/pipeline/tasks/mod.rs @@ -1,5 +1,6 @@ mod dynamic; mod export_traces; +mod fail; mod failure_mode; mod headers; mod send_reply; @@ -9,8 +10,9 @@ mod tracing_decorator; pub use dynamic::DynamicTask; pub use export_traces::ExportTracesTask; +pub use fail::FailTask; pub use failure_mode::FailureModeTask; -pub use headers::{HeaderOperation, HeadersType, ModifyHeadersTask}; +pub use headers::{HeadersType, ModifyHeadersTask}; pub use send_reply::SendReplyTask; pub use store::StoreTask; pub use token_usage::TokenUsageTask; @@ -25,9 +27,7 @@ pub type ResponseProcessor = dyn FnOnce(&mut ReqRespCtx) -> TaskOutcome; pub trait Task { fn apply(self: Box, ctx: &mut ReqRespCtx) -> TaskOutcome; - fn id(&self) -> Option { - None - } + fn id(&self) -> &str; fn dependencies(&self) -> &[String] { &[] @@ -36,6 +36,10 @@ pub trait Task { fn is_guard(&self) -> bool { false } + + fn cel_types(&self) -> Vec { + vec![] + } } pub struct PendingTask { @@ -58,8 +62,8 @@ impl Task for PendingTask { fn apply(self: Box, ctx: &mut ReqRespCtx) -> TaskOutcome { (self.process_response)(ctx) } - fn id(&self) -> Option { - Some(self.task_id.clone()) + fn id(&self) -> &str { + &self.task_id } fn is_guard(&self) -> bool { self.is_guard @@ -117,6 +121,10 @@ pub fn noop_response_processor( pub struct NoopTerminalTask; impl Task for NoopTerminalTask { + fn id(&self) -> &str { + "noop" + } + fn apply(self: Box, _ctx: &mut ReqRespCtx) -> TaskOutcome { TaskOutcome::Done } diff --git a/src/kuadrant/pipeline/tasks/send_reply.rs b/src/kuadrant/pipeline/tasks/send_reply.rs index 547f878f..bc6f0558 100644 --- a/src/kuadrant/pipeline/tasks/send_reply.rs +++ b/src/kuadrant/pipeline/tasks/send_reply.rs @@ -1,7 +1,5 @@ -use std::sync::Arc; - use cel::common::types::{CelString, CelUInt}; -use cel::{Env, Value}; +use cel::Value; use tracing::error; use crate::data::attribute::AttributeState; @@ -13,99 +11,69 @@ use crate::metrics::METRICS; use crate::services::{cel_value_to_header_pairs, deny_response_struct_def}; pub struct SendReplyTask { - predicate: Option, + task_id: String, + predicate: Predicate, deny_with: Expression, terminal: bool, } impl SendReplyTask { - pub fn new(status_code: u32, headers: Vec<(String, String)>, body: Option) -> Self { - let headers = headers - .into_iter() - .map(|(h, v)| format!("['''{h}''', '''{v}''']")) - .collect::>() - .join(", "); - let body_field = body.map(|b| format!("body: '''{b}'''")).unwrap_or_default(); - let expr = format!( - "DenyResponse {{ status: {status_code}u, headers: [{headers}], {body_field} }}" - ); - #[allow(clippy::expect_used)] - let deny_with = Expression::new(&expr).expect("Needs to be valid CEL!"); - Self { - predicate: None, - deny_with, - terminal: false, - } - } - - pub fn new_deferred(predicate: Predicate, deny_with: Expression, terminal: bool) -> Self { + pub fn new( + task_id: String, + predicate: Predicate, + deny_with: Expression, + terminal: bool, + ) -> Self { Self { - predicate: Some(predicate), + task_id, + predicate, deny_with, terminal, } } pub fn default() -> Self { - Self::new( - 500, - Vec::new(), - Some("Internal Server Error.\n".to_string()), + #[allow(clippy::expect_used)] + let deny_with = Expression::new( + r#"DenyResponse { status: 500u, headers: [], body: 'Internal Server Error.\n' }"#, ) + .expect("Needs to be valid CEL!"); + #[allow(clippy::expect_used)] + let predicate = Predicate::new("true").expect("Needs to be valid!"); + Self { + task_id: "default".to_string(), + predicate, + deny_with, + terminal: false, + } } } -impl TryFrom for SendReplyTask { - type Error = String; - - fn try_from(value: Value) -> Result { - let Value::Struct(deny_response) = value else { - return Err(format!("expected DenyResponse struct, got: {value:?}")); - }; +impl Task for SendReplyTask { + fn id(&self) -> &str { + &self.task_id + } - let status = deny_response - .field_value("status") - .and_then(|v| v.downcast_ref::()) - .map(|v| *v.inner() as u32) - .ok_or("DenyResponse missing or invalid 'status' field")?; - - let body = deny_response - .field_value("body") - .and_then(|v| v.downcast_ref::()) - .map(|v| v.inner().to_string()) - .filter(|s| !s.is_empty()); - - let headers = deny_response - .field_value("headers") - .and_then(|v| Value::try_from(v).ok()) - .map(|v| cel_value_to_header_pairs(&v)) - .unwrap_or_default(); - - Ok(Self::new(status, headers, body)) + fn cel_types(&self) -> Vec { + vec![deny_response_struct_def()] } -} -impl Task for SendReplyTask { #[tracing::instrument(name = "send_reply", skip(self, ctx))] fn apply(self: Box, ctx: &mut ReqRespCtx) -> TaskOutcome { - if let Some(ref predicate) = self.predicate { - match predicate.test(ctx) { - Ok(AttributeState::Available(true)) => {} - Ok(AttributeState::Available(false)) => return TaskOutcome::Done, - Ok(AttributeState::Pending) => { - return TaskOutcome::Requeued(vec![self]); - } - Err(e) => { - error!("Failed to evaluate predicate: {e:?}"); - return TaskOutcome::Failed; - } + let mut cel_ctx = ctx.cel.new_ctx(&*self); + match self.predicate.test(ctx, &mut cel_ctx) { + Ok(AttributeState::Available(true)) => {} + Ok(AttributeState::Available(false)) => return TaskOutcome::Done, + Ok(AttributeState::Pending) => { + return TaskOutcome::Requeued(vec![self]); + } + Err(e) => { + error!("Failed to evaluate predicate: {e:?}"); + return TaskOutcome::Failed; } } let (status_code, headers, body) = { - let mut env = Env::stdlib(); - env.add_struct(deny_response_struct_def()); - let mut cel_ctx = cel::Context::with_env(Arc::new(env)); match self.deny_with.eval(ctx, &mut cel_ctx) { Ok(AttributeState::Pending) => { error!("Unexpected pending state in deny expression"); @@ -185,16 +153,15 @@ mod tests { let mock_host = MockWasmHost::new(); let mut ctx = ReqRespCtx::new(Arc::new(mock_host)); + let predicate = Predicate::new("true").unwrap(); + let deny_with = Expression::new( + "DenyResponse { status: 403u, headers: [['content-type', 'text/plain'], ['WWW-Authenticate', 'APIKEY realm=\"api-key-users\"']], body: 'Access Denied' }" + ).unwrap(); let task = Box::new(SendReplyTask::new( - 403, - vec![ - ("content-type".to_string(), "text/plain".to_string()), - ( - "WWW-Authenticate".to_string(), - "APIKEY realm=\"api-key-users\"".to_string(), - ), - ], - Some("Access Denied".to_string()), + "0".to_string(), + predicate, + deny_with, + false, )); let outcome = task.apply(&mut ctx); @@ -206,7 +173,15 @@ mod tests { let mock_host = MockWasmHost::new(); let mut ctx = ReqRespCtx::new(Arc::new(mock_host)); - let task = Box::new(SendReplyTask::new(429, vec![], None)); + let predicate = Predicate::new("true").unwrap(); + let deny_with = + Expression::new("DenyResponse { status: 429u, headers: [], body: '' }").unwrap(); + let task = Box::new(SendReplyTask::new( + "0".to_string(), + predicate, + deny_with, + false, + )); let outcome = task.apply(&mut ctx); assert!(matches!(outcome, TaskOutcome::Done)); diff --git a/src/kuadrant/pipeline/tasks/store.rs b/src/kuadrant/pipeline/tasks/store.rs index 5a46014c..9547eb9f 100644 --- a/src/kuadrant/pipeline/tasks/store.rs +++ b/src/kuadrant/pipeline/tasks/store.rs @@ -16,6 +16,7 @@ enum BodySource { } pub struct StoreTask { + task_id: String, predicate: Option, expression: Expression, path: String, @@ -26,6 +27,7 @@ pub struct StoreTask { impl StoreTask { pub fn new( + task_id: String, predicate: Predicate, expression: Expression, path: String, @@ -34,6 +36,7 @@ impl StoreTask { ) -> Result { let body_parser = create_body_parser(&predicate, &expression)?; Ok(Self { + task_id, predicate: Some(predicate), expression, path, @@ -77,6 +80,10 @@ fn create_body_parser( } impl Task for StoreTask { + fn id(&self) -> &str { + &self.task_id + } + #[tracing::instrument(name = "store", skip(self, ctx), level = tracing::Level::TRACE)] fn apply(mut self: Box, ctx: &mut ReqRespCtx) -> TaskOutcome { if let Some((ref source, ref mut parser)) = self.body_parser { @@ -135,8 +142,10 @@ impl Task for StoreTask { parser.populate(body_ctx_mut); } + let mut cel_ctx = ctx.cel.new_ctx(&*self); + if let Some(ref predicate) = self.predicate { - match predicate.test(ctx) { + match predicate.test(ctx, &mut cel_ctx) { Ok(AttributeState::Available(true)) => {} Ok(AttributeState::Available(false)) => return TaskOutcome::Done, Ok(AttributeState::Pending) => { @@ -148,8 +157,6 @@ impl Task for StoreTask { } } } - - let mut cel_ctx = cel::Context::default(); let value = match self.expression.eval(ctx, &mut cel_ctx) { Ok(AttributeState::Pending) => { return TaskOutcome::Requeued(vec![self]); @@ -202,6 +209,7 @@ mod tests { fn make_store_task(predicate: &str, expression: &str, path: &str) -> Box { Box::new( StoreTask::new( + "0".to_string(), Predicate::new(predicate).unwrap(), Expression::new(expression).unwrap(), path.to_string(), @@ -333,6 +341,7 @@ mod tests { fn invalid_json_pointer_fails_task_creation() { // Invalid JSON pointer format - acutejson expects RFC 6901 format let result = StoreTask::new( + "0".to_string(), Predicate::new("true").unwrap(), Expression::new("requestBodyJSON('not-a-valid-pointer')").unwrap(), "some.path".to_string(), diff --git a/src/kuadrant/pipeline/tasks/token_usage.rs b/src/kuadrant/pipeline/tasks/token_usage.rs index 15f5d07a..d2b9b501 100644 --- a/src/kuadrant/pipeline/tasks/token_usage.rs +++ b/src/kuadrant/pipeline/tasks/token_usage.rs @@ -37,6 +37,10 @@ impl From> for TokenUsageTask { } impl Task for TokenUsageTask { + fn id(&self) -> &str { + "token_usage" + } + #[tracing::instrument(name = "token_usage", skip(self, ctx))] fn apply(self: Box, ctx: &mut ReqRespCtx) -> TaskOutcome { let mut task: TokenUsageTask = self.into(); diff --git a/src/kuadrant/pipeline/tasks/tracing_decorator.rs b/src/kuadrant/pipeline/tasks/tracing_decorator.rs index a20e7061..3bef7b06 100644 --- a/src/kuadrant/pipeline/tasks/tracing_decorator.rs +++ b/src/kuadrant/pipeline/tasks/tracing_decorator.rs @@ -75,7 +75,7 @@ impl Task for TracingDecoratorTask { } } - fn id(&self) -> Option { + fn id(&self) -> &str { self.task.id() } diff --git a/src/services/dynamic.rs b/src/services/dynamic.rs index 4a9c4779..4d85bcae 100644 --- a/src/services/dynamic.rs +++ b/src/services/dynamic.rs @@ -1,9 +1,7 @@ -use std::cell::OnceCell; use std::rc::Rc; -use std::sync::Arc; use std::time::Duration; -use cel::{Context, Env, Value}; +use cel::Value; use prost::Message; use prost_reflect::DynamicMessage; use tracing::debug; @@ -15,7 +13,7 @@ use crate::kuadrant::ReqRespCtx; pub mod converters; -use converters::{deny_response_struct_def, DescriptorConverter, MessageConverter}; +use converters::MessageConverter; pub struct DynamicService { upstream_name: String, @@ -24,7 +22,6 @@ pub struct DynamicService { timeout: Duration, failure_mode: FailureMode, descriptor_manager: Rc, - cel_env: OnceCell>, } impl DynamicService { @@ -45,7 +42,6 @@ impl DynamicService { timeout, failure_mode, descriptor_manager, - cel_env: Default::default(), } } @@ -53,27 +49,6 @@ impl DynamicService { self.failure_mode } - pub fn cel_env(&self) -> Result, ServiceError> { - match self.cel_env.get() { - Some(env) => Ok(Arc::clone(env)), - None => { - let input_descriptor = self.input_descriptor()?; - let output_descriptor = self.output_descriptor()?; - let mut env = Env::stdlib(); - DescriptorConverter::register_message_types(&mut env, &input_descriptor).map_err( - |e| ServiceError::Dispatch(format!("Failed to register message types: {}", e)), - )?; - DescriptorConverter::register_message_types(&mut env, &output_descriptor).map_err( - |e| ServiceError::Dispatch(format!("Failed to register message types: {}", e)), - )?; - env.add_struct(deny_response_struct_def()); - let env_arc = Arc::new(env); - let _ = self.cel_env.set(Arc::clone(&env_arc)); - Ok(env_arc) - } - } - } - pub fn dispatch_value( &self, ctx: &mut ReqRespCtx, @@ -128,29 +103,22 @@ impl DynamicService { Ok(method) } - fn input_descriptor(&self) -> Result { + pub fn input_descriptor(&self) -> Result { Ok(self.method_descriptor()?.input()) } - fn output_descriptor(&self) -> Result { + pub fn output_descriptor(&self) -> Result { Ok(self.method_descriptor()?.output()) } - pub fn response_cel_context( + pub fn get_response_cel_value( &self, ctx: &mut ReqRespCtx, response_size: usize, - name: &str, - ) -> Result, ServiceError> { + ) -> Result { let response = self.get_response(ctx, response_size)?; - let cel_value = MessageConverter::dynamic_message_to_cel(&response).map_err(|e| { - ServiceError::Decode(format!("Failed to convert message to CEL: {}", e)) - })?; - let env = self.cel_env()?; - - let mut cel_ctx = Context::with_env(env); - cel_ctx.add_variable_from_value(name, cel_value); - Ok(cel_ctx) + MessageConverter::dynamic_message_to_cel(&response) + .map_err(|e| ServiceError::Decode(format!("Failed to convert message to CEL: {}", e))) } } @@ -190,9 +158,14 @@ impl Service for DynamicService { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; - use crate::filter::{DescriptorKey, DescriptorManager}; - use cel::Program; + use crate::{ + filter::{DescriptorKey, DescriptorManager}, + services::DescriptorConverter, + }; + use cel::{Context, Env, Program}; use prost_reflect::DescriptorPool; use prost_types::{ field_descriptor_proto, DescriptorProto, FieldDescriptorProto, FileDescriptorProto, @@ -279,8 +252,11 @@ mod tests { let input_desc = method_desc.input(); let mut env = Env::stdlib(); - DescriptorConverter::register_message_types(&mut env, &input_desc) - .expect("Failed to register types"); + for def in DescriptorConverter::collect_struct_defs(&input_desc) + .expect("Failed to collect struct defs") + { + env.add_struct(def); + } let cel_ctx = Context::with_env(Arc::new(env)); let program = Program::compile(cel_expression).expect("Failed to compile"); diff --git a/src/services/dynamic/converters.rs b/src/services/dynamic/converters.rs index dc63a3e8..be0819ed 100644 --- a/src/services/dynamic/converters.rs +++ b/src/services/dynamic/converters.rs @@ -1,6 +1,6 @@ use cel::common::types::*; use cel::objects::Key; -use cel::{Env, StructDef, Value}; +use cel::{StructDef, Value}; use prost_reflect::Cardinality; use prost_reflect::{ DynamicMessage, FieldDescriptor, Kind as ProtoKind, MapKey, MessageDescriptor, ReflectMessage, @@ -65,10 +65,10 @@ pub struct DescriptorConverter; impl DescriptorConverter { /// Register a message descriptor and all its nested message types with the CEL environment /// This must be called before evaluating CEL expressions that construct these messages - pub fn register_message_types( - env: &mut Env, + pub fn collect_struct_defs( descriptor: &MessageDescriptor, - ) -> Result<(), ConversionError> { + ) -> Result, ConversionError> { + let mut defs = Vec::new(); let mut to_register = vec![descriptor.clone()]; let mut visited = HashSet::new(); @@ -88,10 +88,10 @@ impl DescriptorConverter { } let struct_def = Self::to_struct_def(&desc)?; - env.add_struct(struct_def); + defs.push(struct_def); } - Ok(()) + Ok(defs) } /// Convert a protobuf MessageDescriptor to a CEL StructDef @@ -880,7 +880,7 @@ impl MessageConverter { mod tests { use super::*; use cel::common::value::Val; - use cel::{Context, Program}; + use cel::{Context, Env, Program}; use prost::Message; use prost_types::{field_descriptor_proto, DescriptorProto, FieldDescriptorProto}; use prost_types::{FileDescriptorProto, FileDescriptorSet, OneofDescriptorProto}; @@ -1041,8 +1041,11 @@ mod tests { // Register all message types let mut env = cel::Env::stdlib(); - DescriptorConverter::register_message_types(&mut env, &outer_descriptor) - .expect("Failed to register types"); + for def in DescriptorConverter::collect_struct_defs(&outer_descriptor) + .expect("Failed to collect struct defs") + { + env.add_struct(def); + } let ctx = Context::with_env(Arc::new(env)); @@ -1444,8 +1447,11 @@ mod tests { .expect("MapMessage not found"); let mut env = cel::Env::stdlib(); - DescriptorConverter::register_message_types(&mut env, &descriptor) - .expect("Failed to register types"); + for def in DescriptorConverter::collect_struct_defs(&descriptor) + .expect("Failed to collect struct defs") + { + env.add_struct(def); + } let ctx = Context::with_env(Arc::new(env)); @@ -1886,10 +1892,16 @@ mod tests { .expect("Failed to get Request descriptor"); let mut env = Env::default(); - DescriptorConverter::register_message_types(&mut env, ×tamp_desc) - .expect("Failed to register Timestamp"); - DescriptorConverter::register_message_types(&mut env, &request_desc) - .expect("Failed to register Request"); + for def in DescriptorConverter::collect_struct_defs(×tamp_desc) + .expect("Failed to collect struct defs") + { + env.add_struct(def); + } + for def in DescriptorConverter::collect_struct_defs(&request_desc) + .expect("Failed to collect struct defs") + { + env.add_struct(def); + } // Create a CEL timestamp: 2024-05-16 12:00:00 UTC (1715875200 seconds, 123456789 nanos) let dt: DateTime = DateTime::from_timestamp(1715875200, 123456789) diff --git a/src/services/mod.rs b/src/services/mod.rs index 0c398a29..16c25ba7 100644 --- a/src/services/mod.rs +++ b/src/services/mod.rs @@ -7,7 +7,7 @@ mod dynamic; mod tracing; pub use dynamic::converters::{ - cel_value_to_header_pairs, deny_response_struct_def, MessageConverter, + cel_value_to_header_pairs, deny_response_struct_def, DescriptorConverter, MessageConverter, }; pub use dynamic::DynamicService; pub use tracing::TracingService; diff --git a/tests/response_body.rs b/tests/response_body.rs index 61aff6f3..47bb6cf1 100644 --- a/tests/response_body.rs +++ b/tests/response_body.rs @@ -510,7 +510,7 @@ fn it_handles_errors_on_response_body() { Some(LogLevel::Warn), Some("Missing json property: /usage/total_tokens"), ) - .expect_log(Some(LogLevel::Error), Some("Task failed: Some(\"0\")")) + .expect_log(Some(LogLevel::Error), Some("Task failed: \"0\"")) // on response headers/body, expected action is Continue .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap();