Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
410 changes: 354 additions & 56 deletions agent.go

Large diffs are not rendered by default.

210 changes: 210 additions & 0 deletions agent_definitions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
package cogito

import (
"context"
"strings"
"sync"
"testing"

"github.com/sashabaranov/go-openai"
)

// newNamedTool builds an echo-style tool with the given name. It reuses the
// echoRunner machinery from agent_propagation_test.go (mutex+counter) but the
// counter is throwaway here — the tool only needs to be present in the parent
// set so we can assert the sub-agent's allow-list excludes it.
func newNamedTool(name string) ToolDefinitionInterface {
var mu sync.Mutex
count := 0
return NewToolDefinition[EchoArgs](
echoRunner{mu: &mu, count: &count},
EchoArgs{},
name,
name+" tool",
)
}

// inspectingLLM records the fragment (messages) and tool set the sub-agent runs
// with on its first tool-selection turn, then replies plainly so the loop
// terminates without executing any tool. Tools reach the LLM via the
// ChatCompletionRequest's Tools field (tools.ToOpenAI()); the system prompt and
// task reach it via the request's Messages. We reconstruct a Fragment from the
// messages and capture tool names from the request so the assertions can match
// the plan's intent (system prompt seeded; excluded tool absent).
type inspectingLLM struct {
mu sync.Mutex
fn func(f Fragment, toolNames []string)
}

// newInspectingLLM builds an LLM whose first CreateChatCompletion call invokes
// fn with the fragment + tool names it was asked to choose from.
func newInspectingLLM(fn func(f Fragment, toolNames []string)) *inspectingLLM {
return &inspectingLLM{fn: fn}
}

func (m *inspectingLLM) Ask(_ context.Context, f Fragment) (Fragment, error) {
return f.AddMessage(AssistantMessageRole, "done"), nil
}

func (m *inspectingLLM) CreateChatCompletion(_ context.Context, req openai.ChatCompletionRequest) (LLMReply, LLMUsage, error) {
m.mu.Lock()
if m.fn != nil {
var names []string
for _, t := range req.Tools {
if t.Function != nil {
names = append(names, t.Function.Name)
}
}
m.fn(NewFragment(req.Messages...), names)
m.fn = nil // record only the first selection turn
}
m.mu.Unlock()
// No tool call: a plain assistant message so the loop terminates.
return LLMReply{ChatCompletionResponse: openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{{
Message: openai.ChatCompletionMessage{Role: AssistantMessageRole.String(), Content: "done"},
}},
}}, LLMUsage{}, nil
}

// firstSystemContent returns the content of the first system message in f, or "".
func firstSystemContent(f Fragment) string {
for _, msg := range f.GetMessages() {
if msg.Role == SystemMessageRole.String() {
return msg.Content
}
}
return ""
}

func contains(haystack []string, needle string) bool {
for _, s := range haystack {
if s == needle {
return true
}
}
return false
}

func TestWithAgentDefinitionsStoresDefs(t *testing.T) {
defs := []AgentDefinition{
{Name: "explore", Description: "read-only exploration",
SystemPrompt: "You explore.", Tools: []string{"echo"},
Model: "small-model", Temperature: 0.2,
Iterations: 20, MaxAttempts: 2, MaxRetries: 2},
}
o := defaultOptions()
o.Apply(WithAgentDefinitions(defs...))
if len(o.agentDefinitions) != 1 || o.agentDefinitions[0].Name != "explore" {
t.Fatalf("agent definitions not stored: %+v", o.agentDefinitions)
}
}

func TestFindAgentDefinition(t *testing.T) {
defs := []AgentDefinition{{Name: "plan"}, {Name: "explore"}}
if d := findAgentDefinition(defs, "explore"); d == nil || d.Name != "explore" {
t.Fatalf("expected to find explore, got %+v", d)
}
if d := findAgentDefinition(defs, "missing"); d != nil {
t.Fatalf("expected nil for missing type, got %+v", d)
}
}

func TestSpawnAppliesDefinitionSystemPromptAndTools(t *testing.T) {
var emu sync.Mutex
ecount := 0
echo := newEchoTool(&emu, &ecount)
secret := newNamedTool("secret") // the explore type must NOT receive this
defs := []AgentDefinition{{
Name: "explore", SystemPrompt: "You are EXPLORE.",
Tools: []string{"echo"}, Iterations: 7,
}}

var gotSystem string
var gotToolNames []string
llm := newInspectingLLM(func(f Fragment, toolNames []string) {
gotSystem = firstSystemContent(f)
gotToolNames = append(gotToolNames, toolNames...)
})

runner := &spawnAgentRunner{
llm: llm,
parentTools: Tools{echo, secret},
manager: NewAgentManager(),
ctx: context.Background(),
agentDefinitions: defs,
}
_, _, _ = runner.Run(SpawnAgentArgs{AgentType: "explore", Task: "look around", Background: false})

if gotSystem != "You are EXPLORE." {
t.Fatalf("definition system prompt not seeded, got %q", gotSystem)
}
if contains(gotToolNames, "secret") {
t.Fatalf("explore must not receive 'secret' tool, got %v", gotToolNames)
}
if !contains(gotToolNames, "echo") {
t.Fatalf("explore should receive 'echo' tool, got %v", gotToolNames)
}
}

func TestSpawnUnknownAgentTypeErrorsCleanly(t *testing.T) {
runner := &spawnAgentRunner{
llm: newInspectingLLM(func(Fragment, []string) {}),
manager: NewAgentManager(),
ctx: context.Background(),
agentDefinitions: []AgentDefinition{{Name: "explore"}},
}
out, _, err := runner.Run(SpawnAgentArgs{AgentType: "nope", Task: "x", Background: false})
if err != nil {
t.Fatalf("unknown type should not hard-error, got %v", err)
}
if !strings.Contains(out, "unknown agent type") {
t.Fatalf("expected a clear message, got %q", out)
}
}

func TestFactoryResolvesModelAndTemperature(t *testing.T) {
var gotModel string
var gotTemp float32
factory := func(model string, temp float32) LLM {
gotModel, gotTemp = model, temp
return newInspectingLLM(func(Fragment, []string) {})
}
defs := []AgentDefinition{{Name: "cheap", Model: "small", Temperature: 0.3}}
runner := &spawnAgentRunner{
llm: newInspectingLLM(func(Fragment, []string) {}),
manager: NewAgentManager(),
ctx: context.Background(),
agentDefinitions: defs,
llmFactory: factory,
}
_, _, _ = runner.Run(SpawnAgentArgs{AgentType: "cheap", Task: "x", Background: false})
if gotModel != "small" || gotTemp != 0.3 {
t.Fatalf("factory got (%q,%v), want (small,0.3)", gotModel, gotTemp)
}
}

func TestSpawnArgModelBeatsDefinition(t *testing.T) {
var gotModel string
factory := func(model string, temp float32) LLM {
gotModel = model
return newInspectingLLM(func(Fragment, []string) {})
}
defs := []AgentDefinition{{Name: "cheap", Model: "small"}}
runner := &spawnAgentRunner{
llm: newInspectingLLM(func(Fragment, []string) {}), manager: NewAgentManager(),
ctx: context.Background(), agentDefinitions: defs, llmFactory: factory,
}
_, _, _ = runner.Run(SpawnAgentArgs{AgentType: "cheap", Model: "big", Task: "x", Background: false})
if gotModel != "big" {
t.Fatalf("spawn-arg model should win, got %q", gotModel)
}
}

func TestWithAgentLLMFactoryStores(t *testing.T) {
o := defaultOptions()
o.Apply(WithAgentLLMFactory(func(string, float32) LLM { return nil }))
if o.agentLLMFactory == nil {
t.Fatal("factory not stored")
}
}
132 changes: 132 additions & 0 deletions agent_detach_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package cogito

import (
"context"
"strings"
"testing"
"time"

"github.com/sashabaranov/go-openai"
)

// blockingLLM blocks in CreateChatCompletion until release is closed, then
// returns a plain (tool-less) reply so the agent loop terminates. It models a
// long-running foreground sub-agent that an embedder can promote to background.
type blockingLLM struct {
release chan struct{}
reply string
}

// newBlockingLLM builds an LLM whose tool-selection turn blocks on <-release
// before returning a sink reply. Used by the detach test to keep a foreground
// agent in-flight while the parent detaches it.
func newBlockingLLM(release chan struct{}) *blockingLLM {
return &blockingLLM{release: release, reply: "blocked done"}
}

func (m *blockingLLM) Ask(_ context.Context, f Fragment) (Fragment, error) {
return f.AddMessage(AssistantMessageRole, m.reply), nil
}

func (m *blockingLLM) CreateChatCompletion(ctx context.Context, _ openai.ChatCompletionRequest) (LLMReply, LLMUsage, error) {
select {
case <-m.release:
case <-ctx.Done():
return LLMReply{}, LLMUsage{}, ctx.Err()
}
return LLMReply{ChatCompletionResponse: openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{{
Message: openai.ChatCompletionMessage{Role: AssistantMessageRole.String(), Content: m.reply},
}},
}}, LLMUsage{}, nil
}

func TestDetachReturnsBeforeCompletion(t *testing.T) {
m := NewAgentManager()
release := make(chan struct{})
// An LLM that blocks until released, simulating a long-running foreground agent.
llm := newBlockingLLM(release)

runner := &spawnAgentRunner{
llm: llm, manager: m, ctx: context.Background(),
}

type res struct {
out string
id any
}
resCh := make(chan res, 1)
go func() {
out, id, _ := runner.Run(SpawnAgentArgs{Task: "long job", Background: false})
resCh <- res{out, id}
}()

// Wait for the foreground agent to register, then detach it.
var id string
deadline := time.After(2 * time.Second)
for {
agents := m.List()
if len(agents) == 1 {
id = agents[0].ID
break
}
select {
case <-deadline:
t.Fatal("foreground agent never registered")
case <-time.After(10 * time.Millisecond):
}
}

if err := m.Detach(id); err != nil {
t.Fatalf("detach errored: %v", err)
}

select {
case r := <-resCh:
if r.id == nil {
t.Fatal("expected detach to return the agent id")
}
case <-time.After(2 * time.Second):
t.Fatal("Run did not return after detach")
}

// The goroutine is still running; release it so the test can clean up.
close(release)
}

func TestDetachUnknownAgentErrors(t *testing.T) {
m := NewAgentManager()
if err := m.Detach("missing"); err == nil {
t.Fatal("expected error for unknown agent")
}
}

func TestDetachNonDetachableAgentErrors(t *testing.T) {
m := NewAgentManager()
// A background-style agent has no detach channel: it is already detached.
agent := &AgentState{ID: "bg1", Status: AgentStatusRunning, done: make(chan struct{})}
m.Register(agent)
if err := m.Detach("bg1"); err == nil {
t.Fatal("expected error for non-detachable agent")
}
}

// TestForegroundSpawnUnchangedWithoutDetach is the key safety property: when no
// detach ever fires, the foreground select returns on agent.done and yields the
// same result content the old synchronous path returned via
// result.LastMessage().Content. The plan sketch used
// newScriptedLLM(scriptReply("final answer")); that signature does not exist in
// this repo, so we use the reply-only newReplyLLM helper (added in A8), which
// terminates the loop immediately with the fixed reply.
func TestForegroundSpawnUnchangedWithoutDetach(t *testing.T) {
m := NewAgentManager()
llm := newReplyLLM("final answer")
runner := &spawnAgentRunner{llm: llm, manager: m, ctx: context.Background()}
out, _, err := runner.Run(SpawnAgentArgs{Task: "quick", Background: false})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(out, "final answer") {
t.Fatalf("foreground result changed, got %q", out)
}
}
Loading
Loading