Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,12 @@ type AgentDefinition struct {
Tools []string // tool-name allow-list for this type (empty = all parent tools)
Model string // optional model override resolved via the agent LLM factory
Temperature float32 // optional sampling temperature for this type
Iterations int // optional per-type iteration cap (0 = inherit parent)
MaxAttempts int // optional per-type attempt cap (0 = inherit parent)
MaxRetries int // optional per-type retry cap (0 = inherit parent)
// Metadata is an optional per-request metadata object for this type,
// passed to the agent LLM factory and attached to its requests.
Metadata map[string]string
Iterations int // optional per-type iteration cap (0 = inherit parent)
MaxAttempts int // optional per-type attempt cap (0 = inherit parent)
MaxRetries int // optional per-type retry cap (0 = inherit parent)
}

// findAgentDefinition returns the definition with the given name, or nil.
Expand Down Expand Up @@ -296,7 +299,7 @@ type spawnAgentRunner struct {
agentSpawnCallback func(*AgentState)
completionFormatter func(*AgentState) string
agentDefinitions []AgentDefinition
llmFactory func(model string, temperature float32) LLM
llmFactory func(model string, temperature float32, metadata map[string]string) LLM
}

func (r *spawnAgentRunner) Run(args SpawnAgentArgs) (string, any, error) {
Expand Down Expand Up @@ -508,14 +511,16 @@ func derefFragment(f *Fragment) any {
func (r *spawnAgentRunner) resolveLLM(args SpawnAgentArgs, def *AgentDefinition) LLM {
model := args.Model
var temp float32
var meta map[string]string
if def != nil {
if model == "" {
model = def.Model
}
temp = def.Temperature
meta = def.Metadata
}
if model != "" && r.llmFactory != nil {
return r.llmFactory(model, temp)
if (model != "" || len(meta) > 0) && r.llmFactory != nil {
return r.llmFactory(model, temp, meta)
}
return r.llm
}
Expand Down Expand Up @@ -587,7 +592,7 @@ func newSpawnAgentTool(
spawnCB func(*AgentState),
completionFormatter func(*AgentState) string,
defs []AgentDefinition,
llmFactory func(model string, temperature float32) LLM,
llmFactory func(model string, temperature float32, metadata map[string]string) LLM,
) ToolDefinitionInterface {
return NewToolDefinition(
&spawnAgentRunner{
Expand Down
31 changes: 28 additions & 3 deletions agent_definitions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func TestSpawnUnknownAgentTypeErrorsCleanly(t *testing.T) {
func TestFactoryResolvesModelAndTemperature(t *testing.T) {
var gotModel string
var gotTemp float32
factory := func(model string, temp float32) LLM {
factory := func(model string, temp float32, _ map[string]string) LLM {
gotModel, gotTemp = model, temp
return newInspectingLLM(func(Fragment, []string) {})
}
Expand All @@ -184,9 +184,34 @@ func TestFactoryResolvesModelAndTemperature(t *testing.T) {
}
}

func TestFactoryFiresOnMetadataOnlyOverride(t *testing.T) {
var called bool
var gotMeta map[string]string
factory := func(_ string, _ float32, meta map[string]string) LLM {
called = true
gotMeta = meta
return newInspectingLLM(func(Fragment, []string) {})
}
defs := []AgentDefinition{{Name: "nothink", Metadata: map[string]string{"enable_thinking": "false"}}}
runner := &spawnAgentRunner{
llm: newInspectingLLM(func(Fragment, []string) {}),
manager: NewAgentManager(),
ctx: context.Background(),
agentDefinitions: defs,
llmFactory: factory,
}
_, _, _ = runner.Run(SpawnAgentArgs{AgentType: "nothink", Task: "x", Background: false})
if !called {
t.Fatal("factory should fire for a metadata-only override")
}
if gotMeta["enable_thinking"] != "false" {
t.Fatalf("factory got metadata %v, want enable_thinking=false", gotMeta)
}
}

func TestSpawnArgModelBeatsDefinition(t *testing.T) {
var gotModel string
factory := func(model string, temp float32) LLM {
factory := func(model string, temp float32, _ map[string]string) LLM {
gotModel = model
return newInspectingLLM(func(Fragment, []string) {})
}
Expand All @@ -203,7 +228,7 @@ func TestSpawnArgModelBeatsDefinition(t *testing.T) {

func TestWithAgentLLMFactoryStores(t *testing.T) {
o := defaultOptions()
o.Apply(WithAgentLLMFactory(func(string, float32) LLM { return nil }))
o.Apply(WithAgentLLMFactory(func(string, float32, map[string]string) LLM { return nil }))
if o.agentLLMFactory == nil {
t.Fatal("factory not stored")
}
Expand Down
15 changes: 15 additions & 0 deletions clients/openai_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@ type OpenAIClient struct {
model string
client *openai.Client
temperature float32
metadata map[string]string
}

// OpenAIOptions carries optional per-client settings.
type OpenAIOptions struct {
Temperature float32
// Metadata is attached verbatim to every chat-completion request as the
// OpenAI "metadata" object. Backends such as LocalAI use it to carry
// per-request flags, e.g. {"enable_thinking": "false"} to disable reasoning.
Metadata map[string]string
}

func NewOpenAILLM(model, apiKey, baseURL string) *OpenAIClient {
Expand All @@ -34,6 +39,7 @@ func NewOpenAILLMWithOptions(model, apiKey, baseURL string, opts OpenAIOptions)
model: model,
client: client,
temperature: opts.Temperature,
metadata: opts.Metadata,
}
}

Expand All @@ -54,6 +60,9 @@ func (llm *OpenAIClient) Ask(ctx context.Context, f cogito.Fragment) (cogito.Fra
if llm.temperature != 0 {
req.Temperature = llm.temperature
}
if len(llm.metadata) > 0 {
req.Metadata = llm.metadata
}

resp, err := llm.client.CreateChatCompletion(ctx, req)

Expand Down Expand Up @@ -83,6 +92,9 @@ func (llm *OpenAIClient) Ask(ctx context.Context, f cogito.Fragment) (cogito.Fra
}
func (llm *OpenAIClient) CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (cogito.LLMReply, cogito.LLMUsage, error) {
request.Model = llm.model
if len(llm.metadata) > 0 {
request.Metadata = llm.metadata
}
response, err := llm.client.CreateChatCompletion(ctx, request)
if err != nil {
return cogito.LLMReply{}, cogito.LLMUsage{}, err
Expand All @@ -107,6 +119,9 @@ func (llm *OpenAIClient) CreateChatCompletionStream(ctx context.Context, request
if llm.temperature != 0 {
request.Temperature = llm.temperature
}
if len(llm.metadata) > 0 {
request.Metadata = llm.metadata
}

stream, err := llm.client.CreateChatCompletionStream(ctx, request)
if err != nil {
Expand Down
50 changes: 49 additions & 1 deletion clients/openai_client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,54 @@
package clients

import "testing"
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/sashabaranov/go-openai"
)

func TestNewOpenAILLMWithOptionsSetsMetadata(t *testing.T) {
llm := NewOpenAILLMWithOptions("m", "k", "http://localhost", OpenAIOptions{
Metadata: map[string]string{"enable_thinking": "false"},
})
if llm.metadata["enable_thinking"] != "false" {
t.Fatalf("expected metadata enable_thinking=false, got %v", llm.metadata)
}
}

// TestCreateChatCompletionSendsMetadata verifies the configured metadata is
// serialized into the outgoing request body as the OpenAI "metadata" object.
func TestCreateChatCompletionSendsMetadata(t *testing.T) {
var gotMetadata map[string]string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
var req struct {
Metadata map[string]string `json:"metadata"`
}
_ = json.Unmarshal(body, &req)
gotMetadata = req.Metadata
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"ok"}}]}`))
}))
defer srv.Close()

llm := NewOpenAILLMWithOptions("m", "k", srv.URL+"/v1", OpenAIOptions{
Metadata: map[string]string{"enable_thinking": "false"},
})
_, _, err := llm.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "hi"}},
})
if err != nil {
t.Fatalf("CreateChatCompletion: %v", err)
}
if gotMetadata["enable_thinking"] != "false" {
t.Fatalf("request metadata = %v, want enable_thinking=false", gotMetadata)
}
}

func TestNewOpenAILLMWithOptionsSetsTemperature(t *testing.T) {
llm := NewOpenAILLMWithOptions("m", "k", "http://localhost", OpenAIOptions{Temperature: 0.7})
Expand Down
9 changes: 5 additions & 4 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ type Options struct {
agentSpawnCallback func(*AgentState)
agentCompletionFormatter func(*AgentState) string
agentDefinitions []AgentDefinition
agentLLMFactory func(model string, temperature float32) LLM
agentLLMFactory func(model string, temperature float32, metadata map[string]string) LLM
}

type Option func(*Options)
Expand Down Expand Up @@ -530,9 +530,10 @@ func WithAgentDefinitions(defs ...AgentDefinition) Option {
}

// WithAgentLLMFactory sets a factory that builds an LLM for a sub-agent from a
// model name and temperature. Used to resolve per-agent-type or per-spawn model
// overrides while reusing the parent's endpoint/credentials.
func WithAgentLLMFactory(fn func(model string, temperature float32) LLM) Option {
// model name, temperature, and per-request metadata. Used to resolve
// per-agent-type or per-spawn model/metadata overrides while reusing the
// parent's endpoint/credentials.
func WithAgentLLMFactory(fn func(model string, temperature float32, metadata map[string]string) LLM) Option {
return func(o *Options) {
o.agentLLMFactory = fn
}
Expand Down
Loading