diff --git a/pkg/orchestrator/gke/gke_job_orchestrator_test.go b/pkg/orchestrator/gke/gke_job_orchestrator_test.go index 2aff1f94cf..9cdbcb1f99 100644 --- a/pkg/orchestrator/gke/gke_job_orchestrator_test.go +++ b/pkg/orchestrator/gke/gke_job_orchestrator_test.go @@ -447,11 +447,12 @@ func TestGeneratePathwaysManifest(t *testing.T) { } mockResponses := map[string][]shell.CommandResult{ - "gcloud compute machine-types describe n2-standard-2 --zone=us-central1 --format=json": {{ExitCode: 0, Stdout: `{"guestCpus": 2}`}}, + "gcloud compute machine-types describe n2-standard-2 --zone=us-central1-a --format=json": {{ExitCode: 0, Stdout: `{"guestCpus": 2}`}}, } mockExec := NewMockExecutor(mockResponses) orc := newTestGKEOrchestrator(mockExec) orc.projectID = "mock-project" + orc.clusterZones = []string{"us-central1-a"} orc.clusterDesc.NodePools = []gkeJobNodePool{ {Name: "default-pool", Config: gkeNodePoolConfig{MachineType: "n2-standard-2"}}, } diff --git a/pkg/orchestrator/gke/resource_resolver.go b/pkg/orchestrator/gke/resource_resolver.go index 03bb86a05e..b68de0b2ae 100644 --- a/pkg/orchestrator/gke/resource_resolver.go +++ b/pkg/orchestrator/gke/resource_resolver.go @@ -48,6 +48,32 @@ func (g *GKEOrchestrator) FetchMachineCapacity(machineType, zone string) (int, e return 0, fmt.Errorf("no accelerators or guestCpus found for machine type %s in zone %s", machineType, zone) } +// getZonesForMachineType finds all candidate zones hosting the specified machineType in the cluster's node pools. +func (g *GKEOrchestrator) getZonesForMachineType(machineType string) []string { + seen := make(map[string]struct{}) + for _, np := range g.clusterDesc.NodePools { + // Filter node pools that match the target machine type. + if np.Config.MachineType != machineType { + continue + } + // If the node pool doesn't define custom locations, it inherits cluster-wide locations. + locs := np.Locations + if len(locs) == 0 { + locs = g.clusterZones + } + for _, loc := range locs { + seen[loc] = struct{}{} + } + } + + // Return a deduplicated slice of all matched zones. + var zones []string + for loc := range seen { + zones = append(zones, loc) + } + return zones +} + func (g *GKEOrchestrator) FetchMachineCapabilities(machineType, zone string) (MachineTypeCap, error) { cacheKey := machineType + ":" + zone @@ -60,8 +86,18 @@ func (g *GKEOrchestrator) FetchMachineCapabilities(machineType, zone string) (Ma isRegion := len(strings.Split(zone, "-")) < 3 zonesToTry := []string{zone} - if isRegion && len(g.clusterZones) > 0 { - zonesToTry = g.clusterZones + if isRegion { + npZones := g.getZonesForMachineType(machineType) + if len(npZones) > 0 { + zonesToTry = npZones + } else { + if !g.clusterDesc.Autoscaling.EnableNodeAutoprovisioning { + return MachineTypeCap{}, fmt.Errorf("failed to fetch machine capabilities for %s: no node pool matching machine type found and GKE Node Auto-Provisioning is disabled", machineType) + } + if len(g.clusterZones) > 0 { + zonesToTry = g.clusterZones + } + } } var lastErr error diff --git a/pkg/orchestrator/gke/resource_resolver_test.go b/pkg/orchestrator/gke/resource_resolver_test.go index 61e07ea0a4..8bb3399c42 100644 --- a/pkg/orchestrator/gke/resource_resolver_test.go +++ b/pkg/orchestrator/gke/resource_resolver_test.go @@ -103,6 +103,7 @@ func TestFetchMachineCapacity_AllZonesFail(t *testing.T) { g := newTestGKEOrchestrator(nil) g.machineTypeClient = &MockMachineTypeClient{FailAll: true} g.clusterZones = []string{"europe-west2-a", "europe-west2-c", "europe-west2-b"} + g.clusterDesc.Autoscaling.EnableNodeAutoprovisioning = true _, err := g.FetchMachineCapacity("tpu7x-1", "europe-west2") @@ -347,6 +348,7 @@ func TestResolveAcceleratorShorthand(t *testing.T) { mockExecutor := NewMockExecutor(mockResponses) orc := newTestGKEOrchestrator(mockExecutor) orc.projectID = "mock-project" + orc.clusterZones = []string{"us-central1-a"} if len(tt.nodePools) > 0 { for _, mt := range tt.nodePools { orc.clusterDesc.NodePools = append(orc.clusterDesc.NodePools, gkeJobNodePool{ @@ -826,3 +828,120 @@ func TestValidateConsumptionForStaticCluster(t *testing.T) { }) } } + +type ZoneSelectiveMockMachineTypeClient struct { + AllowedZone string + MT *compute.MachineType +} + +func (m *ZoneSelectiveMockMachineTypeClient) GetMachineType(project, zone, machineType string) (*compute.MachineType, error) { + if zone != m.AllowedZone { + return nil, fmt.Errorf("machine type not found in zone %s", zone) + } + return m.MT, nil +} + +func TestFetchMachineCapabilities_NodePoolSpecificZones(t *testing.T) { + g := newTestGKEOrchestrator(nil) + g.projectID = "test-project" + g.clusterZones = []string{"us-central1-a", "us-central1-b"} + g.clusterDesc.NodePools = []gkeJobNodePool{ + { + Name: "tpu-np-0", + Config: gkeNodePoolConfig{ + MachineType: "tpu-v5-lite-podslice", + }, + Locations: []string{"us-central1-c"}, + }, + } + + g.machineTypeClient = &ZoneSelectiveMockMachineTypeClient{ + AllowedZone: "us-central1-c", + MT: &compute.MachineType{ + GuestCpus: 4, + MemoryMb: 16384, + }, + } + + cap, err := g.FetchMachineCapabilities("tpu-v5-lite-podslice", "us-central1") + if err != nil { + t.Fatalf("FetchMachineCapabilities failed: %v", err) + } + + if cap.GuestCpus != 4 { + t.Errorf("cap.GuestCpus = %d, want 4", cap.GuestCpus) + } +} + +func TestFetchMachineCapabilities_NoPools_NAPDisabled_Fails(t *testing.T) { + g := newTestGKEOrchestrator(nil) + g.projectID = "test-project" + g.clusterZones = []string{"us-central1-a", "us-central1-b"} + g.clusterDesc.Autoscaling.EnableNodeAutoprovisioning = false + + _, err := g.FetchMachineCapabilities("tpu-v5-lite-podslice", "us-central1") + if err == nil { + t.Fatalf("Expected error, got nil") + } + + expectedErr := "no node pool matching machine type found and GKE Node Auto-Provisioning is disabled" + if !strings.Contains(err.Error(), expectedErr) { + t.Errorf("expected error containing %q, got: %v", expectedErr, err) + } +} + +func TestFetchMachineCapabilities_NoPools_NAPEnabled_Fallback(t *testing.T) { + g := newTestGKEOrchestrator(nil) + g.projectID = "test-project" + g.clusterZones = []string{"us-central1-a", "us-central1-b"} + g.clusterDesc.Autoscaling.EnableNodeAutoprovisioning = true + + g.machineTypeClient = &ZoneSelectiveMockMachineTypeClient{ + AllowedZone: "us-central1-b", + MT: &compute.MachineType{ + GuestCpus: 8, + MemoryMb: 32768, + }, + } + + cap, err := g.FetchMachineCapabilities("tpu-v5-lite-podslice", "us-central1") + if err != nil { + t.Fatalf("FetchMachineCapabilities failed: %v", err) + } + + if cap.GuestCpus != 8 { + t.Errorf("cap.GuestCpus = %d, want 8", cap.GuestCpus) + } +} + +func TestFetchMachineCapabilities_NodePoolEmptyLocations_ClusterZonesFallback(t *testing.T) { + g := newTestGKEOrchestrator(nil) + g.projectID = "test-project" + g.clusterZones = []string{"us-central1-b"} + g.clusterDesc.NodePools = []gkeJobNodePool{ + { + Name: "tpu-np-0", + Config: gkeNodePoolConfig{ + MachineType: "tpu-v5-lite-podslice", + }, + Locations: []string{}, // Empty/inherited locations + }, + } + + g.machineTypeClient = &ZoneSelectiveMockMachineTypeClient{ + AllowedZone: "us-central1-b", + MT: &compute.MachineType{ + GuestCpus: 4, + MemoryMb: 16384, + }, + } + + cap, err := g.FetchMachineCapabilities("tpu-v5-lite-podslice", "us-central1") + if err != nil { + t.Fatalf("FetchMachineCapabilities failed: %v", err) + } + + if cap.GuestCpus != 4 { + t.Errorf("cap.GuestCpus = %d, want 4", cap.GuestCpus) + } +} diff --git a/pkg/orchestrator/gke/types.go b/pkg/orchestrator/gke/types.go index 8b031834f8..5bbf1bd68e 100644 --- a/pkg/orchestrator/gke/types.go +++ b/pkg/orchestrator/gke/types.go @@ -256,6 +256,7 @@ type gkeJobNodePool struct { Name string `json:"name"` Config gkeNodePoolConfig `json:"config"` InitialNodeCount int `json:"initialNodeCount"` + Locations []string `json:"locations,omitempty"` Autoscaling gkeAutoscaling `json:"autoscaling"` PlacementPolicy *gkePlacementPolicy `json:"placementPolicy,omitempty"` }