Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import functools
from typing import Any, Iterable, Literal, Unpack
from typing import Any, Literal, Unpack

import openai
import structlog
Expand All @@ -15,12 +15,11 @@
from openai.types.completion_usage import CompletionUsage
from preparedness_turn_completer.turn_completer import TurnCompleter
from preparedness_turn_completer.utils import (
DEFAULT_RETRY_CONFIG,
RetryConfig,
get_model_context_window_length,
warn_about_non_empty_params,
)
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator

logger = structlog.stdlib.get_logger(component=__name__)

Expand All @@ -34,9 +33,9 @@ def __init__(
temperature: float | None | NotGiven = NOT_GIVEN,
max_tokens: int | None | NotGiven = NOT_GIVEN,
top_p: float | None | NotGiven = NOT_GIVEN,
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
retry_config: RetryConfig = DEFAULT_RETRY_CONFIG,
retry_config: RetryConfig | None = None,
):
self.model = model
self.reasoning_effort = reasoning_effort
Expand All @@ -47,7 +46,7 @@ def __init__(
self.tools = tools
self.tool_choice = tool_choice
self.encoding_name: str
self.retry_config = retry_config
self.retry_config = retry_config or RetryConfig()
try:
self.encoding_name = tiktoken.encoding_name_for_model(model)
except KeyError:
Expand All @@ -74,9 +73,9 @@ class Config(TurnCompleter.Config):
temperature: float | None | NotGiven = NOT_GIVEN
max_tokens: int | None | NotGiven = NOT_GIVEN
top_p: float | None | NotGiven = NOT_GIVEN
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN
retry_config: RetryConfig = DEFAULT_RETRY_CONFIG
retry_config: RetryConfig = Field(default_factory=RetryConfig)

def build(self) -> OpenAICompletionsTurnCompleter:
return OpenAICompletionsTurnCompleter(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import functools
from typing import Any, Iterable, Unpack
from typing import Any, Literal, Unpack

import openai
import structlog
Expand All @@ -19,27 +19,34 @@
)
from preparedness_turn_completer.turn_completer import TurnCompleter
from preparedness_turn_completer.utils import (
DEFAULT_RETRY_CONFIG,
RetryConfig,
get_model_context_window_length,
warn_about_non_empty_params,
)
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator

logger = structlog.stdlib.get_logger(component=__name__)


class ReasoningConfig(BaseModel):
"""chz-friendly wrapper around openai.types.shared_params.reasoning:Reasoning"""

effort: Literal["minimal", "low", "medium", "high"] | None = None
generate_summary: Literal["auto", "concise", "detailed"] | None = None
summary: Literal["auto", "concise", "detailed"] | None = None


class OpenAIResponsesTurnCompleter(TurnCompleter):
def __init__(
self,
model: str,
reasoning: Reasoning | None | NotGiven = NOT_GIVEN,
text_format: type[BaseModel] | NotGiven = NOT_GIVEN,
tools: Iterable[ParseableToolParam] | NotGiven = NOT_GIVEN,
tools: list[ParseableToolParam] | NotGiven = NOT_GIVEN,
temperature: float | None | NotGiven = NOT_GIVEN,
max_output_tokens: int | None | NotGiven = NOT_GIVEN,
top_p: float | None | NotGiven = NOT_GIVEN,
retry_config: RetryConfig = DEFAULT_RETRY_CONFIG,
retry_config: RetryConfig | None = None,
):
self.model = model
self.reasoning = reasoning
Expand All @@ -49,7 +56,7 @@ def __init__(
self.max_output_tokens = max_output_tokens
self.top_p = top_p
self.encoding_name: str
self.retry_config = retry_config
self.retry_config = retry_config or RetryConfig()
try:
self.encoding_name = tiktoken.encoding_name_for_model(model)
except KeyError:
Expand All @@ -71,18 +78,28 @@ class Config(TurnCompleter.Config):
)

model: str
reasoning: Reasoning | None | NotGiven = NOT_GIVEN
reasoning: ReasoningConfig | None | NotGiven = NOT_GIVEN
text_format: type[BaseModel] | NotGiven = NOT_GIVEN
tools: Iterable[ParseableToolParam] | NotGiven = NOT_GIVEN
tools: list[ParseableToolParam] | NotGiven = NOT_GIVEN
temperature: float | None | NotGiven = NOT_GIVEN
max_output_tokens: int | None | NotGiven = NOT_GIVEN
top_p: float | None | NotGiven = NOT_GIVEN
retry_config: RetryConfig = DEFAULT_RETRY_CONFIG
retry_config: RetryConfig = Field(default_factory=RetryConfig)

def build(self) -> OpenAIResponsesTurnCompleter:
reasoning_param: Reasoning | None | NotGiven
if isinstance(self.reasoning, ReasoningConfig):
reasoning_param = Reasoning(
effort=self.reasoning.effort,
generate_summary=self.reasoning.generate_summary,
summary=self.reasoning.summary,
)
else:
reasoning_param = self.reasoning

return OpenAIResponsesTurnCompleter(
model=self.model,
reasoning=self.reasoning,
reasoning=reasoning_param,
text_format=self.text_format,
tools=self.tools,
temperature=self.temperature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,17 @@
from openai.types.responses.response_output_text import (
AnnotationURLCitation as ResponsesAnnotationURLCitation,
)
from preparedness_turn_completer.oai_responses_turn_completer.type_helpers import (
from preparedness_turn_completer.turn_completer import TurnCompleter
from preparedness_turn_completer.type_helpers import (
ChatCompletionContent,
is_assistant_message,
is_chat_completion_assistant_message_param,
is_chat_completion_tool_message_param,
is_content_array_list,
is_content_part_list,
is_custom_tool_call_param,
is_function_tool_call_param,
is_text_parts_list,
is_tool_message,
)
from preparedness_turn_completer.turn_completer import TurnCompleter

logger = structlog.stdlib.get_logger(component=__name__)

Expand Down Expand Up @@ -140,19 +140,22 @@ def _user_completion_to_response_input_items(
def _assistant_completion_to_response_input_items(
message: ChatCompletionMessageParam,
) -> list[ResponseInputItemParam]:
content = message["content"]
role: Literal["assistant"] = "assistant"
input_items: list[ResponseInputItemParam] = []
assert is_assistant_message(message)
assert isinstance(content, str) or is_content_array_list(content), (
f"Expected content to be str or list of content arrays, got {content!r}"
)
if isinstance(content, str):
input_items.append(EasyInputMessageParam(content=content, role=role, type="message"))
elif is_content_array_list(content):
input_items.extend(_content_array_list_to_response_input_items(content))
else:
raise ValueError(f"Expected content to be str or list of content arrays, got {content!r}")
assert is_chat_completion_assistant_message_param(message)
if message.get("content") is not None:
content = message["content"]
assert isinstance(content, str) or is_content_array_list(content), (
f"Expected content to be str or list of content arrays, got {content!r}"
)
if isinstance(content, str):
input_items.append(EasyInputMessageParam(content=content, role=role, type="message"))
elif is_content_array_list(content):
input_items.extend(_content_array_list_to_response_input_items(content))
else:
raise ValueError(
f"Expected content to be str or list of content arrays, got {content!r}"
)

refusal = message.get("refusal", None)
if isinstance(refusal, str):
Expand Down Expand Up @@ -225,7 +228,7 @@ def _content_array_list_to_response_input_items(
def _tool_completion_to_response_input_items(
message: ChatCompletionMessageParam,
) -> list[ResponseInputItemParam]:
assert is_tool_message(message)
assert is_chat_completion_tool_message_param(message)
content = message["content"]
output_str: str
if isinstance(content, str):
Expand All @@ -235,11 +238,7 @@ def _tool_completion_to_response_input_items(

return [
FunctionCallOutput(
call_id=message["tool_call_id"],
output=output_str,
type="function_call_output",
id=uuid.uuid4().hex,
status="completed",
call_id=message["tool_call_id"], output=output_str, type="function_call_output"
)
]

Expand Down

This file was deleted.

Loading