Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cli/azd/extensions/azure.ai.agents/cspell.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,4 @@ words:
- parseable
- azd's
- deepseek
- ttfb
36 changes: 36 additions & 0 deletions cli/azd/extensions/azure.ai.agents/internal/cmd/invoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/azure/azure-dev/cli/azd/pkg/azdext"
"github.com/fatih/color"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -496,6 +497,24 @@ func contentTypeForBody(data []byte) string {
return "text/plain"
}

// printInvokeTiming prints a green timing line to stdout showing the total
// response time and time-to-first-byte (TTFB). Only call on success paths;
// failures should not display timing to avoid confusion.
//
// Output format:
//
// ⏱ Server responded in 6.667s (first byte: 1.111s)
func printInvokeTiming(w io.Writer, total, ttfb time.Duration) {
_, _ = color.New(color.FgGreen).Fprintf(w, "\n⏱ Server responded in %s (first byte: %s)\n",
formatDuration(total), formatDuration(ttfb))
}

// formatDuration formats a duration for display in timing output.
// Always uses seconds with 3 decimal places for consistency.
func formatDuration(d time.Duration) string {
return fmt.Sprintf("%.3fs", d.Seconds())
}

func (a *InvokeAction) responsesLocal(ctx context.Context) error {
port := a.flags.port

Expand Down Expand Up @@ -568,13 +587,15 @@ func (a *InvokeAction) responsesLocal(ctx context.Context) error {
}

client := &http.Client{Timeout: a.httpTimeout()}
invokeStart := time.Now()
resp, err := client.Do(req) //nolint:gosec // G704: URL targets localhost with user-configured port
if err != nil {
return fmt.Errorf(
"could not connect to localhost:%d -- is the agent running? Start it with: azd ai agent run",
port,
)
}
ttfb := time.Since(invokeStart)
defer resp.Body.Close()

if raw {
Expand All @@ -595,6 +616,7 @@ func (a *InvokeAction) responsesLocal(ctx context.Context) error {
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
totalDuration := time.Since(invokeStart)

if resp.StatusCode >= 400 {
if traceID := responseTraceID(resp); traceID != "" {
Expand All @@ -611,13 +633,15 @@ func (a *InvokeAction) responsesLocal(ctx context.Context) error {
if err := json.Unmarshal(respBody, &result); err != nil {
// Not JSON -- just print raw response
fmt.Println(string(respBody))
printInvokeTiming(os.Stdout, totalDuration, ttfb)
a.emitInvokeSuccessNextStep(nextstep.InvokeLocal, "")
return nil
}

if err := printAgentResponse(result, "local"); err != nil {
return err
}
printInvokeTiming(os.Stdout, totalDuration, ttfb)
a.emitInvokeSuccessNextStep(nextstep.InvokeLocal, "")
return nil
}
Expand Down Expand Up @@ -963,11 +987,13 @@ func (a *InvokeAction) responsesRemote(ctx context.Context) error {
}

client := &http.Client{Timeout: a.httpTimeout()}
invokeStart := time.Now()
//nolint:gosec // G704: URL is built from a validated Foundry endpoint (env or --agent-endpoint)
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("POST %s failed: %w", respURL, err)
}
ttfb := time.Since(invokeStart)
defer resp.Body.Close()

// Always capture session state from response headers (needed even in raw mode
Expand Down Expand Up @@ -1005,6 +1031,8 @@ func (a *InvokeAction) responsesRemote(ctx context.Context) error {
if err := readSSEStream(resp.Body, rc.name); err != nil {
return err
}
totalDuration := time.Since(invokeStart)
printInvokeTiming(os.Stdout, totalDuration, ttfb)
a.emitInvokeSuccessNextStep(nextstep.InvokeRemote, rc.nextStepName())
return nil
}
Expand Down Expand Up @@ -1084,13 +1112,15 @@ func (a *InvokeAction) invocationsLocal(ctx context.Context) error {
}

client := &http.Client{Timeout: a.httpTimeout()}
invokeStart := time.Now()
resp, err := client.Do(req) //nolint:gosec // G704: URL targets localhost with user-configured port
if err != nil {
return fmt.Errorf(
"could not connect to localhost:%d -- is the agent running? Start it with: azd ai agent run",
port,
)
}
ttfb := time.Since(invokeStart)
defer resp.Body.Close()

// Print the invocation ID if the agent returned one.
Expand All @@ -1107,7 +1137,9 @@ func (a *InvokeAction) invocationsLocal(ctx context.Context) error {
}
return err
}
totalDuration := time.Since(invokeStart)
if !raw {
printInvokeTiming(os.Stdout, totalDuration, ttfb)
a.emitInvokeSuccessNextStep(nextstep.InvokeLocal, agentName)
}
return nil
Expand Down Expand Up @@ -1191,11 +1223,13 @@ func (a *InvokeAction) invocationsRemote(ctx context.Context) error {
}

client := &http.Client{Timeout: a.httpTimeout()}
invokeStart := time.Now()
//nolint:gosec // G704: URL is built from a validated Foundry endpoint (env or --agent-endpoint)
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("POST %s failed: %w", invURL, err)
}
ttfb := time.Since(invokeStart)
defer resp.Body.Close()

// Print the invocation ID if the agent returned one. We do not persist it
Expand Down Expand Up @@ -1241,7 +1275,9 @@ func (a *InvokeAction) invocationsRemote(ctx context.Context) error {
}
return err
}
totalDuration := time.Since(invokeStart)
if !raw {
printInvokeTiming(os.Stdout, totalDuration, ttfb)
a.emitInvokeSuccessNextStep(nextstep.InvokeRemote, rc.nextStepName())
}
return nil
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package cmd

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
)

// All tests in this file are sequential (not parallel) because the integration
// tests mutate the global os.Stdout via withCapturedStdout.

func TestFormatDuration(t *testing.T) {
cases := []struct {
d time.Duration
want string
}{
{0, "0.000s"},
{1 * time.Millisecond, "0.001s"},
{943 * time.Millisecond, "0.943s"},
{1000 * time.Millisecond, "1.000s"},
{6667 * time.Millisecond, "6.667s"},
{14727 * time.Millisecond, "14.727s"},
}

for _, tc := range cases {
t.Run(tc.want, func(t *testing.T) {
if got := formatDuration(tc.d); got != tc.want {
t.Errorf("formatDuration(%v) = %q, want %q", tc.d, got, tc.want)
}
})
}
}

func TestPrintInvokeTiming(t *testing.T) {
var buf bytes.Buffer
printInvokeTiming(&buf, 19734*time.Millisecond, 13697*time.Millisecond)
got := buf.String()

for _, want := range []string{"⏱", "19.734s", "first byte: 13.697s"} {
if !strings.Contains(got, want) {
t.Errorf("output %q missing %q", got, want)
}
}
}

func TestResponsesLocal_Timing(t *testing.T) {
okBody, _ := json.Marshal(map[string]any{
"output": []any{map[string]any{"content": []any{map[string]any{"type": "output_text", "text": "hi"}}}},
})

cases := []struct {
name string
status int
body string
raw bool
wantTimer bool
}{
{"success", 200, string(okBody), false, true},
{"failure", 500, "error", false, false},
{"raw_mode", 200, string(okBody), true, false},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(tc.status)
fmt.Fprint(w, tc.body)
}))
defer srv.Close()

outputFmt := ""
if tc.raw {
outputFmt = outputRaw
}

action := &InvokeAction{
flags: &invokeFlags{message: "hi", port: testPort(t, srv.URL), local: true, protocol: "responses", outputFmt: outputFmt},
noPrompt: true,
}

output := withCapturedStdout(t, func() { _ = action.responsesLocal(t.Context()) })

if tc.wantTimer && !strings.Contains(output, "⏱") {
t.Errorf("expected timing, got:\n%s", output)
}
if !tc.wantTimer && strings.Contains(output, "⏱") {
t.Errorf("unexpected timing in output:\n%s", output)
}
})
}
}

func TestInvocationsLocal_Timing(t *testing.T) {
cases := []struct {
name string
contentType string
status int
body string
raw bool
wantTimer bool
}{
{"sync_json_success", "application/json", 200, `{"result":"ok"}`, false, true},
{"sse_success", "text/event-stream", 200, "data: hello\n\n", false, true},
{"failure", "application/json", 400, `{"error":"bad"}`, false, false},
{"raw_mode", "application/json", 200, `{"result":"ok"}`, true, false},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/openapi") {
w.WriteHeader(404)
return
}
w.Header().Set("Content-Type", tc.contentType)
w.WriteHeader(tc.status)
fmt.Fprint(w, tc.body)
}))
defer srv.Close()

outputFmt := ""
if tc.raw {
outputFmt = outputRaw
}

action := &InvokeAction{
flags: &invokeFlags{message: "hi", port: testPort(t, srv.URL), local: true, protocol: "invocations", outputFmt: outputFmt},
noPrompt: true,
}

output := withCapturedStdout(t, func() { _ = action.invocationsLocal(t.Context()) })

if tc.wantTimer && !strings.Contains(output, "⏱") {
t.Errorf("expected timing, got:\n%s", output)
}
if !tc.wantTimer && strings.Contains(output, "⏱") {
t.Errorf("unexpected timing in output:\n%s", output)
}
})
}
}

// --- helpers ---

// withCapturedStdout redirects os.Stdout to a pipe, runs fn, then returns
// everything written to stdout. Uses t.Cleanup to guarantee restoration even
// if the test fails or panics.
func withCapturedStdout(t *testing.T, fn func()) string {
t.Helper()

r, w, err := os.Pipe()
if err != nil {
t.Fatalf("os.Pipe: %v", err)
}

orig := os.Stdout
os.Stdout = w

t.Cleanup(func() {
os.Stdout = orig
_ = w.Close()
_ = r.Close()
})

fn()

_ = w.Close()
os.Stdout = orig

out, err := io.ReadAll(r)
if err != nil {
t.Fatalf("io.ReadAll: %v", err)
}
_ = r.Close()

return string(out)
}

func testPort(t *testing.T, rawURL string) int {
t.Helper()
parts := strings.Split(rawURL, ":")
var port int
if _, err := fmt.Sscanf(parts[len(parts)-1], "%d", &port); err != nil {
t.Fatalf("cannot parse port from %q: %v", rawURL, err)
}
return port
}
Loading