Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/strands_evals/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,10 @@ def run_evaluations(
return asyncio.run(self.run_evaluations_async(task, max_workers=1, evaluation_data_store=evaluation_data_store))

async def run_evaluations_async(
self, task: Callable, max_workers: int = 10, evaluation_data_store: EvaluationDataStore | None = None
self,
task: Callable,
max_workers: int = 10,
evaluation_data_store: EvaluationDataStore | None = None,
) -> list[EvaluationReport]:
"""
Run evaluations asynchronously using a queue for parallel processing.
Expand Down
11 changes: 11 additions & 0 deletions src/strands_evals/providers/trace_provider.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""TraceProvider interface for retrieving agent trace data from observability backends."""

from abc import ABC, abstractmethod
from collections.abc import Callable

from ..case import Case
from ..types.evaluation import TaskOutput


Expand Down Expand Up @@ -32,3 +34,12 @@ def get_evaluation_data(self, session_id: str) -> TaskOutput:
ProviderError: If the provider is unreachable or returns an error
"""
...

def as_task(self) -> Callable[[Case], TaskOutput]:
"""Return a task callable that fetches evaluation data by session_id.

Returns:
A callable that takes a single Case and returns the TaskOutput
for that case's session.
"""
return lambda case: self.get_evaluation_data(case.session_id)
19 changes: 19 additions & 0 deletions tests/strands_evals/providers/test_trace_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,22 @@ def test_get_evaluation_data_raises_session_not_found(self):
provider = ConcreteProvider(session=None)
with pytest.raises(SessionNotFoundError, match="No session found"):
provider.get_evaluation_data("missing")

def test_as_task_returns_callable(self):
session = Session(session_id="s1", traces=[])
provider = ConcreteProvider(session=session)
task = provider.as_task()
assert callable(task)

def test_as_task_callable_delegates_to_get_evaluation_data(self):
"""as_task() callable should pass case.session_id to get_evaluation_data."""
session = Session(session_id="s1", traces=[])
provider = ConcreteProvider(session=session)
task = provider.as_task()

class FakeCase:
session_id = "s1"

result = task(FakeCase())
assert result["output"] == "test response"
assert result["trajectory"] == session
95 changes: 95 additions & 0 deletions tests/strands_evals/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from strands_evals.evaluators.evaluator import DEFAULT_BEDROCK_MODEL_ID
from strands_evals.experiment import is_throttling_error
from strands_evals.providers.trace_provider import TraceProvider
from strands_evals.types import EvaluationData, EvaluationOutput


Expand Down Expand Up @@ -1756,3 +1757,97 @@ def test_run_evaluations_without_store_unchanged(self):

assert len(reports) == 1
assert reports[0].scores[0] == 1.0


class MockTraceProvider(TraceProvider):
"""Simple mock provider for testing."""

def __init__(self, data: dict[str, dict]):
self._data = data
self.call_count = 0
self.called_session_ids: list[str] = []

def get_evaluation_data(self, session_id: str) -> dict:
self.call_count += 1
self.called_session_ids.append(session_id)
return self._data[session_id]


class TestProviderIntegration:
def test_run_evaluations_with_provider(self):
"""provider.as_task() should be called with each case's session_id."""
cases = [
Case(name="c1", session_id="sess-1", input="hello", expected_output="hello"),
Case(name="c2", session_id="sess-2", input="foo", expected_output="foo"),
]
provider = MockTraceProvider(
{
"sess-1": {"output": "hello"},
"sess-2": {"output": "foo"},
}
)
experiment = Experiment(cases=cases, evaluators=[MockEvaluator()])

reports = experiment.run_evaluations(provider.as_task())

assert provider.call_count == 2
assert set(provider.called_session_ids) == {"sess-1", "sess-2"}
assert len(reports) == 1
assert reports[0].scores == [1.0, 1.0]

@pytest.mark.asyncio
async def test_run_evaluations_async_with_provider(self):
"""Async variant should also accept a provider as the task arg."""
cases = [
Case(name="c1", session_id="sess-1", input="hello", expected_output="hello"),
]
provider = MockTraceProvider(
{
"sess-1": {"output": "hello"},
}
)
experiment = Experiment(cases=cases, evaluators=[MockEvaluator()])

reports = await experiment.run_evaluations_async(provider.as_task())

assert provider.call_count == 1
assert provider.called_session_ids == ["sess-1"]
assert len(reports) == 1
assert reports[0].scores == [1.0]

def test_run_evaluations_with_provider_and_data_store_caches(self):
"""When data store has cached data, provider should not be called for that case."""
store = DictEvaluationDataStore()
cached_data = EvaluationData(
input="hello",
actual_output="hello",
name="c1",
expected_output="hello",
)
store.save("c1", cached_data)

provider = MockTraceProvider(
{
"sess-1": {"output": "hello"},
}
)
cases = [Case(name="c1", session_id="sess-1", input="hello", expected_output="hello")]
experiment = Experiment(cases=cases, evaluators=[MockEvaluator()])

reports = experiment.run_evaluations(provider.as_task(), evaluation_data_store=store)

# Provider should NOT have been called - data was cached
assert provider.call_count == 0
assert len(reports) == 1
assert reports[0].scores == [1.0]

def test_run_evaluations_with_task_positional_arg_unchanged(self):
"""Existing positional task argument should continue to work."""
cases = [Case(name="c1", input="hello", expected_output="hello")]
experiment = Experiment(cases=cases, evaluators=[MockEvaluator()])

# Positional arg - existing behavior
reports = experiment.run_evaluations(lambda c: c.input)

assert len(reports) == 1
assert reports[0].scores == [1.0]
Loading