diff --git a/cli/azd/extensions/azure.ai.agents/cspell.yaml b/cli/azd/extensions/azure.ai.agents/cspell.yaml index 2d5c10d89d6..c5f45e7dd71 100644 --- a/cli/azd/extensions/azure.ai.agents/cspell.yaml +++ b/cli/azd/extensions/azure.ai.agents/cspell.yaml @@ -79,3 +79,4 @@ words: - parseable - azd's - deepseek + - ttfb diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/invoke.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/invoke.go index 350ce18846f..e3e25cd2b2d 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/invoke.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/invoke.go @@ -24,6 +24,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/fatih/color" "github.com/spf13/cobra" ) @@ -496,6 +497,24 @@ func contentTypeForBody(data []byte) string { return "text/plain" } +// printInvokeTiming prints a green timing line to stdout showing the total +// response time and time-to-first-byte (TTFB). Only call on success paths; +// failures should not display timing to avoid confusion. +// +// Output format: +// +// ⏱ Server responded in 6.667s (first byte: 1.111s) +func printInvokeTiming(w io.Writer, total, ttfb time.Duration) { + _, _ = color.New(color.FgGreen).Fprintf(w, "\n⏱ Server responded in %s (first byte: %s)\n", + formatDuration(total), formatDuration(ttfb)) +} + +// formatDuration formats a duration for display in timing output. +// Always uses seconds with 3 decimal places for consistency. +func formatDuration(d time.Duration) string { + return fmt.Sprintf("%.3fs", d.Seconds()) +} + func (a *InvokeAction) responsesLocal(ctx context.Context) error { port := a.flags.port @@ -568,6 +587,7 @@ func (a *InvokeAction) responsesLocal(ctx context.Context) error { } client := &http.Client{Timeout: a.httpTimeout()} + invokeStart := time.Now() resp, err := client.Do(req) //nolint:gosec // G704: URL targets localhost with user-configured port if err != nil { return fmt.Errorf( @@ -575,6 +595,7 @@ func (a *InvokeAction) responsesLocal(ctx context.Context) error { port, ) } + ttfb := time.Since(invokeStart) defer resp.Body.Close() if raw { @@ -595,6 +616,7 @@ func (a *InvokeAction) responsesLocal(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to read response: %w", err) } + totalDuration := time.Since(invokeStart) if resp.StatusCode >= 400 { if traceID := responseTraceID(resp); traceID != "" { @@ -611,6 +633,7 @@ func (a *InvokeAction) responsesLocal(ctx context.Context) error { if err := json.Unmarshal(respBody, &result); err != nil { // Not JSON -- just print raw response fmt.Println(string(respBody)) + printInvokeTiming(os.Stdout, totalDuration, ttfb) a.emitInvokeSuccessNextStep(nextstep.InvokeLocal, "") return nil } @@ -618,6 +641,7 @@ func (a *InvokeAction) responsesLocal(ctx context.Context) error { if err := printAgentResponse(result, "local"); err != nil { return err } + printInvokeTiming(os.Stdout, totalDuration, ttfb) a.emitInvokeSuccessNextStep(nextstep.InvokeLocal, "") return nil } @@ -963,11 +987,13 @@ func (a *InvokeAction) responsesRemote(ctx context.Context) error { } client := &http.Client{Timeout: a.httpTimeout()} + invokeStart := time.Now() //nolint:gosec // G704: URL is built from a validated Foundry endpoint (env or --agent-endpoint) resp, err := client.Do(req) if err != nil { return fmt.Errorf("POST %s failed: %w", respURL, err) } + ttfb := time.Since(invokeStart) defer resp.Body.Close() // Always capture session state from response headers (needed even in raw mode @@ -1005,6 +1031,8 @@ func (a *InvokeAction) responsesRemote(ctx context.Context) error { if err := readSSEStream(resp.Body, rc.name); err != nil { return err } + totalDuration := time.Since(invokeStart) + printInvokeTiming(os.Stdout, totalDuration, ttfb) a.emitInvokeSuccessNextStep(nextstep.InvokeRemote, rc.nextStepName()) return nil } @@ -1084,6 +1112,7 @@ func (a *InvokeAction) invocationsLocal(ctx context.Context) error { } client := &http.Client{Timeout: a.httpTimeout()} + invokeStart := time.Now() resp, err := client.Do(req) //nolint:gosec // G704: URL targets localhost with user-configured port if err != nil { return fmt.Errorf( @@ -1091,6 +1120,7 @@ func (a *InvokeAction) invocationsLocal(ctx context.Context) error { port, ) } + ttfb := time.Since(invokeStart) defer resp.Body.Close() // Print the invocation ID if the agent returned one. @@ -1107,7 +1137,9 @@ func (a *InvokeAction) invocationsLocal(ctx context.Context) error { } return err } + totalDuration := time.Since(invokeStart) if !raw { + printInvokeTiming(os.Stdout, totalDuration, ttfb) a.emitInvokeSuccessNextStep(nextstep.InvokeLocal, agentName) } return nil @@ -1191,11 +1223,13 @@ func (a *InvokeAction) invocationsRemote(ctx context.Context) error { } client := &http.Client{Timeout: a.httpTimeout()} + invokeStart := time.Now() //nolint:gosec // G704: URL is built from a validated Foundry endpoint (env or --agent-endpoint) resp, err := client.Do(req) if err != nil { return fmt.Errorf("POST %s failed: %w", invURL, err) } + ttfb := time.Since(invokeStart) defer resp.Body.Close() // Print the invocation ID if the agent returned one. We do not persist it @@ -1241,7 +1275,9 @@ func (a *InvokeAction) invocationsRemote(ctx context.Context) error { } return err } + totalDuration := time.Since(invokeStart) if !raw { + printInvokeTiming(os.Stdout, totalDuration, ttfb) a.emitInvokeSuccessNextStep(nextstep.InvokeRemote, rc.nextStepName()) } return nil diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/invoke_timing_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/invoke_timing_test.go new file mode 100644 index 00000000000..86a07fe8fb0 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/invoke_timing_test.go @@ -0,0 +1,198 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" +) + +// All tests in this file are sequential (not parallel) because the integration +// tests mutate the global os.Stdout via withCapturedStdout. + +func TestFormatDuration(t *testing.T) { + cases := []struct { + d time.Duration + want string + }{ + {0, "0.000s"}, + {1 * time.Millisecond, "0.001s"}, + {943 * time.Millisecond, "0.943s"}, + {1000 * time.Millisecond, "1.000s"}, + {6667 * time.Millisecond, "6.667s"}, + {14727 * time.Millisecond, "14.727s"}, + } + + for _, tc := range cases { + t.Run(tc.want, func(t *testing.T) { + if got := formatDuration(tc.d); got != tc.want { + t.Errorf("formatDuration(%v) = %q, want %q", tc.d, got, tc.want) + } + }) + } +} + +func TestPrintInvokeTiming(t *testing.T) { + var buf bytes.Buffer + printInvokeTiming(&buf, 19734*time.Millisecond, 13697*time.Millisecond) + got := buf.String() + + for _, want := range []string{"⏱", "19.734s", "first byte: 13.697s"} { + if !strings.Contains(got, want) { + t.Errorf("output %q missing %q", got, want) + } + } +} + +func TestResponsesLocal_Timing(t *testing.T) { + okBody, _ := json.Marshal(map[string]any{ + "output": []any{map[string]any{"content": []any{map[string]any{"type": "output_text", "text": "hi"}}}}, + }) + + cases := []struct { + name string + status int + body string + raw bool + wantTimer bool + }{ + {"success", 200, string(okBody), false, true}, + {"failure", 500, "error", false, false}, + {"raw_mode", 200, string(okBody), true, false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(tc.status) + fmt.Fprint(w, tc.body) + })) + defer srv.Close() + + outputFmt := "" + if tc.raw { + outputFmt = outputRaw + } + + action := &InvokeAction{ + flags: &invokeFlags{message: "hi", port: testPort(t, srv.URL), local: true, protocol: "responses", outputFmt: outputFmt}, + noPrompt: true, + } + + output := withCapturedStdout(t, func() { _ = action.responsesLocal(t.Context()) }) + + if tc.wantTimer && !strings.Contains(output, "⏱") { + t.Errorf("expected timing, got:\n%s", output) + } + if !tc.wantTimer && strings.Contains(output, "⏱") { + t.Errorf("unexpected timing in output:\n%s", output) + } + }) + } +} + +func TestInvocationsLocal_Timing(t *testing.T) { + cases := []struct { + name string + contentType string + status int + body string + raw bool + wantTimer bool + }{ + {"sync_json_success", "application/json", 200, `{"result":"ok"}`, false, true}, + {"sse_success", "text/event-stream", 200, "data: hello\n\n", false, true}, + {"failure", "application/json", 400, `{"error":"bad"}`, false, false}, + {"raw_mode", "application/json", 200, `{"result":"ok"}`, true, false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/openapi") { + w.WriteHeader(404) + return + } + w.Header().Set("Content-Type", tc.contentType) + w.WriteHeader(tc.status) + fmt.Fprint(w, tc.body) + })) + defer srv.Close() + + outputFmt := "" + if tc.raw { + outputFmt = outputRaw + } + + action := &InvokeAction{ + flags: &invokeFlags{message: "hi", port: testPort(t, srv.URL), local: true, protocol: "invocations", outputFmt: outputFmt}, + noPrompt: true, + } + + output := withCapturedStdout(t, func() { _ = action.invocationsLocal(t.Context()) }) + + if tc.wantTimer && !strings.Contains(output, "⏱") { + t.Errorf("expected timing, got:\n%s", output) + } + if !tc.wantTimer && strings.Contains(output, "⏱") { + t.Errorf("unexpected timing in output:\n%s", output) + } + }) + } +} + +// --- helpers --- + +// withCapturedStdout redirects os.Stdout to a pipe, runs fn, then returns +// everything written to stdout. Uses t.Cleanup to guarantee restoration even +// if the test fails or panics. +func withCapturedStdout(t *testing.T, fn func()) string { + t.Helper() + + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("os.Pipe: %v", err) + } + + orig := os.Stdout + os.Stdout = w + + t.Cleanup(func() { + os.Stdout = orig + _ = w.Close() + _ = r.Close() + }) + + fn() + + _ = w.Close() + os.Stdout = orig + + out, err := io.ReadAll(r) + if err != nil { + t.Fatalf("io.ReadAll: %v", err) + } + _ = r.Close() + + return string(out) +} + +func testPort(t *testing.T, rawURL string) int { + t.Helper() + parts := strings.Split(rawURL, ":") + var port int + if _, err := fmt.Sscanf(parts[len(parts)-1], "%d", &port); err != nil { + t.Fatalf("cannot parse port from %q: %v", rawURL, err) + } + return port +}