Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 50 additions & 20 deletions ds4_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5972,15 +5972,17 @@ __global__ static void router_select_kernel(
uint32_t hash_rows,
uint32_t n_tokens,
int has_bias,
int hash_mode) {
int hash_mode,
uint32_t n_expert,
float scale) {
uint32_t t = blockIdx.x;
if (t >= n_tokens || threadIdx.x != 0) return;
const float *log = logits + (uint64_t)t * 256;
float *prob = probs + (uint64_t)t * 256;
const float *log = logits + (uint64_t)t * n_expert;
float *prob = probs + (uint64_t)t * n_expert;
int32_t *sel = selected + (uint64_t)t * 6;
float *w = weights + (uint64_t)t * 6;

for (int i = 0; i < 256; i++) prob[i] = sqrtf(softplus_dev(log[i]));
for (uint32_t i = 0; i < n_expert; i++) prob[i] = sqrtf(softplus_dev(log[i]));

if (hash_mode) {
int32_t tok = tokens ? tokens[t] : token_scalar;
Expand All @@ -5989,12 +5991,12 @@ __global__ static void router_select_kernel(
for (int i = 0; i < 6; i++) sel[i] = row[i];
} else {
for (int i = 0; i < 6; i++) sel[i] = -1;
for (int i = 0; i < 256; i++) {
for (uint32_t i = 0; i < n_expert; i++) {
float score = prob[i] + (has_bias ? bias[i] : 0.0f);
for (int j = 0; j < 6; j++) {
if (sel[j] < 0 || score > prob[sel[j]] + (has_bias ? bias[sel[j]] : 0.0f)) {
for (int k = 5; k > j; k--) sel[k] = sel[k - 1];
sel[j] = i;
sel[j] = (int32_t)i;
break;
}
}
Expand All @@ -6004,12 +6006,12 @@ __global__ static void router_select_kernel(
float sum = 0.0f;
for (int i = 0; i < 6; i++) {
int e = sel[i];
float v = (e >= 0 && e < 256) ? prob[e] : 0.0f;
float v = (e >= 0 && (uint32_t)e < n_expert) ? prob[e] : 0.0f;
w[i] = v;
sum += v;
}
sum = fmaxf(sum, 6.103515625e-5f);
for (int i = 0; i < 6; i++) w[i] = w[i] / sum * 1.5f;
for (int i = 0; i < 6; i++) w[i] = w[i] / sum * scale;
}

__global__ static void router_select_parallel_kernel(
Expand Down Expand Up @@ -9530,14 +9532,14 @@ extern "C" int ds4_gpu_directional_steering_project_tensor(
}
extern "C" int ds4_gpu_router_select_tensor(ds4_gpu_tensor *selected, ds4_gpu_tensor *weights, ds4_gpu_tensor *probs, const void *model_map, uint64_t model_size, uint64_t bias_offset, uint64_t hash_offset, uint32_t hash_rows, uint32_t token, uint32_t n_expert, uint32_t n_expert_used, float expert_weight_scale, uint32_t n_expert_groups, uint32_t n_group_used, bool has_bias, bool hash_mode, const ds4_gpu_tensor *logits) {
if (!selected || !weights || !probs || !logits || !model_map || n_expert_groups > 1u || n_group_used > 0u) return 0;
if (n_expert != 256u || n_expert_used != 6u || fabsf(expert_weight_scale - 1.5f) > 1.0e-6f) return 0;
if ((n_expert != 256u && n_expert != 384u) || n_expert_used != 6u) return 0;
int32_t tok = (int32_t)token;
int ok = 1;
const float *bias = NULL;
const int32_t *hash = NULL;
if (ok && has_bias && !hash_mode) {
if (bias_offset > model_size || model_size - bias_offset < 256u * sizeof(float)) ok = 0;
else bias = (const float *)cuda_model_range_ptr(model_map, bias_offset, 256u * sizeof(float), "router_bias");
if (bias_offset > model_size || model_size - bias_offset < (uint64_t)n_expert * sizeof(float)) ok = 0;
else bias = (const float *)cuda_model_range_ptr(model_map, bias_offset, (uint64_t)n_expert * sizeof(float), "router_bias");
if (!bias) ok = 0;
}
if (ok && hash_mode) {
Expand All @@ -9547,7 +9549,14 @@ extern "C" int ds4_gpu_router_select_tensor(ds4_gpu_tensor *selected, ds4_gpu_te
if (!hash) ok = 0;
}
if (ok) {
if (getenv("DS4_CUDA_NO_WARP_ROUTER_SELECT") == NULL &&
if (n_expert != 256u) {
/* Generalized (e.g. PRO, 384 experts): the warp/parallel kernels
* hardcode 256 experts, so use the n_expert-parametrized serial
* kernel. Correct for any expert count; slower but only 1 token. */
router_select_kernel<<<1, 1>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr,
bias, hash, (const float *)logits->ptr, NULL, tok, hash_rows, 1,
has_bias && !hash_mode, hash_mode, n_expert, expert_weight_scale);
} else if (getenv("DS4_CUDA_NO_WARP_ROUTER_SELECT") == NULL &&
getenv("DS4_CUDA_NO_PARALLEL_ROUTER_SELECT") == NULL) {
dim3 block(32, 4, 1);
router_select_warp_topk_kernel<<<1, block>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr,
Expand All @@ -9560,27 +9569,27 @@ extern "C" int ds4_gpu_router_select_tensor(ds4_gpu_tensor *selected, ds4_gpu_te
} else {
router_select_kernel<<<1, 1>>>((int32_t *)selected->ptr, (float *)weights->ptr, (float *)probs->ptr,
bias, hash, (const float *)logits->ptr, NULL, tok, hash_rows, 1,
has_bias && !hash_mode, hash_mode);
has_bias && !hash_mode, hash_mode, 256u, expert_weight_scale);
}
ok = cuda_ok(cudaGetLastError(), "router_select launch");
}
return ok;
}
extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_gpu_tensor *weights, ds4_gpu_tensor *probs, const void *model_map, uint64_t model_size, uint64_t bias_offset, uint64_t hash_offset, uint32_t hash_rows, uint32_t n_expert_groups, uint32_t n_group_used, bool has_bias, bool hash_mode, const ds4_gpu_tensor *logits, const ds4_gpu_tensor *tokens, uint32_t n_expert, uint32_t n_expert_used, float expert_weight_scale, uint32_t n_tokens) {
if (n_expert != 256u || n_expert_used != 6u || fabsf(expert_weight_scale - 1.5f) > 1.0e-6f) return 0;
if ((n_expert != 256u && n_expert != 384u) || n_expert_used != 6u) return 0;
if (!selected || !weights || !probs || !logits || !tokens || !model_map || n_tokens == 0 ||
n_expert_groups > 1u || n_group_used > 0u ||
logits->bytes < (uint64_t)n_tokens * 256u * sizeof(float) ||
probs->bytes < (uint64_t)n_tokens * 256u * sizeof(float) ||
logits->bytes < (uint64_t)n_tokens * (uint64_t)n_expert * sizeof(float) ||
probs->bytes < (uint64_t)n_tokens * (uint64_t)n_expert * sizeof(float) ||
selected->bytes < (uint64_t)n_tokens * 6u * sizeof(int32_t) ||
weights->bytes < (uint64_t)n_tokens * 6u * sizeof(float)) {
return 0;
}
const float *bias = NULL;
const int32_t *hash = NULL;
if (has_bias && !hash_mode) {
if (bias_offset > model_size || model_size - bias_offset < 256u * sizeof(float)) return 0;
bias = (const float *)cuda_model_range_ptr(model_map, bias_offset, 256u * sizeof(float), "router_bias");
if (bias_offset > model_size || model_size - bias_offset < (uint64_t)n_expert * sizeof(float)) return 0;
bias = (const float *)cuda_model_range_ptr(model_map, bias_offset, (uint64_t)n_expert * sizeof(float), "router_bias");
if (!bias) return 0;
}
if (hash_mode) {
Expand All @@ -9589,7 +9598,26 @@ extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_
hash = (const int32_t *)cuda_model_range_ptr(model_map, hash_offset, hash_bytes, "router_hash");
if (!hash) return 0;
}
if (getenv("DS4_CUDA_NO_WARP_ROUTER_SELECT") == NULL &&
if (n_expert != 256u) {
/* Generalized prefill router for non-256 expert counts (e.g. PRO=384).
* The warp/parallel kernels hardcode 256 experts; the serial kernel is
* n_expert-parametrized and correct for any count. One thread per token
* (slower) but unblocks PRO; optimize with a parallel 384 kernel later. */
router_select_kernel<<<n_tokens, 1>>>((int32_t *)selected->ptr,
(float *)weights->ptr,
(float *)probs->ptr,
bias,
hash,
(const float *)logits->ptr,
(const int32_t *)tokens->ptr,
0,
hash_rows,
n_tokens,
has_bias && !hash_mode,
hash_mode,
n_expert,
expert_weight_scale);
} else if (getenv("DS4_CUDA_NO_WARP_ROUTER_SELECT") == NULL &&
getenv("DS4_CUDA_NO_PARALLEL_ROUTER_SELECT") == NULL) {
dim3 block(32, 4, 1);
router_select_warp_topk_kernel<<<(n_tokens + 3u) / 4u, block>>>((int32_t *)selected->ptr,
Expand Down Expand Up @@ -9629,7 +9657,9 @@ extern "C" int ds4_gpu_router_select_batch_tensor(ds4_gpu_tensor *selected, ds4_
hash_rows,
n_tokens,
has_bias && !hash_mode,
hash_mode);
hash_mode,
256u,
expert_weight_scale);
}
return cuda_ok(cudaGetLastError(), "router_select launch");
}
Expand Down