diff --git a/python/copilot/session.py b/python/copilot/session.py index 0dc569f25..b9a328ff7 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -36,7 +36,6 @@ CanvasProviderOpenResult, ClientSessionApiHandlers, CommandsHandlePendingCommandRequest, - ExternalToolTextResultForLlm, HandlePendingToolCallRequest, LogRequest, ModelSwitchToRequest, @@ -79,7 +78,13 @@ from .generated.session_events import ( ReasoningSummary as _RpcReasoningSummary, ) -from .tools import Tool, ToolHandler, ToolInvocation, ToolResult +from .tools import ( + Tool, + ToolHandler, + ToolInvocation, + ToolResult, + tool_result_to_external_tool_text_result_for_llm, +) logger = logging.getLogger(__name__) @@ -1862,12 +1867,7 @@ async def _execute_tool_and_respond( await self.rpc.tools.handle_pending_tool_call( HandlePendingToolCallRequest( request_id=request_id, - result=ExternalToolTextResultForLlm( - text_result_for_llm=tool_result.text_result_for_llm, - error=tool_result.error, - result_type=tool_result.result_type, - tool_telemetry=tool_result.tool_telemetry, - ), + result=tool_result_to_external_tool_text_result_for_llm(tool_result), ) ) log_timing( diff --git a/python/copilot/tools.py b/python/copilot/tools.py index a82a48b1e..80a3c4637 100644 --- a/python/copilot/tools.py +++ b/python/copilot/tools.py @@ -15,6 +15,12 @@ from pydantic import BaseModel +from .generated.rpc import ( + ExternalToolTextResultForLlm, + ExternalToolTextResultForLlmBinaryResultsForLlm, + ExternalToolTextResultForLlmBinaryResultsForLlmType, +) + ToolResultType = Literal["success", "failure", "rejected", "denied", "timeout"] @@ -371,3 +377,29 @@ def convert_mcp_call_tool_result(call_result: dict[str, Any]) -> ToolResult: result_type="failure" if call_result.get("isError") is True else "success", binary_results_for_llm=binary_results if binary_results else None, ) + + +def tool_result_to_external_tool_text_result_for_llm( + tool_result: ToolResult, +) -> ExternalToolTextResultForLlm: + """Convert a ToolResult into the RPC payload sent to HandlePendingToolCall.""" + binary_results_for_llm = None + if tool_result.binary_results_for_llm: + binary_results_for_llm = [ + ExternalToolTextResultForLlmBinaryResultsForLlm( + data=binary_result.data, + mime_type=binary_result.mime_type, + type=ExternalToolTextResultForLlmBinaryResultsForLlmType(binary_result.type), + description=binary_result.description or None, + ) + for binary_result in tool_result.binary_results_for_llm + ] + + return ExternalToolTextResultForLlm( + text_result_for_llm=tool_result.text_result_for_llm, + binary_results_for_llm=binary_results_for_llm, + error=tool_result.error, + result_type=tool_result.result_type, + session_log=tool_result.session_log, + tool_telemetry=tool_result.tool_telemetry, + ) diff --git a/python/test_tools.py b/python/test_tools.py index d583b59c0..16f14635a 100644 --- a/python/test_tools.py +++ b/python/test_tools.py @@ -7,10 +7,12 @@ from copilot import define_tool from copilot.tools import ( + ToolBinaryResult, ToolInvocation, ToolResult, _normalize_result, convert_mcp_call_tool_result, + tool_result_to_external_tool_text_result_for_llm, ) @@ -427,3 +429,39 @@ def test_call_tool_result_dict_is_json_serialized_by_normalize(self): result = _normalize_result({"content": [{"type": "text", "text": "hello"}]}) parsed = json.loads(result.text_result_for_llm) assert parsed == {"content": [{"type": "text", "text": "hello"}]} + + +class TestToolResultToExternalToolTextResultForLlm: + def test_forwards_binary_results_and_session_log(self): + tool_result = ToolResult( + text_result_for_llm="screenshot captured", + binary_results_for_llm=[ + ToolBinaryResult( + data="base64data", + mime_type="image/png", + type="image", + description="screenshot.png", + ) + ], + session_log="tool execution details", + tool_telemetry={"duration_ms": 42}, + ) + + rpc_result = tool_result_to_external_tool_text_result_for_llm(tool_result) + + assert rpc_result.text_result_for_llm == "screenshot captured" + assert rpc_result.session_log == "tool execution details" + assert rpc_result.tool_telemetry == {"duration_ms": 42} + assert rpc_result.binary_results_for_llm is not None + assert len(rpc_result.binary_results_for_llm) == 1 + assert rpc_result.binary_results_for_llm[0].data == "base64data" + assert rpc_result.binary_results_for_llm[0].mime_type == "image/png" + assert rpc_result.binary_results_for_llm[0].type.value == "image" + assert rpc_result.binary_results_for_llm[0].description == "screenshot.png" + + def test_omits_binary_results_when_none(self): + tool_result = ToolResult(text_result_for_llm="done") + rpc_result = tool_result_to_external_tool_text_result_for_llm(tool_result) + assert rpc_result.binary_results_for_llm is None + assert rpc_result.session_log is None +