From f9d41655317829276a510504cd7e11c1afc84067 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Mon, 30 Mar 2026 16:57:14 -0400 Subject: [PATCH 1/3] feat: add TraceProvider support as alternative to task in evaluations Introduce a `provider` parameter to `run_evaluations` and `run_evaluations_async` that accepts a `TraceProvider` to fetch evaluation data by session_id instead of requiring a task callable. Add `_resolve_task` static method to validate mutual exclusivity between `task` and `provider` and generate the appropriate callable. --- src/strands_evals/experiment.py | 46 ++++++++++- tests/strands_evals/test_experiment.py | 105 +++++++++++++++++++++++++ 2 files changed, 147 insertions(+), 4 deletions(-) diff --git a/src/strands_evals/experiment.py b/src/strands_evals/experiment.py index 8afeb37e..21e5ecc7 100644 --- a/src/strands_evals/experiment.py +++ b/src/strands_evals/experiment.py @@ -23,6 +23,7 @@ from .evaluators.interactions_evaluator import InteractionsEvaluator from .evaluators.output_evaluator import OutputEvaluator from .evaluators.trajectory_evaluator import TrajectoryEvaluator +from .providers.trace_provider import TraceProvider from .telemetry import get_tracer, serialize from .telemetry._cloudwatch_logger import _send_to_cloudwatch from .types.evaluation import EvaluationData, InputT, OutputT @@ -168,6 +169,31 @@ def _validate_case_names(self) -> None: if len(case_names) != len(set(case_names)): raise ValueError("All case names must be unique when using an evaluation_data_store.") + @staticmethod + def _resolve_task( + task: Callable | None, + provider: TraceProvider | None, + ) -> Callable: + """Resolve task and provider into a single task callable. + + Args: + task: User-provided task function. + provider: TraceProvider that fetches evaluation data by session_id. + + Returns: + The task callable to use for evaluation. + + Raises: + ValueError: If both or neither of task and provider are specified. + """ + if task is not None and provider is not None: + raise ValueError("Cannot specify both 'task' and 'provider'. Use one or the other.") + if task is None and provider is None: + raise ValueError("Must specify either 'task' or 'provider'.") + if provider is not None: + return lambda case: provider.get_evaluation_data(case.session_id) + return task + async def _run_task_async( self, task: Callable[[Case[InputT, OutputT]], OutputT | dict[str, Any]], case: Case[InputT, OutputT] ) -> EvaluationData[InputT, OutputT]: @@ -499,8 +525,9 @@ async def _worker( def run_evaluations( self, - task: Callable[[Case[InputT, OutputT]], OutputT | dict[str, Any]], + task: Callable[[Case[InputT, OutputT]], OutputT | dict[str, Any]] | None = None, evaluation_data_store: EvaluationDataStore | None = None, + provider: TraceProvider | None = None, ) -> list[EvaluationReport]: """ Run the evaluations for all of the test cases with all evaluators. @@ -509,21 +536,29 @@ def run_evaluations( Args: task: The task to run the test case on. This function should take in InputT and returns either - OutputT or {"output": OutputT, "trajectory": ...}. + OutputT or {"output": OutputT, "trajectory": ...}. Mutually exclusive with provider. evaluation_data_store: Optional store for loading/saving evaluation data. When provided, cached results are loaded instead of running the task, and new results are saved after task execution. + provider: A TraceProvider that fetches evaluation data by session_id. When given, a task function + is generated automatically. Mutually exclusive with task. Return: A list of EvaluationReport objects, one for each evaluator, containing the overall score, individual case results, and basic feedback for each test case. """ + task = self._resolve_task(task, provider) + if asyncio.iscoroutinefunction(task): raise ValueError("Async task is not supported. Please use run_evaluations_async instead.") return asyncio.run(self.run_evaluations_async(task, max_workers=1, evaluation_data_store=evaluation_data_store)) async def run_evaluations_async( - self, task: Callable, max_workers: int = 10, evaluation_data_store: EvaluationDataStore | None = None + self, + task: Callable | None = None, + max_workers: int = 10, + evaluation_data_store: EvaluationDataStore | None = None, + provider: TraceProvider | None = None, ) -> list[EvaluationReport]: """ Run evaluations asynchronously using a queue for parallel processing. @@ -531,14 +566,17 @@ async def run_evaluations_async( Args: task: The task function to run on each case. This function should take in InputT and returns either OutputT or {"output": OutputT, "trajectory": ...}. The task can either run - synchronously or asynchronously. + synchronously or asynchronously. Mutually exclusive with provider. max_workers: Maximum number of parallel workers (default: 10) evaluation_data_store: Optional store for loading/saving evaluation data. When provided, cached results are loaded instead of running the task, and new results are saved after task execution. + provider: A TraceProvider that fetches evaluation data by session_id. When given, a task function + is generated automatically. Mutually exclusive with task. Returns: List of EvaluationReport objects, one for each evaluator, containing evaluation results """ + task = self._resolve_task(task, provider) if evaluation_data_store is not None: self._validate_case_names() diff --git a/tests/strands_evals/test_experiment.py b/tests/strands_evals/test_experiment.py index a806b5c9..c8e3ceee 100644 --- a/tests/strands_evals/test_experiment.py +++ b/tests/strands_evals/test_experiment.py @@ -1756,3 +1756,108 @@ def test_run_evaluations_without_store_unchanged(self): assert len(reports) == 1 assert reports[0].scores[0] == 1.0 + + +class MockTraceProvider: + """Simple mock provider for testing.""" + + def __init__(self, data: dict[str, dict]): + self._data = data + self.call_count = 0 + self.called_session_ids: list[str] = [] + + def get_evaluation_data(self, session_id: str) -> dict: + self.call_count += 1 + self.called_session_ids.append(session_id) + return self._data[session_id] + + +class TestProviderIntegration: + def test_run_evaluations_with_provider(self): + """Provider should be called with each case's session_id and results passed to evaluators.""" + cases = [ + Case(name="c1", session_id="sess-1", input="hello", expected_output="hello"), + Case(name="c2", session_id="sess-2", input="foo", expected_output="foo"), + ] + provider = MockTraceProvider({ + "sess-1": {"output": "hello"}, + "sess-2": {"output": "foo"}, + }) + experiment = Experiment(cases=cases, evaluators=[MockEvaluator()]) + + reports = experiment.run_evaluations(provider=provider) + + assert provider.call_count == 2 + assert set(provider.called_session_ids) == {"sess-1", "sess-2"} + assert len(reports) == 1 + assert reports[0].scores == [1.0, 1.0] + + def test_run_evaluations_with_provider_and_task_raises(self): + """Passing both task and provider should raise ValueError.""" + cases = [Case(name="c1", input="hello", expected_output="hello")] + provider = MockTraceProvider({"sess": {"output": "hello"}}) + experiment = Experiment(cases=cases, evaluators=[MockEvaluator()]) + + with pytest.raises(ValueError, match="Cannot specify both"): + experiment.run_evaluations(task=lambda c: c.input, provider=provider) + + def test_run_evaluations_with_neither_task_nor_provider_raises(self): + """Passing neither task nor provider should raise ValueError.""" + cases = [Case(name="c1", input="hello", expected_output="hello")] + experiment = Experiment(cases=cases, evaluators=[MockEvaluator()]) + + with pytest.raises(ValueError, match="Must specify either"): + experiment.run_evaluations() + + @pytest.mark.asyncio + async def test_run_evaluations_async_with_provider(self): + """Async variant should also accept provider parameter.""" + cases = [ + Case(name="c1", session_id="sess-1", input="hello", expected_output="hello"), + ] + provider = MockTraceProvider({ + "sess-1": {"output": "hello"}, + }) + experiment = Experiment(cases=cases, evaluators=[MockEvaluator()]) + + reports = await experiment.run_evaluations_async(provider=provider) + + assert provider.call_count == 1 + assert provider.called_session_ids == ["sess-1"] + assert len(reports) == 1 + assert reports[0].scores == [1.0] + + def test_run_evaluations_with_provider_and_data_store_caches(self): + """When data store has cached data, provider should not be called for that case.""" + store = DictEvaluationDataStore() + cached_data = EvaluationData( + input="hello", + actual_output="hello", + name="c1", + expected_output="hello", + ) + store.save("c1", cached_data) + + provider = MockTraceProvider({ + "sess-1": {"output": "hello"}, + }) + cases = [Case(name="c1", session_id="sess-1", input="hello", expected_output="hello")] + experiment = Experiment(cases=cases, evaluators=[MockEvaluator()]) + + reports = experiment.run_evaluations(provider=provider, evaluation_data_store=store) + + # Provider should NOT have been called - data was cached + assert provider.call_count == 0 + assert len(reports) == 1 + assert reports[0].scores == [1.0] + + def test_run_evaluations_with_task_positional_arg_unchanged(self): + """Existing positional task argument should continue to work.""" + cases = [Case(name="c1", input="hello", expected_output="hello")] + experiment = Experiment(cases=cases, evaluators=[MockEvaluator()]) + + # Positional arg - existing behavior + reports = experiment.run_evaluations(lambda c: c.input) + + assert len(reports) == 1 + assert reports[0].scores == [1.0] From 0e8eb7cc63ecde717a59ad1671b821fa00eed81e Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Tue, 31 Mar 2026 12:50:42 -0400 Subject: [PATCH 2/3] refactor: approach 2. remove TraceProvider support from Experiment class Remove the TraceProvider integration from run_evaluations and run_evaluations_async, making the task parameter required instead of optional. This removes the _resolve_task static method and the provider parameter, simplifying the evaluation API by requiring callers to always pass a task callable directly. --- src/strands_evals/experiment.py | 43 ++------------ src/strands_evals/providers/trace_provider.py | 9 +++ .../providers/test_trace_provider.py | 19 +++++++ tests/strands_evals/test_experiment.py | 56 ++++++++----------- 4 files changed, 55 insertions(+), 72 deletions(-) diff --git a/src/strands_evals/experiment.py b/src/strands_evals/experiment.py index 21e5ecc7..b94869fd 100644 --- a/src/strands_evals/experiment.py +++ b/src/strands_evals/experiment.py @@ -23,7 +23,6 @@ from .evaluators.interactions_evaluator import InteractionsEvaluator from .evaluators.output_evaluator import OutputEvaluator from .evaluators.trajectory_evaluator import TrajectoryEvaluator -from .providers.trace_provider import TraceProvider from .telemetry import get_tracer, serialize from .telemetry._cloudwatch_logger import _send_to_cloudwatch from .types.evaluation import EvaluationData, InputT, OutputT @@ -169,31 +168,6 @@ def _validate_case_names(self) -> None: if len(case_names) != len(set(case_names)): raise ValueError("All case names must be unique when using an evaluation_data_store.") - @staticmethod - def _resolve_task( - task: Callable | None, - provider: TraceProvider | None, - ) -> Callable: - """Resolve task and provider into a single task callable. - - Args: - task: User-provided task function. - provider: TraceProvider that fetches evaluation data by session_id. - - Returns: - The task callable to use for evaluation. - - Raises: - ValueError: If both or neither of task and provider are specified. - """ - if task is not None and provider is not None: - raise ValueError("Cannot specify both 'task' and 'provider'. Use one or the other.") - if task is None and provider is None: - raise ValueError("Must specify either 'task' or 'provider'.") - if provider is not None: - return lambda case: provider.get_evaluation_data(case.session_id) - return task - async def _run_task_async( self, task: Callable[[Case[InputT, OutputT]], OutputT | dict[str, Any]], case: Case[InputT, OutputT] ) -> EvaluationData[InputT, OutputT]: @@ -525,9 +499,8 @@ async def _worker( def run_evaluations( self, - task: Callable[[Case[InputT, OutputT]], OutputT | dict[str, Any]] | None = None, + task: Callable[[Case[InputT, OutputT]], OutputT | dict[str, Any]], evaluation_data_store: EvaluationDataStore | None = None, - provider: TraceProvider | None = None, ) -> list[EvaluationReport]: """ Run the evaluations for all of the test cases with all evaluators. @@ -536,18 +509,14 @@ def run_evaluations( Args: task: The task to run the test case on. This function should take in InputT and returns either - OutputT or {"output": OutputT, "trajectory": ...}. Mutually exclusive with provider. + OutputT or {"output": OutputT, "trajectory": ...}. evaluation_data_store: Optional store for loading/saving evaluation data. When provided, cached results are loaded instead of running the task, and new results are saved after task execution. - provider: A TraceProvider that fetches evaluation data by session_id. When given, a task function - is generated automatically. Mutually exclusive with task. Return: A list of EvaluationReport objects, one for each evaluator, containing the overall score, individual case results, and basic feedback for each test case. """ - task = self._resolve_task(task, provider) - if asyncio.iscoroutinefunction(task): raise ValueError("Async task is not supported. Please use run_evaluations_async instead.") @@ -555,10 +524,9 @@ def run_evaluations( async def run_evaluations_async( self, - task: Callable | None = None, + task: Callable, max_workers: int = 10, evaluation_data_store: EvaluationDataStore | None = None, - provider: TraceProvider | None = None, ) -> list[EvaluationReport]: """ Run evaluations asynchronously using a queue for parallel processing. @@ -566,17 +534,14 @@ async def run_evaluations_async( Args: task: The task function to run on each case. This function should take in InputT and returns either OutputT or {"output": OutputT, "trajectory": ...}. The task can either run - synchronously or asynchronously. Mutually exclusive with provider. + synchronously or asynchronously. max_workers: Maximum number of parallel workers (default: 10) evaluation_data_store: Optional store for loading/saving evaluation data. When provided, cached results are loaded instead of running the task, and new results are saved after task execution. - provider: A TraceProvider that fetches evaluation data by session_id. When given, a task function - is generated automatically. Mutually exclusive with task. Returns: List of EvaluationReport objects, one for each evaluator, containing evaluation results """ - task = self._resolve_task(task, provider) if evaluation_data_store is not None: self._validate_case_names() diff --git a/src/strands_evals/providers/trace_provider.py b/src/strands_evals/providers/trace_provider.py index 84e9ce54..42b13ccf 100644 --- a/src/strands_evals/providers/trace_provider.py +++ b/src/strands_evals/providers/trace_provider.py @@ -1,6 +1,7 @@ """TraceProvider interface for retrieving agent trace data from observability backends.""" from abc import ABC, abstractmethod +from collections.abc import Callable from ..types.evaluation import TaskOutput @@ -32,3 +33,11 @@ def get_evaluation_data(self, session_id: str) -> TaskOutput: ProviderError: If the provider is unreachable or returns an error """ ... + + def as_task(self) -> Callable: + """Return a task callable that fetches evaluation data by session_id. + + Returns: + A callable that takes a Case and returns the result of get_evaluation_data. + """ + return lambda case: self.get_evaluation_data(case.session_id) diff --git a/tests/strands_evals/providers/test_trace_provider.py b/tests/strands_evals/providers/test_trace_provider.py index fc5ee46f..579c09be 100644 --- a/tests/strands_evals/providers/test_trace_provider.py +++ b/tests/strands_evals/providers/test_trace_provider.py @@ -70,3 +70,22 @@ def test_get_evaluation_data_raises_session_not_found(self): provider = ConcreteProvider(session=None) with pytest.raises(SessionNotFoundError, match="No session found"): provider.get_evaluation_data("missing") + + def test_as_task_returns_callable(self): + session = Session(session_id="s1", traces=[]) + provider = ConcreteProvider(session=session) + task = provider.as_task() + assert callable(task) + + def test_as_task_callable_delegates_to_get_evaluation_data(self): + """as_task() callable should pass case.session_id to get_evaluation_data.""" + session = Session(session_id="s1", traces=[]) + provider = ConcreteProvider(session=session) + task = provider.as_task() + + class FakeCase: + session_id = "s1" + + result = task(FakeCase()) + assert result["output"] == "test response" + assert result["trajectory"] == session diff --git a/tests/strands_evals/test_experiment.py b/tests/strands_evals/test_experiment.py index c8e3ceee..71824b38 100644 --- a/tests/strands_evals/test_experiment.py +++ b/tests/strands_evals/test_experiment.py @@ -19,6 +19,7 @@ ) from strands_evals.evaluators.evaluator import DEFAULT_BEDROCK_MODEL_ID from strands_evals.experiment import is_throttling_error +from strands_evals.providers.trace_provider import TraceProvider from strands_evals.types import EvaluationData, EvaluationOutput @@ -1758,7 +1759,7 @@ def test_run_evaluations_without_store_unchanged(self): assert reports[0].scores[0] == 1.0 -class MockTraceProvider: +class MockTraceProvider(TraceProvider): """Simple mock provider for testing.""" def __init__(self, data: dict[str, dict]): @@ -1774,53 +1775,40 @@ def get_evaluation_data(self, session_id: str) -> dict: class TestProviderIntegration: def test_run_evaluations_with_provider(self): - """Provider should be called with each case's session_id and results passed to evaluators.""" + """provider.as_task() should be called with each case's session_id.""" cases = [ Case(name="c1", session_id="sess-1", input="hello", expected_output="hello"), Case(name="c2", session_id="sess-2", input="foo", expected_output="foo"), ] - provider = MockTraceProvider({ - "sess-1": {"output": "hello"}, - "sess-2": {"output": "foo"}, - }) + provider = MockTraceProvider( + { + "sess-1": {"output": "hello"}, + "sess-2": {"output": "foo"}, + } + ) experiment = Experiment(cases=cases, evaluators=[MockEvaluator()]) - reports = experiment.run_evaluations(provider=provider) + reports = experiment.run_evaluations(provider.as_task()) assert provider.call_count == 2 assert set(provider.called_session_ids) == {"sess-1", "sess-2"} assert len(reports) == 1 assert reports[0].scores == [1.0, 1.0] - def test_run_evaluations_with_provider_and_task_raises(self): - """Passing both task and provider should raise ValueError.""" - cases = [Case(name="c1", input="hello", expected_output="hello")] - provider = MockTraceProvider({"sess": {"output": "hello"}}) - experiment = Experiment(cases=cases, evaluators=[MockEvaluator()]) - - with pytest.raises(ValueError, match="Cannot specify both"): - experiment.run_evaluations(task=lambda c: c.input, provider=provider) - - def test_run_evaluations_with_neither_task_nor_provider_raises(self): - """Passing neither task nor provider should raise ValueError.""" - cases = [Case(name="c1", input="hello", expected_output="hello")] - experiment = Experiment(cases=cases, evaluators=[MockEvaluator()]) - - with pytest.raises(ValueError, match="Must specify either"): - experiment.run_evaluations() - @pytest.mark.asyncio async def test_run_evaluations_async_with_provider(self): - """Async variant should also accept provider parameter.""" + """Async variant should also accept a provider as the task arg.""" cases = [ Case(name="c1", session_id="sess-1", input="hello", expected_output="hello"), ] - provider = MockTraceProvider({ - "sess-1": {"output": "hello"}, - }) + provider = MockTraceProvider( + { + "sess-1": {"output": "hello"}, + } + ) experiment = Experiment(cases=cases, evaluators=[MockEvaluator()]) - reports = await experiment.run_evaluations_async(provider=provider) + reports = await experiment.run_evaluations_async(provider.as_task()) assert provider.call_count == 1 assert provider.called_session_ids == ["sess-1"] @@ -1838,13 +1826,15 @@ def test_run_evaluations_with_provider_and_data_store_caches(self): ) store.save("c1", cached_data) - provider = MockTraceProvider({ - "sess-1": {"output": "hello"}, - }) + provider = MockTraceProvider( + { + "sess-1": {"output": "hello"}, + } + ) cases = [Case(name="c1", session_id="sess-1", input="hello", expected_output="hello")] experiment = Experiment(cases=cases, evaluators=[MockEvaluator()]) - reports = experiment.run_evaluations(provider=provider, evaluation_data_store=store) + reports = experiment.run_evaluations(provider.as_task(), evaluation_data_store=store) # Provider should NOT have been called - data was cached assert provider.call_count == 0 From b8517d288125bffce3542f258e183ae2cf7fcfe4 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Tue, 31 Mar 2026 17:49:29 -0400 Subject: [PATCH 3/3] fix: add explicit type parameters to TraceProvider.as_task return type Narrow the return type from bare `Callable` to `Callable[[Case], TaskOutput]` and update the docstring to match. This improves type safety and editor support for consumers of `as_task`. --- src/strands_evals/providers/trace_provider.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/strands_evals/providers/trace_provider.py b/src/strands_evals/providers/trace_provider.py index 42b13ccf..c3eff4f9 100644 --- a/src/strands_evals/providers/trace_provider.py +++ b/src/strands_evals/providers/trace_provider.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable +from ..case import Case from ..types.evaluation import TaskOutput @@ -34,10 +35,11 @@ def get_evaluation_data(self, session_id: str) -> TaskOutput: """ ... - def as_task(self) -> Callable: + def as_task(self) -> Callable[[Case], TaskOutput]: """Return a task callable that fetches evaluation data by session_id. Returns: - A callable that takes a Case and returns the result of get_evaluation_data. + A callable that takes a single Case and returns the TaskOutput + for that case's session. """ return lambda case: self.get_evaluation_data(case.session_id)