diff --git a/Makefile b/Makefile index 92ba3adb..1247e98c 100644 --- a/Makefile +++ b/Makefile @@ -30,7 +30,7 @@ MAKE_TARGETS := build check fmt test check-vendor update-pcidb clean $(CHECK_TAR TARGETS := $(MAKE_TARGETS) DOCKER_TARGETS := $(patsubst %,docker-%, $(TARGETS)) -.PHONY: $(TARGETS) $(DOCKER_TARGETS) +.PHONY: $(TARGETS) $(DOCKER_TARGETS) docker-test GOOS ?= linux @@ -91,3 +91,12 @@ update-pcidb: clean: rm -rf nvidia-kubevirt-gpu-device-plugin && rm -rf coverage.out + +docker-test: + $(DOCKER) build --target builder \ + --build-arg GOLANG_VERSION="$(GOLANG_VERSION)" \ + --build-arg DRIVER_VERSION="$(DRIVER_VERSION)" \ + -t kubevirt-gpu-test \ + -f deployments/container/Dockerfile.distroless . + $(DOCKER) run --rm kubevirt-gpu-test \ + bash -c "CGO_ENABLED=1 CGO_CPPFLAGS='-I/usr/include' CGO_LDFLAGS='-L/usr/lib -lnvfm' go test -tags=nvfm -coverprofile=coverage.out.with-mocks $(MODULE)/..." diff --git a/deployments/container/Dockerfile.distroless b/deployments/container/Dockerfile.distroless index fdef8747..b438c081 100644 --- a/deployments/container/Dockerfile.distroless +++ b/deployments/container/Dockerfile.distroless @@ -28,9 +28,19 @@ FROM nvcr.io/nvidia/cuda:13.1.1-base-ubi9 AS builder ARG GOLANG_VERSION +ARG DRIVER_VERSION RUN yum install -y wget tar gzip make gcc glibc-devel +# Install Fabric Manager SDK for NVLink partition support +# Setup NVIDIA network repository and install FM devel package +RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \ + && dnf clean expire-cache \ + && DRIVER_STREAM=$(echo "${DRIVER_VERSION}" | cut -d. -f1) \ + && dnf module enable -y nvidia-driver:${DRIVER_STREAM}-open/default \ + && dnf install -y nvidia-fabric-manager-devel-${DRIVER_VERSION} \ + && dnf clean all + RUN set -eux; \ \ arch="$(uname -m)"; \ @@ -58,7 +68,9 @@ RUN go mod download COPY . . RUN CGO_ENABLED=1 \ - go build -trimpath -o nvidia-kubevirt-gpu-device-plugin ./cmd + CGO_CPPFLAGS="-I/usr/include" \ + CGO_LDFLAGS="-L/usr/lib -lnvfm" \ + go build -trimpath -tags=nvfm -o nvidia-kubevirt-gpu-device-plugin ./cmd FROM nvcr.io/nvidia/distroless/go:v4.0.2 @@ -74,6 +86,9 @@ LABEL description="See summary" COPY --from=builder /workspace/nvidia-kubevirt-gpu-device-plugin /usr/bin/ COPY --from=builder /workspace/utils/pci.ids /usr/pci.ids +COPY --from=builder /usr/lib64/libnvfm* /usr/lib64/ + +ENV LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH USER 0:0 diff --git a/deployments/container/Makefile b/deployments/container/Makefile index c118a1b8..6bae7a09 100644 --- a/deployments/container/Makefile +++ b/deployments/container/Makefile @@ -25,7 +25,7 @@ endif GOPROXY ?= https://proxy.golang.org,direct -IMAGE_VERSION := $(VERSION) +IMAGE_VERSION := $(VERSION)-$(DRIVER_VERSION) IMAGE = $(IMAGE_NAME):$(IMAGE_VERSION) @@ -63,6 +63,7 @@ $(BUILD_TARGETS): build-%: --build-arg VERSION="$(VERSION)" \ --build-arg GIT_COMMIT="$(GIT_COMMIT)" \ --build-arg GOPROXY="$(GOPROXY)" \ + --build-arg DRIVER_VERSION="$(DRIVER_VERSION)" \ --file $(DOCKERFILE) \ $(CURDIR) ifeq ($(PUSH_ON_BUILD),true) diff --git a/manifests/nvidia-kubevirt-gpu-device-plugin.yaml b/manifests/nvidia-kubevirt-gpu-device-plugin.yaml index e756a12b..468e0df2 100644 --- a/manifests/nvidia-kubevirt-gpu-device-plugin.yaml +++ b/manifests/nvidia-kubevirt-gpu-device-plugin.yaml @@ -21,6 +21,13 @@ spec: containers: - name: nvidia-kubevirt-gpu-dp-ctr image: nvcr.io/nvidia/kubevirt-gpu-device-plugin:v1.5.0 + env: + # Fabric Manager Integration (Tech Preview) + # Set to "true" to enable fabric manager partition coordination + # Requires fabric manager daemon running on the node + # Defaults to "false" - device plugin works normally without it + - name: ENABLE_FABRIC_MANAGER + value: "false" securityContext: allowPrivilegeEscalation: false capabilities: diff --git a/pkg/device_plugin/generic_device_plugin.go b/pkg/device_plugin/generic_device_plugin.go index 8a9e11fe..cd9ac631 100644 --- a/pkg/device_plugin/generic_device_plugin.go +++ b/pkg/device_plugin/generic_device_plugin.go @@ -38,6 +38,7 @@ import ( "os" "path" "path/filepath" + "strconv" "strings" "time" @@ -45,6 +46,8 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" + + "kubevirt-gpu-device-plugin/pkg/fabricmanager" ) const ( @@ -56,26 +59,44 @@ const ( vgpuPrefix = "MDEV_PCI_RESOURCE_NVIDIA_COM" ) -var returnIommuMap = getIommuMap -var returnBdfToIommuMap = getBdfToIommuMap +var ( + returnIommuMap = getIommuMap + returnBdfToIommuMap = getBdfToIommuMap + + pciModuleMappingPath = "/run/nvidia-fabricmanager/gpu-pci-module-mapping.json" +) + +// isFabricManagerEnabled returns true if fabric manager integration is enabled +// via the ENABLE_FABRIC_MANAGER environment variable. +func isFabricManagerEnabled() bool { + envVar := os.Getenv("ENABLE_FABRIC_MANAGER") + enabled, _ := strconv.ParseBool(envVar) + return enabled +} // Implements the kubernetes device plugin API type GenericDevicePlugin struct { - devs []*pluginapi.Device - server *grpc.Server - socketPath string - stop chan struct{} // this channel signals to stop the DP - term chan bool // this channel detects kubelet restarts - healthy chan string - unhealthy chan string - devicePath string - deviceName string + devs []*pluginapi.Device + server *grpc.Server + socketPath string + stop chan struct{} // this channel signals to stop the DP + term chan bool // this channel detects kubelet restarts + healthy chan string + unhealthy chan string + devicePath string + deviceName string + fmEnabled bool + partitionManager *fabricmanager.PartitionManager } // Returns an initialized instance of GenericDevicePlugin func NewGenericDevicePlugin(deviceName string, devicePath string, devices []*pluginapi.Device) *GenericDevicePlugin { log.Println("Devicename " + deviceName) serverSock := fmt.Sprintf(pluginapi.DevicePluginPath+"kubevirt-%s.sock", deviceName) + + fmEnabled := isFabricManagerEnabled() + log.Printf("[%s] Fabric manager integration enabled: %t", deviceName, fmEnabled) + dpi := &GenericDevicePlugin{ devs: devices, socketPath: serverSock, @@ -84,6 +105,7 @@ func NewGenericDevicePlugin(deviceName string, devicePath string, devices []*plu unhealthy: make(chan string), deviceName: deviceName, devicePath: devicePath, + fmEnabled: fmEnabled, } return dpi } @@ -138,9 +160,72 @@ func (dpi *GenericDevicePlugin) Start(stop chan struct{}) error { return err } + // Attempt fabric manager connection if enabled + if dpi.fmEnabled { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + fmConfig := &fabricmanager.Config{ + AddressInfo: "/run/nvidia-fabricmanager/fmpm.sock", + AddressType: fabricmanager.AddressTypeUnix, + TimeoutMs: 5000, + MaxRetries: 3, + RetryDelay: time.Second * 2, + Debug: false, + } + fmClient := fabricmanager.NewClient(fmConfig) + + if err := fmClient.Connect(ctx); err != nil { + return fmt.Errorf("fabric manager enabled but connection failed: %w", err) + } + log.Print("Fabric manager connected successfully") + + // Load PCI-to-module mapping + pciToModule, moduleToPCI, mapErr := fabricmanager.LoadPCIModuleMapping(pciModuleMappingPath) + if mapErr != nil { + if disconnErr := fmClient.Disconnect(); disconnErr != nil { + log.Printf("WARNING: Error disconnecting from fabric manager: %v", disconnErr) + } + return fmt.Errorf("failed to load PCI module mapping: %w", mapErr) + } + log.Printf("Loaded PCI-to-module mapping with %d entries", len(pciToModule)) + + // Build device NUMA map from device list + deviceNUMAMap := make(map[string]int64, len(dpi.devs)) + for _, dev := range dpi.devs { + if dev.Topology != nil && len(dev.Topology.Nodes) > 0 { + deviceNUMAMap[dev.ID] = dev.Topology.Nodes[0].ID + } + } + + dpi.partitionManager = fabricmanager.NewPartitionManager( + fmClient, pciToModule, moduleToPCI, deviceNUMAMap, + ) + + // List and log partition information + partitions, err := dpi.partitionManager.GetPartitions(ctx) + if err != nil { + log.Printf("WARNING: Failed to retrieve fabric partitions: %v", err) + } else { + log.Printf("INFO: Discovered %d fabric partitions (max: %d):", partitions.NumPartitions, partitions.MaxNumPartitions) + + for _, partition := range partitions.Partitions { + status := "inactive" + if partition.IsActive { + status = "active" + } + log.Printf("INFO: Partition ID %d: %s, GPUs: %d", partition.PartitionID, status, partition.NumGPUs) + + for _, gpu := range partition.GPUs { + log.Printf("INFO: GPU %d: NVLinks=%d/%d", gpu.PhysicalID, gpu.NumNVLinksAvailable, gpu.MaxNumNVLinks) + } + } + } + } + sock, err := net.Listen("unix", dpi.socketPath) if err != nil { - log.Printf("[%s] Error creating GRPC server socket: %v", dpi.deviceName, err) + log.Printf("Error creating GRPC server socket: %v", err) return err } @@ -180,6 +265,17 @@ func (dpi *GenericDevicePlugin) Stop() error { dpi.server.Stop() dpi.server = nil + // Disconnect from fabric manager if connected + if dpi.partitionManager != nil { + if err := dpi.partitionManager.Disconnect(); err != nil { + log.Printf("[%s] WARNING: Error disconnecting from fabric manager: %v", + dpi.deviceName, err) + } else { + log.Printf("[%s] Fabric manager disconnected successfully", dpi.deviceName) + } + dpi.partitionManager = nil + } + return dpi.cleanup() } @@ -356,6 +452,23 @@ func (dpi *GenericDevicePlugin) Allocate(ctx context.Context, reqs *pluginapi.Al responses.ContainerResponses = append(responses.ContainerResponses, &response) } + // Activate fabric manager partitions per container request. + // Each container's devices are activated independently so that each set + // matches its own partition size, avoiding over-constraining multi-container pods. + if dpi.partitionManager != nil { + if !dpi.partitionManager.IsConnected() { + return nil, fmt.Errorf("fabric manager is enabled but connection lost") + } + + for i, req := range reqs.ContainerRequests { + if err := dpi.partitionManager.ActivateForDevices(ctx, req.DevicesIDs); err != nil { + log.Printf("ERROR: Fabric partition activation failed for container %d: %v", i, err) + return nil, fmt.Errorf("fabric manager partition activation failed for container %d: %w", i, err) + } + log.Printf("Fabric partition activated successfully for container %d devices: %v", i, req.DevicesIDs) + } + } + return &responses, nil } @@ -383,144 +496,178 @@ func (dpi *GenericDevicePlugin) PreStartContainer(ctx context.Context, in *plugi // GetPreferredAllocation returns a preferred set of devices to allocate // from a list of available ones. This helps the Topology Manager make // topology-aware allocation decisions based on NUMA affinity. +// When fabric manager is active, it prefers devices that form a complete +// FM partition of the exact requested size with the best NUMA locality. func (dpi *GenericDevicePlugin) GetPreferredAllocation(ctx context.Context, in *pluginapi.PreferredAllocationRequest) (*pluginapi.PreferredAllocationResponse, error) { log.Printf("[%s] GetPreferredAllocation called with %d container request(s)", dpi.deviceName, len(in.ContainerRequests)) response := &pluginapi.PreferredAllocationResponse{} + // Build device-to-NUMA map for logging + deviceToNUMA := make(map[string]int64) + for _, dev := range dpi.devs { + if dev.Topology != nil && len(dev.Topology.Nodes) > 0 { + deviceToNUMA[dev.ID] = dev.Topology.Nodes[0].ID + } + } + for idx, req := range in.ContainerRequests { log.Printf("[%s] Container request %d: Available devices=%v, MustInclude=%v, AllocationSize=%d", dpi.deviceName, idx, req.AvailableDeviceIDs, req.MustIncludeDeviceIDs, req.AllocationSize) - // Build a map of device ID to NUMA node from our device list - deviceToNUMA := make(map[string]int64) - for _, dev := range dpi.devs { - if dev.Topology != nil && len(dev.Topology.Nodes) > 0 { - deviceToNUMA[dev.ID] = dev.Topology.Nodes[0].ID - } + var preferredDevices []string + var err error + + if dpi.partitionManager != nil { + preferredDevices, err = dpi.partitionManager.SelectPreferred(ctx, &fabricmanager.SelectPreferredRequest{ + AvailableDeviceIDs: req.AvailableDeviceIDs, + MustIncludeDeviceIDs: req.MustIncludeDeviceIDs, + AllocationSize: int(req.AllocationSize), + }) + } else { + preferredDevices, err = dpi.preferDevicesByNUMA(req) } - getNUMANode := func(deviceID string) int64 { - if node, ok := deviceToNUMA[deviceID]; ok { - return node - } - return -1 + if err != nil { + return nil, err } - // Group available devices by NUMA node while preserving iteration order - numaToDevices := make(map[int64][]string) - var nodeOrder []int64 - nodeSeen := make(map[int64]struct{}) - for _, deviceID := range req.AvailableDeviceIDs { - numaNode := getNUMANode(deviceID) - numaToDevices[numaNode] = append(numaToDevices[numaNode], deviceID) - if _, ok := nodeSeen[numaNode]; !ok { - nodeOrder = append(nodeOrder, numaNode) - nodeSeen[numaNode] = struct{}{} - } + log.Printf("[%s] Preferred allocation for container %d: %v (NUMA nodes: %v)", + dpi.deviceName, idx, preferredDevices, func() []int64 { + var nodes []int64 + for _, devID := range preferredDevices { + if node, ok := deviceToNUMA[devID]; ok { + nodes = append(nodes, node) + } + } + return nodes + }()) + + response.ContainerResponses = append(response.ContainerResponses, + &pluginapi.ContainerPreferredAllocationResponse{ + DeviceIDs: preferredDevices, + }) + } + + return response, nil +} + +// preferDevicesByNUMA selects preferred devices based on NUMA node locality. +// It prefers devices from the same NUMA node and falls back to kubelet-provided order. +func (dpi *GenericDevicePlugin) preferDevicesByNUMA( + req *pluginapi.ContainerPreferredAllocationRequest, +) ([]string, error) { + // Build a map of device ID to NUMA node from our device list + deviceToNUMA := make(map[string]int64) + for _, dev := range dpi.devs { + if dev.Topology != nil && len(dev.Topology.Nodes) > 0 { + deviceToNUMA[dev.ID] = dev.Topology.Nodes[0].ID + } + } + getNUMANode := func(deviceID string) int64 { + if node, ok := deviceToNUMA[deviceID]; ok { + return node } + return -1 + } - // Prefer devices from the same NUMA node - var preferredDevices []string - preferredSet := make(map[string]struct{}) - selectedPerNode := make(map[int64]int) - addDevice := func(deviceID string) { - if _, exists := preferredSet[deviceID]; exists { - return - } - preferredSet[deviceID] = struct{}{} - numaNode := getNUMANode(deviceID) - selectedPerNode[numaNode]++ - preferredDevices = append(preferredDevices, deviceID) + // Group available devices by NUMA node while preserving iteration order + numaToDevices := make(map[int64][]string) + var nodeOrder []int64 + nodeSeen := make(map[int64]struct{}) + for _, deviceID := range req.AvailableDeviceIDs { + numaNode := getNUMANode(deviceID) + numaToDevices[numaNode] = append(numaToDevices[numaNode], deviceID) + if _, ok := nodeSeen[numaNode]; !ok { + nodeOrder = append(nodeOrder, numaNode) + nodeSeen[numaNode] = struct{}{} } + } - // Always place must-include devices first - selectedNodeOrder := []int64{} - selectedNodeSeen := make(map[int64]struct{}) - for _, deviceID := range req.MustIncludeDeviceIDs { - if _, exists := preferredSet[deviceID]; exists { - continue - } - addDevice(deviceID) - numaNode := getNUMANode(deviceID) - if _, ok := selectedNodeSeen[numaNode]; !ok { - selectedNodeOrder = append(selectedNodeOrder, numaNode) - selectedNodeSeen[numaNode] = struct{}{} - } + // Prefer devices from the same NUMA node + var preferredDevices []string + preferredSet := make(map[string]struct{}) + selectedPerNode := make(map[int64]int) + addDevice := func(deviceID string) { + if _, exists := preferredSet[deviceID]; exists { + return } + preferredSet[deviceID] = struct{}{} + numaNode := getNUMANode(deviceID) + selectedPerNode[numaNode]++ + preferredDevices = append(preferredDevices, deviceID) + } - if len(preferredDevices) > int(req.AllocationSize) { - return nil, fmt.Errorf("number of MustIncludeDeviceIDs (%d) exceeds allocation size (%d)", - len(preferredDevices), req.AllocationSize) + // Always place must-include devices first + selectedNodeOrder := []int64{} + selectedNodeSeen := make(map[int64]struct{}) + for _, deviceID := range req.MustIncludeDeviceIDs { + if _, exists := preferredSet[deviceID]; exists { + continue + } + addDevice(deviceID) + numaNode := getNUMANode(deviceID) + if _, ok := selectedNodeSeen[numaNode]; !ok { + selectedNodeOrder = append(selectedNodeOrder, numaNode) + selectedNodeSeen[numaNode] = struct{}{} } + } - // First, try to satisfy the request from a single NUMA node (including already selected devices) - if len(preferredDevices) < int(req.AllocationSize) { - targetNode := int64(-1) - var candidateNodes []int64 - candidateNodes = append(candidateNodes, selectedNodeOrder...) - for _, node := range nodeOrder { - if _, seen := selectedNodeSeen[node]; seen { - continue - } - candidateNodes = append(candidateNodes, node) - } + if len(preferredDevices) > int(req.AllocationSize) { + return nil, fmt.Errorf("number of MustIncludeDeviceIDs (%d) exceeds allocation size (%d)", + len(preferredDevices), req.AllocationSize) + } - for _, numaNode := range candidateNodes { - availableOnNode := 0 - for _, deviceID := range numaToDevices[numaNode] { - if _, exists := preferredSet[deviceID]; !exists { - availableOnNode++ - } - } - totalOnNode := selectedPerNode[numaNode] + availableOnNode - if totalOnNode >= int(req.AllocationSize) { - log.Printf("[%s] Selecting NUMA node %d (have %d selected, %d available) to satisfy %d devices", - dpi.deviceName, numaNode, selectedPerNode[numaNode], availableOnNode, req.AllocationSize) - targetNode = numaNode - break - } + // First, try to satisfy the request from a single NUMA node (including already selected devices) + if len(preferredDevices) < int(req.AllocationSize) { + targetNode := int64(-1) + var candidateNodes []int64 + candidateNodes = append(candidateNodes, selectedNodeOrder...) + for _, node := range nodeOrder { + if _, seen := selectedNodeSeen[node]; seen { + continue } + candidateNodes = append(candidateNodes, node) + } - if targetNode != -1 { - for _, deviceID := range numaToDevices[targetNode] { - if len(preferredDevices) >= int(req.AllocationSize) { - break - } - addDevice(deviceID) + for _, numaNode := range candidateNodes { + availableOnNode := 0 + for _, deviceID := range numaToDevices[numaNode] { + if _, exists := preferredSet[deviceID]; !exists { + availableOnNode++ } } + totalOnNode := selectedPerNode[numaNode] + availableOnNode + if totalOnNode >= int(req.AllocationSize) { + log.Printf("[%s] Selecting NUMA node %d (have %d selected, %d available) to satisfy %d devices", + dpi.deviceName, numaNode, selectedPerNode[numaNode], availableOnNode, req.AllocationSize) + targetNode = numaNode + break + } } - // If we couldn't fill the request from a single NUMA node, fall back to the kubelet-provided order - if len(preferredDevices) < int(req.AllocationSize) { - log.Printf("[%s] Using kubelet-provided device order to satisfy remaining slots (need %d more)", - dpi.deviceName, int(req.AllocationSize)-len(preferredDevices)) - for _, deviceID := range req.AvailableDeviceIDs { + if targetNode != -1 { + for _, deviceID := range numaToDevices[targetNode] { if len(preferredDevices) >= int(req.AllocationSize) { break } addDevice(deviceID) } } + } - log.Printf("[%s] Preferred allocation for container %d: %v (NUMA nodes: %v)", - dpi.deviceName, idx, preferredDevices, func() []int64 { - var nodes []int64 - for _, devID := range preferredDevices { - if node, ok := deviceToNUMA[devID]; ok { - nodes = append(nodes, node) - } - } - return nodes - }()) - - response.ContainerResponses = append(response.ContainerResponses, - &pluginapi.ContainerPreferredAllocationResponse{ - DeviceIDs: preferredDevices, - }) + // If we couldn't fill the request from a single NUMA node, fall back to the kubelet-provided order + if len(preferredDevices) < int(req.AllocationSize) { + log.Printf("[%s] Using kubelet-provided device order to satisfy remaining slots (need %d more)", + dpi.deviceName, int(req.AllocationSize)-len(preferredDevices)) + for _, deviceID := range req.AvailableDeviceIDs { + if len(preferredDevices) >= int(req.AllocationSize) { + break + } + addDevice(deviceID) + } } - return response, nil + return preferredDevices, nil } // Health check of GPU devices diff --git a/pkg/device_plugin/generic_device_plugin_test.go b/pkg/device_plugin/generic_device_plugin_test.go index 9f6a28a9..d0b8d9a6 100644 --- a/pkg/device_plugin/generic_device_plugin_test.go +++ b/pkg/device_plugin/generic_device_plugin_test.go @@ -40,6 +40,8 @@ import ( . "github.com/onsi/gomega" "google.golang.org/grpc" pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" + + "kubevirt-gpu-device-plugin/pkg/fabricmanager" ) var devices []*pluginapi.Device @@ -267,3 +269,360 @@ var _ = Describe("Generic Device", func() { Expect(devices[1].Health).To(Equal(pluginapi.Healthy)) }) }) + +// mockFMClient is a mock implementation of fabricmanager.Client for testing. +type mockFMClient struct { + partitions *fabricmanager.PartitionList + err error + activatedIDs []uint32 + activateErr error + connected bool +} + +func (m *mockFMClient) Connect(ctx context.Context) error { return nil } +func (m *mockFMClient) Disconnect() error { return nil } +func (m *mockFMClient) IsConnected() bool { return m.connected } +func (m *mockFMClient) GetPartition(ctx context.Context, partitionID uint32) (*fabricmanager.Partition, error) { + return nil, nil +} +func (m *mockFMClient) ActivatePartition(ctx context.Context, req *fabricmanager.ActivateRequest) error { + m.activatedIDs = append(m.activatedIDs, req.PartitionID) + return m.activateErr +} +func (m *mockFMClient) DeactivatePartition(ctx context.Context, partitionID uint32) error { + return nil +} +func (m *mockFMClient) GetPartitionForDevices(ctx context.Context, deviceIDs []string) (*fabricmanager.Partition, error) { + return nil, nil +} +func (m *mockFMClient) GetPartitions(ctx context.Context) (*fabricmanager.PartitionList, error) { + if m.err != nil { + return nil, m.err + } + return m.partitions, nil +} + +var _ = Describe("GetPreferredAllocation() FM Tests", func() { + buildDevice := func(id string, node int64) *pluginapi.Device { + return &pluginapi.Device{ + ID: id, + Health: pluginapi.Healthy, + Topology: &pluginapi.TopologyInfo{ + Nodes: []*pluginapi.NUMANode{ + {ID: node}, + }, + }, + } + } + + // pciToPhysical maps test PCI device names to unique physical module IDs. + pciToPhysical := map[string]uint32{ + "gpu0": 0, "gpu1": 1, "gpu2": 2, "gpu3": 3, + } + + // buildModuleMaps returns pciToModuleID and moduleIDToPCI maps for the test devices. + buildModuleMaps := func() (map[string]uint32, map[uint32]string) { + forward := make(map[string]uint32) + reverse := make(map[uint32]string) + for pci, mod := range pciToPhysical { + forward[pci] = mod + reverse[mod] = pci + } + return forward, reverse + } + + // buildDeviceNUMAMap returns a map from device ID to NUMA node for the given devices. + buildDeviceNUMAMap := func(devs []*pluginapi.Device) map[string]int64 { + m := make(map[string]int64, len(devs)) + for _, dev := range devs { + if dev.Topology != nil && len(dev.Topology.Nodes) > 0 { + m[dev.ID] = dev.Topology.Nodes[0].ID + } + } + return m + } + + buildPartition := func(id uint32, pciBusIDs ...string) fabricmanager.Partition { + gpus := make([]fabricmanager.GPU, len(pciBusIDs)) + for i, bdf := range pciBusIDs { + gpus[i] = fabricmanager.GPU{PCIBusID: bdf, PhysicalID: pciToPhysical[bdf]} + } + return fabricmanager.Partition{ + PartitionID: id, + NumGPUs: uint32(len(pciBusIDs)), + GPUs: gpus, + } + } + + It("selects the single matching FM partition", func() { + mock := &mockFMClient{ + partitions: &fabricmanager.PartitionList{ + NumPartitions: 1, + Partitions: []fabricmanager.Partition{ + buildPartition(1, "gpu0", "gpu1"), + }, + }, + } + fwd, rev := buildModuleMaps() + devs := []*pluginapi.Device{ + buildDevice("gpu0", 0), + buildDevice("gpu1", 0), + buildDevice("gpu2", 1), + buildDevice("gpu3", 1), + } + dpi := &GenericDevicePlugin{ + deviceName: "test", + partitionManager: fabricmanager.NewPartitionManager(mock, fwd, rev, buildDeviceNUMAMap(devs)), + devs: devs, + } + + request := &pluginapi.PreferredAllocationRequest{ + ContainerRequests: []*pluginapi.ContainerPreferredAllocationRequest{ + { + AvailableDeviceIDs: []string{"gpu0", "gpu1", "gpu2", "gpu3"}, + AllocationSize: 2, + }, + }, + } + + resp, err := dpi.GetPreferredAllocation(context.Background(), request) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.ContainerResponses).To(HaveLen(1)) + Expect(resp.ContainerResponses[0].DeviceIDs).To(Equal([]string{"gpu0", "gpu1"})) + }) + + It("picks the partition with fewest NUMA nodes", func() { + mock := &mockFMClient{ + partitions: &fabricmanager.PartitionList{ + NumPartitions: 2, + Partitions: []fabricmanager.Partition{ + // Partition 1: GPUs span NUMA 0 and 1 + buildPartition(1, "gpu0", "gpu2"), + // Partition 2: GPUs both on NUMA 1 + buildPartition(2, "gpu2", "gpu3"), + }, + }, + } + fwd, rev := buildModuleMaps() + devs := []*pluginapi.Device{ + buildDevice("gpu0", 0), + buildDevice("gpu1", 0), + buildDevice("gpu2", 1), + buildDevice("gpu3", 1), + } + dpi := &GenericDevicePlugin{ + deviceName: "test", + partitionManager: fabricmanager.NewPartitionManager(mock, fwd, rev, buildDeviceNUMAMap(devs)), + devs: devs, + } + + request := &pluginapi.PreferredAllocationRequest{ + ContainerRequests: []*pluginapi.ContainerPreferredAllocationRequest{ + { + AvailableDeviceIDs: []string{"gpu0", "gpu1", "gpu2", "gpu3"}, + AllocationSize: 2, + }, + }, + } + + resp, err := dpi.GetPreferredAllocation(context.Background(), request) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.ContainerResponses).To(HaveLen(1)) + // Partition 2 is preferred (both GPUs on NUMA 1, single node) + Expect(resp.ContainerResponses[0].DeviceIDs).To(Equal([]string{"gpu2", "gpu3"})) + }) + + It("tie-breaks by lowest NUMA node ID when distinct node count is equal", func() { + mock := &mockFMClient{ + partitions: &fabricmanager.PartitionList{ + NumPartitions: 2, + Partitions: []fabricmanager.Partition{ + // Partition 1: GPUs both on NUMA 1 + buildPartition(1, "gpu2", "gpu3"), + // Partition 2: GPUs both on NUMA 0 + buildPartition(2, "gpu0", "gpu1"), + }, + }, + } + fwd, rev := buildModuleMaps() + devs := []*pluginapi.Device{ + buildDevice("gpu0", 0), + buildDevice("gpu1", 0), + buildDevice("gpu2", 1), + buildDevice("gpu3", 1), + } + dpi := &GenericDevicePlugin{ + deviceName: "test", + partitionManager: fabricmanager.NewPartitionManager(mock, fwd, rev, buildDeviceNUMAMap(devs)), + devs: devs, + } + + request := &pluginapi.PreferredAllocationRequest{ + ContainerRequests: []*pluginapi.ContainerPreferredAllocationRequest{ + { + AvailableDeviceIDs: []string{"gpu0", "gpu1", "gpu2", "gpu3"}, + AllocationSize: 2, + }, + }, + } + + resp, err := dpi.GetPreferredAllocation(context.Background(), request) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.ContainerResponses).To(HaveLen(1)) + // Partition 2 is preferred (NUMA 0 < NUMA 1) + Expect(resp.ContainerResponses[0].DeviceIDs).To(Equal([]string{"gpu0", "gpu1"})) + }) + + It("returns error when no partition matches the requested size", func() { + mock := &mockFMClient{ + partitions: &fabricmanager.PartitionList{ + NumPartitions: 1, + Partitions: []fabricmanager.Partition{ + // Partition has 4 GPUs, but we request 2 + buildPartition(1, "gpu0", "gpu1", "gpu2", "gpu3"), + }, + }, + } + fwd, rev := buildModuleMaps() + devs := []*pluginapi.Device{ + buildDevice("gpu0", 0), + buildDevice("gpu1", 0), + buildDevice("gpu2", 1), + buildDevice("gpu3", 1), + } + dpi := &GenericDevicePlugin{ + deviceName: "test", + partitionManager: fabricmanager.NewPartitionManager(mock, fwd, rev, buildDeviceNUMAMap(devs)), + devs: devs, + } + + request := &pluginapi.PreferredAllocationRequest{ + ContainerRequests: []*pluginapi.ContainerPreferredAllocationRequest{ + { + AvailableDeviceIDs: []string{"gpu0", "gpu1", "gpu2", "gpu3"}, + AllocationSize: 2, + }, + }, + } + + _, err := dpi.GetPreferredAllocation(context.Background(), request) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("no fabric manager partition of size 2")) + }) + + It("returns error when must-include device is not in any matching partition", func() { + mock := &mockFMClient{ + partitions: &fabricmanager.PartitionList{ + NumPartitions: 1, + Partitions: []fabricmanager.Partition{ + buildPartition(1, "gpu0", "gpu1"), + }, + }, + } + fwd, rev := buildModuleMaps() + devs := []*pluginapi.Device{ + buildDevice("gpu0", 0), + buildDevice("gpu1", 0), + buildDevice("gpu2", 1), + buildDevice("gpu3", 1), + } + dpi := &GenericDevicePlugin{ + deviceName: "test", + partitionManager: fabricmanager.NewPartitionManager(mock, fwd, rev, buildDeviceNUMAMap(devs)), + devs: devs, + } + + request := &pluginapi.PreferredAllocationRequest{ + ContainerRequests: []*pluginapi.ContainerPreferredAllocationRequest{ + { + AvailableDeviceIDs: []string{"gpu0", "gpu1", "gpu2", "gpu3"}, + MustIncludeDeviceIDs: []string{"gpu2"}, + AllocationSize: 2, + }, + }, + } + + _, err := dpi.GetPreferredAllocation(context.Background(), request) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("no fabric manager partition of size 2")) + }) + + It("skips partitions with unavailable GPUs", func() { + mock := &mockFMClient{ + partitions: &fabricmanager.PartitionList{ + NumPartitions: 2, + Partitions: []fabricmanager.Partition{ + // Partition 1 has gpu1 which is not available + buildPartition(1, "gpu0", "gpu1"), + // Partition 2 has both GPUs available + buildPartition(2, "gpu2", "gpu3"), + }, + }, + } + fwd, rev := buildModuleMaps() + devs := []*pluginapi.Device{ + buildDevice("gpu0", 0), + buildDevice("gpu1", 0), + buildDevice("gpu2", 1), + buildDevice("gpu3", 1), + } + dpi := &GenericDevicePlugin{ + deviceName: "test", + partitionManager: fabricmanager.NewPartitionManager(mock, fwd, rev, buildDeviceNUMAMap(devs)), + devs: devs, + } + + request := &pluginapi.PreferredAllocationRequest{ + ContainerRequests: []*pluginapi.ContainerPreferredAllocationRequest{ + { + // gpu1 is NOT in the available list + AvailableDeviceIDs: []string{"gpu0", "gpu2", "gpu3"}, + AllocationSize: 2, + }, + }, + } + + resp, err := dpi.GetPreferredAllocation(context.Background(), request) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.ContainerResponses).To(HaveLen(1)) + // Partition 2 is the only candidate + Expect(resp.ContainerResponses[0].DeviceIDs).To(Equal([]string{"gpu2", "gpu3"})) + }) + + It("places must-include devices first in the result", func() { + mock := &mockFMClient{ + partitions: &fabricmanager.PartitionList{ + NumPartitions: 1, + Partitions: []fabricmanager.Partition{ + buildPartition(1, "gpu0", "gpu1"), + }, + }, + } + fwd, rev := buildModuleMaps() + devs := []*pluginapi.Device{ + buildDevice("gpu0", 0), + buildDevice("gpu1", 0), + } + dpi := &GenericDevicePlugin{ + deviceName: "test", + partitionManager: fabricmanager.NewPartitionManager(mock, fwd, rev, buildDeviceNUMAMap(devs)), + devs: devs, + } + + request := &pluginapi.PreferredAllocationRequest{ + ContainerRequests: []*pluginapi.ContainerPreferredAllocationRequest{ + { + AvailableDeviceIDs: []string{"gpu0", "gpu1"}, + MustIncludeDeviceIDs: []string{"gpu1"}, + AllocationSize: 2, + }, + }, + } + + resp, err := dpi.GetPreferredAllocation(context.Background(), request) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.ContainerResponses).To(HaveLen(1)) + // gpu1 (must-include) comes first, then gpu0 + Expect(resp.ContainerResponses[0].DeviceIDs).To(Equal([]string{"gpu1", "gpu0"})) + }) +}) diff --git a/pkg/fabricmanager/README.md b/pkg/fabricmanager/README.md new file mode 100644 index 00000000..2ad0d473 --- /dev/null +++ b/pkg/fabricmanager/README.md @@ -0,0 +1,278 @@ +# Fabric Manager Client + +This package provides a high-level client for managing NVIDIA Fabric Manager partitions. It is designed to integrate with the KubeVirt GPU device plugin to coordinate GPU device allocations with fabric manager partition management. + +## Overview + +The Fabric Manager Client abstracts the low-level `nvfm` package (NVIDIA Fabric Manager SDK CGO bindings) and provides: + +- High-level partition management operations +- Connection handling with retry logic +- Structured error handling +- Thread-safe operations for concurrent device allocation + +## Architecture + +``` +┌─────────────────────┐ +│ Device Plugin │ +│ (Allocate/Free) │ +└─────────┬───────────┘ + │ + │ Uses + │ +┌─────────▼───────────┐ +│ Fabric Manager │ ← This package +│ Client │ +└─────────┬───────────┘ + │ + │ Uses + │ +┌─────────▼───────────┐ +│ nvfm Package │ ← FM SDK bindings +│ (CGO bindings) │ +└─────────────────────┘ +``` + +## Basic Usage + +### Creating a Client + +```go +package main + +import ( + "context" + + "kubevirt-gpu-device-plugin/pkg/fabricmanager" +) + +func main() { + // Use default configuration + client := fabricmanager.NewClient(nil) + + // Or customize configuration + config := &fabricmanager.Config{ + AddressInfo: "localhost:6666", + AddressType: AddressTypeInet, + TimeoutMs: 10000, + MaxRetries: 5, + Debug: true, + } + client = fabricmanager.NewClient(config) + + ctx := context.Background() + + // Connect to fabric manager + if err := client.Connect(ctx); err != nil { + log.Fatalf("Failed to connect: %v", err) + } + defer client.Disconnect() +} +``` + +### Managing Partitions + +```go +// Get all available partitions +partitions, err := client.GetPartitions(ctx) +if err != nil { + return err +} + +// Find partition for specific GPU devices +deviceIDs := []string{"0000:01:00.0", "0000:02:00.0"} +partition, err := client.GetPartitionForDevices(ctx, deviceIDs) +if err != nil { + return err +} + +// Activate a partition +req := &fabricmanager.ActivateRequest{ + PartitionID: partition.PartitionID, +} +if err := client.ActivatePartition(ctx, req); err != nil { + return err +} + +// Deactivate when done +if err := client.DeactivatePartition(ctx, partition.PartitionID); err != nil { + return err +} +``` + +## Device Plugin Integration + +### Integration Points + +The fabric manager client should be integrated at these points in the device plugin lifecycle: + +1. **Initialization**: Create and connect the fabric manager client +2. **Device Allocation** (`Allocate()` method): Activate appropriate partition +3. **Device Cleanup**: Deactivate partitions when devices are released +4. **Shutdown**: Disconnect from fabric manager + +### Example Integration + +```go +// In device_plugin/generic_device_plugin.go + +import ( + "kubevirt-gpu-device-plugin/pkg/fabricmanager" +) + +type GenericDevicePlugin struct { + // ... existing fields ... + fmClient fabricmanager.Client +} + +func NewGenericDevicePlugin(deviceName string, devicePath string, devices []*pluginapi.Device) *GenericDevicePlugin { + // ... existing code ... + + // Initialize fabric manager client + fmClient := fabricmanager.NewClient(nil) + + dpi := &GenericDevicePlugin{ + // ... existing fields ... + fmClient: fmClient, + } + return dpi +} + +func (dpi *GenericDevicePlugin) Start(stop chan struct{}) error { + // ... existing start logic ... + + // Connect to fabric manager + ctx := context.Background() + if err := dpi.fmClient.Connect(ctx); err != nil { + log.Printf("[%s] Warning: Could not connect to fabric manager: %v", dpi.deviceName, err) + // Continue without fabric manager - graceful degradation + } + + // ... rest of start logic ... +} + +func (dpi *GenericDevicePlugin) Stop() error { + // ... existing stop logic ... + + // Disconnect from fabric manager + if dpi.fmClient != nil { + dpi.fmClient.Disconnect() + } + + return dpi.cleanup() +} + +func (dpi *GenericDevicePlugin) Allocate(ctx context.Context, reqs *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) { + // ... existing allocation logic ... + + // Extract device IDs from the request + var allDeviceIDs []string + for _, req := range reqs.ContainerRequests { + allDeviceIDs = append(allDeviceIDs, req.DevicesIDs...) + } + + // Activate fabric partition if fabric manager is available + if dpi.fmClient != nil && dpi.fmClient.IsConnected() { + if err := dpi.activateFabricPartition(ctx, allDeviceIDs); err != nil { + log.Fatalf("[%s]: Failed to activate fabric manager partition: %v", dpi.deviceName, err) + } + } + + // ... rest of allocation logic ... +} + +func (dpi *GenericDevicePlugin) activateFabricPartition(ctx context.Context, deviceIDs []string) error { + if len(deviceIDs) == 0 { + return nil + } + + // Find appropriate partition for the devices + partition, err := dpi.fmClient.GetPartitionForDevices(ctx, deviceIDs) + if err != nil { + if fabricmanager.IsPermanent(err) { + return err // Don't retry permanent errors + } + return err + } + + // Activate the partition + req := &fabricmanager.ActivateRequest{ + PartitionID: partition.PartitionID, + } + + return dpi.fmClient.ActivatePartition(ctx, req) +} +``` + +### Error Handling Strategy + +The client provides structured error handling to help with integration decisions: + +```go +// Check if an error is retryable +if fabricmanager.IsRetryable(err) { + // Implement retry logic + return retryOperation() +} + +// Check if an error is permanent +if fabricmanager.IsPermanent(err) { + // Don't retry, handle gracefully + log.Printf("Permanent error, continuing without fabric manager: %v", err) + return nil +} +``` + +## Configuration + +### Environment Variables + +The fabric manager client can be configured via environment variables: + +- `FABRIC_MANAGER_ADDRESS`: Address of fabric manager instance +- `FABRIC_MANAGER_TIMEOUT`: Connection timeout in milliseconds +- `FABRIC_MANAGER_MAX_RETRIES`: Maximum number of connection retries + +### Build Tags + +The client depends on the `nvfm` package which uses build tags: + +- Build with `nvfm` tag when fabric manager libraries are available +- Build without tag for stub implementations + +```bash +# Build with fabric manager support +go build -tags=nvfm ./cmd/main.go + +# Build without fabric manager support (stub) +go build ./cmd/main.go +``` + +## Thread Safety + +The client is thread-safe and can be used concurrently from multiple goroutines. Internal operations are protected by read-write mutexes. + +## Testing + +The package includes comprehensive unit tests. For integration testing with the device plugin: + +```go +// Mock the fabric manager client for testing +type mockFabricManagerClient struct { + partitions map[string]*fabricmanager.Partition +} + +func (m *mockFabricManagerClient) GetPartitionForDevices(ctx context.Context, deviceIDs []string) (*fabricmanager.Partition, error) { + // Mock implementation +} +``` + +## Dependencies + +- `kubevirt-gpu-device-plugin/pkg/nvfm`: NVIDIA Fabric Manager SDK CGO bindings +- Standard library packages: `context`, `sync`, `time`, `fmt`, `errors` + +## License + +Copyright (c) 2026, NVIDIA CORPORATION. See license header in source files. diff --git a/pkg/fabricmanager/client.go b/pkg/fabricmanager/client.go new file mode 100644 index 00000000..9cf03c92 --- /dev/null +++ b/pkg/fabricmanager/client.go @@ -0,0 +1,357 @@ +//go:build nvfm + +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of NVIDIA CORPORATION nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package fabricmanager + +import ( + "context" + "fmt" + "sync" + "time" + + "kubevirt-gpu-device-plugin/pkg/nvfm" +) + +// Handle interface abstracts the NVFM handle operations for testing. +type Handle interface { + GetSupportedFabricPartitions() (*nvfm.FabricPartitionList, error) + ActivateFabricPartition(partitionID uint32) error + ActivateFabricPartitionWithVFs(partitionID uint32, vfs []nvfm.PCIDevice) error + DeactivateFabricPartition(partitionID uint32) error + Disconnect() error +} + +// Client provides a high-level interface for managing fabric manager partitions +// used by the device plugin to coordinate GPU allocations with fabric manager. +type Client interface { + // Connect establishes connection to the fabric manager. + Connect(ctx context.Context) error + + // Disconnect closes the connection to fabric manager. + Disconnect() error + + // IsConnected returns true if connected to fabric manager. + IsConnected() bool + + // GetPartitions retrieves all supported fabric partitions. + GetPartitions(ctx context.Context) (*PartitionList, error) + + // GetPartition retrieves information about a specific partition. + GetPartition(ctx context.Context, partitionID uint32) (*Partition, error) + + // ActivatePartition activates a fabric partition for the given devices. + ActivatePartition(ctx context.Context, req *ActivateRequest) error + + // DeactivatePartition deactivates a fabric partition. + DeactivatePartition(ctx context.Context, partitionID uint32) error + + // GetPartitionForDevices finds the appropriate partition for the given GPU devices. + GetPartitionForDevices(ctx context.Context, deviceIDs []string) (*Partition, error) +} + +// AddressType represents the address type for fabric manager connections. +type AddressType int + +const ( + // AddressTypeInet represents TCP/IP connections. + AddressTypeInet AddressType = iota + // AddressTypeUnix represents Unix domain socket connections. + AddressTypeUnix + // AddressTypeVsock represents VSOCK connections. + AddressTypeVsock +) + +// Config contains configuration options for the fabric manager client. +type Config struct { + // Address information for connecting to fabric manager. + AddressInfo string + + // Address type (INET, UNIX, VSOCK). + AddressType AddressType + + // Connection timeout in milliseconds. + TimeoutMs uint32 + + // Number of connection retry attempts. + MaxRetries int + + // Delay between retry attempts. + RetryDelay time.Duration + + // Enable debug logging. + Debug bool +} + +// DefaultConfig returns a default configuration. +func DefaultConfig() *Config { + return &Config{ + AddressInfo: "localhost:6666", // Default fabric manager address + AddressType: AddressTypeInet, + TimeoutMs: 5000, + MaxRetries: 3, + RetryDelay: time.Second * 2, + Debug: false, + } +} + +// client is the concrete implementation of the Client interface. +type client struct { + config *Config + handle Handle + connected bool + mutex sync.RWMutex +} + +// NewClient creates a new fabric manager client with the given configuration. +func NewClient(config *Config) Client { + if config == nil { + config = DefaultConfig() + } + + return &client{ + config: config, + connected: false, + } +} + +// toNVFMAddressType converts fabricmanager.AddressType to nvfm.AddressType. +func toNVFMAddressType(addrType AddressType) nvfm.AddressType { + switch addrType { + case AddressTypeInet: + return nvfm.AddressTypeInet + case AddressTypeUnix: + return nvfm.AddressTypeUnix + case AddressTypeVsock: + return nvfm.AddressTypeVsock + default: + return nvfm.AddressTypeInet // fallback to default + } +} + +// Connect establishes connection to the fabric manager. +func (c *client) Connect(ctx context.Context) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.connected { + return nil + } + + // Initialize fabric manager library + if err := nvfm.Init(); err != nil { + return newClientError("connect", err, "failed to initialize fabric manager library") + } + + connectParams := nvfm.ConnectParams{ + AddressInfo: c.config.AddressInfo, + AddressType: toNVFMAddressType(c.config.AddressType), + TimeoutMs: c.config.TimeoutMs, + } + + var handle *nvfm.Handle + var err error + + // Retry connection with exponential backoff + for attempt := 0; attempt <= c.config.MaxRetries; attempt++ { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + handle, err = nvfm.Connect(connectParams) + if err == nil { + break + } + + if attempt < c.config.MaxRetries { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(c.config.RetryDelay * time.Duration(attempt+1)): + } + } + } + + if err != nil { + nvfm.Shutdown() + return newClientError("connect", err, fmt.Sprintf("failed after %d attempts", c.config.MaxRetries+1)) + } + + c.handle = handle + c.connected = true + + return nil +} + +// Disconnect closes the connection to fabric manager. +func (c *client) Disconnect() error { + c.mutex.Lock() + defer c.mutex.Unlock() + + if !c.connected { + return nil + } + + var err error + if c.handle != nil { + err = c.handle.Disconnect() + c.handle = nil + } + + nvfm.Shutdown() + c.connected = false + + return err +} + +// IsConnected returns true if connected to fabric manager. +func (c *client) IsConnected() bool { + c.mutex.RLock() + defer c.mutex.RUnlock() + + return c.connected +} + +// GetPartitions retrieves all supported fabric partitions. +func (c *client) GetPartitions(ctx context.Context) (*PartitionList, error) { + c.mutex.RLock() + defer c.mutex.RUnlock() + + if !c.connected { + return nil, ErrNotConnected + } + + nvfmPartitions, err := c.handle.GetSupportedFabricPartitions() + if err != nil { + return nil, newClientError("get_partitions", err, "") + } + + return fromNVFMPartitionList(nvfmPartitions), nil +} + +// GetPartition retrieves information about a specific partition. +func (c *client) GetPartition(ctx context.Context, partitionID uint32) (*Partition, error) { + partitions, err := c.GetPartitions(ctx) + if err != nil { + return nil, err + } + + for _, partition := range partitions.Partitions { + if partition.PartitionID == partitionID { + return &partition, nil + } + } + + return nil, newClientError("get_partition", ErrPartitionNotFound, fmt.Sprintf("partition ID: %d", partitionID)) +} + +// ActivatePartition activates a fabric partition for the given devices. +func (c *client) ActivatePartition(ctx context.Context, req *ActivateRequest) error { + c.mutex.RLock() + defer c.mutex.RUnlock() + + if !c.connected { + return ErrNotConnected + } + + if req == nil { + return newClientError("activate_partition", ErrInvalidRequest, "activation request cannot be nil") + } + + var err error + if len(req.VFDevices) > 0 { + nvfmVFs := toNVFMPCIDevices(req.VFDevices) + err = c.handle.ActivateFabricPartitionWithVFs(req.PartitionID, nvfmVFs) + } else { + err = c.handle.ActivateFabricPartition(req.PartitionID) + } + + if err != nil { + return newClientError("activate_partition", err, fmt.Sprintf("partition ID: %d", req.PartitionID)) + } + + return nil +} + +// DeactivatePartition deactivates a fabric partition. +func (c *client) DeactivatePartition(ctx context.Context, partitionID uint32) error { + c.mutex.RLock() + defer c.mutex.RUnlock() + + if !c.connected { + return ErrNotConnected + } + + err := c.handle.DeactivateFabricPartition(partitionID) + if err != nil { + return newClientError("deactivate_partition", err, fmt.Sprintf("partition ID: %d", partitionID)) + } + + return nil +} + +// GetPartitionForDevices finds the appropriate partition for the given GPU devices +// This method analyzes the device IDs (PCI BDF addresses) and finds a partition +// that contains GPUs matching those addresses. +func (c *client) GetPartitionForDevices(ctx context.Context, deviceIDs []string) (*Partition, error) { + if len(deviceIDs) == 0 { + return nil, ErrNoDevicesProvided + } + + partitions, err := c.GetPartitions(ctx) + if err != nil { + return nil, err + } + + // Convert device IDs to a set for faster lookup + deviceSet := make(map[string]struct{}) + for _, deviceID := range deviceIDs { + deviceSet[deviceID] = struct{}{} + } + + // Find partitions that contain all requested devices + for _, partition := range partitions.Partitions { + matchedDevices := 0 + + for _, gpu := range partition.GPUs { + if _, exists := deviceSet[gpu.PCIBusID]; exists { + matchedDevices++ + } + } + + // If all requested devices are found in this partition, return it + if matchedDevices == len(deviceIDs) { + return &partition, nil + } + } + + return nil, newClientError("get_partition_for_devices", ErrPartitionNotAvailable, fmt.Sprintf("devices: %v", deviceIDs)) +} diff --git a/pkg/fabricmanager/client_test.go b/pkg/fabricmanager/client_test.go new file mode 100644 index 00000000..dd845f42 --- /dev/null +++ b/pkg/fabricmanager/client_test.go @@ -0,0 +1,324 @@ +//go:build nvfm + +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of NVIDIA CORPORATION nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package fabricmanager + +import ( + "context" + "errors" + "testing" + "time" + + "kubevirt-gpu-device-plugin/pkg/nvfm" +) + +func TestNewClient(t *testing.T) { + tests := []struct { + name string + config *Config + expected *Config + }{ + { + name: "with nil config", + config: nil, + expected: &Config{ + AddressInfo: "localhost:6666", + AddressType: AddressTypeInet, + TimeoutMs: 5000, + MaxRetries: 3, + RetryDelay: time.Second * 2, + Debug: false, + }, + }, + { + name: "with custom config", + config: &Config{ + AddressInfo: "custom:8080", + AddressType: AddressTypeUnix, + TimeoutMs: 10000, + MaxRetries: 5, + RetryDelay: time.Second * 3, + Debug: true, + }, + expected: &Config{ + AddressInfo: "custom:8080", + AddressType: AddressTypeUnix, + TimeoutMs: 10000, + MaxRetries: 5, + RetryDelay: time.Second * 3, + Debug: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := NewClient(tt.config).(*client) + + if client.config.AddressInfo != tt.expected.AddressInfo { + t.Errorf("expected AddressInfo %s, got %s", tt.expected.AddressInfo, client.config.AddressInfo) + } + if client.config.AddressType != tt.expected.AddressType { + t.Errorf("expected AddressType %v, got %v", tt.expected.AddressType, client.config.AddressType) + } + if client.config.TimeoutMs != tt.expected.TimeoutMs { + t.Errorf("expected TimeoutMs %d, got %d", tt.expected.TimeoutMs, client.config.TimeoutMs) + } + if client.config.MaxRetries != tt.expected.MaxRetries { + t.Errorf("expected MaxRetries %d, got %d", tt.expected.MaxRetries, client.config.MaxRetries) + } + if client.config.RetryDelay != tt.expected.RetryDelay { + t.Errorf("expected RetryDelay %v, got %v", tt.expected.RetryDelay, client.config.RetryDelay) + } + if client.config.Debug != tt.expected.Debug { + t.Errorf("expected Debug %v, got %v", tt.expected.Debug, client.config.Debug) + } + + if client.connected { + t.Error("expected client to not be connected initially") + } + }) + } +} + +func TestClient_IsConnected(t *testing.T) { + client := NewClient(nil).(*client) + + if client.IsConnected() { + t.Error("expected client to not be connected initially") + } + + client.connected = true + if !client.IsConnected() { + t.Error("expected client to be connected after setting connected=true") + } +} + +func TestClient_GetPartitions_NotConnected(t *testing.T) { + client := NewClient(nil) + ctx := context.Background() + + _, err := client.GetPartitions(ctx) + + if !errors.Is(err, ErrNotConnected) { + t.Errorf("expected ErrNotConnected, got %v", err) + } +} + +func TestClient_GetPartition_NotFound(t *testing.T) { + client := NewClient(nil).(*client) + client.connected = true + client.handle = &mockHandle{ + partitions: &nvfm.FabricPartitionList{ + NumPartitions: 1, + MaxNumPartitions: 4, + Partitions: []nvfm.PartitionInfo{ + { + PartitionID: 1, + IsActive: false, + NumGPUs: 2, + }, + }, + }, + } + + ctx := context.Background() + + _, err := client.GetPartition(ctx, 999) + + var clientErr *ClientError + if !errors.As(err, &clientErr) { + t.Errorf("expected ClientError, got %T", err) + } else if !errors.Is(clientErr.Err, ErrPartitionNotFound) { + t.Errorf("expected ErrPartitionNotFound, got %v", clientErr.Err) + } +} + +func TestClient_ActivatePartition_InvalidRequest(t *testing.T) { + client := NewClient(nil).(*client) + client.connected = true + + ctx := context.Background() + + err := client.ActivatePartition(ctx, nil) + + var clientErr *ClientError + if !errors.As(err, &clientErr) { + t.Errorf("expected ClientError, got %T", err) + } else if !errors.Is(clientErr.Err, ErrInvalidRequest) { + t.Errorf("expected ErrInvalidRequest, got %v", clientErr.Err) + } +} + +func TestClient_GetPartitionForDevices_NoDevices(t *testing.T) { + client := NewClient(nil) + ctx := context.Background() + + _, err := client.GetPartitionForDevices(ctx, []string{}) + + if !errors.Is(err, ErrNoDevicesProvided) { + t.Errorf("expected ErrNoDevicesProvided, got %v", err) + } +} + +func TestIsRetryable(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "timeout error", + err: nvfm.Timeout, + expected: true, + }, + { + name: "connection not valid error", + err: nvfm.ConnectionNotValid, + expected: true, + }, + { + name: "not ready error", + err: nvfm.NotReady, + expected: true, + }, + { + name: "bad param error", + err: nvfm.BadParam, + expected: false, + }, + { + name: "wrapped timeout error", + err: &ClientError{Op: "test", Err: nvfm.Timeout}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsRetryable(tt.err) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestIsPermanent(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "bad param error", + err: nvfm.BadParam, + expected: true, + }, + { + name: "not supported error", + err: nvfm.NotSupported, + expected: true, + }, + { + name: "not connected error", + err: ErrNotConnected, + expected: true, + }, + { + name: "partition not found error", + err: ErrPartitionNotFound, + expected: true, + }, + { + name: "timeout error", + err: nvfm.Timeout, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsPermanent(tt.err) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +// mockHandle is a mock implementation of nvfm.Handle for testing +type mockHandle struct { + partitions *nvfm.FabricPartitionList + activated map[uint32]bool +} + +func (m *mockHandle) GetSupportedFabricPartitions() (*nvfm.FabricPartitionList, error) { + if m.partitions == nil { + return nil, nvfm.GenericError + } + return m.partitions, nil +} + +func (m *mockHandle) ActivateFabricPartition(partitionID uint32) error { + if m.activated == nil { + m.activated = make(map[uint32]bool) + } + m.activated[partitionID] = true + return nil +} + +func (m *mockHandle) ActivateFabricPartitionWithVFs(partitionID uint32, vfs []nvfm.PCIDevice) error { + return m.ActivateFabricPartition(partitionID) +} + +func (m *mockHandle) DeactivateFabricPartition(partitionID uint32) error { + if m.activated == nil { + m.activated = make(map[uint32]bool) + } + m.activated[partitionID] = false + return nil +} + +func (m *mockHandle) Disconnect() error { + return nil +} + + diff --git a/pkg/fabricmanager/errors.go b/pkg/fabricmanager/errors.go new file mode 100644 index 00000000..d52356a1 --- /dev/null +++ b/pkg/fabricmanager/errors.go @@ -0,0 +1,151 @@ +//go:build nvfm + +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of NVIDIA CORPORATION nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package fabricmanager + +import ( + "errors" + "fmt" + + "kubevirt-gpu-device-plugin/pkg/nvfm" +) + +var ( + // ErrNotConnected indicates the client is not connected to fabric manager. + ErrNotConnected = errors.New("not connected to fabric manager") + + // ErrPartitionNotFound indicates the requested partition was not found. + ErrPartitionNotFound = errors.New("partition not found") + + // ErrNoDevicesProvided indicates no device IDs were provided. + ErrNoDevicesProvided = errors.New("no device IDs provided") + + // ErrInvalidRequest indicates an invalid request parameter. + ErrInvalidRequest = errors.New("invalid request") + + // ErrPartitionNotAvailable indicates the partition is not available for the requested devices. + ErrPartitionNotAvailable = errors.New("no partition available for devices") +) + +// ClientError wraps fabric manager errors with additional context. +type ClientError struct { + Op string // Operation that failed + Err error // Underlying error + Details string // Additional details +} + +func (e *ClientError) Error() string { + if e.Details != "" { + return fmt.Sprintf("fabricmanager: %s failed: %s (%s)", e.Op, e.Err.Error(), e.Details) + } + return fmt.Sprintf("fabricmanager: %s failed: %s", e.Op, e.Err.Error()) +} + +func (e *ClientError) Unwrap() error { + return e.Err +} + +// newClientError creates a new ClientError. +func newClientError(op string, err error, details string) error { + return &ClientError{ + Op: op, + Err: err, + Details: details, + } +} + +// IsRetryable checks if an error is retryable. +func IsRetryable(err error) bool { + if err == nil { + return false + } + + var clientErr *ClientError + if errors.As(err, &clientErr) { + err = clientErr.Err + } + + // Check nvfm errors that are retryable + if nvfmErr, ok := err.(nvfm.Return); ok { + switch nvfmErr { + case nvfm.Timeout: + return true + case nvfm.ConnectionNotValid: + return true + case nvfm.NotReady: + return true + case nvfm.ResourceNotReady: + return true + default: + return false + } + } + + return false +} + +// IsPermanent checks if an error is permanent and should not be retried. +func IsPermanent(err error) bool { + if err == nil { + return false + } + + var clientErr *ClientError + if errors.As(err, &clientErr) { + err = clientErr.Err + } + + // Check nvfm errors that are permanent + if nvfmErr, ok := err.(nvfm.Return); ok { + switch nvfmErr { + case nvfm.BadParam: + return true + case nvfm.NotSupported: + return true + case nvfm.VersionMismatch: + return true + case nvfm.NotConfigured: + return true + default: + return false + } + } + + // Check client-side errors + if errors.Is(err, ErrNotConnected) || + errors.Is(err, ErrPartitionNotFound) || + errors.Is(err, ErrNoDevicesProvided) || + errors.Is(err, ErrInvalidRequest) || + errors.Is(err, ErrPartitionNotAvailable) { + return true + } + + return false +} diff --git a/pkg/fabricmanager/partition.go b/pkg/fabricmanager/partition.go new file mode 100644 index 00000000..dc3b1b12 --- /dev/null +++ b/pkg/fabricmanager/partition.go @@ -0,0 +1,415 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of NVIDIA CORPORATION nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package fabricmanager + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "sort" + "strconv" + "strings" + "sync" +) + +// SelectPreferredRequest contains the parameters for selecting preferred devices +// from fabric manager partitions. It replaces the kubelet plugin API request type +// with primitive fields to decouple partition orchestration from the plugin API. +type SelectPreferredRequest struct { + AvailableDeviceIDs []string + MustIncludeDeviceIDs []string + AllocationSize int +} + +// PartitionManager encapsulates fabric manager partition orchestration including +// partition selection, NUMA scoring, activation, and PCI-to-module translation. +type PartitionManager struct { + mu sync.Mutex + client Client + pciToModuleID map[string]uint32 + moduleIDToPCI map[uint32]string + deviceNUMAMap map[string]int64 +} + +// NewPartitionManager creates a new PartitionManager with the given client, +// PCI-to-module mappings, and device NUMA topology map. +func NewPartitionManager( + client Client, + pciToModuleID map[string]uint32, + moduleIDToPCI map[uint32]string, + deviceNUMAMap map[string]int64, +) *PartitionManager { + return &PartitionManager{ + client: client, + pciToModuleID: pciToModuleID, + moduleIDToPCI: moduleIDToPCI, + deviceNUMAMap: deviceNUMAMap, + } +} + +// IsConnected returns true if the underlying fabric manager client is connected. +func (pm *PartitionManager) IsConnected() bool { + return pm.client.IsConnected() +} + +// Disconnect closes the connection to fabric manager. +func (pm *PartitionManager) Disconnect() error { + return pm.client.Disconnect() +} + +// GetPartitions retrieves all supported fabric partitions from the fabric manager. +func (pm *PartitionManager) GetPartitions(ctx context.Context) (*PartitionList, error) { + return pm.client.GetPartitions(ctx) +} + +// normalizePCIAddress normalizes a PCI BDF address to match the sysfs format +// used by the Linux kernel (4-digit lowercase hex domain). The mapping file +// produced by NVIDIA tooling may use an 8-digit domain and uppercase hex +// (e.g. "00000000:AB:00.0"), while sysfs uses "0000:ab:00.0". +func normalizePCIAddress(addr string) string { + lower := strings.ToLower(addr) + + colonIdx := strings.Index(lower, ":") + if colonIdx < 0 { + return lower + } + + domain := lower[:colonIdx] + rest := lower[colonIdx:] + + domainVal, err := strconv.ParseUint(domain, 16, 32) + if err != nil { + return lower + } + + if domainVal <= 0xFFFF { + return fmt.Sprintf("%04x%s", domainVal, rest) + } + return fmt.Sprintf("%08x%s", domainVal, rest) +} + +// LoadPCIModuleMapping reads the GPU PCI-to-module mapping JSON file produced +// by NVIDIA driver installation script. The file maps PCI BDF addresses to +// physical module IDs. Returns both forward (PCI->moduleID) and reverse +// (moduleID->PCI) maps. PCI addresses are normalized to the sysfs format +// (4-digit lowercase hex domain) to ensure consistent lookups. +func LoadPCIModuleMapping(path string) (map[string]uint32, map[uint32]string, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, nil, fmt.Errorf("failed to read PCI module mapping file: %w", err) + } + + var raw map[string]string + if err := json.Unmarshal(data, &raw); err != nil { + return nil, nil, fmt.Errorf("failed to parse PCI module mapping JSON: %w", err) + } + + pciToModule := make(map[string]uint32, len(raw)) + moduleToPCI := make(map[uint32]string, len(raw)) + + for pciAddr, moduleIDStr := range raw { + moduleID, err := strconv.ParseUint(moduleIDStr, 10, 32) + if err != nil { + return nil, nil, fmt.Errorf("invalid module ID %q for PCI address %s: %w", moduleIDStr, pciAddr, err) + } + normalized := normalizePCIAddress(pciAddr) + pciToModule[normalized] = uint32(moduleID) + moduleToPCI[uint32(moduleID)] = normalized + } + + return pciToModule, moduleToPCI, nil +} + +// ActivateForDevices handles fabric partition activation for the given devices. +// It converts PCI addresses to module IDs and matches against partition GPU PhysicalIDs, +// since the FM SDK returns empty PCIBusID values in partition data. +func (pm *PartitionManager) ActivateForDevices(ctx context.Context, deviceIDs []string) error { + if len(deviceIDs) == 0 { + return nil + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + deviceModuleIDs := make(map[uint32]struct{}, len(deviceIDs)) + for _, pciAddr := range deviceIDs { + moduleID, ok := pm.pciToModuleID[pciAddr] + if !ok { + return fmt.Errorf("no module ID mapping for PCI address %s", pciAddr) + } + deviceModuleIDs[moduleID] = struct{}{} + } + + partitions, err := pm.client.GetPartitions(ctx) + if err != nil { + return fmt.Errorf("failed to get fabric partitions: %w", err) + } + + var matchedPartition *Partition + for i, partition := range partitions.Partitions { + if int(partition.NumGPUs) != len(deviceIDs) { + continue + } + matched := 0 + for _, gpu := range partition.GPUs { + if _, ok := deviceModuleIDs[gpu.PhysicalID]; ok { + matched++ + } + } + if matched == len(deviceIDs) { + matchedPartition = &partitions.Partitions[i] + break + } + } + + if matchedPartition == nil { + return fmt.Errorf("no partition of size %d found containing all devices %v", len(deviceIDs), deviceIDs) + } + + // Deactivate any active partitions that contain any of the requested devices. + // A partition must be inactive before it can be activated, and different + // partition configurations may reference the same physical GPU. + for _, partition := range partitions.Partitions { + if !partition.IsActive { + continue + } + for _, gpu := range partition.GPUs { + if _, ok := deviceModuleIDs[gpu.PhysicalID]; ok { + log.Printf("Deactivating active partition %d before activation", partition.PartitionID) + if err := pm.client.DeactivatePartition(ctx, partition.PartitionID); err != nil { + return fmt.Errorf("failed to deactivate partition %d: %w", partition.PartitionID, err) + } + break + } + } + } + + req := &ActivateRequest{ + PartitionID: matchedPartition.PartitionID, + } + + if err := pm.client.ActivatePartition(ctx, req); err != nil { + return fmt.Errorf("failed to activate partition %d: %w", matchedPartition.PartitionID, err) + } + + return nil +} + +// SelectPreferred selects preferred devices based on fabric manager partitions. +// It finds FM partitions whose size exactly matches the allocation size and whose GPUs are +// all available, then picks the partition with the best NUMA locality (fewest distinct NUMA +// nodes, tie-broken by lowest NUMA node ID). +// +// Because the FM SDK returns empty PCIBusID values in partition GPU data, this function +// uses the PCI-to-module-ID mapping loaded at startup to translate between PCI addresses +// (used by kubelet) and physical module IDs (used by FM partitions via GPU.PhysicalID). +func (pm *PartitionManager) SelectPreferred( + ctx context.Context, + req *SelectPreferredRequest, +) ([]string, error) { + partitions, err := pm.client.GetPartitions(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get fabric partitions: %w", err) + } + + availableModuleIDs := pciAddrsToModuleIDSet(req.AvailableDeviceIDs, pm.pciToModuleID) + mustIncludeModuleIDs := pciAddrsToModuleIDSet(req.MustIncludeDeviceIDs, pm.pciToModuleID) + + if len(availableModuleIDs) != len(req.AvailableDeviceIDs) { + log.Printf("WARNING: PCI-to-module translation: only %d/%d available devices resolved to module IDs", + len(availableModuleIDs), len(req.AvailableDeviceIDs)) + } + + var candidates []partitionCandidate + for _, partition := range partitions.Partitions { + if int(partition.NumGPUs) != req.AllocationSize { + continue + } + if !allPartitionGPUsAvailable(partition, availableModuleIDs) { + continue + } + if !partitionContainsAllModuleIDs(partition, mustIncludeModuleIDs) { + continue + } + + distinctNodes, minNUMANode := computeNUMALocality(partition, pm.moduleIDToPCI, pm.deviceNUMAMap) + candidates = append(candidates, partitionCandidate{ + partition: partition, + distinctNodes: distinctNodes, + minNUMANode: minNUMANode, + }) + } + + if len(candidates) == 0 { + return nil, fmt.Errorf("no fabric manager partition of size %d found for available devices", + req.AllocationSize) + } + + sortCandidatesByNUMALocality(candidates) + + best := candidates[0] + log.Printf("Selected FM partition %d (NUMA nodes: %d, min NUMA: %d) from %d candidates", + best.partition.PartitionID, best.distinctNodes, best.minNUMANode, len(candidates)) + + return pm.buildPreferredDeviceList(best.partition, req), nil +} + +// buildPreferredDeviceList constructs the ordered list of preferred device PCI +// addresses from the selected partition, placing must-include devices first and +// filling the remainder from available devices in request order. +func (pm *PartitionManager) buildPreferredDeviceList( + partition Partition, + req *SelectPreferredRequest, +) []string { + partitionPCISet := make(map[string]struct{}, len(partition.GPUs)) + for _, gpu := range partition.GPUs { + if pciAddr, ok := pm.moduleIDToPCI[gpu.PhysicalID]; ok { + partitionPCISet[pciAddr] = struct{}{} + } + } + + var preferred []string + added := make(map[string]struct{}) + + for _, id := range req.MustIncludeDeviceIDs { + if _, exists := added[id]; exists { + continue + } + added[id] = struct{}{} + preferred = append(preferred, id) + } + + for _, id := range req.AvailableDeviceIDs { + if len(preferred) >= req.AllocationSize { + break + } + if _, exists := added[id]; exists { + continue + } + if _, inPartition := partitionPCISet[id]; !inPartition { + continue + } + added[id] = struct{}{} + preferred = append(preferred, id) + } + + return preferred +} + +// partitionCandidate holds a partition and its NUMA locality score for ranking. +type partitionCandidate struct { + partition Partition + distinctNodes int + minNUMANode int64 +} + +// pciAddrsToModuleIDSet converts a slice of PCI addresses to a set of module IDs +// using the provided PCI-to-module mapping. +func pciAddrsToModuleIDSet(addrs []string, pciToModuleID map[string]uint32) map[uint32]struct{} { + result := make(map[uint32]struct{}, len(addrs)) + for _, addr := range addrs { + if moduleID, ok := pciToModuleID[addr]; ok { + result[moduleID] = struct{}{} + } + } + return result +} + +// allPartitionGPUsAvailable returns true if every GPU in the partition has its +// PhysicalID present in the available set. +func allPartitionGPUsAvailable(partition Partition, available map[uint32]struct{}) bool { + for _, gpu := range partition.GPUs { + if _, ok := available[gpu.PhysicalID]; !ok { + return false + } + } + return true +} + +// partitionContainsAllModuleIDs returns true if every module ID in required is +// present among the partition's GPU PhysicalIDs. +func partitionContainsAllModuleIDs(partition Partition, required map[uint32]struct{}) bool { + partitionIDs := make(map[uint32]struct{}, len(partition.GPUs)) + for _, gpu := range partition.GPUs { + partitionIDs[gpu.PhysicalID] = struct{}{} + } + for id := range required { + if _, ok := partitionIDs[id]; !ok { + return false + } + } + return true +} + +// computeNUMALocality calculates the NUMA locality score for a partition by +// translating GPU PhysicalIDs to PCI addresses and looking up their NUMA nodes. +// Returns the number of distinct NUMA nodes and the lowest NUMA node ID (-1 if +// no topology info is available). +func computeNUMALocality( + partition Partition, + moduleIDToPCI map[uint32]string, + deviceToNUMA map[string]int64, +) (distinctNodes int, minNUMANode int64) { + numaNodes := make(map[int64]struct{}) + minNUMANode = -1 + for _, gpu := range partition.GPUs { + node := int64(-1) + if pciAddr, ok := moduleIDToPCI[gpu.PhysicalID]; ok { + if n, ok := deviceToNUMA[pciAddr]; ok { + node = n + } + } + numaNodes[node] = struct{}{} + if node >= 0 && (minNUMANode == -1 || node < minNUMANode) { + minNUMANode = node + } + } + return len(numaNodes), minNUMANode +} + +// sortCandidatesByNUMALocality sorts candidates preferring fewest distinct NUMA +// nodes first, then lowest NUMA node ID as tiebreaker. Candidates without +// topology info (minNUMANode == -1) sort last. +func sortCandidatesByNUMALocality(candidates []partitionCandidate) { + sort.Slice(candidates, func(i, j int) bool { + if candidates[i].distinctNodes != candidates[j].distinctNodes { + return candidates[i].distinctNodes < candidates[j].distinctNodes + } + mi, mj := candidates[i].minNUMANode, candidates[j].minNUMANode + if mi == -1 && mj != -1 { + return false + } + if mj == -1 && mi != -1 { + return true + } + return mi < mj + }) +} diff --git a/pkg/fabricmanager/partition_test.go b/pkg/fabricmanager/partition_test.go new file mode 100644 index 00000000..5bdc93d5 --- /dev/null +++ b/pkg/fabricmanager/partition_test.go @@ -0,0 +1,834 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of NVIDIA CORPORATION nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package fabricmanager + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" +) + +// mockFMClient is a mock implementation of Client for testing. +type mockFMClient struct { + partitions *PartitionList + err error + activatedID uint32 + activateErr error + deactivatedIDs []uint32 + deactivateErr error + connected bool +} + +func (m *mockFMClient) Connect(ctx context.Context) error { return nil } +func (m *mockFMClient) Disconnect() error { return nil } +func (m *mockFMClient) IsConnected() bool { return m.connected } +func (m *mockFMClient) GetPartition(ctx context.Context, partitionID uint32) (*Partition, error) { + return nil, nil +} + +func (m *mockFMClient) ActivatePartition(ctx context.Context, req *ActivateRequest) error { + m.activatedID = req.PartitionID + return m.activateErr +} + +func (m *mockFMClient) DeactivatePartition(ctx context.Context, partitionID uint32) error { + if m.deactivateErr != nil { + return m.deactivateErr + } + m.deactivatedIDs = append(m.deactivatedIDs, partitionID) + return nil +} + +func (m *mockFMClient) GetPartitionForDevices(ctx context.Context, deviceIDs []string) (*Partition, error) { + return nil, nil +} + +func (m *mockFMClient) GetPartitions(ctx context.Context) (*PartitionList, error) { + if m.err != nil { + return nil, m.err + } + return m.partitions, nil +} + +func TestLoadPCIModuleMapping(t *testing.T) { + t.Run("parses a valid mapping file", func(t *testing.T) { + tmpDir := t.TempDir() + mappingFile := filepath.Join(tmpDir, "mapping.json") + content := `{ + "0000:3b:00.0": "0", + "0000:86:00.0": "1", + "0000:af:00.0": "2", + "0000:d8:00.0": "3" + }` + if err := os.WriteFile(mappingFile, []byte(content), 0o644); err != nil { + t.Fatalf("failed to write mapping file: %v", err) + } + + pciToModule, moduleToPCI, err := LoadPCIModuleMapping(mappingFile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(pciToModule) != 4 { + t.Errorf("expected 4 entries in pciToModule, got %d", len(pciToModule)) + } + if len(moduleToPCI) != 4 { + t.Errorf("expected 4 entries in moduleToPCI, got %d", len(moduleToPCI)) + } + + expectedPCI := map[string]uint32{ + "0000:3b:00.0": 0, + "0000:86:00.0": 1, + "0000:af:00.0": 2, + "0000:d8:00.0": 3, + } + for addr, expected := range expectedPCI { + if got := pciToModule[addr]; got != expected { + t.Errorf("pciToModule[%s] = %d, want %d", addr, got, expected) + } + } + + expectedModule := map[uint32]string{ + 0: "0000:3b:00.0", + 1: "0000:86:00.0", + 2: "0000:af:00.0", + 3: "0000:d8:00.0", + } + for id, expected := range expectedModule { + if got := moduleToPCI[id]; got != expected { + t.Errorf("moduleToPCI[%d] = %s, want %s", id, got, expected) + } + } + }) + + t.Run("normalizes 8-digit domain and uppercase hex to sysfs format", func(t *testing.T) { + tmpDir := t.TempDir() + mappingFile := filepath.Join(tmpDir, "mapping.json") + content := `{ + "00000000:18:00.0": "2", + "00000000:2A:00.0": "4", + "00000000:3A:00.0": "1", + "00000000:5D:00.0": "3", + "00000000:9A:00.0": "6", + "00000000:AB:00.0": "8", + "00000000:BA:00.0": "5", + "00000000:DB:00.0": "7" + }` + if err := os.WriteFile(mappingFile, []byte(content), 0o644); err != nil { + t.Fatalf("failed to write mapping file: %v", err) + } + + pciToModule, moduleToPCI, err := LoadPCIModuleMapping(mappingFile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Keys should be normalized to 4-digit lowercase domain + expectedPCI := map[string]uint32{ + "0000:18:00.0": 2, + "0000:2a:00.0": 4, + "0000:3a:00.0": 1, + "0000:5d:00.0": 3, + "0000:9a:00.0": 6, + "0000:ab:00.0": 8, + "0000:ba:00.0": 5, + "0000:db:00.0": 7, + } + for addr, expected := range expectedPCI { + got, ok := pciToModule[addr] + if !ok { + t.Errorf("pciToModule missing normalized key %s", addr) + continue + } + if got != expected { + t.Errorf("pciToModule[%s] = %d, want %d", addr, got, expected) + } + } + + // Reverse map values should also be normalized + for moduleID, expectedAddr := range map[uint32]string{ + 2: "0000:18:00.0", + 4: "0000:2a:00.0", + 1: "0000:3a:00.0", + 3: "0000:5d:00.0", + 6: "0000:9a:00.0", + 8: "0000:ab:00.0", + 5: "0000:ba:00.0", + 7: "0000:db:00.0", + } { + if got := moduleToPCI[moduleID]; got != expectedAddr { + t.Errorf("moduleToPCI[%d] = %s, want %s", moduleID, got, expectedAddr) + } + } + }) + + t.Run("returns error for missing file", func(t *testing.T) { + tmpDir := t.TempDir() + _, _, err := LoadPCIModuleMapping(filepath.Join(tmpDir, "nonexistent.json")) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to read PCI module mapping file") { + t.Errorf("unexpected error message: %v", err) + } + }) + + t.Run("returns error for invalid JSON", func(t *testing.T) { + tmpDir := t.TempDir() + mappingFile := filepath.Join(tmpDir, "bad.json") + if err := os.WriteFile(mappingFile, []byte("not json"), 0o644); err != nil { + t.Fatalf("failed to write file: %v", err) + } + + _, _, err := LoadPCIModuleMapping(mappingFile) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to parse PCI module mapping JSON") { + t.Errorf("unexpected error message: %v", err) + } + }) + + t.Run("returns error for invalid module ID value", func(t *testing.T) { + tmpDir := t.TempDir() + mappingFile := filepath.Join(tmpDir, "bad-id.json") + content := `{"0000:3b:00.0": "not-a-number"}` + if err := os.WriteFile(mappingFile, []byte(content), 0o644); err != nil { + t.Fatalf("failed to write file: %v", err) + } + + _, _, err := LoadPCIModuleMapping(mappingFile) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "invalid module ID") { + t.Errorf("unexpected error message: %v", err) + } + }) +} + +func TestSelectPreferred(t *testing.T) { + pciToModuleID := map[string]uint32{ + "0000:3b:00.0": 0, + "0000:86:00.0": 1, + "0000:af:00.0": 2, + "0000:d8:00.0": 3, + } + moduleIDToPCI := map[uint32]string{ + 0: "0000:3b:00.0", + 1: "0000:86:00.0", + 2: "0000:af:00.0", + 3: "0000:d8:00.0", + } + deviceNUMAMap := map[string]int64{ + "0000:3b:00.0": 0, + "0000:86:00.0": 0, + "0000:af:00.0": 1, + "0000:d8:00.0": 1, + } + + partitions := &PartitionList{ + NumPartitions: 2, + MaxNumPartitions: 4, + Partitions: []Partition{ + { + PartitionID: 0, + NumGPUs: 2, + GPUs: []GPU{ + {PhysicalID: 0, PCIBusID: ""}, + {PhysicalID: 1, PCIBusID: ""}, + }, + }, + { + PartitionID: 1, + NumGPUs: 2, + GPUs: []GPU{ + {PhysicalID: 2, PCIBusID: ""}, + {PhysicalID: 3, PCIBusID: ""}, + }, + }, + }, + } + + t.Run("selects partition matching available devices by PhysicalID", func(t *testing.T) { + mock := &mockFMClient{connected: true, partitions: partitions} + pm := NewPartitionManager(mock, pciToModuleID, moduleIDToPCI, deviceNUMAMap) + + preferred, err := pm.SelectPreferred(context.Background(), &SelectPreferredRequest{ + AvailableDeviceIDs: []string{ + "0000:3b:00.0", "0000:86:00.0", "0000:af:00.0", "0000:d8:00.0", + }, + AllocationSize: 2, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(preferred) != 2 { + t.Fatalf("expected 2 preferred devices, got %d", len(preferred)) + } + // Should pick partition 0 (NUMA node 0) since it has better locality + expectedSet := map[string]struct{}{ + "0000:3b:00.0": {}, + "0000:86:00.0": {}, + } + for _, dev := range preferred { + if _, ok := expectedSet[dev]; !ok { + t.Errorf("unexpected device in preferred list: %s", dev) + } + } + }) + + t.Run("respects must-include devices", func(t *testing.T) { + mock := &mockFMClient{connected: true, partitions: partitions} + pm := NewPartitionManager(mock, pciToModuleID, moduleIDToPCI, deviceNUMAMap) + + preferred, err := pm.SelectPreferred(context.Background(), &SelectPreferredRequest{ + AvailableDeviceIDs: []string{ + "0000:3b:00.0", "0000:86:00.0", "0000:af:00.0", "0000:d8:00.0", + }, + MustIncludeDeviceIDs: []string{"0000:af:00.0"}, + AllocationSize: 2, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(preferred) != 2 { + t.Fatalf("expected 2 preferred devices, got %d", len(preferred)) + } + // Must include 0000:af:00.0 (module 2), so partition 1 is selected + found := map[string]bool{} + for _, dev := range preferred { + found[dev] = true + } + if !found["0000:af:00.0"] { + t.Error("expected 0000:af:00.0 in preferred list") + } + if !found["0000:d8:00.0"] { + t.Error("expected 0000:d8:00.0 in preferred list") + } + }) + + t.Run("returns error when no partition matches", func(t *testing.T) { + mock := &mockFMClient{connected: true, partitions: partitions} + pm := NewPartitionManager(mock, pciToModuleID, moduleIDToPCI, deviceNUMAMap) + + _, err := pm.SelectPreferred(context.Background(), &SelectPreferredRequest{ + AvailableDeviceIDs: []string{"0000:3b:00.0"}, + AllocationSize: 2, + }) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "no fabric manager partition") { + t.Errorf("unexpected error message: %v", err) + } + }) + + t.Run("filters out partitions when not all GPUs are available", func(t *testing.T) { + mock := &mockFMClient{connected: true, partitions: partitions} + pm := NewPartitionManager(mock, pciToModuleID, moduleIDToPCI, deviceNUMAMap) + + preferred, err := pm.SelectPreferred(context.Background(), &SelectPreferredRequest{ + // Only 3 devices available — partition 1 has modules 2,3 but module 3 + // (0000:d8:00.0) is not available, so only partition 0 qualifies + AvailableDeviceIDs: []string{ + "0000:3b:00.0", "0000:86:00.0", "0000:af:00.0", + }, + AllocationSize: 2, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(preferred) != 2 { + t.Fatalf("expected 2 preferred devices, got %d", len(preferred)) + } + expectedSet := map[string]struct{}{ + "0000:3b:00.0": {}, + "0000:86:00.0": {}, + } + for _, dev := range preferred { + if _, ok := expectedSet[dev]; !ok { + t.Errorf("unexpected device in preferred list: %s", dev) + } + } + }) +} + +func TestActivateForDevices(t *testing.T) { + t.Run("finds and activates the correct partition by PhysicalID", func(t *testing.T) { + mock := &mockFMClient{ + connected: true, + partitions: &PartitionList{ + NumPartitions: 2, + Partitions: []Partition{ + { + PartitionID: 0, + NumGPUs: 2, + GPUs: []GPU{ + {PhysicalID: 0, PCIBusID: ""}, + {PhysicalID: 1, PCIBusID: ""}, + }, + }, + { + PartitionID: 1, + NumGPUs: 2, + GPUs: []GPU{ + {PhysicalID: 2, PCIBusID: ""}, + {PhysicalID: 3, PCIBusID: ""}, + }, + }, + }, + }, + } + + pm := NewPartitionManager(mock, + map[string]uint32{ + "0000:af:00.0": 2, + "0000:d8:00.0": 3, + }, + map[uint32]string{ + 2: "0000:af:00.0", + 3: "0000:d8:00.0", + }, + nil, + ) + + err := pm.ActivateForDevices(context.Background(), []string{"0000:af:00.0", "0000:d8:00.0"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if mock.activatedID != 1 { + t.Errorf("expected partition 1 to be activated, got %d", mock.activatedID) + } + }) + + t.Run("returns error when PCI address has no module mapping", func(t *testing.T) { + pm := NewPartitionManager(nil, map[string]uint32{}, nil, nil) + + err := pm.ActivateForDevices(context.Background(), []string{"0000:ff:00.0"}) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "no module ID mapping") { + t.Errorf("unexpected error message: %v", err) + } + }) + + t.Run("returns error when no partition matches the devices", func(t *testing.T) { + mock := &mockFMClient{ + connected: true, + partitions: &PartitionList{ + NumPartitions: 1, + Partitions: []Partition{ + { + PartitionID: 0, + NumGPUs: 2, + GPUs: []GPU{ + {PhysicalID: 0, PCIBusID: ""}, + {PhysicalID: 1, PCIBusID: ""}, + }, + }, + }, + }, + } + + pm := NewPartitionManager(mock, + map[string]uint32{ + "0000:af:00.0": 2, + "0000:d8:00.0": 3, + }, + nil, + nil, + ) + + err := pm.ActivateForDevices(context.Background(), []string{"0000:af:00.0", "0000:d8:00.0"}) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "no partition of size") { + t.Errorf("unexpected error message: %v", err) + } + }) + + t.Run("deactivates active partition containing allocated devices before activation", func(t *testing.T) { + mock := &mockFMClient{ + connected: true, + partitions: &PartitionList{ + NumPartitions: 2, + Partitions: []Partition{ + { + PartitionID: 0, + IsActive: true, + NumGPUs: 2, + GPUs: []GPU{ + {PhysicalID: 0, PCIBusID: ""}, + {PhysicalID: 1, PCIBusID: ""}, + }, + }, + { + PartitionID: 1, + IsActive: false, + NumGPUs: 2, + GPUs: []GPU{ + {PhysicalID: 2, PCIBusID: ""}, + {PhysicalID: 3, PCIBusID: ""}, + }, + }, + }, + }, + } + + pm := NewPartitionManager(mock, + map[string]uint32{ + "0000:3b:00.0": 0, + "0000:86:00.0": 1, + }, + map[uint32]string{ + 0: "0000:3b:00.0", + 1: "0000:86:00.0", + }, + nil, + ) + + err := pm.ActivateForDevices(context.Background(), []string{"0000:3b:00.0", "0000:86:00.0"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(mock.deactivatedIDs) != 1 || mock.deactivatedIDs[0] != 0 { + t.Errorf("expected partition 0 to be deactivated, got %v", mock.deactivatedIDs) + } + if mock.activatedID != 0 { + t.Errorf("expected partition 0 to be activated, got %d", mock.activatedID) + } + }) + + t.Run("does not deactivate inactive partitions", func(t *testing.T) { + mock := &mockFMClient{ + connected: true, + partitions: &PartitionList{ + NumPartitions: 2, + Partitions: []Partition{ + { + PartitionID: 0, + IsActive: false, + NumGPUs: 2, + GPUs: []GPU{ + {PhysicalID: 0, PCIBusID: ""}, + {PhysicalID: 1, PCIBusID: ""}, + }, + }, + { + PartitionID: 1, + IsActive: false, + NumGPUs: 2, + GPUs: []GPU{ + {PhysicalID: 2, PCIBusID: ""}, + {PhysicalID: 3, PCIBusID: ""}, + }, + }, + }, + }, + } + + pm := NewPartitionManager(mock, + map[string]uint32{ + "0000:af:00.0": 2, + "0000:d8:00.0": 3, + }, + map[uint32]string{ + 2: "0000:af:00.0", + 3: "0000:d8:00.0", + }, + nil, + ) + + err := pm.ActivateForDevices(context.Background(), []string{"0000:af:00.0", "0000:d8:00.0"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(mock.deactivatedIDs) != 0 { + t.Errorf("expected no deactivations, got %v", mock.deactivatedIDs) + } + if mock.activatedID != 1 { + t.Errorf("expected partition 1 to be activated, got %d", mock.activatedID) + } + }) + + t.Run("returns error when deactivation fails", func(t *testing.T) { + mock := &mockFMClient{ + connected: true, + deactivateErr: fmt.Errorf("deactivation refused"), + partitions: &PartitionList{ + NumPartitions: 1, + Partitions: []Partition{ + { + PartitionID: 0, + IsActive: true, + NumGPUs: 2, + GPUs: []GPU{ + {PhysicalID: 0, PCIBusID: ""}, + {PhysicalID: 1, PCIBusID: ""}, + }, + }, + }, + }, + } + + pm := NewPartitionManager(mock, + map[string]uint32{ + "0000:3b:00.0": 0, + "0000:86:00.0": 1, + }, + nil, + nil, + ) + + err := pm.ActivateForDevices(context.Background(), []string{"0000:3b:00.0", "0000:86:00.0"}) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to deactivate partition 0") { + t.Errorf("unexpected error message: %v", err) + } + }) + + t.Run("does not deactivate active partition with no overlapping devices", func(t *testing.T) { + mock := &mockFMClient{ + connected: true, + partitions: &PartitionList{ + NumPartitions: 2, + Partitions: []Partition{ + { + PartitionID: 0, + IsActive: true, + NumGPUs: 2, + GPUs: []GPU{ + {PhysicalID: 0, PCIBusID: ""}, + {PhysicalID: 1, PCIBusID: ""}, + }, + }, + { + PartitionID: 1, + IsActive: false, + NumGPUs: 2, + GPUs: []GPU{ + {PhysicalID: 2, PCIBusID: ""}, + {PhysicalID: 3, PCIBusID: ""}, + }, + }, + }, + }, + } + + pm := NewPartitionManager(mock, + map[string]uint32{ + "0000:af:00.0": 2, + "0000:d8:00.0": 3, + }, + map[uint32]string{ + 2: "0000:af:00.0", + 3: "0000:d8:00.0", + }, + nil, + ) + + err := pm.ActivateForDevices(context.Background(), []string{"0000:af:00.0", "0000:d8:00.0"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(mock.deactivatedIDs) != 0 { + t.Errorf("expected no deactivations, got %v", mock.deactivatedIDs) + } + if mock.activatedID != 1 { + t.Errorf("expected partition 1 to be activated, got %d", mock.activatedID) + } + }) + + t.Run("rejects partition containing all devices but with mismatched size", func(t *testing.T) { + mock := &mockFMClient{ + connected: true, + partitions: &PartitionList{ + NumPartitions: 1, + Partitions: []Partition{ + { + PartitionID: 0, + NumGPUs: 4, + GPUs: []GPU{ + {PhysicalID: 0, PCIBusID: ""}, + {PhysicalID: 1, PCIBusID: ""}, + {PhysicalID: 2, PCIBusID: ""}, + {PhysicalID: 3, PCIBusID: ""}, + }, + }, + }, + }, + } + + pm := NewPartitionManager(mock, + map[string]uint32{ + "0000:3b:00.0": 0, + "0000:86:00.0": 1, + }, + nil, + nil, + ) + + err := pm.ActivateForDevices(context.Background(), []string{"0000:3b:00.0", "0000:86:00.0"}) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "no partition of size 2") { + t.Errorf("unexpected error message: %v", err) + } + }) + + t.Run("returns nil for empty device list", func(t *testing.T) { + pm := NewPartitionManager(nil, nil, nil, nil) + + err := pm.ActivateForDevices(context.Background(), []string{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("serializes concurrent calls", func(t *testing.T) { + mock := &sequencingMockFMClient{ + partitions: &PartitionList{ + NumPartitions: 2, + Partitions: []Partition{ + {PartitionID: 0, NumGPUs: 2, GPUs: []GPU{{PhysicalID: 0}, {PhysicalID: 1}}}, + {PartitionID: 1, NumGPUs: 2, GPUs: []GPU{{PhysicalID: 2}, {PhysicalID: 3}}}, + }, + }, + delay: 10 * time.Millisecond, + } + + pm := NewPartitionManager(mock, + map[string]uint32{ + "0000:3b:00.0": 0, "0000:86:00.0": 1, + "0000:af:00.0": 2, "0000:d8:00.0": 3, + }, + map[uint32]string{ + 0: "0000:3b:00.0", 1: "0000:86:00.0", + 2: "0000:af:00.0", 3: "0000:d8:00.0", + }, + nil, + ) + + var wg sync.WaitGroup + errs := make([]error, 2) + wg.Add(2) + go func() { + defer wg.Done() + errs[0] = pm.ActivateForDevices(context.Background(), []string{"0000:3b:00.0", "0000:86:00.0"}) + }() + go func() { + defer wg.Done() + errs[1] = pm.ActivateForDevices(context.Background(), []string{"0000:af:00.0", "0000:d8:00.0"}) + }() + wg.Wait() + + for i, err := range errs { + if err != nil { + t.Fatalf("goroutine %d returned error: %v", i, err) + } + } + + ops := mock.getOps() + if len(ops) != 4 { + t.Fatalf("expected 4 operations, got %d: %v", len(ops), ops) + } + + // With serialization, operations are grouped per call: + // [GetPartitions, Activate(X), GetPartitions, Activate(Y)] + // Without serialization, they would interleave: + // [GetPartitions, GetPartitions, Activate(X), Activate(Y)] + for i := 0; i < len(ops)-1; i++ { + if ops[i] == "GetPartitions" && ops[i+1] == "GetPartitions" { + t.Errorf("detected interleaved operations (consecutive GetPartitions at positions %d-%d), want serialized sequences: %v", i, i+1, ops) + } + } + }) +} + +// sequencingMockFMClient is a thread-safe mock that records the order of FM +// operations and supports an artificial delay in GetPartitions to widen the +// race window for concurrent calls. +type sequencingMockFMClient struct { + mu sync.Mutex + ops []string + partitions *PartitionList + delay time.Duration +} + +func (m *sequencingMockFMClient) Connect(ctx context.Context) error { return nil } +func (m *sequencingMockFMClient) Disconnect() error { return nil } +func (m *sequencingMockFMClient) IsConnected() bool { return true } +func (m *sequencingMockFMClient) GetPartition(ctx context.Context, id uint32) (*Partition, error) { + return nil, nil +} +func (m *sequencingMockFMClient) GetPartitionForDevices(ctx context.Context, ids []string) (*Partition, error) { + return nil, nil +} + +func (m *sequencingMockFMClient) GetPartitions(ctx context.Context) (*PartitionList, error) { + m.mu.Lock() + m.ops = append(m.ops, "GetPartitions") + m.mu.Unlock() + if m.delay > 0 { + time.Sleep(m.delay) + } + return m.partitions, nil +} + +func (m *sequencingMockFMClient) ActivatePartition(ctx context.Context, req *ActivateRequest) error { + m.mu.Lock() + m.ops = append(m.ops, fmt.Sprintf("Activate(%d)", req.PartitionID)) + m.mu.Unlock() + return nil +} + +func (m *sequencingMockFMClient) DeactivatePartition(ctx context.Context, id uint32) error { + m.mu.Lock() + m.ops = append(m.ops, fmt.Sprintf("Deactivate(%d)", id)) + m.mu.Unlock() + return nil +} + +func (m *sequencingMockFMClient) getOps() []string { + m.mu.Lock() + defer m.mu.Unlock() + result := make([]string, len(m.ops)) + copy(result, m.ops) + return result +} diff --git a/pkg/fabricmanager/stub.go b/pkg/fabricmanager/stub.go new file mode 100644 index 00000000..15d22b11 --- /dev/null +++ b/pkg/fabricmanager/stub.go @@ -0,0 +1,174 @@ +//go:build !nvfm + +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of NVIDIA CORPORATION nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package fabricmanager + +import ( + "context" + "errors" + "time" +) + +var ErrNotSupported = errors.New("fabricmanager: package built without nvfm support") + +// AddressType represents the address type for fabric manager connections. +type AddressType int + +const ( + // AddressTypeInet represents TCP/IP connections. + AddressTypeInet AddressType = iota + // AddressTypeUnix represents Unix domain socket connections. + AddressTypeUnix + // AddressTypeVsock represents VSOCK connections. + AddressTypeVsock +) + +// Client provides a stub interface when nvfm support is not available. +type Client interface { + Connect(ctx context.Context) error + Disconnect() error + IsConnected() bool + GetPartitions(ctx context.Context) (*PartitionList, error) + GetPartition(ctx context.Context, partitionID uint32) (*Partition, error) + ActivatePartition(ctx context.Context, req *ActivateRequest) error + DeactivatePartition(ctx context.Context, partitionID uint32) error + GetPartitionForDevices(ctx context.Context, deviceIDs []string) (*Partition, error) +} + +// Config contains configuration options for the fabric manager client. +type Config struct { + AddressInfo string + AddressType AddressType + TimeoutMs uint32 + MaxRetries int + RetryDelay time.Duration + Debug bool +} + +// DefaultConfig returns a default configuration. +func DefaultConfig() *Config { + return &Config{ + AddressInfo: "localhost:6666", + AddressType: AddressTypeInet, + TimeoutMs: 5000, + MaxRetries: 3, + RetryDelay: time.Second * 2, + Debug: false, + } +} + +// Stub types +type ( + GPU struct { + PhysicalID uint32 + UUID string + PCIBusID string + NumNVLinksAvailable uint32 + MaxNumNVLinks uint32 + NVLinkLineRateMBps uint32 + } + + Partition struct { + PartitionID uint32 + IsActive bool + NumGPUs uint32 + GPUs []GPU + LastActivated time.Time + LastDeactivated time.Time + ActivationCount int + } + + PartitionList struct { + NumPartitions uint32 + MaxNumPartitions uint32 + Partitions []Partition + } + + ActivateRequest struct { + PartitionID uint32 + VFDevices []PCIDevice + } + + PCIDevice struct { + Domain uint32 + Bus uint32 + Device uint32 + Function uint32 + } + + ClientError struct { + Op string + Err error + Details string + } +) + +func (e *ClientError) Error() string { + if e.Details != "" { + return "fabricmanager: " + e.Op + " failed: " + e.Err.Error() + " (" + e.Details + ")" + } + return "fabricmanager: " + e.Op + " failed: " + e.Err.Error() +} + +func (e *ClientError) Unwrap() error { + return e.Err +} + +var ( + ErrNotConnected = errors.New("not connected to fabric manager") + ErrPartitionNotFound = errors.New("partition not found") + ErrNoDevicesProvided = errors.New("no device IDs provided") + ErrInvalidRequest = errors.New("invalid request") + ErrPartitionNotAvailable = errors.New("no partition available for devices") +) + +// Stub client implementation +type stubClient struct{} + +func NewClient(config *Config) Client { + return &stubClient{} +} + +func (c *stubClient) Connect(ctx context.Context) error { return ErrNotSupported } +func (c *stubClient) Disconnect() error { return ErrNotSupported } +func (c *stubClient) IsConnected() bool { return false } +func (c *stubClient) GetPartitions(ctx context.Context) (*PartitionList, error) { return nil, ErrNotSupported } +func (c *stubClient) GetPartition(ctx context.Context, partitionID uint32) (*Partition, error) { return nil, ErrNotSupported } +func (c *stubClient) ActivatePartition(ctx context.Context, req *ActivateRequest) error { return ErrNotSupported } +func (c *stubClient) DeactivatePartition(ctx context.Context, partitionID uint32) error { return ErrNotSupported } +func (c *stubClient) GetPartitionForDevices(ctx context.Context, deviceIDs []string) (*Partition, error) { return nil, ErrNotSupported } + +func IsRetryable(err error) bool { + return false +} + +func IsPermanent(err error) bool { + return true +} \ No newline at end of file diff --git a/pkg/fabricmanager/types.go b/pkg/fabricmanager/types.go new file mode 100644 index 00000000..06f7e9e8 --- /dev/null +++ b/pkg/fabricmanager/types.go @@ -0,0 +1,142 @@ +//go:build nvfm + +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of NVIDIA CORPORATION nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package fabricmanager + +import ( + "time" + + "kubevirt-gpu-device-plugin/pkg/nvfm" +) + +// GPU represents a GPU device within a fabric partition. +type GPU struct { + PhysicalID uint32 + UUID string + PCIBusID string + NumNVLinksAvailable uint32 + MaxNumNVLinks uint32 + NVLinkLineRateMBps uint32 +} + +// Partition represents a fabric partition with its associated GPUs. +type Partition struct { + PartitionID uint32 + IsActive bool + NumGPUs uint32 + GPUs []GPU + + // Client-side fields for management + LastActivated time.Time + LastDeactivated time.Time + ActivationCount int +} + +// PartitionList contains information about all supported fabric partitions. +type PartitionList struct { + NumPartitions uint32 + MaxNumPartitions uint32 + Partitions []Partition +} + +// ActivateRequest contains parameters for activating a fabric partition. +type ActivateRequest struct { + PartitionID uint32 + VFDevices []PCIDevice // Optional VF devices to activate with the partition +} + +// PCIDevice represents a PCI device (for VF activation). +type PCIDevice struct { + Domain uint32 + Bus uint32 + Device uint32 + Function uint32 +} + +// fromNVFMGPU converts an nvfm.GPUInfo to a GPU. +func fromNVFMGPU(nvfmGPU nvfm.GPUInfo) GPU { + return GPU{ + PhysicalID: nvfmGPU.PhysicalID, + UUID: nvfmGPU.UUID, + PCIBusID: nvfmGPU.PCIBusID, + NumNVLinksAvailable: nvfmGPU.NumNVLinksAvailable, + MaxNumNVLinks: nvfmGPU.MaxNumNVLinks, + NVLinkLineRateMBps: nvfmGPU.NVLinkLineRateMBps, + } +} + +// fromNVFMPartition converts an nvfm.PartitionInfo to a Partition. +func fromNVFMPartition(nvfmPartition nvfm.PartitionInfo) Partition { + gpus := make([]GPU, len(nvfmPartition.GPUs)) + for i, nvfmGPU := range nvfmPartition.GPUs { + gpus[i] = fromNVFMGPU(nvfmGPU) + } + + return Partition{ + PartitionID: nvfmPartition.PartitionID, + IsActive: nvfmPartition.IsActive, + NumGPUs: nvfmPartition.NumGPUs, + GPUs: gpus, + } +} + +// fromNVFMPartitionList converts an nvfm.FabricPartitionList to a PartitionList. +func fromNVFMPartitionList(nvfmList *nvfm.FabricPartitionList) *PartitionList { + partitions := make([]Partition, len(nvfmList.Partitions)) + for i, nvfmPartition := range nvfmList.Partitions { + partitions[i] = fromNVFMPartition(nvfmPartition) + } + + return &PartitionList{ + NumPartitions: nvfmList.NumPartitions, + MaxNumPartitions: nvfmList.MaxNumPartitions, + Partitions: partitions, + } +} + +// toNVFMPCIDevice converts a PCIDevice to an nvfm.PCIDevice. +func toNVFMPCIDevice(device PCIDevice) nvfm.PCIDevice { + return nvfm.PCIDevice{ + Domain: device.Domain, + Bus: device.Bus, + Device: device.Device, + Function: device.Function, + } +} + +// toNVFMPCIDevices converts a slice of PCIDevice to a slice of nvfm.PCIDevice. +func toNVFMPCIDevices(devices []PCIDevice) []nvfm.PCIDevice { + nvfmDevices := make([]nvfm.PCIDevice, len(devices)) + for i, device := range devices { + nvfmDevices[i] = toNVFMPCIDevice(device) + } + return nvfmDevices +} + diff --git a/pkg/nvfm/build.go b/pkg/nvfm/build.go new file mode 100644 index 00000000..cb93f9aa --- /dev/null +++ b/pkg/nvfm/build.go @@ -0,0 +1,5 @@ +//go:build nvfm + +package nvfm + + diff --git a/pkg/nvfm/nvfm.go b/pkg/nvfm/nvfm.go new file mode 100644 index 00000000..c7509ee3 --- /dev/null +++ b/pkg/nvfm/nvfm.go @@ -0,0 +1,349 @@ +//go:build nvfm + +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * This file provides the CGO bindings to the NVIDIA Fabric Manager SDK. The + * Fabric Manager SDK is a shared library, i.e. a set of C/C++ APIs (SDK), and + * the corresponding header files. The library and APIs are used to interface + * with FM when FM runs in the shared NVSwitch and vGPU multi-tenant modes to + * query, activate, and deactivate GPU partitions. + * + * https://docs.nvidia.com/datacenter/tesla/fabric-manager-user-guide/index.html#fabric-manager-sdk + * + * Requirements: + * - libnvfm.so library installed, it provides /usr/include/{nv_fm_agent.h,nv_fm_types.h} + * - to be linked with the flag: -lnvfm + */ + +package nvfm + +/* +#cgo CPPFLAGS: -I/usr/include +#cgo LDFLAGS: -L/usr/lib -lnvfm + +#include +#include "nv_fm_agent.h" +#include "nv_fm_types.h" +*/ +import "C" + +import ( + "errors" + "fmt" + "unsafe" +) + +// Return represents a Fabric Manager API return code. +type Return int32 + +const ( + Success Return = C.FM_ST_SUCCESS + BadParam Return = C.FM_ST_BADPARAM + GenericError Return = C.FM_ST_GENERIC_ERROR + NotSupported Return = C.FM_ST_NOT_SUPPORTED + Uninitialized Return = C.FM_ST_UNINITIALIZED + Timeout Return = C.FM_ST_TIMEOUT + VersionMismatch Return = C.FM_ST_VERSION_MISMATCH + InUse Return = C.FM_ST_IN_USE + NotConfigured Return = C.FM_ST_NOT_CONFIGURED + ConnectionNotValid Return = C.FM_ST_CONNECTION_NOT_VALID + NVLinkError Return = C.FM_ST_NVLINK_ERROR + ResourceBad Return = C.FM_ST_RESOURCE_BAD + ResourceInUse Return = C.FM_ST_RESOURCE_IN_USE + ResourceNotInUse Return = C.FM_ST_RESOURCE_NOT_IN_USE + ResourceExhausted Return = C.FM_ST_RESOURCE_EXHAUSTED + ResourceNotReady Return = C.FM_ST_RESOURCE_NOT_READY + PartitionExists Return = C.FM_ST_PARTITION_EXISTS + PartitionIDInUse Return = C.FM_ST_PARTITION_ID_IN_USE + PartitionIDNotInUse Return = C.FM_ST_PARTITION_ID_NOT_IN_USE + PartitionNameInUse Return = C.FM_ST_PARTITION_NAME_IN_USE + PartitionNameNotInUse Return = C.FM_ST_PARTITION_NAME_NOT_IN_USE + PartitionIDNameMismatch Return = C.FM_ST_PARTITION_ID_NAME_MISMATCH + NotReady Return = C.FM_ST_NOT_READY + ResourceUsedInThisPartition Return = C.FM_ST_RESOURCE_USED_IN_THIS_PARTITION + ResourceUsedInOtherPartition Return = C.FM_ST_RESOURCE_USED_IN_ANOTHER_PARTITION + PartitionMiswiredTrunks Return = C.FM_ST_PARTITION_MISWIRED_TRUNKS + PartitionInsufficientTrunks Return = C.FM_ST_PARTITION_INSUFFICIENT_TRUNKS + PartitionMissingSwitches Return = C.FM_ST_PARTITION_MISSING_SWITCHES + PartitionNetworkConfigError Return = C.FM_ST_PARTITION_NETWORK_CONFIG_ERROR + PartitionRouteProgrammingError Return = C.FM_ST_PARTITION_ROUTE_PROGRAMMING_ERROR +) + +// Error returns the error representation of a Return value. +func (r Return) Error() string { + switch r { + case Success: + return "success" + case BadParam: + return "bad parameter" + case GenericError: + return "generic error" + case NotSupported: + return "not supported" + case Uninitialized: + return "uninitialized" + case Timeout: + return "timeout" + case VersionMismatch: + return "version mismatch" + case InUse: + return "in use" + case NotConfigured: + return "not configured" + case ConnectionNotValid: + return "connection not valid" + case NVLinkError: + return "nvlink error" + case ResourceBad: + return "resource bad" + case ResourceInUse: + return "resource in use" + case ResourceNotInUse: + return "resource not in use" + case ResourceExhausted: + return "resource exhausted" + case ResourceNotReady: + return "resource not ready" + case PartitionExists: + return "partition exists" + case PartitionIDInUse: + return "partition ID in use" + case PartitionIDNotInUse: + return "partition ID not in use" + case PartitionNameInUse: + return "partition name in use" + case PartitionNameNotInUse: + return "partition name not in use" + case PartitionIDNameMismatch: + return "partition ID name mismatch" + case NotReady: + return "not ready" + case ResourceUsedInThisPartition: + return "resource used in this partition" + case ResourceUsedInOtherPartition: + return "resource used in other partition" + case PartitionMiswiredTrunks: + return "partition miswired trunks" + case PartitionInsufficientTrunks: + return "partition insufficient trunks" + case PartitionMissingSwitches: + return "partition missing switches" + case PartitionNetworkConfigError: + return "partition network config error" + case PartitionRouteProgrammingError: + return "partition route programming error" + default: + return fmt.Sprintf("unknown return code: %d", int32(r)) + } +} + +// Handle represents a Fabric Manager API handle. +type Handle struct { + handle C.fmHandle_t +} + +// AddressType represents the address type for connections. +type AddressType int32 + +const ( + AddressTypeUnknown AddressType = C.NV_FM_API_ADDR_TYPE_UNKNOWN + AddressTypeInet AddressType = C.NV_FM_API_ADDR_TYPE_INET + AddressTypeUnix AddressType = C.NV_FM_API_ADDR_TYPE_UNIX + AddressTypeVsock AddressType = C.NV_FM_API_ADDR_TYPE_VSOCK +) + +// ConnectParams contains connection parameters for Fabric Manager. +type ConnectParams struct { + AddressInfo string + TimeoutMs uint32 + AddressType AddressType +} + +// PCIDevice represents PCI device information. +type PCIDevice struct { + Domain uint32 + Bus uint32 + Device uint32 + Function uint32 +} + +// GPUInfo contains information about a GPU in a fabric partition. +type GPUInfo struct { + PhysicalID uint32 + UUID string + PCIBusID string + NumNVLinksAvailable uint32 + MaxNumNVLinks uint32 + NVLinkLineRateMBps uint32 +} + +// PartitionInfo contains information about a fabric partition. +type PartitionInfo struct { + PartitionID uint32 + IsActive bool + NumGPUs uint32 + GPUs []GPUInfo +} + +// FabricPartitionList contains information about all supported fabric partitions. +type FabricPartitionList struct { + NumPartitions uint32 + MaxNumPartitions uint32 + Partitions []PartitionInfo +} + +// Init initializes the Fabric Manager API library. +func Init() error { + ret := C.fmLibInit() + if ret != C.FM_ST_SUCCESS { + return Return(ret) + } + return nil +} + +// Shutdown shuts down the Fabric Manager API library. +func Shutdown() error { + ret := C.fmLibShutdown() + if ret != C.FM_ST_SUCCESS { + return Return(ret) + } + return nil +} + +// Connect connects to a Fabric Manager instance. +func Connect(params ConnectParams) (*Handle, error) { + if len(params.AddressInfo) >= C.FM_MAX_STR_LENGTH { + return nil, errors.New("address info too long") + } + + var connectParams C.fmConnectParams_t + connectParams.version = C.fmConnectParams_version + connectParams.timeoutMs = C.uint(params.TimeoutMs) + connectParams.addressType = uint32(params.AddressType) + if params.AddressType == AddressTypeUnix { + connectParams.addressIsUnixSocket = 1 + } else { + connectParams.addressIsUnixSocket = 0 + } + + // Copy address info + cAddressInfo := C.CString(params.AddressInfo) + defer C.free(unsafe.Pointer(cAddressInfo)) + + // Use C.strncpy equivalent + for i, ch := range params.AddressInfo { + if i >= C.FM_MAX_STR_LENGTH-1 { + break + } + connectParams.addressInfo[i] = C.char(ch) + } + connectParams.addressInfo[len(params.AddressInfo)] = 0 + + handle := &Handle{} + ret := C.fmConnect(&connectParams, &handle.handle) + + if ret != C.FM_ST_SUCCESS { + return nil, Return(ret) + } + + return handle, nil +} + +// Disconnect disconnects from the Fabric Manager instance. +func (h *Handle) Disconnect() error { + ret := C.fmDisconnect(h.handle) + if ret != C.FM_ST_SUCCESS { + return Return(ret) + } + return nil +} + +// GetSupportedFabricPartitions retrieves all supported fabric partitions. +func (h *Handle) GetSupportedFabricPartitions() (*FabricPartitionList, error) { + var cPartitions C.fmFabricPartitionList_t + cPartitions.version = C.fmFabricPartitionList_version + + ret := C.fmGetSupportedFabricPartitions(h.handle, &cPartitions) + if ret != C.FM_ST_SUCCESS { + return nil, Return(ret) + } + + partitions := &FabricPartitionList{ + NumPartitions: uint32(cPartitions.numPartitions), + MaxNumPartitions: uint32(cPartitions.maxNumPartitions), + Partitions: make([]PartitionInfo, cPartitions.numPartitions), + } + + for i := uint32(0); i < partitions.NumPartitions; i++ { + cPartition := cPartitions.partitionInfo[i] + partition := PartitionInfo{ + PartitionID: uint32(cPartition.partitionId), + IsActive: cPartition.isActive != 0, + NumGPUs: uint32(cPartition.numGpus), + GPUs: make([]GPUInfo, cPartition.numGpus), + } + + for j := uint32(0); j < partition.NumGPUs; j++ { + cGPU := cPartition.gpuInfo[j] + partition.GPUs[j] = GPUInfo{ + PhysicalID: uint32(cGPU.physicalId), + UUID: C.GoString(&cGPU.uuid[0]), + PCIBusID: C.GoString(&cGPU.pciBusId[0]), + NumNVLinksAvailable: uint32(cGPU.numNvLinksAvailable), + MaxNumNVLinks: uint32(cGPU.maxNumNvLinks), + NVLinkLineRateMBps: uint32(cGPU.nvlinkLineRateMBps), + } + } + + partitions.Partitions[i] = partition + } + + return partitions, nil +} + +// ActivateFabricPartition activates a fabric partition. +func (h *Handle) ActivateFabricPartition(partitionID uint32) error { + ret := C.fmActivateFabricPartition(h.handle, C.fmFabricPartitionId_t(partitionID)) + if ret != C.FM_ST_SUCCESS { + return Return(ret) + } + return nil +} + +// ActivateFabricPartitionWithVFs activates a fabric partition with VFs. +func (h *Handle) ActivateFabricPartitionWithVFs(partitionID uint32, vfs []PCIDevice) error { + if len(vfs) == 0 { + return h.ActivateFabricPartition(partitionID) + } + + cVFs := make([]C.fmPciDevice_t, len(vfs)) + for i, vf := range vfs { + cVFs[i] = C.fmPciDevice_t{ + domain: C.uint(vf.Domain), + bus: C.uint(vf.Bus), + device: C.uint(vf.Device), + function: C.uint(vf.Function), + } + } + + ret := C.fmActivateFabricPartitionWithVFs( + h.handle, + C.fmFabricPartitionId_t(partitionID), + &cVFs[0], + C.uint(len(vfs)), + ) + if ret != C.FM_ST_SUCCESS { + return Return(ret) + } + return nil +} + +// DeactivateFabricPartition deactivates a fabric partition. +func (h *Handle) DeactivateFabricPartition(partitionID uint32) error { + ret := C.fmDeactivateFabricPartition(h.handle, C.fmFabricPartitionId_t(partitionID)) + if ret != C.FM_ST_SUCCESS { + return Return(ret) + } + return nil +} diff --git a/pkg/nvfm/stub.go b/pkg/nvfm/stub.go new file mode 100644 index 00000000..14105a0d --- /dev/null +++ b/pkg/nvfm/stub.go @@ -0,0 +1,30 @@ +//go:build !nvfm + +package nvfm + +import "errors" + +var ErrNotBuilt = errors.New("nvfm: package built without nvfm support") + +type ( + Return int32 + Handle struct{} + AddressType int32 + ConnectParams struct{} + PCIDevice struct{} + GPUInfo struct{} + PartitionInfo struct{} + FabricPartitionList struct{} +) + +func Init() error { return ErrNotBuilt } +func Shutdown() error { return ErrNotBuilt } +func Connect(ConnectParams) (*Handle, error) { return nil, ErrNotBuilt } +func (h *Handle) Disconnect() error { return ErrNotBuilt } +func (h *Handle) GetSupportedFabricPartitions() (*FabricPartitionList, error) { + return nil, ErrNotBuilt +} +func (h *Handle) ActivateFabricPartition(uint32) error { return ErrNotBuilt } +func (h *Handle) ActivateFabricPartitionWithVFs(uint32, []PCIDevice) error { return ErrNotBuilt } +func (h *Handle) DeactivateFabricPartition(uint32) error { return ErrNotBuilt } + diff --git a/versions.mk b/versions.mk index 3dfaae23..2b139cd8 100644 --- a/versions.mk +++ b/versions.mk @@ -20,4 +20,6 @@ VERSION ?= v1.5.0 GOLANG_VERSION ?= 1.26.1 +DRIVER_VERSION ?= 590.48.01 + GIT_COMMIT ?= $(shell git describe --match="" --dirty --long --always --abbrev=40 2> /dev/null || echo "")