diff --git a/cmd/opencodereview/review_cmd.go b/cmd/opencodereview/review_cmd.go index be68b32c..53d8a8ad 100644 --- a/cmd/opencodereview/review_cmd.go +++ b/cmd/opencodereview/review_cmd.go @@ -68,10 +68,17 @@ func runReview(args []string) error { return runPreview(repoDir, opts, fileFilter) } + mode := tool.ParseReviewMode(opts.from, opts.to, opts.commit) + ref, _ := mode.RefValue(opts.to, opts.commit) + toolEntries, err := toolsconfig.Load(opts.toolConfigPath) if err != nil { return fmt.Errorf("load tools: %w", err) } + codeGraph := detectCodeGraphForReview(repoDir, ref) + if !codeGraph.Available { + toolEntries = toolsconfig.ExcludeByName(toolEntries, tool.CodeGraphContext.Name()) + } planToolDefs := agent.BuildToolDefs(toolEntries, true) mainToolDefs := agent.BuildToolDefs(toolEntries, false) @@ -101,15 +108,13 @@ func runReview(args []string) error { gitRunner := gitcmd.New(opts.maxGitProcs) collector := tool.NewCommentCollector() - mode := tool.ParseReviewMode(opts.from, opts.to, opts.commit) - ref, _ := mode.RefValue(opts.to, opts.commit) fileReader := &tool.FileReader{ RepoDir: repoDir, Mode: mode, Ref: ref, Runner: gitRunner, } - tools := buildToolRegistry(collector, fileReader) + tools := buildToolRegistry(collector, fileReader, codeGraph) ag := agent.New(agent.Args{ RepoDir: repoDir, @@ -269,12 +274,48 @@ func runPreview(repoDir string, opts reviewOptions, fileFilter *rules.FileFilter return nil } -func buildToolRegistry(collector *tool.CommentCollector, fr *tool.FileReader) *tool.Registry { +func detectCodeGraphForReview(repoDir, ref string) tool.CodeGraphAvailability { + codeGraph := tool.DetectCodeGraph(repoDir) + if !codeGraph.Available || ref == "" { + return codeGraph + } + + head, err := resolveCommit(repoDir, "HEAD") + if err != nil { + codeGraph.Available = false + codeGraph.Reason = "cannot resolve HEAD for CodeGraph ref compatibility check" + return codeGraph + } + target, err := resolveCommit(repoDir, ref) + if err != nil { + codeGraph.Available = false + codeGraph.Reason = "cannot resolve review target ref for CodeGraph compatibility check" + return codeGraph + } + if head != target { + codeGraph.Available = false + codeGraph.Reason = "CodeGraph index is for current checkout, which differs from review target ref" + } + return codeGraph +} + +func resolveCommit(repoDir, ref string) (string, error) { + out, err := runGitCmd(repoDir, "rev-parse", "--verify", "--end-of-options", ref+"^{commit}") + if err != nil { + return "", err + } + return strings.TrimSpace(string(out)), nil +} + +func buildToolRegistry(collector *tool.CommentCollector, fr *tool.FileReader, codeGraph tool.CodeGraphAvailability) *tool.Registry { reg := tool.NewRegistry() reg.Register(tool.NewFileRead(fr)) reg.Register(tool.NewFileFind(fr)) reg.Register(tool.NewFileReadDiff(tool.DiffMap{})) reg.Register(tool.NewCodeSearch(fr)) + if codeGraph.Available { + reg.Register(tool.NewCodeGraph(fr.RepoDir, codeGraph.BinPath)) + } reg.Register(&tool.CodeCommentProvider{Collector: collector}) return reg } diff --git a/internal/config/toolsconfig/tools.json b/internal/config/toolsconfig/tools.json index 1784c234..c9e1423d 100644 --- a/internal/config/toolsconfig/tools.json +++ b/internal/config/toolsconfig/tools.json @@ -155,6 +155,50 @@ } } }, + { + "name": "code_graph_context", + "plan_task": true, + "main_task": true, + "definition": { + "name": "code_graph_context", + "description": "Use this optional structural-code tool when text search is not enough and you need to understand symbols, callers, callees, or impact radius. Prefer this tool for changes to public functions, APIs, interfaces, method signatures, shared utilities, route handlers, data models, security/auth logic, concurrency/lifecycle code, or when a risk depends on how other files call the changed symbol. Do not use it for simple literal text lookup; use code_search instead. This tool is only available when a compatible CodeGraph index is detected for the repository.", + "parameters": { + "type": "object", + "properties": { + "mode": { + "type": "string", + "enum": [ + "explore", + "search", + "callers", + "callees", + "impact" + ], + "description": "Structural query mode. Use search to find symbols, explore for source plus relationship context, callers/callees for dependency direction, and impact to estimate what may be affected by changing a symbol. Defaults to explore." + }, + "query": { + "type": "string", + "description": "A symbol name, file path, or concise structural query. Prefer exact function/type/class/interface names when available." + }, + "kind": { + "type": "string", + "description": "Optional symbol kind filter for mode=search, such as function, method, class, interface, type, variable, route, or component." + }, + "limit": { + "type": "integer", + "description": "Maximum number of symbols or relationships to return for search/callers/callees. Defaults to 12, maximum 30." + }, + "max_files": { + "type": "integer", + "description": "Maximum number of source files to include for mode=explore. Defaults to 4, maximum 8." + } + }, + "required": [ + "query" + ] + } + } + }, { "name": "file_find", "plan_task": true, @@ -180,4 +224,4 @@ } } } -] \ No newline at end of file +] diff --git a/internal/config/toolsconfig/toolsconfig.go b/internal/config/toolsconfig/toolsconfig.go index 9213def3..3bd777fc 100644 --- a/internal/config/toolsconfig/toolsconfig.go +++ b/internal/config/toolsconfig/toolsconfig.go @@ -39,6 +39,25 @@ func Load(path string) ([]ToolConfigEntry, error) { return tools, nil } +// ExcludeByName returns entries excluding any tool whose name appears in names. +func ExcludeByName(entries []ToolConfigEntry, names ...string) []ToolConfigEntry { + if len(names) == 0 { + return entries + } + excluded := make(map[string]bool, len(names)) + for _, name := range names { + excluded[name] = true + } + out := make([]ToolConfigEntry, 0, len(entries)) + for _, entry := range entries { + if excluded[entry.Name] { + continue + } + out = append(out, entry) + } + return out +} + // ToolDefsByPhase returns the parsed tool definitions filtered by phase. // planOnly=true returns only tools with plan_task:true. // planOnly=false returns only tools with main_task:true. diff --git a/internal/config/toolsconfig/toolsconfig_test.go b/internal/config/toolsconfig/toolsconfig_test.go new file mode 100644 index 00000000..434fb523 --- /dev/null +++ b/internal/config/toolsconfig/toolsconfig_test.go @@ -0,0 +1,21 @@ +package toolsconfig + +import "testing" + +func TestExcludeByName(t *testing.T) { + entries := []ToolConfigEntry{ + {Name: "code_search"}, + {Name: "code_graph_context"}, + {Name: "file_read"}, + } + + filtered := ExcludeByName(entries, "code_graph_context") + if len(filtered) != 2 { + t.Fatalf("expected 2 entries, got %d", len(filtered)) + } + for _, entry := range filtered { + if entry.Name == "code_graph_context" { + t.Fatal("expected code_graph_context to be filtered") + } + } +} diff --git a/internal/tool/code_graph.go b/internal/tool/code_graph.go new file mode 100644 index 00000000..0af3f55d --- /dev/null +++ b/internal/tool/code_graph.go @@ -0,0 +1,234 @@ +package tool + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + "unicode/utf8" +) + +const ( + codeGraphTimeout = 15 * time.Second + codeGraphDetectTimeout = 3 * time.Second + codeGraphMaxOutput = 12000 + codeGraphMaxFiles = 8 + codeGraphDefaultFiles = 4 + codeGraphMaxLimit = 30 + codeGraphDefaultLimit = 12 +) + +var ansiEscapePattern = regexp.MustCompile(`\x1b\[[0-9;]*[A-Za-z]`) + +// CodeGraphProvider retrieves structural code context from an optional external +// CodeGraph installation. It is intentionally CLI-backed so OCR does not take a +// hard dependency on CodeGraph's database schema or Go libraries. +type CodeGraphProvider struct { + RepoDir string + BinPath string +} + +// CodeGraphAvailability describes whether the optional CodeGraph integration can +// be exposed to the model for this review run. +type CodeGraphAvailability struct { + Available bool + BinPath string + Version string + Reason string +} + +func NewCodeGraph(repoDir, binPath string) *CodeGraphProvider { + return &CodeGraphProvider{RepoDir: repoDir, BinPath: binPath} +} + +func (p *CodeGraphProvider) Tool() Tool { return CodeGraphContext } + +func (p *CodeGraphProvider) Execute(ctx context.Context, args map[string]any) (string, error) { + mode := stringArg(args, "mode") + if mode == "" { + mode = "explore" + } + query := strings.TrimSpace(stringArg(args, "query")) + if query == "" { + return "Error: query is required", nil + } + + limit := intArg(args, "limit", codeGraphDefaultLimit) + if limit <= 0 || limit > codeGraphMaxLimit { + limit = codeGraphDefaultLimit + } + maxFiles := intArg(args, "max_files", codeGraphDefaultFiles) + if maxFiles <= 0 { + maxFiles = codeGraphDefaultFiles + } + if maxFiles > codeGraphMaxFiles { + maxFiles = codeGraphMaxFiles + } + + cmdArgs := []string{} + switch mode { + case "search": + cmdArgs = []string{"query", "-p", p.RepoDir, "-l", strconv.Itoa(limit)} + if kind := strings.TrimSpace(stringArg(args, "kind")); kind != "" { + cmdArgs = append(cmdArgs, "-k", kind) + } + cmdArgs = append(cmdArgs, "--", query) + case "explore": + cmdArgs = []string{"explore", "-p", p.RepoDir, "--max-files", strconv.Itoa(maxFiles), "--", query} + case "callers": + cmdArgs = []string{"callers", "-p", p.RepoDir, "-l", strconv.Itoa(limit), "--", query} + case "callees": + cmdArgs = []string{"callees", "-p", p.RepoDir, "-l", strconv.Itoa(limit), "--", query} + case "impact": + cmdArgs = []string{"impact", "-p", p.RepoDir, "--", query} + default: + return fmt.Sprintf("Error: unsupported mode %q. Supported modes: search, explore, callers, callees, impact", mode), nil + } + + out, err := p.run(ctx, cmdArgs...) + if err != nil { + return "", fmt.Errorf("code_graph_context failed: %w", err) + } + return out, nil +} + +func (p *CodeGraphProvider) run(parentCtx context.Context, args ...string) (string, error) { + ctx, cancel := context.WithTimeout(parentCtx, codeGraphTimeout) + defer cancel() + + cmd := exec.CommandContext(ctx, p.BinPath, args...) + cmd.Dir = p.RepoDir + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + if ctx.Err() != nil { + return codeGraphTimeoutMessage(stdout.String(), stderr.String()), nil + } + + out := strings.TrimSpace(stdout.String()) + errOut := strings.TrimSpace(stderr.String()) + if err != nil { + if out == "" && errOut == "" { + return "Error: codegraph command failed", nil + } + if out == "" { + return "Error: " + stripANSI(errOut), nil + } + } + if errOut != "" { + out += "\nWarning: " + stripANSI(errOut) + } + out = stripANSI(out) + out = truncateToolOutput(out) + if out == "" { + return "No structural context found", nil + } + return out, nil +} + +func truncateToolOutput(out string) string { + if len(out) <= codeGraphMaxOutput { + return out + } + truncated := out[:codeGraphMaxOutput] + for len(truncated) > 0 && !utf8.ValidString(truncated) { + truncated = truncated[:len(truncated)-1] + } + return truncated + "\n\n[truncated: CodeGraph output exceeded tool limit]" +} + +func codeGraphTimeoutMessage(stdout, stderr string) string { + msg := "code_graph_context timed out. Try using mode=search with a specific symbol, or reduce max_files/limit." + if partial := strings.TrimSpace(stdout); partial != "" { + msg += "\n\nPartial output:\n" + truncateToolOutput(stripANSI(partial)) + } + if partialErr := strings.TrimSpace(stderr); partialErr != "" { + msg += "\n\nPartial error output:\n" + truncateToolOutput(stripANSI(partialErr)) + } + return msg +} + +// DetectCodeGraph checks whether CodeGraph can be used for repoDir. A negative +// result means the tool definition should be hidden from the model entirely. +func DetectCodeGraph(repoDir string) CodeGraphAvailability { + dbPath := filepath.Join(repoDir, ".codegraph", "codegraph.db") + if _, err := os.Stat(dbPath); err != nil { + return CodeGraphAvailability{Reason: ".codegraph/codegraph.db not found"} + } + + binPath, err := exec.LookPath("codegraph") + if err != nil { + return CodeGraphAvailability{Reason: "codegraph executable not found in PATH"} + } + + ctx, cancel := context.WithTimeout(context.Background(), codeGraphDetectTimeout) + defer cancel() + versionOut, err := exec.CommandContext(ctx, binPath, "version").Output() + if ctx.Err() != nil { + return CodeGraphAvailability{Reason: "codegraph version check timed out"} + } + if err != nil { + return CodeGraphAvailability{Reason: "codegraph version check failed"} + } + version := strings.TrimSpace(string(versionOut)) + if !isSupportedCodeGraphVersion(version) { + return CodeGraphAvailability{Version: version, Reason: "unsupported codegraph version"} + } + + ctx, cancel = context.WithTimeout(context.Background(), codeGraphDetectTimeout) + defer cancel() + cmd := exec.CommandContext(ctx, binPath, "status", repoDir) + cmd.Dir = repoDir + if err := cmd.Run(); err != nil { + if ctx.Err() != nil { + return CodeGraphAvailability{BinPath: binPath, Version: version, Reason: "codegraph status timed out"} + } + return CodeGraphAvailability{BinPath: binPath, Version: version, Reason: "codegraph status failed"} + } + + return CodeGraphAvailability{Available: true, BinPath: binPath, Version: version} +} + +func isSupportedCodeGraphVersion(version string) bool { + if version == "" { + return false + } + parts := strings.Split(strings.TrimPrefix(version, "v"), ".") + if len(parts) == 0 { + return false + } + major, err := strconv.Atoi(parts[0]) + if err != nil { + return false + } + return major == 1 +} + +func stringArg(args map[string]any, key string) string { + value, _ := args[key].(string) + return value +} + +func intArg(args map[string]any, key string, fallback int) int { + switch value := args[key].(type) { + case float64: + return int(value) + case int: + return value + default: + return fallback + } +} + +func stripANSI(s string) string { + return ansiEscapePattern.ReplaceAllString(s, "") +} diff --git a/internal/tool/code_graph_test.go b/internal/tool/code_graph_test.go new file mode 100644 index 00000000..84136f16 --- /dev/null +++ b/internal/tool/code_graph_test.go @@ -0,0 +1,115 @@ +package tool + +import ( + "context" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "unicode/utf8" +) + +func TestCodeGraphExecuteExplore(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("shell script test helper is Unix-only") + } + dir := t.TempDir() + bin := writeFakeCodeGraph(t, dir, `#!/bin/sh +printf 'args:%s\n' "$*" +printf '\033[32mSymbol: Foo\033[0m\n' +`) + + p := NewCodeGraph(dir, bin) + result, err := p.Execute(context.Background(), map[string]any{ + "query": "Foo", + "max_files": float64(2), + }) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(result, "args:explore -p "+dir+" --max-files 2 -- Foo") { + t.Fatalf("unexpected command args: %s", result) + } + if strings.Contains(result, "\033[") { + t.Fatalf("expected ANSI escapes to be stripped, got: %q", result) + } + if !strings.Contains(result, "Symbol: Foo") { + t.Fatalf("expected command output, got: %s", result) + } +} + +func TestCodeGraphExecuteSearchWithKindAndLimit(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("shell script test helper is Unix-only") + } + dir := t.TempDir() + bin := writeFakeCodeGraph(t, dir, `#!/bin/sh +printf '%s\n' "$*" +`) + + p := NewCodeGraph(dir, bin) + result, err := p.Execute(context.Background(), map[string]any{ + "mode": "search", + "query": "Foo", + "kind": "function", + "limit": float64(99), + }) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(result, "query -p "+dir+" -l 12 -k function -- Foo") { + t.Fatalf("expected out-of-range limit to fall back to default, got: %s", result) + } +} + +func TestCodeGraphExecuteUnsupportedMode(t *testing.T) { + p := NewCodeGraph(t.TempDir(), "codegraph") + result, err := p.Execute(context.Background(), map[string]any{ + "mode": "trace", + "query": "Foo", + }) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(result, "unsupported mode") { + t.Fatalf("expected unsupported mode error, got: %s", result) + } +} + +func TestDetectCodeGraphMissingDB(t *testing.T) { + result := DetectCodeGraph(t.TempDir()) + if result.Available { + t.Fatal("expected CodeGraph to be unavailable without .codegraph/codegraph.db") + } + if result.Reason == "" { + t.Fatal("expected unavailable reason") + } +} + +func TestTruncateToolOutputPreservesUTF8(t *testing.T) { + prefix := strings.Repeat("a", codeGraphMaxOutput-1) + result := truncateToolOutput(prefix + "界") + if !strings.Contains(result, "[truncated: CodeGraph output exceeded tool limit]") { + t.Fatalf("expected truncation marker, got: %s", result) + } + if !utf8.ValidString(result) { + t.Fatalf("expected valid UTF-8, got: %q", result) + } +} + +func TestCodeGraphTimeoutMessageIncludesPartialOutput(t *testing.T) { + result := codeGraphTimeoutMessage("partial stdout\n", "partial stderr\n") + if !strings.Contains(result, "timed out") || !strings.Contains(result, "partial stdout") || !strings.Contains(result, "partial stderr") { + t.Fatalf("expected timeout message with partial output, got: %s", result) + } +} + +func writeFakeCodeGraph(t *testing.T, dir, script string) string { + t.Helper() + path := filepath.Join(dir, "codegraph") + if err := os.WriteFile(path, []byte(script), 0755); err != nil { + t.Fatal(err) + } + return path +} diff --git a/internal/tool/definitions.go b/internal/tool/definitions.go index cf24f99e..1a7c8f74 100644 --- a/internal/tool/definitions.go +++ b/internal/tool/definitions.go @@ -11,13 +11,14 @@ type Tool struct { } var ( - Unknown = Tool{name: "unknown"} - TaskDone = Tool{name: "task_done"} - CodeComment = Tool{name: "code_comment"} - FileRead = Tool{name: "file_read"} - FileFind = Tool{name: "file_find"} - FileReadDiff = Tool{name: "file_read_diff"} - CodeSearch = Tool{name: "code_search"} + Unknown = Tool{name: "unknown"} + TaskDone = Tool{name: "task_done"} + CodeComment = Tool{name: "code_comment"} + FileRead = Tool{name: "file_read"} + FileFind = Tool{name: "file_find"} + FileReadDiff = Tool{name: "file_read_diff"} + CodeSearch = Tool{name: "code_search"} + CodeGraphContext = Tool{name: "code_graph_context"} ) func OfName(name string) Tool { @@ -30,7 +31,7 @@ func OfName(name string) Tool { } func allTools() []Tool { - return []Tool{Unknown, TaskDone, CodeComment, FileRead, FileFind, FileReadDiff, CodeSearch} + return []Tool{Unknown, TaskDone, CodeComment, FileRead, FileFind, FileReadDiff, CodeSearch, CodeGraphContext} } // Name returns the tool's identifier name.