Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ issues:
linters:
- funlen
- noctx
- path: mocktail.go
linters:
- gocyclo
text: "cyclomatic complexity 16 of func `processSingleFile` is high" # The complexity is expected.

output:
show-stats: true
175 changes: 172 additions & 3 deletions mocktail.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ func main() {
}

var exported bool
var sourceFile string
var interfaceNames string
flag.BoolVar(&exported, "e", false, "generate exported mocks")
flag.StringVar(&sourceFile, "source", "", "source file containing interfaces to mock")
flag.StringVar(&interfaceNames, "interface", "", "comma-separated list of interface names to mock (used with -source), mock every interface by default (has no effect without -source)")
flag.Parse()

root := info.Dir
Expand All @@ -58,9 +62,17 @@ func main() {
log.Fatalf("Chdir: %v", err)
}

model, err := walk(root, info.Path)
if err != nil {
log.Fatalf("walk: %v", err)
var model map[string]PackageDesc
if sourceFile != "" {
model, err = processSingleFile(sourceFile, root, info.Path, interfaceNames)
if err != nil {
log.Fatalf("process single file: %v", err)
}
} else {
model, err = walk(root, info.Path)
if err != nil {
log.Fatalf("walk: %v", err)
}
}

if len(model) == 0 {
Expand All @@ -73,6 +85,163 @@ func main() {
}
}

// processSingleFile processes a single source file to extract interfaces for mocking.
func processSingleFile(sourceFile, root, moduleName, interfaceFilter string) (map[string]PackageDesc, error) {
model := make(map[string]PackageDesc)

// Convert to absolute path if relative
if !filepath.IsAbs(sourceFile) {
sourceFile = filepath.Join(os.Getenv("PWD"), sourceFile)
}

// Check if file exists
if _, err := os.Stat(sourceFile); os.IsNotExist(err) {
return nil, fmt.Errorf("source file does not exist: %s", sourceFile)
}

// Parse interface filter if provided
targetInterfaces := parseInterfaceFilter(interfaceFilter)

// Load package from source file
pkg, err := loadPackageFromFile(sourceFile, root, moduleName)
if err != nil {
return nil, fmt.Errorf("load package from file: %w", err)
}

if pkg == nil {
return model, nil // Return empty model when no packages found
}

// Process interfaces in the package
packageDesc := processPackageInterfaces(pkg, targetInterfaces)

if len(packageDesc.Interfaces) > 0 {
// Use the source file path as the key, but change the filename to match expected output location
outputDir := filepath.Dir(sourceFile)
outputKey := filepath.Join(outputDir, srcMockFile)
model[outputKey] = packageDesc
}

return model, nil
}

// parseInterfaceFilter parses the interface filter string into a map of target interfaces.
func parseInterfaceFilter(interfaceFilter string) map[string]struct{} {
if interfaceFilter == "" {
return nil
}

targetInterfaces := make(map[string]struct{})
for _, name := range strings.Split(interfaceFilter, ",") {
name = strings.TrimSpace(name)
if name != "" {
targetInterfaces[name] = struct{}{}
}
}

return targetInterfaces
}

// loadPackageFromFile loads a Go package from a source file.
func loadPackageFromFile(sourceFile, root, moduleName string) (*types.Package, error) {
// Get the package path for this file
fileDir := filepath.Dir(sourceFile)

// Load the package
pkgs, err := packages.Load(
&packages.Config{
Mode: packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedFiles,
Dir: fileDir,
},
".",
)
if err != nil {
return nil, fmt.Errorf("load package: %w", err)
}

if len(pkgs) == 0 {
return nil, nil
}

pkg := pkgs[0]
if pkg.Types == nil {
relDir, err := filepath.Rel(root, fileDir)
if err != nil {
return nil, fmt.Errorf("get relative directory: %w", err)
}

return nil, fmt.Errorf("package %q has no type information", path.Join(moduleName, relDir))
}

return pkg.Types, nil
}

// processPackageInterfaces processes all interfaces in a package, optionally filtering by target interfaces.
func processPackageInterfaces(pkg *types.Package, targetInterfaces map[string]struct{}) PackageDesc {
packageDesc := PackageDesc{
Pkg: pkg,
Imports: map[string]struct{}{},
}

scope := pkg.Scope()
for _, name := range scope.Names() {
obj := scope.Lookup(name)
if obj == nil {
continue
}

// If interface filter is specified, only process those interfaces
if targetInterfaces != nil {
if _, wanted := targetInterfaces[name]; !wanted {
continue
}
}

// Check if it's an interface and process it
interfaceDesc := processInterfaceType(name, obj)

if interfaceDesc != nil {
packageDesc.Interfaces = append(packageDesc.Interfaces, *interfaceDesc)
// Collect imports from the interface methods
for _, method := range interfaceDesc.Methods {
for _, imp := range getMethodImports(method, pkg.Path()) {
packageDesc.Imports[imp] = struct{}{}
}
}
}
}

return packageDesc
}

// processInterfaceType processes a single type to check if it's an interface and extract its methods.
func processInterfaceType(name string, obj types.Object) *InterfaceDesc {
// Check if it's an interface
named, ok := obj.Type().(*types.Named)
if !ok {
return nil
}

interfaceType, ok := named.Underlying().(*types.Interface)
if !ok {
return nil
}

interfaceDesc := InterfaceDesc{Name: name}

// Get all methods from the interface
for i := range interfaceType.NumMethods() {
method := interfaceType.Method(i)
interfaceDesc.Methods = append(interfaceDesc.Methods, method)
}

if len(interfaceDesc.Methods) == 0 {
return nil
}

return &interfaceDesc
}

//nolint:gocognit,gocyclo // The complexity is expected.
func walk(root, moduleName string) (map[string]PackageDesc, error) {
model := make(map[string]PackageDesc)
Expand Down
Loading