diff --git a/ds4_cuda.cu b/ds4_cuda.cu index 188b341ad..c9f0c8dec 100644 --- a/ds4_cuda.cu +++ b/ds4_cuda.cu @@ -5972,15 +5972,17 @@ __global__ static void router_select_kernel( uint32_t hash_rows, uint32_t n_tokens, int has_bias, - int hash_mode) { + int hash_mode, + uint32_t n_expert, + float scale) { uint32_t t = blockIdx.x; if (t >= n_tokens || threadIdx.x != 0) return; - const float *log = logits + (uint64_t)t * 256; - float *prob = probs + (uint64_t)t * 256; + const float *log = logits + (uint64_t)t * n_expert; + float *prob = probs + (uint64_t)t * n_expert; int32_t *sel = selected + (uint64_t)t * 6; float *w = weights + (uint64_t)t * 6; - for (int i = 0; i < 256; i++) prob[i] = sqrtf(softplus_dev(log[i])); + for (uint32_t i = 0; i < n_expert; i++) prob[i] = sqrtf(softplus_dev(log[i])); if (hash_mode) { int32_t tok = tokens ? tokens[t] : token_scalar; @@ -5989,12 +5991,12 @@ __global__ static void router_select_kernel( for (int i = 0; i < 6; i++) sel[i] = row[i]; } else { for (int i = 0; i < 6; i++) sel[i] = -1; - for (int i = 0; i < 256; i++) { + for (uint32_t i = 0; i < n_expert; i++) { float score = prob[i] + (has_bias ? bias[i] : 0.0f); for (int j = 0; j < 6; j++) { if (sel[j] < 0 || score > prob[sel[j]] + (has_bias ? bias[sel[j]] : 0.0f)) { for (int k = 5; k > j; k--) sel[k] = sel[k - 1]; - sel[j] = i; + sel[j] = (int32_t)i; break; } } @@ -6004,12 +6006,12 @@ __global__ static void router_select_kernel( float sum = 0.0f; for (int i = 0; i < 6; i++) { int e = sel[i]; - float v = (e >= 0 && e < 256) ? prob[e] : 0.0f; + float v = (e >= 0 && (uint32_t)e < n_expert) ? prob[e] : 0.0f; w[i] = v; sum += v; } sum = fmaxf(sum, 6.103515625e-5f); - for (int i = 0; i < 6; i++) w[i] = w[i] / sum * 1.5f; + for (int i = 0; i < 6; i++) w[i] = w[i] / sum * scale; } __global__ static void router_select_parallel_kernel( @@ -9530,14 +9532,14 @@ extern "C" int ds4_gpu_directional_steering_project_tensor( } extern "C" int ds4_gpu_router_select_tensor(ds4_gpu_tensor *selected, ds4_gpu_tensor *weights, ds4_gpu_tensor *probs, const void *model_map, uint64_t model_size, uint64_t bias_offset, uint64_t hash_offset, uint32_t hash_rows, uint32_t token, uint32_t n_expert, uint32_t n_expert_used, float expert_weight_scale, uint32_t n_expert_groups, uint32_t n_group_used, bool has_bias, bool hash_mode, const ds4_gpu_tensor *logits) { if (!selected || !weights || !probs || !logits || !model_map || n_expert_groups > 1u || n_group_used > 0u) return 0; - if (n_expert != 256u || n_expert_used != 6u || fabsf(expert_weight_scale - 1.5f) > 1.0e-6f) return 0; + if ((n_expert != 256u && n_expert != 384u) || n_expert_used != 6u) return 0; int32_t tok = (int32_t)token; int ok = 1; const float *bias = NULL; const int32_t *hash = NULL; if (ok && has_bias && !hash_mode) { - if (bias_offset > model_size || model_size - bias_offset < 256u * sizeof(float)) ok = 0; - else bias = (const float *)cuda_model_range_ptr(model_map, bias_offset, 256u * sizeof(float), "router_bias"); + if (bias_offset > model_size || model_size - bias_offset < (uint64_t)n_expert * sizeof(float)) ok = 0; + else bias = (const float *)cuda_model_range_ptr(model_map, bias_offset, (uint64_t)n_expert * sizeof(float), "router_bias"); if (!bias) ok = 0; } if (ok && hash_mode) { @@ -9547,7 +9549,14 @@ extern "C" int ds4_gpu_router_select_tensor(ds4_gpu_tensor *selected, ds4_gpu_te if (!hash) ok = 0; } if (ok) { - if (getenv("DS4_CUDA_NO_WARP_ROUTER_SELECT") == NULL && + if (n_expert != 256u) { + /* Generalized (e.g. PRO, 384 experts): the warp/parallel kernels + * hardcode 256 experts, so use the n_expert-parametrized serial + * kernel. Correct for any expert count; slower but only 1 token. */ + router_select_kernel<<<1, 1>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr, + bias, hash, (const float *)logits->ptr, NULL, tok, hash_rows, 1, + has_bias && !hash_mode, hash_mode, n_expert, expert_weight_scale); + } else if (getenv("DS4_CUDA_NO_WARP_ROUTER_SELECT") == NULL && getenv("DS4_CUDA_NO_PARALLEL_ROUTER_SELECT") == NULL) { dim3 block(32, 4, 1); router_select_warp_topk_kernel<<<1, block>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr, @@ -9560,18 +9569,18 @@ extern "C" int ds4_gpu_router_select_tensor(ds4_gpu_tensor *selected, ds4_gpu_te } else { router_select_kernel<<<1, 1>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr, bias, hash, (const float *)logits->ptr, NULL, tok, hash_rows, 1, - has_bias && !hash_mode, hash_mode); + has_bias && !hash_mode, hash_mode, 256u, expert_weight_scale); } ok = cuda_ok(cudaGetLastError(), "router_select launch"); } return ok; } extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_gpu_tensor *weights, ds4_gpu_tensor *probs, const void *model_map, uint64_t model_size, uint64_t bias_offset, uint64_t hash_offset, uint32_t hash_rows, uint32_t n_expert_groups, uint32_t n_group_used, bool has_bias, bool hash_mode, const ds4_gpu_tensor *logits, const ds4_gpu_tensor *tokens, uint32_t n_expert, uint32_t n_expert_used, float expert_weight_scale, uint32_t n_tokens) { - if (n_expert != 256u || n_expert_used != 6u || fabsf(expert_weight_scale - 1.5f) > 1.0e-6f) return 0; + if ((n_expert != 256u && n_expert != 384u) || n_expert_used != 6u) return 0; if (!selected || !weights || !probs || !logits || !tokens || !model_map || n_tokens == 0 || n_expert_groups > 1u || n_group_used > 0u || - logits->bytes < (uint64_t)n_tokens * 256u * sizeof(float) || - probs->bytes < (uint64_t)n_tokens * 256u * sizeof(float) || + logits->bytes < (uint64_t)n_tokens * (uint64_t)n_expert * sizeof(float) || + probs->bytes < (uint64_t)n_tokens * (uint64_t)n_expert * sizeof(float) || selected->bytes < (uint64_t)n_tokens * 6u * sizeof(int32_t) || weights->bytes < (uint64_t)n_tokens * 6u * sizeof(float)) { return 0; @@ -9579,8 +9588,8 @@ extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_ const float *bias = NULL; const int32_t *hash = NULL; if (has_bias && !hash_mode) { - if (bias_offset > model_size || model_size - bias_offset < 256u * sizeof(float)) return 0; - bias = (const float *)cuda_model_range_ptr(model_map, bias_offset, 256u * sizeof(float), "router_bias"); + if (bias_offset > model_size || model_size - bias_offset < (uint64_t)n_expert * sizeof(float)) return 0; + bias = (const float *)cuda_model_range_ptr(model_map, bias_offset, (uint64_t)n_expert * sizeof(float), "router_bias"); if (!bias) return 0; } if (hash_mode) { @@ -9589,7 +9598,26 @@ extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_ hash = (const int32_t *)cuda_model_range_ptr(model_map, hash_offset, hash_bytes, "router_hash"); if (!hash) return 0; } - if (getenv("DS4_CUDA_NO_WARP_ROUTER_SELECT") == NULL && + if (n_expert != 256u) { + /* Generalized prefill router for non-256 expert counts (e.g. PRO=384). + * The warp/parallel kernels hardcode 256 experts; the serial kernel is + * n_expert-parametrized and correct for any count. One thread per token + * (slower) but unblocks PRO; optimize with a parallel 384 kernel later. */ + router_select_kernel<<>>((int32_t *)selected->ptr, + (float *)weights->ptr, + (float *)probs->ptr, + bias, + hash, + (const float *)logits->ptr, + (const int32_t *)tokens->ptr, + 0, + hash_rows, + n_tokens, + has_bias && !hash_mode, + hash_mode, + n_expert, + expert_weight_scale); + } else if (getenv("DS4_CUDA_NO_WARP_ROUTER_SELECT") == NULL && getenv("DS4_CUDA_NO_PARALLEL_ROUTER_SELECT") == NULL) { dim3 block(32, 4, 1); router_select_warp_topk_kernel<<<(n_tokens + 3u) / 4u, block>>>((int32_t *)selected->ptr, @@ -9629,7 +9657,9 @@ extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_ hash_rows, n_tokens, has_bias && !hash_mode, - hash_mode); + hash_mode, + 256u, + expert_weight_scale); } return cuda_ok(cudaGetLastError(), "router_select launch"); }