diff --git a/agent.go b/agent.go index b6e2775..421ed3e 100644 --- a/agent.go +++ b/agent.go @@ -3,6 +3,7 @@ package cogito import ( "context" "fmt" + "strings" "sync" "github.com/google/uuid" @@ -19,13 +20,39 @@ const ( ) // agentToolNames are the names of the built-in agent management tools. -var agentToolNames = []string{"spawn_agent", "check_agent", "get_agent_result"} +var agentToolNames = []string{"spawn_agent", "check_agent", "get_agent_result", "send_agent_message"} // SpawnAgentArgs are the arguments the LLM provides when spawning a sub-agent. type SpawnAgentArgs struct { + AgentType string `json:"agent_type" description:"Optional named agent type to use (persona/system prompt/tools/model). If empty, a generic sub-agent is used."` Task string `json:"task" description:"The task or prompt for the sub-agent to execute"` Background bool `json:"background" description:"If true, the agent runs in the background and returns an ID immediately. If false, blocks until the agent completes."` - Tools []string `json:"tools" description:"Optional subset of tool names available to the sub-agent. If empty, all parent tools (except agent tools) are given."` + Tools []string `json:"tools" description:"Optional subset of tool names available to the sub-agent. If empty, the agent type's tools (or all parent tools) are used."` + Model string `json:"model" description:"Optional model override for this sub-agent."` +} + +// AgentDefinition is a named sub-agent "type" (persona). The embedder registers +// definitions via WithAgentDefinitions; spawn_agent selects one by Name. +type AgentDefinition struct { + Name string // unique identifier referenced by spawn_agent.agent_type + Description string // shown to the LLM in the spawn tool description + SystemPrompt string // seeded as the sub-agent's first system message + Tools []string // tool-name allow-list for this type (empty = all parent tools) + Model string // optional model override resolved via the agent LLM factory + Temperature float32 // optional sampling temperature for this type + Iterations int // optional per-type iteration cap (0 = inherit parent) + MaxAttempts int // optional per-type attempt cap (0 = inherit parent) + MaxRetries int // optional per-type retry cap (0 = inherit parent) +} + +// findAgentDefinition returns the definition with the given name, or nil. +func findAgentDefinition(defs []AgentDefinition, name string) *AgentDefinition { + for i := range defs { + if defs[i].Name == name { + return &defs[i] + } + } + return nil } // CheckAgentArgs are the arguments for checking a background agent's status. @@ -43,12 +70,19 @@ type GetAgentResultArgs struct { type AgentState struct { ID string Task string + Type string // requested agent type name (empty for generic) Status AgentStatusType Result string Fragment *Fragment Error error Cancel context.CancelFunc done chan struct{} + inject chan openai.ChatCompletionMessage + // detach, when non-nil, lets an embedder promote a running foreground + // agent to the background: a non-blocking send here unblocks the + // spawn_agent call so it returns the agent ID while the goroutine keeps + // running. Background agents leave this nil (they are already detached). + detach chan struct{} } // AgentManager is a thread-safe registry of background sub-agents. @@ -110,6 +144,44 @@ func (m *AgentManager) Wait(id string) (*AgentState, error) { return agent, nil } +// Inject pushes a user-role follow-up message into a running agent's loop. +// Returns an error if the agent is unknown or has no injection channel. +func (m *AgentManager) Inject(id, message string) error { + m.mu.RLock() + a, ok := m.agents[id] + m.mu.RUnlock() + if !ok { + return fmt.Errorf("agent %s not found", id) + } + if a.inject == nil { + return fmt.Errorf("agent %s does not accept injections", id) + } + a.inject <- openai.ChatCompletionMessage{Role: "user", Content: message} + return nil +} + +// Detach promotes a running foreground agent to background. The blocked +// spawn_agent call returns immediately with the agent ID; the agent's goroutine +// keeps running and the agent becomes an ordinary background agent. Returns an +// error if the agent is unknown or not detachable (already-background agents +// carry a nil detach channel). +func (m *AgentManager) Detach(id string) error { + m.mu.RLock() + a, ok := m.agents[id] + m.mu.RUnlock() + if !ok { + return fmt.Errorf("agent %s not found", id) + } + if a.detach == nil { + return fmt.Errorf("agent %s is not detachable", id) + } + select { + case a.detach <- struct{}{}: + default: + } + return nil +} + // isAgentTool returns true if the tool name is one of the built-in agent tools. func isAgentTool(name string) bool { for _, n := range agentToolNames { @@ -187,6 +259,25 @@ func formatAgentCompletion(a *AgentState, formatter func(*AgentState) string) st return fmt.Sprintf("Background agent %s has failed.\nTask: %s\nError: %v", a.ID, a.Task, a.Error) } +// withAgentIDStamp wraps the option set so that, when ExecuteTools invokes the +// tool-call callback, SessionState.AgentID carries the given sub-agent id. It +// composes with the propagated parent callback rather than replacing it: if no +// callback is set, it is a no-op. +func withAgentIDStamp(id string) Option { + return func(o *Options) { + inner := o.toolCallCallback + if inner == nil { + return + } + o.toolCallCallback = func(tc *ToolChoice, st *SessionState) ToolCallDecision { + if st != nil { + st.AgentID = id + } + return inner(tc, st) + } + } +} + // spawnAgentRunner implements Tool[SpawnAgentArgs]. type spawnAgentRunner struct { llm LLM @@ -197,11 +288,28 @@ type spawnAgentRunner struct { streamCB StreamCallback messageInjectionChan chan openai.ChatCompletionMessage agentCompletionCallback func(*AgentState) + agentSpawnCallback func(*AgentState) completionFormatter func(*AgentState) string + agentDefinitions []AgentDefinition + llmFactory func(model string, temperature float32) LLM } func (r *spawnAgentRunner) Run(args SpawnAgentArgs) (string, any, error) { - subTools := FilterToolsForSubAgent(r.parentTools, args.Tools) + // Resolve the named agent definition (persona), if one was requested. + var def *AgentDefinition + if args.AgentType != "" { + def = findAgentDefinition(r.agentDefinitions, args.AgentType) + if def == nil { + return fmt.Sprintf("Cannot spawn: unknown agent type %q", args.AgentType), nil, nil + } + } + + // Resolve the tool allow-list: explicit spawn arg > definition tools > all parent tools. + requestedTools := args.Tools + if len(requestedTools) == 0 && def != nil { + requestedTools = def.Tools + } + subTools := FilterToolsForSubAgent(r.parentTools, requestedTools) subOpts := append([]Option{}, WithTools(subTools...), @@ -209,89 +317,195 @@ func (r *spawnAgentRunner) Run(args SpawnAgentArgs) (string, any, error) { ) subOpts = append(subOpts, r.parentOpts...) - subFragment := NewFragment( - openai.ChatCompletionMessage{Role: "user", Content: args.Task}, - ) + // Per-type execution limits override the propagated parent limits. + if def != nil { + if def.Iterations > 0 { + subOpts = append(subOpts, WithIterations(def.Iterations)) + } + if def.MaxAttempts > 0 { + subOpts = append(subOpts, WithMaxAttempts(def.MaxAttempts)) + } + if def.MaxRetries > 0 { + subOpts = append(subOpts, WithMaxRetries(def.MaxRetries)) + } + } + + // Seed the system prompt from the definition. + var subFragment Fragment + if def != nil && def.SystemPrompt != "" { + subFragment = NewFragment( + openai.ChatCompletionMessage{Role: "system", Content: def.SystemPrompt}, + openai.ChatCompletionMessage{Role: "user", Content: args.Task}, + ) + } else { + subFragment = NewFragment( + openai.ChatCompletionMessage{Role: "user", Content: args.Task}, + ) + } + + // Resolve the LLM (model/temperature) for this sub-agent. + subLLM := r.resolveLLM(args, def) + + agentID := uuid.New().String() + subCtx, cancel := context.WithCancel(r.ctx) if !args.Background { - // Foreground: execute synchronously + // Foreground: register the agent and run it in a goroutine so the + // embedder can promote it to the background (detach). When no detach + // fires we behave exactly like the old synchronous path: block on + // agent.done and return agent.Result (== result.LastMessage().Content). + agent := &AgentState{ + ID: agentID, + Task: args.Task, + Type: args.AgentType, + Status: AgentStatusRunning, + Cancel: cancel, + done: make(chan struct{}), + inject: make(chan openai.ChatCompletionMessage, 8), + detach: make(chan struct{}, 1), + } + r.manager.Register(agent) + if r.agentSpawnCallback != nil { + r.agentSpawnCallback(agent) + } + + fgOpts := append([]Option{}, subOpts...) + fgOpts = append(fgOpts, withAgentIDStamp(agentID)) + fgOpts = append(fgOpts, WithMessageInjectionChan(agent.inject)) + fgOpts = append(fgOpts, WithContext(subCtx)) if r.streamCB != nil { - subOpts = append(subOpts, WithStreamCallback(r.streamCB)) + fgOpts = append(fgOpts, WithStreamCallback(r.streamCB)) } - result, err := ExecuteTools(r.llm, subFragment, subOpts...) - if err != nil { - return fmt.Sprintf("Sub-agent failed: %v", err), nil, nil + + go r.runAgent(agent, subLLM, subFragment, fgOpts, cancel) + + select { + case <-agent.done: + // Completed before any detach: behave like the old synchronous path. + r.manager.mu.RLock() + defer r.manager.mu.RUnlock() + if agent.Status == AgentStatusFailed { + return fmt.Sprintf("Sub-agent failed: %v", agent.Error), nil, nil + } + return agent.Result, derefFragment(agent.Fragment), nil + case <-agent.detach: + // Promoted to background: return the ID, leave the goroutine running. + return fmt.Sprintf("Agent detached to background with ID: %s", agentID), agentID, nil + case <-r.ctx.Done(): + cancel() + return "Sub-agent cancelled", nil, r.ctx.Err() } - msg := result.LastMessage().Content - return msg, result, nil } - // Background: launch goroutine, return ID immediately - agentID := uuid.New().String() + // Background: launch goroutine, return ID immediately. agent := &AgentState{ ID: agentID, Task: args.Task, + Type: args.AgentType, Status: AgentStatusRunning, + Cancel: cancel, done: make(chan struct{}), + inject: make(chan openai.ChatCompletionMessage, 8), } r.manager.Register(agent) + if r.agentSpawnCallback != nil { + r.agentSpawnCallback(agent) + } - subCtx, cancel := context.WithCancel(r.ctx) - agent.Cancel = cancel + bgOpts := append([]Option{}, subOpts...) + // Stamp the real registry ID so sub-agent tool calls route through the + // parent callback with the correct AgentID (matching the foreground path). + bgOpts = append(bgOpts, withAgentIDStamp(agentID)) + // Give the running sub-agent its own injection channel so a follow-up + // message (via AgentManager.Inject / send_agent_message) reaches its loop. + bgOpts = append(bgOpts, WithMessageInjectionChan(agent.inject)) - // Wrap stream callback to tag events with agent ID + // Wrap stream callback to tag events with agent ID. if r.streamCB != nil { parentCB := r.streamCB - subOpts = append(subOpts, WithStreamCallback(func(ev StreamEvent) { + bgOpts = append(bgOpts, WithStreamCallback(func(ev StreamEvent) { ev.AgentID = agentID ev.Type = StreamEventSubAgent parentCB(ev) })) } - // Override context for sub-agent - subOpts = append(subOpts, WithContext(subCtx)) + // Override context for sub-agent. + bgOpts = append(bgOpts, WithContext(subCtx)) - go func() { - defer close(agent.done) - defer cancel() + go r.runAgent(agent, subLLM, subFragment, bgOpts, cancel) - result, err := ExecuteTools(r.llm, subFragment, subOpts...) + return fmt.Sprintf("Agent spawned in background with ID: %s", agentID), agentID, nil +} - r.manager.mu.Lock() - if err != nil { - agent.Status = AgentStatusFailed - agent.Error = err - agent.Result = fmt.Sprintf("Failed: %v", err) - } else { - agent.Status = AgentStatusCompleted - agent.Result = result.LastMessage().Content - agent.Fragment = &result - } - r.manager.mu.Unlock() +// runAgent executes a sub-agent to completion and records its terminal state, +// firing the completion callback and injecting a completion notification into +// the parent loop. Shared by the foreground (detachable) and background spawn +// branches so the lifecycle bookkeeping lives in one place. +func (r *spawnAgentRunner) runAgent(agent *AgentState, llm LLM, frag Fragment, opts []Option, cancel context.CancelFunc) { + defer close(agent.done) + defer cancel() + + result, err := ExecuteTools(llm, frag, opts...) + + r.manager.mu.Lock() + if err != nil { + agent.Status = AgentStatusFailed + agent.Error = err + agent.Result = fmt.Sprintf("Failed: %v", err) + } else { + agent.Status = AgentStatusCompleted + agent.Result = result.LastMessage().Content + agent.Fragment = &result + } + r.manager.mu.Unlock() - // Fire completion callback - if r.agentCompletionCallback != nil { - r.agentCompletionCallback(agent) - } + // Fire completion callback. + if r.agentCompletionCallback != nil { + r.agentCompletionCallback(agent) + } - // Inject completion notification into parent's loop. The content - // is built by formatAgentCompletion so an embedder can override - // it via WithAgentCompletionFormatter (see helper docs). - if r.messageInjectionChan != nil { - content := formatAgentCompletion(agent, r.completionFormatter) - select { - case r.messageInjectionChan <- openai.ChatCompletionMessage{ - Role: "user", - Content: content, - }: - default: - // Non-blocking: if the channel is full or closed, skip notification - } + // Inject completion notification into parent's loop. The content is built + // by formatAgentCompletion so an embedder can override it via + // WithAgentCompletionFormatter (see helper docs). + if r.messageInjectionChan != nil { + content := formatAgentCompletion(agent, r.completionFormatter) + select { + case r.messageInjectionChan <- openai.ChatCompletionMessage{ + Role: "user", + Content: content, + }: + default: + // Non-blocking: if the channel is full or closed, skip notification. } - }() + } +} - return fmt.Sprintf("Agent spawned in background with ID: %s", agentID), agentID, nil +// derefFragment returns the pointed-to Fragment as an any, or nil if the +// pointer is nil. Used by the foreground branch to return the completed +// sub-agent's fragment in the same shape the old synchronous path did. +func derefFragment(f *Fragment) any { + if f == nil { + return nil + } + return *f +} + +// resolveLLM picks the LLM for a sub-agent. Order: spawn-arg model > definition +// model/temperature via the factory > parent LLM. Fully wired in Task A6. +func (r *spawnAgentRunner) resolveLLM(args SpawnAgentArgs, def *AgentDefinition) LLM { + model := args.Model + var temp float32 + if def != nil { + if model == "" { + model = def.Model + } + temp = def.Temperature + } + if model != "" && r.llmFactory != nil { + return r.llmFactory(model, temp) + } + return r.llm } // checkAgentRunner implements Tool[CheckAgentArgs]. @@ -358,7 +572,10 @@ func newSpawnAgentTool( streamCB StreamCallback, injectionChan chan openai.ChatCompletionMessage, completionCB func(*AgentState), + spawnCB func(*AgentState), completionFormatter func(*AgentState) string, + defs []AgentDefinition, + llmFactory func(model string, temperature float32) LLM, ) ToolDefinitionInterface { return NewToolDefinition( &spawnAgentRunner{ @@ -370,14 +587,38 @@ func newSpawnAgentTool( streamCB: streamCB, messageInjectionChan: injectionChan, agentCompletionCallback: completionCB, + agentSpawnCallback: spawnCB, completionFormatter: completionFormatter, + agentDefinitions: defs, + llmFactory: llmFactory, }, SpawnAgentArgs{}, "spawn_agent", - "Spawn a sub-agent to handle a task. Use background=true for non-blocking execution, or background=false to wait for the result.", + spawnToolDescription(defs), ) } +// spawnToolDescription enumerates available agent types so the LLM can choose one. +func spawnToolDescription(defs []AgentDefinition) string { + base := "Spawn a sub-agent to handle a task. Use background=true for non-blocking execution, or background=false to wait for the result." + if len(defs) == 0 { + return base + } + var b strings.Builder + b.WriteString(base) + b.WriteString(" Available agent_type values: ") + for i, d := range defs { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(d.Name) + if d.Description != "" { + b.WriteString(" (" + d.Description + ")") + } + } + return b.String() +} + // newCheckAgentTool creates the check_agent tool definition. func newCheckAgentTool(manager *AgentManager) ToolDefinitionInterface { return NewToolDefinition( @@ -397,3 +638,60 @@ func newGetAgentResultTool(manager *AgentManager, ctx context.Context) ToolDefin "Get the result of a background sub-agent. Set wait=true to block until the agent finishes.", ) } + +// SendAgentMessageArgs is the argument for the unified resume/inject tool. +type SendAgentMessageArgs struct { + AgentID string `json:"agent_id" description:"The ID of the agent to message"` + Message string `json:"message" description:"The follow-up message. Injected live if the agent is running, or re-runs the agent with prior context if it has finished."` +} + +// sendAgentMessageRunner implements Tool[SendAgentMessageArgs]. It either injects +// a live message into a running agent or re-runs a finished agent from its prior +// context with the new message appended. +type sendAgentMessageRunner struct { + manager *AgentManager + ctx context.Context + llm LLM + subOpts []Option +} + +func (r *sendAgentMessageRunner) Run(args SendAgentMessageArgs) (string, any, error) { + agent, ok := r.manager.Get(args.AgentID) + if !ok { + return fmt.Sprintf("Agent %s not found", args.AgentID), nil, nil + } + + if agent.Status == AgentStatusRunning { + if err := r.manager.Inject(args.AgentID, args.Message); err != nil { + return fmt.Sprintf("Could not message agent %s: %v", args.AgentID, err), nil, nil + } + return fmt.Sprintf("Message delivered to running agent %s.", args.AgentID), nil, nil + } + + // Completed/failed: resume by appending the message to the stored fragment and re-running. + if agent.Fragment == nil { + return fmt.Sprintf("Agent %s has no stored context to resume", args.AgentID), nil, nil + } + resumed := agent.Fragment.AddMessage(UserMessageRole, args.Message) + opts := append([]Option{WithContext(r.ctx)}, r.subOpts...) + result, err := ExecuteTools(r.llm, resumed, opts...) + if err != nil { + return fmt.Sprintf("Resume of agent %s failed: %v", args.AgentID, err), nil, nil + } + r.manager.mu.Lock() + agent.Status = AgentStatusCompleted + agent.Result = result.LastMessage().Content + agent.Fragment = &result + r.manager.mu.Unlock() + return agent.Result, result, nil +} + +// newSendAgentMessageTool creates the send_agent_message tool definition. +func newSendAgentMessageTool(manager *AgentManager, ctx context.Context, llm LLM, subOpts []Option) ToolDefinitionInterface { + return NewToolDefinition( + &sendAgentMessageRunner{manager: manager, ctx: ctx, llm: llm, subOpts: subOpts}, + SendAgentMessageArgs{}, + "send_agent_message", + "Send a follow-up message to a sub-agent. If it is still running the message is injected live; if it has finished, the agent resumes from its prior context.", + ) +} diff --git a/agent_definitions_test.go b/agent_definitions_test.go new file mode 100644 index 0000000..980bb30 --- /dev/null +++ b/agent_definitions_test.go @@ -0,0 +1,210 @@ +package cogito + +import ( + "context" + "strings" + "sync" + "testing" + + "github.com/sashabaranov/go-openai" +) + +// newNamedTool builds an echo-style tool with the given name. It reuses the +// echoRunner machinery from agent_propagation_test.go (mutex+counter) but the +// counter is throwaway here — the tool only needs to be present in the parent +// set so we can assert the sub-agent's allow-list excludes it. +func newNamedTool(name string) ToolDefinitionInterface { + var mu sync.Mutex + count := 0 + return NewToolDefinition[EchoArgs]( + echoRunner{mu: &mu, count: &count}, + EchoArgs{}, + name, + name+" tool", + ) +} + +// inspectingLLM records the fragment (messages) and tool set the sub-agent runs +// with on its first tool-selection turn, then replies plainly so the loop +// terminates without executing any tool. Tools reach the LLM via the +// ChatCompletionRequest's Tools field (tools.ToOpenAI()); the system prompt and +// task reach it via the request's Messages. We reconstruct a Fragment from the +// messages and capture tool names from the request so the assertions can match +// the plan's intent (system prompt seeded; excluded tool absent). +type inspectingLLM struct { + mu sync.Mutex + fn func(f Fragment, toolNames []string) +} + +// newInspectingLLM builds an LLM whose first CreateChatCompletion call invokes +// fn with the fragment + tool names it was asked to choose from. +func newInspectingLLM(fn func(f Fragment, toolNames []string)) *inspectingLLM { + return &inspectingLLM{fn: fn} +} + +func (m *inspectingLLM) Ask(_ context.Context, f Fragment) (Fragment, error) { + return f.AddMessage(AssistantMessageRole, "done"), nil +} + +func (m *inspectingLLM) CreateChatCompletion(_ context.Context, req openai.ChatCompletionRequest) (LLMReply, LLMUsage, error) { + m.mu.Lock() + if m.fn != nil { + var names []string + for _, t := range req.Tools { + if t.Function != nil { + names = append(names, t.Function.Name) + } + } + m.fn(NewFragment(req.Messages...), names) + m.fn = nil // record only the first selection turn + } + m.mu.Unlock() + // No tool call: a plain assistant message so the loop terminates. + return LLMReply{ChatCompletionResponse: openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{Role: AssistantMessageRole.String(), Content: "done"}, + }}, + }}, LLMUsage{}, nil +} + +// firstSystemContent returns the content of the first system message in f, or "". +func firstSystemContent(f Fragment) string { + for _, msg := range f.GetMessages() { + if msg.Role == SystemMessageRole.String() { + return msg.Content + } + } + return "" +} + +func contains(haystack []string, needle string) bool { + for _, s := range haystack { + if s == needle { + return true + } + } + return false +} + +func TestWithAgentDefinitionsStoresDefs(t *testing.T) { + defs := []AgentDefinition{ + {Name: "explore", Description: "read-only exploration", + SystemPrompt: "You explore.", Tools: []string{"echo"}, + Model: "small-model", Temperature: 0.2, + Iterations: 20, MaxAttempts: 2, MaxRetries: 2}, + } + o := defaultOptions() + o.Apply(WithAgentDefinitions(defs...)) + if len(o.agentDefinitions) != 1 || o.agentDefinitions[0].Name != "explore" { + t.Fatalf("agent definitions not stored: %+v", o.agentDefinitions) + } +} + +func TestFindAgentDefinition(t *testing.T) { + defs := []AgentDefinition{{Name: "plan"}, {Name: "explore"}} + if d := findAgentDefinition(defs, "explore"); d == nil || d.Name != "explore" { + t.Fatalf("expected to find explore, got %+v", d) + } + if d := findAgentDefinition(defs, "missing"); d != nil { + t.Fatalf("expected nil for missing type, got %+v", d) + } +} + +func TestSpawnAppliesDefinitionSystemPromptAndTools(t *testing.T) { + var emu sync.Mutex + ecount := 0 + echo := newEchoTool(&emu, &ecount) + secret := newNamedTool("secret") // the explore type must NOT receive this + defs := []AgentDefinition{{ + Name: "explore", SystemPrompt: "You are EXPLORE.", + Tools: []string{"echo"}, Iterations: 7, + }} + + var gotSystem string + var gotToolNames []string + llm := newInspectingLLM(func(f Fragment, toolNames []string) { + gotSystem = firstSystemContent(f) + gotToolNames = append(gotToolNames, toolNames...) + }) + + runner := &spawnAgentRunner{ + llm: llm, + parentTools: Tools{echo, secret}, + manager: NewAgentManager(), + ctx: context.Background(), + agentDefinitions: defs, + } + _, _, _ = runner.Run(SpawnAgentArgs{AgentType: "explore", Task: "look around", Background: false}) + + if gotSystem != "You are EXPLORE." { + t.Fatalf("definition system prompt not seeded, got %q", gotSystem) + } + if contains(gotToolNames, "secret") { + t.Fatalf("explore must not receive 'secret' tool, got %v", gotToolNames) + } + if !contains(gotToolNames, "echo") { + t.Fatalf("explore should receive 'echo' tool, got %v", gotToolNames) + } +} + +func TestSpawnUnknownAgentTypeErrorsCleanly(t *testing.T) { + runner := &spawnAgentRunner{ + llm: newInspectingLLM(func(Fragment, []string) {}), + manager: NewAgentManager(), + ctx: context.Background(), + agentDefinitions: []AgentDefinition{{Name: "explore"}}, + } + out, _, err := runner.Run(SpawnAgentArgs{AgentType: "nope", Task: "x", Background: false}) + if err != nil { + t.Fatalf("unknown type should not hard-error, got %v", err) + } + if !strings.Contains(out, "unknown agent type") { + t.Fatalf("expected a clear message, got %q", out) + } +} + +func TestFactoryResolvesModelAndTemperature(t *testing.T) { + var gotModel string + var gotTemp float32 + factory := func(model string, temp float32) LLM { + gotModel, gotTemp = model, temp + return newInspectingLLM(func(Fragment, []string) {}) + } + defs := []AgentDefinition{{Name: "cheap", Model: "small", Temperature: 0.3}} + runner := &spawnAgentRunner{ + llm: newInspectingLLM(func(Fragment, []string) {}), + manager: NewAgentManager(), + ctx: context.Background(), + agentDefinitions: defs, + llmFactory: factory, + } + _, _, _ = runner.Run(SpawnAgentArgs{AgentType: "cheap", Task: "x", Background: false}) + if gotModel != "small" || gotTemp != 0.3 { + t.Fatalf("factory got (%q,%v), want (small,0.3)", gotModel, gotTemp) + } +} + +func TestSpawnArgModelBeatsDefinition(t *testing.T) { + var gotModel string + factory := func(model string, temp float32) LLM { + gotModel = model + return newInspectingLLM(func(Fragment, []string) {}) + } + defs := []AgentDefinition{{Name: "cheap", Model: "small"}} + runner := &spawnAgentRunner{ + llm: newInspectingLLM(func(Fragment, []string) {}), manager: NewAgentManager(), + ctx: context.Background(), agentDefinitions: defs, llmFactory: factory, + } + _, _, _ = runner.Run(SpawnAgentArgs{AgentType: "cheap", Model: "big", Task: "x", Background: false}) + if gotModel != "big" { + t.Fatalf("spawn-arg model should win, got %q", gotModel) + } +} + +func TestWithAgentLLMFactoryStores(t *testing.T) { + o := defaultOptions() + o.Apply(WithAgentLLMFactory(func(string, float32) LLM { return nil })) + if o.agentLLMFactory == nil { + t.Fatal("factory not stored") + } +} diff --git a/agent_detach_test.go b/agent_detach_test.go new file mode 100644 index 0000000..b36ee9c --- /dev/null +++ b/agent_detach_test.go @@ -0,0 +1,132 @@ +package cogito + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/sashabaranov/go-openai" +) + +// blockingLLM blocks in CreateChatCompletion until release is closed, then +// returns a plain (tool-less) reply so the agent loop terminates. It models a +// long-running foreground sub-agent that an embedder can promote to background. +type blockingLLM struct { + release chan struct{} + reply string +} + +// newBlockingLLM builds an LLM whose tool-selection turn blocks on <-release +// before returning a sink reply. Used by the detach test to keep a foreground +// agent in-flight while the parent detaches it. +func newBlockingLLM(release chan struct{}) *blockingLLM { + return &blockingLLM{release: release, reply: "blocked done"} +} + +func (m *blockingLLM) Ask(_ context.Context, f Fragment) (Fragment, error) { + return f.AddMessage(AssistantMessageRole, m.reply), nil +} + +func (m *blockingLLM) CreateChatCompletion(ctx context.Context, _ openai.ChatCompletionRequest) (LLMReply, LLMUsage, error) { + select { + case <-m.release: + case <-ctx.Done(): + return LLMReply{}, LLMUsage{}, ctx.Err() + } + return LLMReply{ChatCompletionResponse: openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{Role: AssistantMessageRole.String(), Content: m.reply}, + }}, + }}, LLMUsage{}, nil +} + +func TestDetachReturnsBeforeCompletion(t *testing.T) { + m := NewAgentManager() + release := make(chan struct{}) + // An LLM that blocks until released, simulating a long-running foreground agent. + llm := newBlockingLLM(release) + + runner := &spawnAgentRunner{ + llm: llm, manager: m, ctx: context.Background(), + } + + type res struct { + out string + id any + } + resCh := make(chan res, 1) + go func() { + out, id, _ := runner.Run(SpawnAgentArgs{Task: "long job", Background: false}) + resCh <- res{out, id} + }() + + // Wait for the foreground agent to register, then detach it. + var id string + deadline := time.After(2 * time.Second) + for { + agents := m.List() + if len(agents) == 1 { + id = agents[0].ID + break + } + select { + case <-deadline: + t.Fatal("foreground agent never registered") + case <-time.After(10 * time.Millisecond): + } + } + + if err := m.Detach(id); err != nil { + t.Fatalf("detach errored: %v", err) + } + + select { + case r := <-resCh: + if r.id == nil { + t.Fatal("expected detach to return the agent id") + } + case <-time.After(2 * time.Second): + t.Fatal("Run did not return after detach") + } + + // The goroutine is still running; release it so the test can clean up. + close(release) +} + +func TestDetachUnknownAgentErrors(t *testing.T) { + m := NewAgentManager() + if err := m.Detach("missing"); err == nil { + t.Fatal("expected error for unknown agent") + } +} + +func TestDetachNonDetachableAgentErrors(t *testing.T) { + m := NewAgentManager() + // A background-style agent has no detach channel: it is already detached. + agent := &AgentState{ID: "bg1", Status: AgentStatusRunning, done: make(chan struct{})} + m.Register(agent) + if err := m.Detach("bg1"); err == nil { + t.Fatal("expected error for non-detachable agent") + } +} + +// TestForegroundSpawnUnchangedWithoutDetach is the key safety property: when no +// detach ever fires, the foreground select returns on agent.done and yields the +// same result content the old synchronous path returned via +// result.LastMessage().Content. The plan sketch used +// newScriptedLLM(scriptReply("final answer")); that signature does not exist in +// this repo, so we use the reply-only newReplyLLM helper (added in A8), which +// terminates the loop immediately with the fixed reply. +func TestForegroundSpawnUnchangedWithoutDetach(t *testing.T) { + m := NewAgentManager() + llm := newReplyLLM("final answer") + runner := &spawnAgentRunner{llm: llm, manager: m, ctx: context.Background()} + out, _, err := runner.Run(SpawnAgentArgs{Task: "quick", Background: false}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(out, "final answer") { + t.Fatalf("foreground result changed, got %q", out) + } +} diff --git a/agent_propagation_test.go b/agent_propagation_test.go new file mode 100644 index 0000000..d4cd4fc --- /dev/null +++ b/agent_propagation_test.go @@ -0,0 +1,174 @@ +package cogito + +import ( + "context" + "sync" + "testing" + + "github.com/sashabaranov/go-openai" +) + +func TestSessionStateHasAgentID(t *testing.T) { + s := SessionState{AgentID: "abc-123"} + if s.AgentID != "abc-123" { + t.Fatalf("expected AgentID to round-trip, got %q", s.AgentID) + } +} + +// scriptedLLM is a minimal internal LLM mock for sub-agent tests. It drives a +// single tool call (via CreateChatCompletion) then a plain reply (via Ask), +// mirroring how the public mock.MockOpenAIClient is scripted in agent_test.go. +// We use an internal mock because these tests live in package cogito and need +// to construct spawnAgentRunner directly. +type scriptedLLM struct { + mu sync.Mutex + toolName string + toolArgs string + reply string + called bool +} + +// newScriptedLLM builds an LLM that on its first tool-selection call selects +// toolName(toolArgs), and on the subsequent (post-tool) turn replies plainly. +func newScriptedLLM(toolName, toolArgs, reply string) *scriptedLLM { + return &scriptedLLM{toolName: toolName, toolArgs: toolArgs, reply: reply} +} + +func (m *scriptedLLM) Ask(_ context.Context, f Fragment) (Fragment, error) { + return f.AddMessage(AssistantMessageRole, m.reply), nil +} + +func (m *scriptedLLM) CreateChatCompletion(_ context.Context, _ openai.ChatCompletionRequest) (LLMReply, LLMUsage, error) { + m.mu.Lock() + defer m.mu.Unlock() + if !m.called { + m.called = true + return LLMReply{ChatCompletionResponse: openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + ToolCalls: []openai.ToolCall{{ + Type: openai.ToolTypeFunction, + Function: openai.FunctionCall{ + Name: m.toolName, + Arguments: m.toolArgs, + }, + }}, + }, + }}, + }}, LLMUsage{}, nil + } + // No further tool calls: plain assistant message so the loop terminates. + return LLMReply{ChatCompletionResponse: openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{Role: AssistantMessageRole.String(), Content: m.reply}, + }}, + }}, LLMUsage{}, nil +} + +// echoRunner is a trivial tool that records how many times it ran. The counter +// lets tests assert the tool executed (approve path) or did not (reject path). +type echoRunner struct { + mu *sync.Mutex + count *int +} + +func (r echoRunner) Run(EchoArgs) (string, any, error) { + r.mu.Lock() + *r.count++ + r.mu.Unlock() + return "echo", nil, nil +} + +// EchoArgs is the argument type for the echo test tool. +type EchoArgs struct { + Text string `json:"text" description:"text to echo"` +} + +// newEchoTool builds an echo tool whose invocation count is tracked via the +// supplied mutex+counter. +func newEchoTool(mu *sync.Mutex, count *int) ToolDefinitionInterface { + return NewToolDefinition[EchoArgs]( + echoRunner{mu: mu, count: count}, + EchoArgs{}, + "echo", + "echo text", + ) +} + +func TestSubAgentToolCallReachesParentCallback(t *testing.T) { + var mu sync.Mutex + var seenAgentIDs []string + + parentCB := func(tc *ToolChoice, st *SessionState) ToolCallDecision { + mu.Lock() + seenAgentIDs = append(seenAgentIDs, st.AgentID) + mu.Unlock() + return ToolCallDecision{Approved: true} + } + + var echoMu sync.Mutex + echoCount := 0 + echo := newEchoTool(&echoMu, &echoCount) + llm := newScriptedLLM("echo", `{"text": "hi"}`, "done") + + runner := &spawnAgentRunner{ + llm: llm, + parentTools: Tools{echo}, + parentOpts: []Option{WithToolCallBack(parentCB)}, + manager: NewAgentManager(), + ctx: context.Background(), + } + + _, _, err := runner.Run(SpawnAgentArgs{Task: "say hi", Background: false}) + if err != nil { + t.Fatalf("foreground spawn errored: %v", err) + } + + mu.Lock() + defer mu.Unlock() + if len(seenAgentIDs) == 0 { + t.Fatal("parent tool callback was never invoked from the sub-agent") + } + for _, id := range seenAgentIDs { + if id == "" { + t.Fatal("expected a non-empty AgentID in sub-agent tool callback") + } + } + + echoMu.Lock() + defer echoMu.Unlock() + if echoCount != 1 { + t.Fatalf("expected approved echo tool to run exactly once, ran %d times", echoCount) + } +} + +func TestSubAgentToolRejectionIsHonored(t *testing.T) { + rejectCB := func(tc *ToolChoice, st *SessionState) ToolCallDecision { + return ToolCallDecision{Approved: false} + } + + var echoMu sync.Mutex + echoCount := 0 + echo := newEchoTool(&echoMu, &echoCount) + llm := newScriptedLLM("echo", `{"text": "hi"}`, "done") + + runner := &spawnAgentRunner{ + llm: llm, + parentTools: Tools{echo}, + parentOpts: []Option{WithToolCallBack(rejectCB)}, + manager: NewAgentManager(), + ctx: context.Background(), + } + + _, _, err := runner.Run(SpawnAgentArgs{Task: "say hi", Background: false}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + echoMu.Lock() + defer echoMu.Unlock() + if echoCount != 0 { + t.Fatalf("rejected echo tool must not run, ran %d times", echoCount) + } +} diff --git a/agent_resume_test.go b/agent_resume_test.go new file mode 100644 index 0000000..91ef7c9 --- /dev/null +++ b/agent_resume_test.go @@ -0,0 +1,125 @@ +package cogito + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/sashabaranov/go-openai" +) + +// replyLLM is a reply-only LLM mock: it never selects a tool, it just replies +// with a fixed message on every turn. Used by the send_agent_message resume +// tests where the re-run should terminate immediately with a plain reply. +type replyLLM struct { + reply string +} + +// newReplyLLM builds an LLM that replies plainly (no tool calls) on every turn. +// The plan's A8 sketch used newScriptedLLM(scriptReply("...")), which does not +// exist in this repo; this is the cleanest reply-only equivalent. +func newReplyLLM(reply string) *replyLLM { + return &replyLLM{reply: reply} +} + +func (m *replyLLM) Ask(_ context.Context, f Fragment) (Fragment, error) { + return f.AddMessage(AssistantMessageRole, m.reply), nil +} + +func (m *replyLLM) CreateChatCompletion(_ context.Context, _ openai.ChatCompletionRequest) (LLMReply, LLMUsage, error) { + return LLMReply{ChatCompletionResponse: openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{Role: AssistantMessageRole.String(), Content: m.reply}, + }}, + }}, LLMUsage{}, nil +} + +func closedChan() chan struct{} { c := make(chan struct{}); close(c); return c } + +func TestInjectDeliversToRunningAgent(t *testing.T) { + m := NewAgentManager() + delivered := make(chan string, 1) + agent := &AgentState{ + ID: "a1", Status: AgentStatusRunning, + done: make(chan struct{}), + inject: make(chan openai.ChatCompletionMessage, 1), + } + m.Register(agent) + + go func() { + msg := <-agent.inject + delivered <- msg.Content + }() + + if err := m.Inject("a1", "keep going"); err != nil { + t.Fatalf("inject errored: %v", err) + } + select { + case got := <-delivered: + if got != "keep going" { + t.Fatalf("got %q", got) + } + case <-time.After(time.Second): + t.Fatal("inject not delivered") + } + _ = context.Background() +} + +func TestInjectUnknownAgentErrors(t *testing.T) { + m := NewAgentManager() + if err := m.Inject("missing", "x"); err == nil { + t.Fatal("expected error for unknown agent") + } +} + +func TestSendAgentMessageResumesCompletedAgent(t *testing.T) { + m := NewAgentManager() + frag := NewFragment(openai.ChatCompletionMessage{Role: "user", Content: "first task"}) + agent := &AgentState{ + ID: "done1", Status: AgentStatusCompleted, + Result: "first result", Fragment: &frag, + done: closedChan(), + } + m.Register(agent) + + llm := newReplyLLM("second result") + runner := &sendAgentMessageRunner{manager: m, ctx: context.Background(), llm: llm} + out, _, err := runner.Run(SendAgentMessageArgs{AgentID: "done1", Message: "now do more"}) + if err != nil { + t.Fatalf("resume errored: %v", err) + } + if !strings.Contains(out, "second result") { + t.Fatalf("expected re-run result, got %q", out) + } +} + +func TestSendAgentMessageInjectsRunningAgent(t *testing.T) { + m := NewAgentManager() + agent := &AgentState{ID: "run1", Status: AgentStatusRunning, + done: make(chan struct{}), inject: make(chan openai.ChatCompletionMessage, 1)} + m.Register(agent) + runner := &sendAgentMessageRunner{manager: m, ctx: context.Background()} + out, _, err := runner.Run(SendAgentMessageArgs{AgentID: "run1", Message: "hint"}) + if err != nil { + t.Fatalf("inject errored: %v", err) + } + if got := <-agent.inject; got.Content != "hint" { + t.Fatalf("injected %q", got.Content) + } + if !strings.Contains(out, "run1") { + t.Fatalf("expected ack mentioning agent id, got %q", out) + } +} + +func TestSendAgentMessageUnknownAgent(t *testing.T) { + m := NewAgentManager() + runner := &sendAgentMessageRunner{manager: m, ctx: context.Background()} + out, _, err := runner.Run(SendAgentMessageArgs{AgentID: "nope", Message: "hi"}) + if err != nil { + t.Fatalf("unknown agent should not hard-error, got %v", err) + } + if !strings.Contains(out, "not found") { + t.Fatalf("expected not-found message, got %q", out) + } +} diff --git a/agent_spawn_callback_test.go b/agent_spawn_callback_test.go new file mode 100644 index 0000000..a6f10ab --- /dev/null +++ b/agent_spawn_callback_test.go @@ -0,0 +1,121 @@ +package cogito + +import ( + "context" + "sync" + "testing" +) + +// TestWithAgentSpawnCallbackStores asserts the option stores the fn on Options. +func TestWithAgentSpawnCallbackStores(t *testing.T) { + o := defaultOptions() + o.Apply(WithAgentSpawnCallback(func(*AgentState) {})) + if o.agentSpawnCallback == nil { + t.Fatal("spawn callback not stored") + } +} + +// TestSpawnCallbackFiresForeground asserts a foreground spawn fires the spawn +// callback with a running AgentState whose Type matches the requested type. +func TestSpawnCallbackFiresForeground(t *testing.T) { + var mu sync.Mutex + var fired bool + var gotStatus AgentStatusType + var gotType string + var nonNil bool + + defs := []AgentDefinition{{Name: "explore", SystemPrompt: "You are EXPLORE."}} + runner := &spawnAgentRunner{ + llm: newReplyLLM("foreground done"), + manager: NewAgentManager(), + ctx: context.Background(), + agentDefinitions: defs, + // Snapshot the AgentState fields at callback time. The foreground agent + // runs in a goroutine and may mutate Status to "completed" by the time + // Run returns, so we must capture the values inside the callback (while + // Status is still running) rather than reading the live pointer after. + agentSpawnCallback: func(a *AgentState) { + mu.Lock() + fired = true + nonNil = a != nil + if a != nil { + gotStatus = a.Status + gotType = a.Type + } + mu.Unlock() + }, + } + + _, _, err := runner.Run(SpawnAgentArgs{AgentType: "explore", Task: "look around", Background: false}) + if err != nil { + t.Fatalf("foreground spawn errored: %v", err) + } + + mu.Lock() + defer mu.Unlock() + if !fired { + t.Fatal("spawn callback did not fire for foreground spawn") + } + if !nonNil { + t.Fatal("spawn callback received a nil AgentState") + } + if gotStatus != AgentStatusRunning { + t.Fatalf("spawn callback AgentState status = %q, want %q", gotStatus, AgentStatusRunning) + } + if gotType != "explore" { + t.Fatalf("spawn callback AgentState type = %q, want %q", gotType, "explore") + } +} + +// TestSpawnCallbackFiresBackground asserts a background spawn fires the spawn +// callback synchronously (before Run returns the ID) with a running AgentState +// whose Type matches the requested type. +func TestSpawnCallbackFiresBackground(t *testing.T) { + var mu sync.Mutex + var fired bool + var gotStatus AgentStatusType + var gotType string + var nonNil bool + + defs := []AgentDefinition{{Name: "plan", SystemPrompt: "You are PLAN."}} + runner := &spawnAgentRunner{ + llm: newReplyLLM("background done"), + manager: NewAgentManager(), + ctx: context.Background(), + agentDefinitions: defs, + // The background spawn fires the callback synchronously (before Run + // returns the ID) while Status is still running, but the agent's + // goroutine may mutate Status afterward, so snapshot inside the callback. + agentSpawnCallback: func(a *AgentState) { + mu.Lock() + fired = true + nonNil = a != nil + if a != nil { + gotStatus = a.Status + gotType = a.Type + } + mu.Unlock() + }, + } + + out, _, err := runner.Run(SpawnAgentArgs{AgentType: "plan", Task: "make a plan", Background: true}) + if err != nil { + t.Fatalf("background spawn errored: %v", err) + } + _ = out + + mu.Lock() + defer mu.Unlock() + if !fired { + t.Fatal("spawn callback did not fire for background spawn") + } + if !nonNil { + t.Fatal("spawn callback received a nil AgentState") + } + if gotStatus != AgentStatusRunning { + t.Fatalf("spawn callback AgentState status = %q, want %q", gotStatus, AgentStatusRunning) + } + if gotType != "plan" { + t.Fatalf("spawn callback AgentState type = %q, want %q", gotType, "plan") + } +} diff --git a/agent_test.go b/agent_test.go index 2d357c0..a6a24aa 100644 --- a/agent_test.go +++ b/agent_test.go @@ -500,6 +500,256 @@ var _ = Describe("Sub-Agent Spawning", func() { }) }) + Context("Agent definitions and approval propagation through ExecuteTools", func() { + // Drives a foreground spawn_agent call through the PUBLIC ExecuteTools API + // with a scripted parent mock and a SEPARATE sub-agent mock (via + // WithAgentLLM) so the two response queues are independent and + // deterministic. Proves the security-critical property that a sub-agent's + // tool call reaches the embedder's approval callback with a NON-EMPTY + // SessionState.AgentID, while the parent's own spawn_agent call reaches it + // with an EMPTY AgentID. The sub-agent's restricted "echo" tool running + // proves the AgentDefinition's tool restriction took effect. + It("propagates an empty AgentID for the parent's tool call and a non-empty AgentID for the restricted sub-agent tool", func() { + parentMock := mock.NewMockOpenAIClient() + subMock := mock.NewMockOpenAIClient() + + // --- Parent script --- + // 1. Parent iteration 1: LLM decides to call spawn_agent (foreground). + parentMock.AddCreateChatCompletionFunction("spawn_agent", + `{"agent_type":"explore","task":"investigate","background":false}`) + // 2. Parent iteration 2: no more tools (sink state). + parentMock.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "Parent done.", + }, + }}, + }) + // 3. Parent final Ask after the sink state. + parentMock.SetAskResponse("The explore sub-agent finished investigating.") + + // --- Sub-agent script (its OWN mock, independent queue) --- + // 1. Sub-agent iteration 1: LLM decides to call the echo tool. + subMock.AddCreateChatCompletionFunction("echo", `{"text":"hi"}`) + // 2. Sub-agent iteration 2: no more tools (sink state). + subMock.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "Sub-agent done.", + }, + }}, + }) + // 3. Sub-agent final Ask after the sink state. + subMock.SetAskResponse("echoed: hi") + + // The echo tool the sub-agent is allowed to use. + echoTool := mock.NewMockTool("echo", "Echo back the provided text") + mock.SetRunResult(echoTool, "echoed: hi") + + // The named sub-agent persona, restricted to the echo tool. + def := AgentDefinition{ + Name: "explore", + Description: "An exploration agent", + SystemPrompt: "You are EXPLORE.", + Tools: []string{"echo"}, + } + + type callbackEntry struct { + tool string + agentID string + } + var ( + mu sync.Mutex + entries []callbackEntry + ) + cb := func(tc *ToolChoice, state *SessionState) ToolCallDecision { + mu.Lock() + id := "" + if state != nil { + id = state.AgentID + } + name := "" + if tc != nil { + name = tc.Name + } + entries = append(entries, callbackEntry{tool: name, agentID: id}) + mu.Unlock() + return ToolCallDecision{Approved: true} + } + + fragment := NewEmptyFragment().AddMessage(UserMessageRole, "Investigate something") + + result, err := ExecuteTools(parentMock, fragment, + EnableAgentSpawning, + WithAgentLLM(subMock), + WithTools(echoTool), + WithAgentDefinitions(def), + WithToolCallBack(cb), + WithIterations(5), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(result.LastMessage().Content).ToNot(BeEmpty()) + + mu.Lock() + defer mu.Unlock() + Expect(entries).ToNot(BeEmpty()) + + var ( + sawSpawn bool + spawnAgentID string + sawEcho bool + echoAgentID string + ) + for _, e := range entries { + switch e.tool { + case "spawn_agent": + sawSpawn = true + spawnAgentID = e.agentID + case "echo": + sawEcho = true + echoAgentID = e.agentID + } + } + + // The parent's own spawn_agent call must reach the callback with an + // EMPTY AgentID. + Expect(sawSpawn).To(BeTrue(), "expected a callback entry for spawn_agent (parent tool); entries=%+v", entries) + Expect(spawnAgentID).To(BeEmpty(), "expected EMPTY AgentID for the parent's spawn_agent call; entries=%+v", entries) + + // The echo tool running inside the sub-agent proves the sub-agent + // executed with its restricted tool set (only "echo" from the + // definition), and that the approval gate fired for it with a + // non-empty sub-agent AgentID — the security property. + Expect(sawEcho).To(BeTrue(), "expected the sub-agent's restricted echo tool to reach the approval callback; entries=%+v", entries) + Expect(echoAgentID).ToNot(BeEmpty(), "SECURITY: expected NON-EMPTY AgentID for the sub-agent's echo call; entries=%+v", entries) + }) + }) + + Context("Spawn callback and background completion", func() { + // Drives a background spawn_agent call through the PUBLIC ExecuteTools API. + // The parent decides to spawn an agent in the background; the parent loop + // continues (parking on the auto-created injection channel until the + // background agent completes). A shared AgentManager lets us assert the + // agent registered and reached AgentStatusCompleted via mgr.Wait (robust + // against background-queue timing rather than racing on status reads). + // WithAgentSpawnCallback must fire at spawn time with a running explore + // agent, and WithAgentCompletionCallback must fire when it completes. + It("fires the spawn callback with a running explore agent and completes the background agent", func() { + parentMock := mock.NewMockOpenAIClient() + subMock := mock.NewMockOpenAIClient() + + // --- Parent script --- + // 1. Parent: spawn an explore agent in the background. + parentMock.AddCreateChatCompletionFunction("spawn_agent", + `{"agent_type":"explore","task":"investigate in background","background":true}`) + // 2+. After spawn_agent returns the ID, the parent keeps looping. + // Depending on timing it may pick a no-tool sink BEFORE the background + // completion message is injected, then loop again AFTER the injection — + // each loop consumes one CreateChatCompletion response. Queue several + // no-tool sink responses so the parent always has a response regardless + // of injection timing; the loop ends on the first no-tool reply once no + // background agents remain running. + for i := 0; i < 6; i++ { + parentMock.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "Background agent finished, all done.", + }, + }}, + }) + } + // Parent final Ask after the loop terminates on a no-tool sink. + parentMock.SetAskResponse("Spawned and completed a background explore agent.") + + // --- Sub-agent script (its OWN mock) --- + subMock.AddCreateChatCompletionFunction("echo", `{"text":"bg"}`) + subMock.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "Sub-agent done.", + }, + }}, + }) + subMock.SetAskResponse("echoed: bg") + + echoTool := mock.NewMockTool("echo", "Echo back the provided text") + mock.SetRunResult(echoTool, "echoed: bg") + + def := AgentDefinition{ + Name: "explore", + Description: "An exploration agent", + SystemPrompt: "You are EXPLORE.", + Tools: []string{"echo"}, + } + + var ( + evMu sync.Mutex + spawnedAgent *AgentState + doneAgent *AgentState + ) + spawnCB := func(a *AgentState) { + evMu.Lock() + spawnedAgent = a + evMu.Unlock() + } + completionCB := func(a *AgentState) { + evMu.Lock() + doneAgent = a + evMu.Unlock() + } + + mgr := NewAgentManager() + fragment := NewEmptyFragment().AddMessage(UserMessageRole, "Investigate in the background") + + result, err := ExecuteTools(parentMock, fragment, + EnableAgentSpawning, + WithAgentManager(mgr), + WithAgentLLM(subMock), + WithTools(echoTool), + WithAgentDefinitions(def), + WithAgentSpawnCallback(spawnCB), + WithAgentCompletionCallback(completionCB), + WithIterations(10), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(result.LastMessage().Content).ToNot(BeEmpty()) + + // Spawn event must have fired at spawn time with a running explore agent. + evMu.Lock() + sp := spawnedAgent + evMu.Unlock() + Expect(sp).ToNot(BeNil(), "expected the spawn callback to fire") + Expect(sp.Type).To(Equal("explore")) + Expect(sp.ID).ToNot(BeEmpty()) + + // Robust completion check: wait on the agent's done channel rather than + // racing on status reads. + finished, werr := mgr.Wait(sp.ID) + Expect(werr).ToNot(HaveOccurred()) + Expect(finished.Status).To(Equal(AgentStatusCompleted)) + + // The completion callback must fire for the completed agent. It fires + // from the sub-agent goroutine just before done, so give it a window. + Eventually(func() *AgentState { + evMu.Lock() + defer evMu.Unlock() + return doneAgent + }, 2*time.Second, 10*time.Millisecond).ShouldNot(BeNil()) + evMu.Lock() + Expect(doneAgent.Status).To(Equal(AgentStatusCompleted)) + evMu.Unlock() + + // The agent must be registered in the shared manager and completed. + got, ok := mgr.Get(sp.ID) + Expect(ok).To(BeTrue()) + Expect(got.Status).To(Equal(AgentStatusCompleted)) + }) + }) + Context("Context cancellation", func() { It("should cancel sub-agents when parent context is cancelled", func() { ctx, cancel := context.WithCancel(context.Background()) diff --git a/clients/openai_client.go b/clients/openai_client.go index 519251b..2bef331 100644 --- a/clients/openai_client.go +++ b/clients/openai_client.go @@ -13,16 +13,27 @@ var _ cogito.LLM = (*OpenAIClient)(nil) var _ cogito.StreamingLLM = (*OpenAIClient)(nil) type OpenAIClient struct { - model string - client *openai.Client + model string + client *openai.Client + temperature float32 +} + +// OpenAIOptions carries optional per-client settings. +type OpenAIOptions struct { + Temperature float32 } func NewOpenAILLM(model, apiKey, baseURL string) *OpenAIClient { + return NewOpenAILLMWithOptions(model, apiKey, baseURL, OpenAIOptions{}) +} + +func NewOpenAILLMWithOptions(model, apiKey, baseURL string, opts OpenAIOptions) *OpenAIClient { client := openaiClient(apiKey, baseURL) return &OpenAIClient{ - model: model, - client: client, + model: model, + client: client, + temperature: opts.Temperature, } } @@ -36,13 +47,15 @@ func (llm *OpenAIClient) Ask(ctx context.Context, f cogito.Fragment) (cogito.Fra // system message when tool calls are detected in the conversation messages := f.GetMessages() - resp, err := llm.client.CreateChatCompletion( - ctx, - openai.ChatCompletionRequest{ - Model: llm.model, - Messages: messages, - }, - ) + req := openai.ChatCompletionRequest{ + Model: llm.model, + Messages: messages, + } + if llm.temperature != 0 { + req.Temperature = llm.temperature + } + + resp, err := llm.client.CreateChatCompletion(ctx, req) if err != nil { return cogito.Fragment{}, err @@ -91,6 +104,9 @@ func (llm *OpenAIClient) CreateChatCompletion(ctx context.Context, request opena func (llm *OpenAIClient) CreateChatCompletionStream(ctx context.Context, request openai.ChatCompletionRequest) (<-chan cogito.StreamEvent, error) { request.Model = llm.model request.Stream = true + if llm.temperature != 0 { + request.Temperature = llm.temperature + } stream, err := llm.client.CreateChatCompletionStream(ctx, request) if err != nil { diff --git a/clients/openai_client_test.go b/clients/openai_client_test.go new file mode 100644 index 0000000..d8ab2e3 --- /dev/null +++ b/clients/openai_client_test.go @@ -0,0 +1,17 @@ +package clients + +import "testing" + +func TestNewOpenAILLMWithOptionsSetsTemperature(t *testing.T) { + llm := NewOpenAILLMWithOptions("m", "k", "http://localhost", OpenAIOptions{Temperature: 0.7}) + if llm.temperature != 0.7 { + t.Fatalf("expected temperature 0.7, got %v", llm.temperature) + } +} + +func TestNewOpenAILLMDefaultsTemperatureZeroMeansUnset(t *testing.T) { + llm := NewOpenAILLM("m", "k", "http://localhost") + if llm.temperature != 0 { + t.Fatalf("expected default temperature 0 (unset), got %v", llm.temperature) + } +} diff --git a/options.go b/options.go index 452e569..562fb94 100644 --- a/options.go +++ b/options.go @@ -92,7 +92,10 @@ type Options struct { agentManager *AgentManager agentLLM LLM agentCompletionCallback func(*AgentState) + agentSpawnCallback func(*AgentState) agentCompletionFormatter func(*AgentState) string + agentDefinitions []AgentDefinition + agentLLMFactory func(model string, temperature float32) LLM } type Option func(*Options) @@ -517,6 +520,24 @@ func WithAgentLLM(llm LLM) Option { } } +// WithAgentDefinitions registers named sub-agent types (personas). spawn_agent +// can select one via its agent_type argument; the chosen definition supplies the +// system prompt, tool allow-list, model, temperature, and per-type execution limits. +func WithAgentDefinitions(defs ...AgentDefinition) Option { + return func(o *Options) { + o.agentDefinitions = defs + } +} + +// WithAgentLLMFactory sets a factory that builds an LLM for a sub-agent from a +// model name and temperature. Used to resolve per-agent-type or per-spawn model +// overrides while reusing the parent's endpoint/credentials. +func WithAgentLLMFactory(fn func(model string, temperature float32) LLM) Option { + return func(o *Options) { + o.agentLLMFactory = fn + } +} + // WithAgentCompletionCallback sets a callback that fires when any background sub-agent finishes. // Useful for external monitoring or UI updates outside the LLM loop. func WithAgentCompletionCallback(fn func(*AgentState)) Option { @@ -525,6 +546,13 @@ func WithAgentCompletionCallback(fn func(*AgentState)) Option { } } +// WithAgentSpawnCallback sets a callback that fires when a sub-agent starts +// (is registered and about to run), for both foreground and background spawns. +// Useful for UIs that show running agents. The AgentState has Status=running. +func WithAgentSpawnCallback(fn func(*AgentState)) Option { + return func(o *Options) { o.agentSpawnCallback = fn } +} + // WithAgentCompletionFormatter overrides the message a finished background // sub-agent injects into the parent loop. By default cogito injects a // fixed prose notification ("Background agent has completed…"); set diff --git a/tools.go b/tools.go index 0cc4918..7959f52 100644 --- a/tools.go +++ b/tools.go @@ -33,6 +33,10 @@ type ToolStatus struct { type SessionState struct { ToolChoice *ToolChoice `json:"tool_choice"` Fragment Fragment `json:"fragment"` + // AgentID identifies the sub-agent whose tool call is being evaluated. + // Empty for the root agent. Set when the tool-call callback is invoked + // from within a spawned sub-agent (see WithToolCallBack propagation). + AgentID string `json:"agent_id,omitempty"` } // decisionResult holds the result of a tool decision from the LLM @@ -1138,11 +1142,21 @@ func ExecuteTools(llm LLM, f Fragment, opts ...Option) (Fragment, error) { if o.maxRetries > 0 { subAgentOpts = append(subAgentOpts, WithMaxRetries(o.maxRetries)) } + // Security-critical: propagate the parent's tool-call approval gate and + // MCP sessions so sub-agent tool calls flow through the same callback + // (stamped with the sub-agent's AgentID) instead of bypassing approval. + if o.toolCallCallback != nil { + subAgentOpts = append(subAgentOpts, WithToolCallBack(o.toolCallCallback)) + } + if len(o.mcpSessions) > 0 { + subAgentOpts = append(subAgentOpts, WithMCPs(o.mcpSessions...)) + } agentTools := []ToolDefinitionInterface{ - newSpawnAgentTool(agentLLM, o.tools, o.agentManager, o.context, subAgentOpts, o.streamCallback, o.messageInjectionChan, o.agentCompletionCallback, o.agentCompletionFormatter), + newSpawnAgentTool(agentLLM, o.tools, o.agentManager, o.context, subAgentOpts, o.streamCallback, o.messageInjectionChan, o.agentCompletionCallback, o.agentSpawnCallback, o.agentCompletionFormatter, o.agentDefinitions, o.agentLLMFactory), newCheckAgentTool(o.agentManager), newGetAgentResultTool(o.agentManager, o.context), + newSendAgentMessageTool(o.agentManager, o.context, agentLLM, subAgentOpts), } // Append agent tools to both o.tools (for this call) and opts (so usableTools sees them)