Skip to content
Merged
167 changes: 136 additions & 31 deletions kernel-builder/skills/cuda-kernels/SKILL.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
name: cuda-kernels
description: "Provides guidance for writing and benchmarking optimized CUDA kernels for NVIDIA GPUs (H100, A100, T4) targeting HuggingFace diffusers and transformers libraries. Supports models like LTX-Video, Stable Diffusion, LLaMA, Mistral, and Qwen. Includes integration with HuggingFace Kernels Hub (get_kernel) for loading pre-compiled kernels. Includes benchmarking scripts to compare kernel performance against baseline implementations."
description: "Provides guidance for writing and benchmarking optimized CUDA kernels for NVIDIA GPUs (H100, A100, T4) targeting HuggingFace diffusers and transformers libraries. Kernels must be kernel-builder/ABI3-compliant: no pybind11, no setup.py, TORCH_LIBRARY_EXPAND bindings only. Supports models like LTX-Video, Stable Diffusion, LLaMA, Mistral, and Qwen. Includes integration with HuggingFace Kernels Hub (get_kernel) for loading pre-compiled kernels. Includes benchmarking scripts to compare kernel performance against baseline implementations."
disable-model-invocation: false
user-invocable: true
allowed-tools: "Read, Grep, Glob, Bash"
Expand All @@ -11,6 +11,71 @@ argument-hint: "kernel type: attention, rmsnorm, rope, adaln, geglu, benchmark,

This skill provides patterns and guidance for developing optimized CUDA kernels targeting NVIDIA GPUs (H100, A100, T4) for use with HuggingFace **diffusers** and **transformers** libraries.

## Hard Constraints — Read Before Writing Any Code

Kernels MUST build with [kernel-builder](https://github.com/huggingface/kernels) and meet the [Kernel Hub requirements](https://huggingface.co/docs/kernels/kernel-requirements). kernel-builder compiles against the **Python limited API (ABI3)** so a single binary works for Python 3.9+ across versions. Several patterns that are standard in generic PyTorch-extension tutorials are therefore **hard build failures** here. Do not use them, even if PyTorch documentation or your training data suggests them.

### Disallowed patterns — never generate these

| ❌ Never use | Why it fails | ✅ Use instead |
|---|---|---|
| pybind11 in any form: `#include <torch/extension.h>`, `#include <pybind11/...>`, `PYBIND11_MODULE(...)`, `py::arg`, any `py::` symbol | pybind11 is incompatible with the limited API (ABI3); the build does not compile | `TORCH_LIBRARY_EXPAND` in `torch-ext/torch_binding.cpp` (see below). Note: `torch/extension.h` transitively includes pybind11 — include `torch/torch.h` + `torch/library.h` instead |
| Hand-written `setup.py` / `pyproject.toml` using `torch.utils.cpp_extension` (`CUDAExtension`, `BuildExtension`, `cpp_extension.load`, `load_inline`) | setuptools extensions are not ABI3 and bypass `build.toml`; kernel-builder owns the build | `build.toml` + `nix run .#build-and-copy -L`. For an editable dev install, generate the project files with `kernel-builder create-pyproject -f` — never write them by hand |
| `TORCH_LIBRARY(my_kernel, m)`, `TORCH_LIBRARY_FRAGMENT(...)`, or `TORCH_LIBRARY_IMPL(...)` with a hardcoded namespace | kernel-builder suffixes the op namespace with a per-build hash (e.g. `_my_kernel_a1b2c3d`); a hardcoded name never resolves | `TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops)` from the generated `registration.h` |
| Hardcoded `torch.ops.my_kernel.fn(...)` calls in Python | Same namespace mangling — the op namespace name is only known at build time | `from ._ops import ops` then `ops.fn(...)` |
| Hand-written `PyMODINIT_FUNC PyInit__...` or any manual CPython module init | Generated by `REGISTER_EXTENSION`; duplicating it breaks module loading | `REGISTER_EXTENSION(TORCH_EXTENSION_NAME)` exactly once, in `torch_binding.cpp` |
| Non-limited CPython API calls (`PyArg_ParseTuple`, direct `PyObject*` manipulation) | Violates ABI3 | Stay within the torch C++ API: `torch::Tensor`, `TORCH_CHECK`, `at::cuda::*` |
| Absolute imports of your own package inside `torch-ext/` (`from my_kernel.utils import x`) | The package directory is renamed when loaded from the Hub; absolute imports break | Relative imports only: `from .utils import x`, `from ._ops import ops` |
| Runtime Python deps beyond `torch` (and `einops` if truly needed) | Hub compliance restricts kernel dependencies; imports of numpy, triton, packaging, etc. are rejected | Standard library + `torch` only |
| Python-side `@torch.library.custom_op` as the primary binding | The op must be registered in C++ so it ships in the compiled extension | C++ registration via `TORCH_LIBRARY_EXPAND`; Python-side `torch.library.register_fake` is only for adding a fake/meta impl (see torch.compile section) |

### The only supported binding pattern

`registration.h` and `_ops.py` are **generated by kernel-builder** — reference them, never write them yourself.

**`torch-ext/torch_binding.h`:**
```cpp
#pragma once
#include <torch/torch.h>

void my_kernel_forward(torch::Tensor &out, torch::Tensor const &input);
```

**`torch-ext/torch_binding.cpp`:**
```cpp
#include <torch/library.h>
#include "registration.h"
#include "torch_binding.h"

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("my_kernel_forward(Tensor! out, Tensor input) -> ()");
ops.impl("my_kernel_forward", torch::kCUDA, &my_kernel_forward);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
```

**`torch-ext/my_kernel/__init__.py`:**
```python
import torch
from ._ops import ops

def my_kernel(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
ops.my_kernel_forward(out, x)
return out
```

### Pre-flight checklist before declaring a kernel done

1. `grep -rn "pybind11\|PYBIND11\|torch/extension.h\|py::" torch-ext/` returns nothing.
2. `grep -rn "TORCH_LIBRARY(\|TORCH_LIBRARY_FRAGMENT\|PyInit" torch-ext/` returns nothing (only `TORCH_LIBRARY_EXPAND` is allowed).
3. No `setup.py` exists unless generated by `kernel-builder create-pyproject`.
4. `kernel-builder check-config` passes — `[general]` needs a **dash-separated** `name` (never underscores) and a `license`, plus `[torch]` (binding sources) and `[kernel.<name>]` sections.
5. The kernel directory is a git repository with all files committed (Nix refuses non-git builds).
6. The build succeeds: `nix run .#build-and-copy -L`.
7. ABI compliance passes: `kernel-builder check-abi` (after building).

## Quick Start

### Diffusers (Video/Image Generation)
Expand Down Expand Up @@ -91,11 +156,10 @@ Use this skill when:

## Working Example

A complete working example is available at `examples/ltx_video/`. This demonstrates:
- Custom CUDA kernels (RMSNorm, RoPE 3D, GEGLU, AdaLN)
- Build system setup with setup.py, build.toml, and flake.nix
- PyTorch C++ bindings and Python API
- Benchmarking script for comparing optimized vs baseline performance
Complete working examples ship with the kernels repo under `examples/kernels/` (also at [github.com/huggingface/kernels](https://github.com/huggingface/kernels/tree/main/examples/kernels)):
- `relu/` — the canonical minimal kernel: build.toml, flake.nix, `TORCH_LIBRARY_EXPAND` bindings, Python API, `layers/`, tests
- `relu-backprop-compile/` — backward pass + `torch.compile` support (fake-op registration)
- `silu-and-mul/` — activation kernel following the same layout

## Benchmarking Kernels

Expand Down Expand Up @@ -183,13 +247,15 @@ The vectorized RMSNorm kernel achieves **2.67x average speedup** over PyTorch ba
│ └── t4-optimization-guide.md # T4 (Turing) optimization deep dive
└── SKILL.md # This file

examples/ltx_video/ # Complete working example
├── kernel_src/
│ └── rmsnorm.cu # Vectorized RMSNorm kernel (2.67x faster)
├── torch-ext/ # PyTorch bindings
├── generate_video.py # Full benchmark script
├── benchmark_rmsnorm.py # Isolated kernel benchmark
└── setup.py # pip install -e .
examples/kernels/relu/ # Canonical working example (kernels repo)
├── build.toml # kernel-builder build configuration
├── flake.nix # Nix build entry point
├── CARD.md # Kernel card template (becomes README.md)
├── relu_cuda/relu.cu # CUDA kernel source
├── torch-ext/
│ ├── torch_binding.h / .cpp # TORCH_LIBRARY_EXPAND bindings
│ └── relu/__init__.py # Python API (+ optional layers/)
└── tests/test_relu.py # Kernel tests (nix run .#ci-test)
```

## GPU Architecture Reference
Expand Down Expand Up @@ -291,28 +357,67 @@ All kernels support three precision modes:

## Building Kernels

### Scaffold a new kernel project

Start new kernels with `kernel-builder init` instead of creating files by hand — it generates the compliant layout in one shot:

```bash
kernel-builder init --name my-username/my-kernel
```

This creates `build.toml` (valid dash-separated name, license, `[general.hub] repo-id` already wired), `flake.nix`, `torch-ext/` with compilable `torch_binding.{h,cpp}` and the Python package, a `<name>_cuda/` kernel source dir, `tests/`, `benchmarks/`, `example.py`, and `CARD.md` — and it initializes a git repository (required for builds). Then replace the stub kernel with your own sources and update the `src` lists in `build.toml`.

### With Nix (Recommended)
```bash
nix run .#build-and-copy --max-jobs 2 --cores 8 -L
```

### With pip/uv
### Build and publish to the Hub in one go
```bash
uv pip install -e .
kernel-builder build-and-upload
```
The target repo is set by `repo-id` under `[general.hub]` and `version` under `[general]` in `build.toml`. Uploads go to a **`kernel`-type** Hub repository (not a model repo); the owning user/org needs kernel-creation access ("Request Kernels Creation" at [huggingface.co/settings/account](https://huggingface.co/settings/account)).

### Editable install for local development
Never hand-write a `setup.py` (it leads to `torch.utils.cpp_extension`/pybind11, which cannot build under ABI3). Let kernel-builder generate the project files:
```bash
kernel-builder create-pyproject -f
pip install wheel
pip install --no-build-isolation -e .
```

### build.toml Configuration
```toml
[general]
name = "ltx_kernels"
# Name MUST be dash-separated lowercase (my-kernel), never underscores —
# `kernel-builder check-config` rejects underscores. The Python package
# lives at torch-ext/<name with dashes replaced by underscores>.
name = "ltx-kernels"
backends = ["cuda"]
version = 1
license = "Apache-2.0" # required field

[general.hub]
# Hub repo for `kernel-builder build-and-upload`; with `version` this
# selects the version branch (e.g. v1).
repo-id = "my-username/ltx-kernels"

[torch]
src = [
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding.h"
]

[kernel.your_kernel]
backend = "cuda"
src = ["kernel_src/your_kernel.cu"]
cuda-capabilities = ["9.0"]
depends = ["torch"]
# Only constrain cuda-capabilities when the kernel truly requires it —
# do not over-specify.
```

The kernel directory **must be a git repository with files committed** (`git init && git add -A && git commit`) — Nix refuses to build non-git kernels ("Kernel is not in a git repository").

## Library Integration

### HuggingFace Kernels Hub (get_kernel)
Expand All @@ -324,8 +429,9 @@ Load pre-compiled, optimized kernels directly from HuggingFace Hub without local
```python
from kernels import get_kernel, has_kernel

# Check availability and load
if has_kernel("kernels-community/activation"):
# Check availability and load — Hub loads REQUIRE version= (or revision=);
# a bare get_kernel(repo_id) raises ValueError.
if has_kernel("kernels-community/activation", version=1):
activation = get_kernel("kernels-community/activation", version=1)

# Use the kernel
Expand All @@ -335,9 +441,11 @@ if has_kernel("kernels-community/activation"):
```

**Key functions:**
- `get_kernel(repo_id, version=None)` - Download and load kernel from Hub
- `has_kernel(repo_id)` - Check if compatible build exists
- `get_local_kernel(path)` - Load from local directory (development)
- `get_kernel(repo_id, version=N)` - Download and load kernel from Hub; `version=` (major version) or `revision=` (branch/tag/commit) is **required**
- `has_kernel(repo_id, version=N)` - Check if compatible build exists
- `get_local_kernel(Path("path/to/kernel-project"))` - Load a local build (looks in `<path>` and `<path>/build`) — use during development

**Testing local builds through the `get_kernel()` code path:** set `LOCAL_KERNELS="org/name=/path/to/kernel-project"` and call `get_kernel("org/name")` unchanged — the override short-circuits the Hub entirely (no download, no version needed), so integration code can be tested verbatim against a local build.

**Popular community kernels:**
- `kernels-community/activation` - GELU, SiLU, etc.
Expand Down Expand Up @@ -514,19 +622,16 @@ torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
**Workaround options:**
1. Use `--use-optimized-kernels` without `--compile` (6% speedup)
2. Use `--compile` without custom kernels (34% speedup)
3. Register kernel as custom op (advanced, requires `torch.library`)
3. Add a fake/meta implementation for the C++-registered op (see below)

**To register as custom op (for torch.compile compatibility):**
**To make the op torch.compile-compatible:** ops registered via `TORCH_LIBRARY_EXPAND` in C++ are already proper custom ops — do NOT re-wrap them with `@torch.library.custom_op` in Python. Just register a fake (meta) implementation using the generated `_ops.py` helpers:
```python
import torch
from ._ops import ops, add_op_namespace_prefix

@torch.library.custom_op("ltx_kernels::rmsnorm", mutates_args={"out"})
def rmsnorm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, eps: float) -> None:
ops.rmsnorm_forward(out, input.contiguous(), weight.contiguous(), eps)

@rmsnorm.register_fake
@torch.library.register_fake(add_op_namespace_prefix("rmsnorm_forward"))
def _(out, input, weight, eps):
pass # No shape changes
return None # out-variant op: no shape changes
```

## See Also
Expand All @@ -550,7 +655,7 @@ def _(out, input, weight, eps):
### Reference
- [troubleshooting.md](references/troubleshooting.md) - Common issues and solutions
- [kernel-templates.md](references/kernel-templates.md) - Complete kernel templates
- [examples/ltx_video/](../../../examples/ltx_video/) - Full LTX-Video example directory
- [examples/kernels/relu/](../../../examples/kernels/relu/) - Canonical working kernel example (bindings, layers, tests)

### External Resources
- [HuggingFace Kernels Documentation](https://huggingface.co/docs/kernels/en/index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,21 @@ TF32: Best throughput for FP32-like accuracy (A100 specific)

```toml
[general]
name = "ltx_kernels"
name = "ltx-kernels" # dash-separated; underscores are rejected
backends = ["cuda"]
version = 1
license = "Apache-2.0"

[torch]
src = [
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding.h"
]

[kernel.your_kernel]
backend = "cuda"
src = ["kernel_src/your_kernel.cu"]
depends = ["torch"]
cuda-capabilities = ["8.0"] # sm_80 for A100
```

Expand Down Expand Up @@ -300,12 +309,12 @@ ncu --set full -o a100_metrics.ncu-rep python your_script.py
## Working Example

```bash
cd examples/ltx_video
cd <your-kernel-project>

# Build for A100
# Ensure build.toml includes cuda-capabilities = ["8.0"]
uv pip install -e .
# Leave cuda-capabilities unspecified in build.toml unless the kernel
# truly requires specific architectures (A100 is `cuda-capabilities = ["8.0"]`).
nix run .#build-and-copy -L # Build kernels with kernel-builder

# Run benchmark
python generate_video.py --use-optimized-kernels
# Run the kernel's test suite
nix run .#ci-test
```
48 changes: 30 additions & 18 deletions kernel-builder/skills/cuda-kernels/references/diffusers-h100.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ Use this skill when:

## Working Example

A complete working example is available at `examples/ltx_video/`. This demonstrates:
- Custom CUDA kernels (RMSNorm, RoPE 3D, GEGLU, AdaLN)
- Build system setup with setup.py, build.toml, and flake.nix
- PyTorch C++ bindings and Python API
- Benchmarking script for comparing optimized vs baseline performance
Complete working examples ship with the kernels repo under `examples/kernels/` (e.g. `relu/`, `relu-backprop-compile/`). They demonstrate:
- Custom CUDA kernels with the canonical project layout
- Build system setup with build.toml and flake.nix
- PyTorch C++ bindings (`TORCH_LIBRARY_EXPAND`) and Python API
- Kernel tests runnable via `nix run .#ci-test`

## Benchmarking Kernels

Expand Down Expand Up @@ -114,7 +114,7 @@ The vectorized RMSNorm kernel achieves **2.67x average speedup** over PyTorch ba
## Project Structure

```
.claude/skills/h100-diffusers-kernels/
skills/cuda-kernels/
├── scripts/
│ ├── benchmark_example.py # End-to-end video generation benchmark
│ ├── benchmark_rmsnorm.py # Isolated RMSNorm micro-benchmark
Expand All @@ -126,13 +126,12 @@ The vectorized RMSNorm kernel achieves **2.67x average speedup** over PyTorch ba
│ └── h100-optimization-guide.md # H100 optimization deep dive
└── SKILL.md # This file

examples/ltx_video/ # Complete working example
├── kernel_src/
│ └── rmsnorm.cu # Vectorized RMSNorm kernel (2.67x faster)
├── torch-ext/ # PyTorch bindings
├── generate_video.py # Full benchmark script
├── benchmark_rmsnorm.py # Isolated kernel benchmark
└── setup.py # pip install -e .
examples/kernels/relu/ # Canonical working example (kernels repo)
├── build.toml # kernel-builder build configuration
├── flake.nix # Nix build entry point
├── relu_cuda/relu.cu # CUDA kernel source
├── torch-ext/ # TORCH_LIBRARY_EXPAND bindings + Python API
└── tests/ # Kernel tests
```

## H100 Architecture Reference
Expand Down Expand Up @@ -225,21 +224,34 @@ All kernels support three precision modes:
nix run .#build-and-copy --max-jobs 2 --cores 8 -L
```

### With pip/uv
### Editable install for local development
Never hand-write a `setup.py` (it leads to `torch.utils.cpp_extension`/pybind11, which cannot build under ABI3). Let kernel-builder generate the project files:
```bash
uv pip install -e .
kernel-builder create-pyproject -f
pip install wheel
pip install --no-build-isolation -e .
```

### build.toml Configuration
```toml
[general]
name = "ltx_kernels"
# Dash-separated lowercase name (underscores are rejected); license required.
name = "ltx-kernels"
backends = ["cuda"]
version = 1
license = "Apache-2.0"

[torch]
src = [
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding.h"
]

[kernel.your_kernel]
backend = "cuda"
src = ["kernel_src/your_kernel.cu"]
cuda-capabilities = ["9.0"]
depends = ["torch"]
# Only constrain cuda-capabilities when the kernel truly requires it.
```

## Diffusers Integration
Expand Down Expand Up @@ -406,4 +418,4 @@ def _(out, input, weight, eps):
- [troubleshooting.md](references/troubleshooting.md) - Common issues and solutions
- [kernel-templates.md](references/kernel-templates.md) - Complete kernel templates
- [h100-optimization-guide.md](references/h100-optimization-guide.md) - Deep dive on H100 optimizations
- [examples/ltx_video/](../../../examples/ltx_video/) - Full LTX-Video example directory
- [examples/kernels/relu/](../../../../examples/kernels/relu/) - Canonical working kernel example
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,9 @@ See SKILL.md "Common Issues and Solutions" for:
For a self-contained, runnable example that demonstrates all patterns above:

```bash
cd examples/ltx_video
uv pip install -e . # Build kernels
python ../../.claude/skills/h100-diffusers-kernels/references/ltx_kernel_injection_example.py
cd <your-kernel-project> # kernel-builder project with your kernels
nix run .#build-and-copy -L # Build kernels with kernel-builder
python path/to/skills/cuda-kernels/scripts/ltx_kernel_injection_example.py

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this run command still correct?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you rather kernel-builder build-and-copy -L?

The other stuff, e.g., cd <> and python ... are okay because the users are supposed to change them. Or did you mean something else?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the python path/to/skills etc. I would expect a placeholder in place of skills/cuda-kernels/scripts/ltx_kernel_injection_example.py.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I follow. ltx_kernel_injection_example.py is the script we want the agent to refer to here.

```

This example:
Expand Down
Loading
Loading