From 41131fb374dd1b2199ff1adfa5e952b833297ddc Mon Sep 17 00:00:00 2001 From: Borislav Borisov Date: Wed, 8 Apr 2026 12:19:32 +0100 Subject: [PATCH] Add property-based tests via proptest Introduce property tests that fuzz batch boundaries, request ordering, permit lifecycle, cancellation patterns, error propagation, concurrent clones, and error display formatting. --- Cargo.toml | 1 + proptest-regressions/property.txt | 6 + tests/main.rs | 85 +------ tests/property.rs | 355 ++++++++++++++++++++++++++++++ tests/support.rs | 122 +++++++++- 5 files changed, 484 insertions(+), 85 deletions(-) create mode 100644 proptest-regressions/property.txt create mode 100644 tests/property.rs diff --git a/Cargo.toml b/Cargo.toml index 54742cc..6348f08 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,5 +42,6 @@ tokio = { version = "1", features = ["macros", "sync", "test-util", "rt-multi-th tokio-test = "0.4" tower = { version = "0.5", features = ["full"] } tower-test = "0.4" +proptest = "1" rusqlite = { version = "0.39", features = ["bundled", "array"] } diff --git a/proptest-regressions/property.txt b/proptest-regressions/property.txt new file mode 100644 index 0000000..341ba3f --- /dev/null +++ b/proptest-regressions/property.txt @@ -0,0 +1,6 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. diff --git a/tests/main.rs b/tests/main.rs index 84da79f..bdd41a6 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -1,10 +1,9 @@ use std::{ - fmt::Debug, future::Future, pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, Mutex, + Arc, }, task::{Context, Poll}, time::Duration, @@ -22,84 +21,7 @@ use tower_test::{ use tower_batch::{error, Batch, BatchControl, BatchLayer, BoxError}; mod support; - -#[derive(Clone)] -struct Aggregator { - items: Arc>>>, - current: Arc, -} - -impl Aggregator { - pub fn new() -> Self { - Self { - items: Arc::new(Mutex::new(Vec::new())), - current: Arc::new(AtomicUsize::new(0)), - } - } - - fn batch_has_size(&self, index: usize, size: usize) -> bool { - if index == self.current.load(Ordering::Acquire) { - return false; - } - let items = &self.items.lock().unwrap(); - items.get(index).is_some_and(|v| v.len() == size) - } - - fn batch_items(&self, index: usize) -> Option> - where - T: Clone, - { - if index == self.current.load(Ordering::Acquire) { - return None; - } - let items = self.items.lock().unwrap(); - items.get(index).cloned() - } -} - -impl Service> for Aggregator -where - T: Debug, -{ - type Response = (); - type Error = BoxError; - type Future = Pin> + Send + Sync + 'static>>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: BatchControl) -> Self::Future { - match req { - BatchControl::Item(item) => { - let mut items = self.items.lock().unwrap(); - match items.get_mut(self.current.load(Ordering::Acquire)) { - None => { - items.push(vec![item]); - } - Some(v) => { - v.push(item); - } - } - } - BatchControl::Flush => { - self.current.fetch_add(1, Ordering::SeqCst); - return Box::pin(async { - tracing::info!("sleeping ..."); - async { - // Simulate some activity to catch any flushing issues - tokio::time::sleep(Duration::from_nanos(5)).await; - } - .await; - tracing::info!("awaking ..."); - Ok(()) - }); - } - } - - Box::pin(futures::future::ready(Ok(()))) - } -} +use support::Aggregator; #[tokio::test] async fn batch_flushes_on_max_size() -> Result<(), BoxError> { @@ -236,8 +158,7 @@ async fn concurrent_clones_send_requests() -> Result<(), BoxError> { tokio::time::sleep(Duration::from_millis(200)).await; // Verify all 9 items were actually delivered to the aggregator. - let items = aggregator.items.lock().unwrap(); - let delivered: usize = items.iter().map(Vec::len).sum(); + let delivered = aggregator.all_items_flat().len(); assert_eq!(delivered, 9, "all 9 items should reach the aggregator"); Ok(()) diff --git a/tests/property.rs b/tests/property.rs new file mode 100644 index 0000000..22b13e3 --- /dev/null +++ b/tests/property.rs @@ -0,0 +1,355 @@ +use std::{ + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + task::{Context, Poll}, + time::Duration, +}; + +use futures::{stream::FuturesUnordered, StreamExt}; +use proptest::prelude::*; +use tower::{Service, ServiceExt}; +use tower_batch::{Batch, BatchControl, BoxError}; + +mod support; +use support::Aggregator; + +const MAX_TIME: Duration = Duration::from_millis(100); + +fn rt() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() +} + +fn rt_multi_thread() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() +} + +// ===== FailingAggregator ===== + +#[derive(Clone, Debug)] +struct FailingAggregator { + flush_count: Arc, + fail_at_flush: usize, +} + +impl FailingAggregator { + fn new(fail_at_flush: usize) -> Self { + Self { + flush_count: Arc::new(AtomicUsize::new(0)), + fail_at_flush, + } + } +} + +impl Service> for FailingAggregator { + type Response = (); + type Error = BoxError; + type Future = Pin> + Send + Sync>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: BatchControl) -> Self::Future { + match req { + BatchControl::Item(_) => Box::pin(futures::future::ready(Ok(()))), + BatchControl::Flush => { + let n = self.flush_count.fetch_add(1, Ordering::SeqCst); + if n == self.fail_at_flush { + Box::pin(futures::future::ready(Err("injected flush error".into()))) + } else { + Box::pin(futures::future::ready(Ok(()))) + } + } + } + } +} + +// ===== Property tests ===== + +proptest! { + #![proptest_config(ProptestConfig::with_cases(256))] + + #[test] + fn batch_flush_boundary( + batch_size in 1usize..=100, + request_count in 1usize..=500, + ) { + let rt = rt(); + let batches = rt.block_on(async { + let aggregator: Aggregator = Aggregator::new(); + let mut batch = Batch::new(aggregator.clone(), batch_size, MAX_TIME); + + let mut futs = FuturesUnordered::new(); + for i in 0..request_count { + batch.ready().await.unwrap(); + #[allow(clippy::cast_possible_truncation)] + futs.push(batch.call(i as u32)); + } + while let Some(result) = futs.next().await { + result.unwrap(); + } + + aggregator.completed_batches() + }); + + let expected_flushes = request_count.div_ceil(batch_size); + prop_assert_eq!(batches.len(), expected_flushes); + for (i, b) in batches.iter().enumerate() { + if i < batches.len() - 1 { + prop_assert_eq!(b.len(), batch_size); + } else { + let rem = request_count % batch_size; + let expected_last = if rem == 0 { batch_size } else { rem }; + prop_assert_eq!(b.len(), expected_last); + } + } + } + + #[test] + fn request_ordering( + batch_size in 1usize..=50, + request_count in 1usize..=200, + ) { + let rt = rt(); + let flat_items = rt.block_on(async { + let aggregator: Aggregator = Aggregator::new(); + let mut batch = Batch::new(aggregator.clone(), batch_size, MAX_TIME); + + let mut futs = FuturesUnordered::new(); + #[allow(clippy::cast_possible_truncation)] + for i in 0..request_count as u32 { + batch.ready().await.unwrap(); + futs.push(batch.call(i)); + } + while let Some(r) = futs.next().await { + r.unwrap(); + } + + aggregator.all_items_flat() + }); + + #[allow(clippy::cast_possible_truncation)] + let expected: Vec = (0..request_count as u32).collect(); + prop_assert_eq!(flat_items, expected); + } + + #[test] + fn permit_accounting_no_leak( + batch_size in 2usize..=10, + ops in prop::collection::vec(any::(), 1..=30), + ) { + let rt = rt(); + rt.block_on(async { + let aggregator: Aggregator = Aggregator::new(); + let mut batch = Batch::new(aggregator.clone(), batch_size, MAX_TIME); + let mut call_futs = FuturesUnordered::new(); + + for (i, do_call) in ops.iter().enumerate() { + if *do_call { + batch.ready().await.unwrap(); + #[allow(clippy::cast_possible_truncation)] + call_futs.push(batch.call(i as u32)); + } else { + let mut clone = batch.clone(); + clone.ready().await.unwrap(); + drop(clone); + } + } + + // Drain all call futures + while let Some(r) = call_futs.next().await { + r.unwrap(); + } + + // Verify all permits are available by acquiring batch_size of them via clones. + // If any leaked, this will deadlock and the timeout will catch it. + let mut clones: Vec<_> = (0..batch_size).map(|_| batch.clone()).collect(); + for c in &mut clones { + tokio::time::timeout(Duration::from_secs(5), c.ready()) + .await + .expect("timed out waiting for permit — likely a permit leak") + .unwrap(); + } + drop(clones); + }); + } + + #[test] + fn cancellation_patterns( + (batch_size, cancel_mask) in + (1usize..=20, 1usize..=50).prop_flat_map(|(bs, rc)| { + ( + Just(bs), + prop::collection::vec(any::(), rc..=rc), + ) + }) + ) { + let rt = rt(); + let (results, batches) = rt.block_on(async { + let aggregator: Aggregator = Aggregator::new(); + let mut batch = Batch::new(aggregator.clone(), batch_size, MAX_TIME); + + let mut kept_futs = FuturesUnordered::new(); + for (i, cancelled) in cancel_mask.iter().enumerate() { + batch.ready().await.unwrap(); + #[allow(clippy::cast_possible_truncation)] + let fut = batch.call(i as u32); + if *cancelled { + drop(fut); + } else { + kept_futs.push(fut); + } + } + + let mut results = Vec::new(); + while let Some(r) = kept_futs.next().await { + results.push(r.is_ok()); + } + + drop(batch); + tokio::time::sleep(Duration::from_millis(150)).await; + (results, aggregator.completed_batches()) + }); + + // All non-cancelled requests got Ok responses + for (i, ok) in results.iter().enumerate() { + prop_assert!(*ok, "non-cancelled request {} should have succeeded", i); + } + // No batch exceeds batch_size + for b in &batches { + prop_assert!(b.len() <= batch_size); + } + } + + #[test] + fn error_propagation( + (batch_size, request_count, fail_at_flush) in + (1usize..=20, 2usize..=50).prop_flat_map(|(bs, rc)| { + let num_flushes = rc.div_ceil(bs); + (Just(bs), Just(rc), 0..num_flushes) + }) + ) { + let rt = rt(); + let results = rt.block_on(async { + let svc = FailingAggregator::new(fail_at_flush); + let mut batch = Batch::new(svc, batch_size, MAX_TIME); + + let mut futs = FuturesUnordered::new(); + #[allow(clippy::cast_possible_truncation)] + for i in 0..request_count as u32 { + match batch.ready().await { + Ok(_) => futs.push(batch.call(i)), + Err(_) => break, + } + } + + let mut results = Vec::new(); + while let Some(r) = futs.next().await { + results.push(r.is_ok()); + } + results + }); + + let has_errors = results.iter().any(|ok| !ok); + prop_assert!(has_errors, "should have at least one error from the injected flush failure"); + } +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(128))] + + #[test] + fn concurrent_clone_stress( + num_clones in 2usize..=10, + requests_per_clone in 1usize..=30, + batch_size in 1usize..=20, + ) { + let rt = rt_multi_thread(); + let (total_delivered, total_succeeded) = rt.block_on(async { + let aggregator: Aggregator = Aggregator::new(); + let batch = Batch::new(aggregator.clone(), batch_size, MAX_TIME); + + let mut handles = Vec::new(); + for clone_id in 0..num_clones { + let mut svc = batch.clone(); + handles.push(tokio::spawn(async move { + let mut count = 0usize; + for i in 0..requests_per_clone { + #[allow(clippy::cast_possible_truncation)] + if svc.ready().await.is_ok() + && svc.call((clone_id * 1000 + i) as u32).await.is_ok() + { + count += 1; + } + } + count + })); + } + + let mut total_succeeded = 0usize; + for h in handles { + total_succeeded += h.await.unwrap(); + } + + // Drop the original handle and wait for partial-batch flush + drop(batch); + tokio::time::sleep(Duration::from_millis(150)).await; + + let total_delivered: usize = + aggregator.completed_batches().iter().map(Vec::len).sum(); + (total_delivered, total_succeeded) + }); + + prop_assert_eq!(total_delivered, total_succeeded); + } + + #[test] + fn service_error_display_contains_message(msg in "[a-zA-Z0-9 ]{1,50}") { + let rt = rt(); + let err_display = rt.block_on(async { + use tower_test::mock; + + let (service, mut handle) = mock::pair::, ()>(); + let mut batch = Batch::new(service, 1, Duration::from_secs(1)); + + // Allow the first request through so the worker picks it up + handle.allow(1); + + batch.ready().await.unwrap(); + let resp_fut = batch.call(()); + + // The mock receives BatchControl::Item(()), respond to it + let (request, send_response) = handle.next_request().await.unwrap(); + assert_eq!(request, BatchControl::Item(())); + send_response.send_response(()); + + // Now the worker will try to flush; the mock receives BatchControl::Flush + handle.allow(1); + let (request, send_response) = handle.next_request().await.unwrap(); + assert_eq!(request, BatchControl::Flush); + send_response.send_error(msg.clone()); + + let result = resp_fut.await; + result.unwrap_err().to_string() + }); + + prop_assert!( + err_display.contains("batch service failed:"), + "display was: {}", err_display + ); + prop_assert!( + err_display.contains(&msg), + "display should contain inner message '{}', was: {}", msg, err_display + ); + } +} diff --git a/tests/support.rs b/tests/support.rs index e0dca37..133fa1f 100644 --- a/tests/support.rs +++ b/tests/support.rs @@ -1,13 +1,20 @@ -#![allow(dead_code)] +#![allow(dead_code, clippy::missing_panics_doc)] //! The code below is borrowed from Tower's test suite. use std::{ - fmt, future, + fmt, + future::{self, Future}, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, + }, task::{Context, Poll}, + time::Duration, }; use tower::Service; -use tower_batch::BatchControl; +use tower_batch::{BatchControl, BoxError}; #[must_use] pub fn trace_init() -> tracing::subscriber::DefaultGuard { @@ -81,3 +88,112 @@ impl Service> for AssertSpanSvc { future::ready(self.check("call")) } } + +// ===== Aggregator ===== + +#[derive(Clone)] +pub struct Aggregator { + items: Arc>>>, + current: Arc, +} + +impl Default for Aggregator { + fn default() -> Self { + Self::new() + } +} + +impl Aggregator { + #[must_use] + pub fn new() -> Self { + Self { + items: Arc::new(Mutex::new(Vec::new())), + current: Arc::new(AtomicUsize::new(0)), + } + } + + #[must_use] + pub fn batch_has_size(&self, index: usize, size: usize) -> bool { + if index == self.current.load(Ordering::Acquire) { + return false; + } + let items = &self.items.lock().unwrap(); + items.get(index).is_some_and(|v| v.len() == size) + } + + #[must_use] + pub fn batch_items(&self, index: usize) -> Option> + where + T: Clone, + { + if index == self.current.load(Ordering::Acquire) { + return None; + } + let items = self.items.lock().unwrap(); + items.get(index).cloned() + } + + #[must_use] + pub fn all_items_flat(&self) -> Vec + where + T: Clone, + { + let items = self.items.lock().unwrap(); + let current = self.current.load(Ordering::Acquire); + items.iter().take(current).flatten().cloned().collect() + } + + #[must_use] + pub fn completed_batches(&self) -> Vec> + where + T: Clone, + { + let items = self.items.lock().unwrap(); + let current = self.current.load(Ordering::Acquire); + items.iter().take(current).cloned().collect() + } +} + +impl Service> for Aggregator +where + T: fmt::Debug, +{ + type Response = (); + type Error = BoxError; + type Future = Pin> + Send + Sync + 'static>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: BatchControl) -> Self::Future { + match req { + BatchControl::Item(item) => { + let mut items = self.items.lock().unwrap(); + match items.get_mut(self.current.load(Ordering::Acquire)) { + None => { + items.push(vec![item]); + } + Some(v) => { + v.push(item); + } + } + } + BatchControl::Flush => { + self.current.fetch_add(1, Ordering::SeqCst); + return Box::pin(async { + tracing::info!("sleeping ..."); + async { + // Simulate some activity to catch any flushing issues + tokio::time::sleep(Duration::from_nanos(5)).await; + } + .await; + tracing::info!("awaking ..."); + Ok(()) + }); + } + } + + Box::pin(futures::future::ready(Ok(()))) + } +}