Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,13 @@ type AgentState struct {
Fragment *Fragment
Error error
Cancel context.CancelFunc
done chan struct{}
inject chan openai.ChatCompletionMessage
// Background reports whether the agent was spawned to run in the background
// (spawn_agent background=true) rather than in the foreground. Embedders use
// it to tell unattended background work apart from a foreground sub-agent
// whose result is consumed inline by the spawn call.
Background bool
done chan struct{}
inject chan openai.ChatCompletionMessage
// detach, when non-nil, lets an embedder promote a running foreground
// agent to the background: a non-blocking send here unblocks the
// spawn_agent call so it returns the agent ID while the goroutine keeps
Expand Down Expand Up @@ -405,13 +410,14 @@ func (r *spawnAgentRunner) Run(args SpawnAgentArgs) (string, any, error) {

// Background: launch goroutine, return ID immediately.
agent := &AgentState{
ID: agentID,
Task: args.Task,
Type: args.AgentType,
Status: AgentStatusRunning,
Cancel: cancel,
done: make(chan struct{}),
inject: make(chan openai.ChatCompletionMessage, 8),
ID: agentID,
Task: args.Task,
Type: args.AgentType,
Status: AgentStatusRunning,
Cancel: cancel,
Background: true,
done: make(chan struct{}),
inject: make(chan openai.ChatCompletionMessage, 8),
}
r.manager.Register(agent)
if r.agentSpawnCallback != nil {
Expand Down
28 changes: 28 additions & 0 deletions agent_background_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package cogito

import (
"context"
"testing"
)

// TestBackgroundSpawnSetsBackgroundFlag verifies that a background spawn marks
// the AgentState so embedders can tell it apart from a foreground sub-agent.
func TestBackgroundSpawnSetsBackgroundFlag(t *testing.T) {
m := NewAgentManager()
llm := newBlockingLLM(make(chan struct{})) // blocks; background spawn returns immediately
runner := &spawnAgentRunner{llm: llm, manager: m, ctx: context.Background()}

_, idAny, err := runner.Run(SpawnAgentArgs{Task: "bg job", Background: true})
if err != nil {
t.Fatalf("background Run: %v", err)
}
id, _ := idAny.(string)
a, ok := m.Get(id)
if !ok {
t.Fatal("background agent should be registered")
}
if !a.Background {
t.Fatal("a background spawn should set AgentState.Background = true")
}
a.Cancel() // unblock the goroutine so the test cleans up
}
1 change: 1 addition & 0 deletions fragment.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type InjectedMessage struct {

type Status struct {
LastUsage LLMUsage // Track token usage from the last LLM call
CumulativeUsage LLMUsage // Sum of token usage across every LLM call in the run
Iterations int
ToolsCalled Tools
ToolResults []ToolStatus
Expand Down
14 changes: 13 additions & 1 deletion tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,7 @@ func askWithStreaming(ctx context.Context, llm LLM, f Fragment, streamCB StreamC

// ExecuteTools runs a fragment through an LLM, and executes Tools. It returns a new fragment with the tool result at the end
// The result is guaranteed that can be called afterwards with llm.Ask() to explain the result to the user.
func ExecuteTools(llm LLM, f Fragment, opts ...Option) (Fragment, error) {
func ExecuteTools(llm LLM, f Fragment, opts ...Option) (result Fragment, retErr error) {
o := defaultOptions()
o.Apply(opts...)

Expand Down Expand Up @@ -1206,6 +1206,18 @@ func ExecuteTools(llm LLM, f Fragment, opts ...Option) (Fragment, error) {
o.messageInjectionChan = make(chan openai.ChatCompletionMessage, 16)
}

// Accumulate token usage across every LLM call in this run and stamp the
// total onto the returned fragment, so callers (and sub-agent completion
// callbacks) can report cumulative usage. The sub-agent fallback LLM
// (agentLLM, captured above) stays unwrapped so its usage is not folded in.
runUsage := &usageCounter{}
llm = newCountingLLM(llm, runUsage)
defer func() {
if result.Status != nil {
result.Status.CumulativeUsage = runUsage.snapshot()
}
}()

// should I plan?
if o.autoPlan {
xlog.Debug("Checking if planning is needed")
Expand Down
50 changes: 50 additions & 0 deletions tools_cumulative_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package cogito_test

import (
. "github.com/mudler/cogito"
"github.com/mudler/cogito/tests/mock"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/sashabaranov/go-openai"
)

var _ = Describe("ExecuteTools cumulative usage", func() {
It("sums token usage across every LLM call in the run", func() {
mockLLM := mock.NewMockOpenAIClient()

// One tool round then a final text answer => >= 2 CreateChatCompletion
// calls plus one Ask. Each configured call reports 100 total tokens.
mockLLM.AddCreateChatCompletionFunction("search", `{"query": "test"}`)
mockTool := mock.NewMockTool("search", "Search for information")
mock.SetRunResult(mockTool, "Result")
mockLLM.SetAskResponse("Final answer")
mockLLM.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{
{Message: openai.ChatCompletionMessage{Role: "assistant", Content: "No more tools needed."}},
},
})
mockLLM.SetUsage(40, 60, 100)
mockLLM.SetUsage(40, 60, 100)
mockLLM.SetUsage(40, 60, 100)

fragment := NewEmptyFragment().AddMessage(UserMessageRole, "Task")
result, err := ExecuteTools(mockLLM, fragment, WithTools(mockTool))
Expect(err).ToNot(HaveOccurred())

// Expected = the total tokens of every usage entry the mock dispensed.
expected := 0
for i := 0; i < mockLLM.CreateChatCompletionUsageIndex; i++ {
expected += mockLLM.CreateChatCompletionUsage[i].TotalTokens
}
for i := 0; i < mockLLM.AskUsageIndex; i++ {
expected += mockLLM.AskUsage[i].TotalTokens
}

Expect(expected).To(BeNumerically(">", 100), "test must drive at least two billed calls")
Expect(result.Status.CumulativeUsage.TotalTokens).To(Equal(expected))
Expect(result.Status.CumulativeUsage.TotalTokens).To(
BeNumerically(">", result.Status.LastUsage.TotalTokens),
"cumulative must exceed the last single call",
)
})
})
109 changes: 109 additions & 0 deletions usage_counter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package cogito

import (
"context"
"sync/atomic"

"github.com/sashabaranov/go-openai"
)

// usageCounter accumulates token usage across every LLM call routed through a
// countingLLM. Safe for concurrent use (sub-agents run in their own goroutines,
// each with its own counter, but streaming delivery may add from a goroutine).
type usageCounter struct {
prompt atomic.Int64
completion atomic.Int64
total atomic.Int64
}

func (c *usageCounter) add(u LLMUsage) {
c.prompt.Add(int64(u.PromptTokens))
c.completion.Add(int64(u.CompletionTokens))
c.total.Add(int64(u.TotalTokens))
}

func (c *usageCounter) snapshot() LLMUsage {
return LLMUsage{
PromptTokens: int(c.prompt.Load()),
CompletionTokens: int(c.completion.Load()),
TotalTokens: int(c.total.Load()),
}
}

// countingLLM wraps an LLM, accumulating token usage from every call into
// counter. CreateChatCompletion returns usage directly; Ask discards it from
// its signature but records it on the returned fragment's Status.LastUsage,
// which is where we read it.
type countingLLM struct {
LLM
counter *usageCounter
}

func (c *countingLLM) CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (LLMReply, LLMUsage, error) {
reply, usage, err := c.LLM.CreateChatCompletion(ctx, req)
if err == nil {
c.counter.add(usage)
}
return reply, usage, err
}

// Ask recovers per-call usage from the returned fragment's Status.LastUsage,
// which every cogito Ask implementation (and the test mock) refreshes on each
// call. If a future Ask returned a fragment carrying a stale LastUsage, this
// would re-add it — the assumption is that Ask always sets LastUsage fresh.
func (c *countingLLM) Ask(ctx context.Context, f Fragment) (Fragment, error) {
res, err := c.LLM.Ask(ctx, f)
if err == nil && res.Status != nil {
c.counter.add(res.Status.LastUsage)
}
return res, err
}

// countingStreamingLLM preserves StreamingLLM so wrapping does not disable the
// streaming code path for callers that use it. Usage is accumulated from the
// StreamEventDone event's Usage field.
//
// NOTE: cogito's bundled clients (clients/openai_client.go, clients/localai_client.go)
// do not currently populate StreamEvent.Usage on the done event, so streaming-path
// token accumulation is zero in production until those clients request usage from
// the API (e.g. StreamOptions{IncludeUsage: true}). The non-streaming path
// (CreateChatCompletion / Ask) is fully counted.
type countingStreamingLLM struct {
countingLLM
streaming StreamingLLM
}

func (c *countingStreamingLLM) CreateChatCompletionStream(ctx context.Context, req openai.ChatCompletionRequest) (<-chan StreamEvent, error) {
in, err := c.streaming.CreateChatCompletionStream(ctx, req)
if err != nil {
return nil, err
}
// Buffer to match the client convention (clients/openai_client.go) and make
// the forward context-aware so a stopped consumer cannot leak this goroutine.
out := make(chan StreamEvent, 64)
go func() {
defer close(out)
for ev := range in {
if ev.Type == StreamEventDone {
c.counter.add(ev.Usage)
}
select {
case out <- ev:
case <-ctx.Done():
return
}
}
}()
return out, nil
}

// newCountingLLM wraps llm so token usage accumulates into counter. When llm is
// streaming-capable, the returned wrapper is too, so the streaming path is
// preserved.
func newCountingLLM(llm LLM, counter *usageCounter) LLM {
base := countingLLM{LLM: llm, counter: counter}
if s, ok := llm.(StreamingLLM); ok {
return &countingStreamingLLM{countingLLM: base, streaming: s}
}
return &base
}
79 changes: 79 additions & 0 deletions usage_counter_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package cogito

import (
"context"
"testing"

"github.com/sashabaranov/go-openai"
)

// fakeLLM is a minimal LLM that returns a fixed usage per CreateChatCompletion
// call and records a fixed usage on the fragment it returns from Ask.
type fakeLLM struct {
ccUsage LLMUsage
askUsage LLMUsage
}

func (f *fakeLLM) CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (LLMReply, LLMUsage, error) {
return LLMReply{ChatCompletionResponse: openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{{Message: openai.ChatCompletionMessage{Role: "assistant"}}},
}}, f.ccUsage, nil
}

func (f *fakeLLM) Ask(ctx context.Context, frag Fragment) (Fragment, error) {
out := Fragment{Status: &Status{}}
out.Status.LastUsage = f.askUsage
return out, nil
}

func TestCountingLLMAccumulatesBothPaths(t *testing.T) {
inner := &fakeLLM{
ccUsage: LLMUsage{PromptTokens: 10, CompletionTokens: 5, TotalTokens: 15},
askUsage: LLMUsage{PromptTokens: 7, CompletionTokens: 3, TotalTokens: 10},
}
counter := &usageCounter{}
llm := newCountingLLM(inner, counter)

if _, _, err := llm.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{}); err != nil {
t.Fatalf("CreateChatCompletion: %v", err)
}
if _, _, err := llm.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{}); err != nil {
t.Fatalf("CreateChatCompletion: %v", err)
}
if _, err := llm.Ask(context.Background(), NewEmptyFragment()); err != nil {
t.Fatalf("Ask: %v", err)
}

got := counter.snapshot()
if got.TotalTokens != 40 { // 15 + 15 + 10
t.Errorf("TotalTokens = %d, want 40", got.TotalTokens)
}
if got.PromptTokens != 27 { // 10 + 10 + 7
t.Errorf("PromptTokens = %d, want 27", got.PromptTokens)
}
if got.CompletionTokens != 13 { // 5 + 5 + 3
t.Errorf("CompletionTokens = %d, want 13", got.CompletionTokens)
}
}

// streamingFake additionally implements StreamingLLM.
type streamingFake struct{ fakeLLM }

func (s *streamingFake) CreateChatCompletionStream(ctx context.Context, req openai.ChatCompletionRequest) (<-chan StreamEvent, error) {
ch := make(chan StreamEvent, 1)
ch <- StreamEvent{Type: StreamEventDone, Usage: LLMUsage{TotalTokens: 99}}
close(ch)
return ch, nil
}

func TestNewCountingLLMPreservesStreaming(t *testing.T) {
plain := newCountingLLM(&fakeLLM{}, &usageCounter{})
if _, ok := plain.(StreamingLLM); ok {
t.Error("wrapping a non-streaming LLM must not yield a StreamingLLM")
}

streaming := newCountingLLM(&streamingFake{}, &usageCounter{})
if _, ok := streaming.(StreamingLLM); !ok {
t.Error("wrapping a StreamingLLM must yield a StreamingLLM")
}
}
Loading