Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/configs/agent_simple.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
type: simple # agent type, support `simple | orchestrator` for now
max_turns: 50 # max number of turns

runner: openai # `openai | react` you can choose customized agent Runner

# -------------------------------------------------------------------------------------------
# model configs
model:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
"jinja2>=3.1.6",
"mcp>=1.12.3",
# https://github.com/openai/openai-agents-python/issues/1395 | openai==1.97.1 & openai-agents==0.2.5
"openai-agents==0.5.0",
"openai-agents==0.6.4",
# https://github.com/Arize-ai/openinference/tree/main/python/instrumentation/openinference-instrumentation-openai
"openinference-instrumentation-openai>=0.1.30",
# https://opentelemetry.io/docs/languages/python/
Expand Down
3 changes: 0 additions & 3 deletions utu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
from agents.run import set_default_agent_runner

from .utils import EnvUtils, setup_logging
from .patch.runner import UTUAgentRunner
from .tracing import setup_tracing

EnvUtils.assert_env(["UTU_LLM_TYPE", "UTU_LLM_MODEL"])
setup_logging(EnvUtils.get_env("UTU_LOG_LEVEL", "WARNING"))
setup_tracing()
# patched runner
set_default_agent_runner(UTUAgentRunner())
12 changes: 7 additions & 5 deletions utu/agents/simple_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
ModelSettings,
RunConfig,
RunHooks,
Runner,
StopAtTools,
TContext,
Tool,
Expand All @@ -25,6 +24,7 @@
from ..db import DBService, TrajectoryModel
from ..env import _BaseEnv, get_env
from ..hooks import get_run_hooks
from ..runner import get_runner
from ..tools import TOOLKIT_MAP, AsyncBaseToolkit
from ..utils import AgentsMCPUtils, AgentsUtils, get_logger, load_class_from_file
from .common import QueueCompleteSentinel, TaskRecorder
Expand Down Expand Up @@ -260,11 +260,12 @@ async def run(
if isinstance(input, str): # only add history when input is str?
input = self.input_items + [{"content": input, "role": "user"}]
run_kwargs = self._prepare_run_kwargs(input)
runner = get_runner(self.config.runner)
if AgentsUtils.get_current_trace():
run_result = await Runner.run(**run_kwargs)
run_result = await runner.run(**run_kwargs)
else:
with trace(workflow_name="simple_agent", trace_id=recorder.trace_id):
run_result = await Runner.run(**run_kwargs)
run_result = await runner.run(**run_kwargs)
# save final output and trajectory
recorder.add_run_result(run_result)
if save:
Expand Down Expand Up @@ -304,11 +305,12 @@ async def _start_streaming(self, recorder: TaskRecorder, save: bool = False, log
if isinstance(input, str): # only add history when input is str?
input = self.input_items + [{"content": input, "role": "user"}]
run_kwargs = self._prepare_run_kwargs(input)
runner = get_runner(self.config.runner)
if AgentsUtils.get_current_trace():
run_streamed_result = Runner.run_streamed(**run_kwargs)
run_streamed_result = runner.run_streamed(**run_kwargs)
else:
with trace(workflow_name="simple_agent", trace_id=recorder.trace_id):
run_streamed_result = Runner.run_streamed(**run_kwargs)
run_streamed_result = runner.run_streamed(**run_kwargs)
async for event in run_streamed_result.stream_events():
recorder._event_queue.put_nowait(event)
# save final output and trajectory
Expand Down
2 changes: 2 additions & 0 deletions utu/config/agent_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class AgentConfig(ConfigBaseModel):
"""Max turns for simple agent. This param is derived from @openai-agents"""
stop_at_tool_names: list[str] | None = None
"""Stop at tools for simple agent. This param is derived from @openai-agents"""
runner: Literal["openai", "react"] = "openai"
"""Runner name for simple agent."""

# orchestra agent config
planner_model: ModelConfigs = Field(default_factory=ModelConfigs)
Expand Down
27 changes: 27 additions & 0 deletions utu/runner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Literal

from agents import RunConfig, Runner

from .openai_runner import UTUAgentRunner
from .react_runner import ReactRunner


def get_runner(name: Literal["openai", "react"] = "openai") -> object:
"""Get a runner class by name.

Args:
name: Runner name ("openai" for default, "react" for ReactRunner)

Returns:
Runner class (not instance)
"""
# TODO: add a protocol for runner
if name == "react":
return ReactRunner
elif name == "openai":
return UTUAgentRunner()
else:
raise ValueError(f"Unknown runner name: {name}")


__all__ = ["Runner", "RunConfig", "ReactRunner", "get_runner"]
134 changes: 52 additions & 82 deletions utu/patch/runner.py → utu/runner/openai_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

import asyncio
import logging
from typing import cast
from typing_extensions import Unpack
from typing import Unpack, cast

from agents import (
Agent,
Expand All @@ -22,22 +21,33 @@
TResponseInputItem,
)
from agents._run_impl import (
RunImpl, NextStepHandoff, TraceCtxManager, NextStepFinalOutput, NextStepRunAgain,
get_model_tracing_impl,
NextStepFinalOutput,
NextStepHandoff,
NextStepRunAgain,
RunImpl,
TraceCtxManager,
get_model_tracing_impl,
)
from agents.exceptions import ModelBehaviorError, MaxTurnsExceeded, AgentsException, RunErrorDetails
from agents.exceptions import AgentsException, MaxTurnsExceeded, ModelBehaviorError, RunErrorDetails
from agents.guardrail import InputGuardrailResult
from agents.items import HandoffCallItem, ModelResponse, ToolCallItem, ToolCallItemTypes, ReasoningItem
from agents.items import HandoffCallItem, ModelResponse, ReasoningItem, ToolCallItem, ToolCallItemTypes
from agents.result import RunResult
from agents.run import (
AgentRunner, AgentToolUseTracker, RunResultStreaming, SingleStepResult, RunOptions,
_TOOL_CALL_TYPES, _ServerConversationTracker, _copy_str_or_list, DEFAULT_MAX_TURNS
_TOOL_CALL_TYPES,
DEFAULT_MAX_TURNS,
AgentRunner,
AgentToolUseTracker,
RunOptions,
RunResultStreaming,
SingleStepResult,
_copy_str_or_list,
_ServerConversationTracker,
)
from agents.stream_events import RawResponsesStreamEvent, RunItemStreamEvent
from agents.tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult
from agents.tracing import Span, AgentSpanData, SpanError, agent_span
from agents.util import _coro, _error_tracing
from agents.tracing import AgentSpanData, Span, SpanError, agent_span
from agents.usage import Usage
from agents.util import _coro, _error_tracing
from openai.types.responses import (
ResponseCompletedEvent,
ResponseFunctionToolCall,
Expand All @@ -53,7 +63,6 @@


class UTUAgentRunner(AgentRunner):

async def run(
self,
starting_agent: Agent[TContext],
Expand All @@ -79,9 +88,7 @@ async def run(

# Keep original user input separate from session-prepared input
original_user_input = input
prepared_input = await self._prepare_input_with_session(
input, session, run_config.session_input_callback
)
prepared_input = await self._prepare_input_with_session(input, session, run_config.session_input_callback)

tool_use_tracker = AgentToolUseTracker()

Expand Down Expand Up @@ -120,8 +127,7 @@ async def run(
# agent changes, or if the agent loop ends.
if current_span is None:
handoff_names = [
h.agent_name
for h in await AgentRunner._get_handoffs(current_agent, context_wrapper)
h.agent_name for h in await AgentRunner._get_handoffs(current_agent, context_wrapper)
]
if output_schema := AgentRunner._get_output_schema(current_agent):
output_type_name = output_schema.name()
Expand Down Expand Up @@ -152,17 +158,18 @@ async def run(
)
# ADD: inject context infos
if isinstance(context_wrapper.context, dict):
context_wrapper.context.update({
"current_turn": current_turn,
"max_turns": max_turns,
})
context_wrapper.context.update(
{
"current_turn": current_turn,
"max_turns": max_turns,
}
)

if current_turn == 1:
input_guardrail_results, turn_result = await asyncio.gather(
self._run_input_guardrails(
starting_agent,
starting_agent.input_guardrails
+ (run_config.input_guardrails or []),
starting_agent.input_guardrails + (run_config.input_guardrails or []),
_copy_str_or_list(prepared_input),
context_wrapper,
),
Expand Down Expand Up @@ -225,12 +232,9 @@ async def run(
context_wrapper=context_wrapper,
)
if not any(
guardrail_result.output.tripwire_triggered
for guardrail_result in input_guardrail_results
guardrail_result.output.tripwire_triggered for guardrail_result in input_guardrail_results
):
await self._save_result_to_session(
session, [], turn_result.new_step_items
)
await self._save_result_to_session(session, [], turn_result.new_step_items)

return result
elif isinstance(turn_result.next_step, NextStepHandoff):
Expand All @@ -240,16 +244,11 @@ async def run(
should_run_agent_start_hooks = True
elif isinstance(turn_result.next_step, NextStepRunAgain):
if not any(
guardrail_result.output.tripwire_triggered
for guardrail_result in input_guardrail_results
guardrail_result.output.tripwire_triggered for guardrail_result in input_guardrail_results
):
await self._save_result_to_session(
session, [], turn_result.new_step_items
)
await self._save_result_to_session(session, [], turn_result.new_step_items)
else:
raise AgentsException(
f"Unknown next step type: {type(turn_result.next_step)}"
)
raise AgentsException(f"Unknown next step type: {type(turn_result.next_step)}")
except AgentsException as exc:
exc.run_data = RunErrorDetails(
input=original_input,
Expand Down Expand Up @@ -284,11 +283,7 @@ async def _run_single_turn_streamed(
if should_run_agent_start_hooks:
await asyncio.gather(
hooks.on_agent_start(context_wrapper, agent),
(
agent.hooks.on_start(context_wrapper, agent)
if agent.hooks
else _coro.noop_coroutine()
),
(agent.hooks.on_start(context_wrapper, agent) if agent.hooks else _coro.noop_coroutine()),
)

output_schema = cls._get_output_schema(agent)
Expand All @@ -309,20 +304,20 @@ async def _run_single_turn_streamed(
final_response: ModelResponse | None = None

if server_conversation_tracker is not None:
input = server_conversation_tracker.prepare_input(
streamed_result.input, streamed_result.new_items
)
input = server_conversation_tracker.prepare_input(streamed_result.input, streamed_result.new_items)
else:
input = ItemHelpers.input_to_new_input_list(streamed_result.input)
input.extend([item.to_input_item() for item in streamed_result.new_items])

# ADD: inject context infos
if isinstance(context_wrapper.context, dict):
context_wrapper.context.update({
"current_turn": streamed_result.current_turn,
"max_turns": streamed_result.max_turns,
# "streamed_result": streamed_result
})
context_wrapper.context.update(
{
"current_turn": streamed_result.current_turn,
"max_turns": streamed_result.max_turns,
# "streamed_result": streamed_result
}
)
input = cls._context_manager_preprocess(input, context_wrapper)

# THIS IS THE RESOLVED CONFLICT BLOCK
Expand All @@ -338,22 +333,14 @@ async def _run_single_turn_streamed(
await asyncio.gather(
hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input),
(
agent.hooks.on_llm_start(
context_wrapper, agent, filtered.instructions, filtered.input
)
agent.hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input)
if agent.hooks
else _coro.noop_coroutine()
),
)

previous_response_id = (
server_conversation_tracker.previous_response_id
if server_conversation_tracker
else None
)
conversation_id = (
server_conversation_tracker.conversation_id if server_conversation_tracker else None
)
previous_response_id = server_conversation_tracker.previous_response_id if server_conversation_tracker else None
conversation_id = server_conversation_tracker.conversation_id if server_conversation_tracker else None

# 1. Stream the output events
async for event in model.stream_response(
Expand All @@ -363,9 +350,7 @@ async def _run_single_turn_streamed(
all_tools,
output_schema,
handoffs,
get_model_tracing_impl(
run_config.tracing_disabled, run_config.trace_include_sensitive_data
),
get_model_tracing_impl(run_config.tracing_disabled, run_config.trace_include_sensitive_data),
previous_response_id=previous_response_id,
conversation_id=conversation_id,
prompt=prompt_config,
Expand Down Expand Up @@ -397,9 +382,7 @@ async def _run_single_turn_streamed(
output_item = event.item

if isinstance(output_item, _TOOL_CALL_TYPES):
call_id: str | None = getattr(
output_item, "call_id", getattr(output_item, "id", None)
)
call_id: str | None = getattr(output_item, "call_id", getattr(output_item, "id", None))

if call_id and call_id not in emitted_tool_call_ids:
emitted_tool_call_ids.add(call_id)
Expand All @@ -408,9 +391,7 @@ async def _run_single_turn_streamed(
raw_item=cast(ToolCallItemTypes, output_item),
agent=agent,
)
streamed_result._event_queue.put_nowait(
RunItemStreamEvent(item=tool_item, name="tool_called")
)
streamed_result._event_queue.put_nowait(RunItemStreamEvent(item=tool_item, name="tool_called"))

elif isinstance(output_item, ResponseReasoningItem):
reasoning_id: str | None = getattr(output_item, "id", None)
Expand Down Expand Up @@ -469,11 +450,7 @@ async def _run_single_turn_streamed(
for item in items_to_filter
if not (
isinstance(item, ToolCallItem)
and (
call_id := getattr(
item.raw_item, "call_id", getattr(item.raw_item, "id", None)
)
)
and (call_id := getattr(item.raw_item, "call_id", getattr(item.raw_item, "id", None)))
and call_id in emitted_tool_call_ids
)
]
Expand All @@ -491,9 +468,7 @@ async def _run_single_turn_streamed(
]

# Filter out HandoffCallItem to avoid duplicates (already sent earlier)
items_to_filter = [
item for item in items_to_filter if not isinstance(item, HandoffCallItem)
]
items_to_filter = [item for item in items_to_filter if not isinstance(item, HandoffCallItem)]

# Create filtered result and send to queue
filtered_result = _dc.replace(single_step_result, new_step_items=items_to_filter)
Expand All @@ -519,11 +494,7 @@ async def _run_single_turn(
if should_run_agent_start_hooks:
await asyncio.gather(
hooks.on_agent_start(context_wrapper, agent),
(
agent.hooks.on_start(context_wrapper, agent)
if agent.hooks
else _coro.noop_coroutine()
),
(agent.hooks.on_start(context_wrapper, agent) if agent.hooks else _coro.noop_coroutine()),
)

system_prompt, prompt_config = await asyncio.gather(
Expand Down Expand Up @@ -574,7 +545,6 @@ async def _run_single_turn(
tool_use_tracker=tool_use_tracker,
)


@classmethod
def _context_manager_preprocess(
cls, input: list[TResponseInputItem], context_wrapper: RunContextWrapper[TContext]
Expand Down
Loading