cuda-kernels
$
npx mdskill add huggingface/kernels/cuda-kernelsOptimize CUDA kernels for HuggingFace diffusers and transformers on NVIDIA GPUs
- Solve performance bottlenecks in AI models like LLaMA and Stable Diffusion
- Uses kernel-builder, TORCH_LIBRARY_EXPAND, and HuggingFace Kernels Hub
- Enforces ABI3-compliance and no pybind11 or setup.py usage
- Delivers pre-compiled kernels and benchmarking comparisons via get_kernel
SKILL.md
.github/skills/cuda-kernelsView on GitHub ↗
---
name: cuda-kernels
description: "Provides guidance for writing and benchmarking optimized CUDA kernels for NVIDIA GPUs (H100, A100, T4) targeting HuggingFace diffusers and transformers libraries. Kernels must be kernel-builder/ABI3-compliant: no pybind11, no setup.py, TORCH_LIBRARY_EXPAND bindings only. Supports models like LTX-Video, Stable Diffusion, LLaMA, Mistral, and Qwen. Includes integration with HuggingFace Kernels Hub (get_kernel) for loading pre-compiled kernels. Includes benchmarking scripts to compare kernel performance against baseline implementations."
disable-model-invocation: false
user-invocable: true
allowed-tools: "Read, Grep, Glob, Bash"
argument-hint: "kernel type: attention, rmsnorm, rope, adaln, geglu, benchmark, transformers, diffusers, huggingface-kernels, get_kernel"
---
# CUDA Kernels for Diffusers & Transformers
This skill provides patterns and guidance for developing optimized CUDA kernels targeting NVIDIA GPUs (H100, A100, T4) for use with HuggingFace **diffusers** and **transformers** libraries.
## Hard Constraints — Read Before Writing Any Code
Kernels MUST build with [kernel-builder](https://github.com/huggingface/kernels) and meet the [Kernel Hub requirements](https://huggingface.co/docs/kernels/kernel-requirements). kernel-builder compiles against the **Python limited API (ABI3)** so a single binary works for Python 3.9+ across versions. Several patterns that are standard in generic PyTorch-extension tutorials are therefore **hard build failures** here. Do not use them, even if PyTorch documentation or your training data suggests them.
### Disallowed patterns — never generate these
| ❌ Never use | Why it fails | ✅ Use instead |
|---|---|---|
| pybind11 in any form: `#include <torch/extension.h>`, `#include <pybind11/...>`, `PYBIND11_MODULE(...)`, `py::arg`, any `py::` symbol | pybind11 is incompatible with the limited API (ABI3); the build does not compile | `TORCH_LIBRARY_EXPAND` in `torch-ext/torch_binding.cpp` (see below). Note: `torch/extension.h` transitively includes pybind11 — include `torch/torch.h` + `torch/library.h` instead |
| Hand-written `setup.py` / `pyproject.toml` using `torch.utils.cpp_extension` (`CUDAExtension`, `BuildExtension`, `cpp_extension.load`, `load_inline`) | setuptools extensions are not ABI3 and bypass `build.toml`; kernel-builder owns the build | `build.toml` + `nix run .#build-and-copy -L`. For an editable dev install, generate the project files with `kernel-builder create-pyproject -f` — never write them by hand |
| `TORCH_LIBRARY(my_kernel, m)`, `TORCH_LIBRARY_FRAGMENT(...)`, or `TORCH_LIBRARY_IMPL(...)` with a hardcoded namespace | kernel-builder suffixes the op namespace with a per-build hash (e.g. `_my_kernel_a1b2c3d`); a hardcoded name never resolves | `TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops)` from the generated `registration.h` |
| Hardcoded `torch.ops.my_kernel.fn(...)` calls in Python | Same namespace mangling — the op namespace name is only known at build time | `from ._ops import ops` then `ops.fn(...)` |
| Hand-written `PyMODINIT_FUNC PyInit__...` or any manual CPython module init | Generated by `REGISTER_EXTENSION`; duplicating it breaks module loading | `REGISTER_EXTENSION(TORCH_EXTENSION_NAME)` exactly once, in `torch_binding.cpp` |
| Non-limited CPython API calls (`PyArg_ParseTuple`, direct `PyObject*` manipulation) | Violates ABI3 | Stay within the torch C++ API: `torch::Tensor`, `TORCH_CHECK`, `at::cuda::*` |
| Absolute imports of your own package inside `torch-ext/` (`from my_kernel.utils import x`) | The package directory is renamed when loaded from the Hub; absolute imports break | Relative imports only: `from .utils import x`, `from ._ops import ops` |
| Runtime Python deps beyond `torch` (and `einops` if truly needed) | Hub compliance restricts kernel dependencies; imports of numpy, triton, packaging, etc. are rejected | Standard library + `torch` only |
| Python-side `@torch.library.custom_op` as the primary binding | The op must be registered in C++ so it ships in the compiled extension | C++ registration via `TORCH_LIBRARY_EXPAND`; Python-side `torch.library.register_fake` is only for adding a fake/meta impl (see torch.compile section) |
### The only supported binding pattern
`registration.h` and `_ops.py` are **generated by kernel-builder** — reference them, never write them yourself.
**`torch-ext/torch_binding.h`:**
```cpp
#pragma once
#include <torch/torch.h>
void my_kernel_forward(torch::Tensor &out, torch::Tensor const &input);
```
**`torch-ext/torch_binding.cpp`:**
```cpp
#include <torch/library.h>
#include "registration.h"
#include "torch_binding.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("my_kernel_forward(Tensor! out, Tensor input) -> ()");
ops.impl("my_kernel_forward", torch::kCUDA, &my_kernel_forward);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
```
**`torch-ext/my_kernel/__init__.py`:**
```python
import torch
from ._ops import ops
def my_kernel(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
ops.my_kernel_forward(out, x)
return out
```
### Pre-flight checklist before declaring a kernel done
1. `grep -rn "pybind11\|PYBIND11\|torch/extension.h\|py::" torch-ext/` returns nothing.
2. `grep -rn "TORCH_LIBRARY(\|TORCH_LIBRARY_FRAGMENT\|PyInit" torch-ext/` returns nothing (only `TORCH_LIBRARY_EXPAND` is allowed).
3. No `setup.py` exists unless generated by `kernel-builder create-pyproject`.
4. `kernel-builder check-config` passes — `[general]` needs a **dash-separated** `name` (never underscores) and a `license`, plus `[torch]` (binding sources) and `[kernel.<name>]` sections.
5. The kernel directory is a git repository with all files committed (Nix refuses non-git builds).
6. The build succeeds: `nix run .#build-and-copy -L`.
7. ABI compliance passes: `kernel-builder check-abi` (after building).
## Quick Start
### Diffusers (Video/Image Generation)
**For benchmarking kernel performance:**
```bash
# Benchmark with optimized kernels (6% end-to-end speedup)
python generate_video.py --use-optimized-kernels
# Benchmark baseline with torch.compile (34% speedup)
python generate_video.py --no-optimized-kernels --compile
# Compare configurations (note: --compile and --use-optimized-kernels are mutually exclusive)
python generate_video.py --use-optimized-kernels && \
python generate_video.py --no-optimized-kernels --compile
```
**For a minimal diffusers integration example (~150 lines):**
```bash
python scripts/ltx_kernel_injection_example.py
```
### Transformers (LLMs)
**For a minimal transformers integration example (~120 lines):**
```bash
python scripts/transformers_injection_example.py
```
### HuggingFace Kernels Hub
**Load pre-compiled kernels from HuggingFace Hub (no local compilation):**
```python
from kernels import get_kernel
# Load optimized activation kernels
activation = get_kernel("kernels-community/activation", version=1)
# Use the kernel
y = torch.empty_like(x)
activation.gelu_fast(y, x)
```
**For a complete HuggingFace Kernels example:**
```bash
python scripts/huggingface_kernels_example.py
```
### Isolated Kernel Micro-benchmarks
```bash
python benchmark_rmsnorm.py
```
## Supported Libraries & Models
| Library | Supported Models | Key Kernels |
|---------|------------------|-------------|
| **diffusers** | LTX-Video, Stable Diffusion, FLUX, DiT | RMSNorm, GEGLU, RoPE, AdaLN |
| **transformers** | LLaMA, Mistral, Qwen, Falcon | RMSNorm, Attention |
| GPU | Compute Capability | Guide |
|-----|-------------------|-------|
| H100 | sm_90 | [h100-optimization-guide.md](references/h100-optimization-guide.md) |
| A100 | sm_80 | [a100-optimization-guide.md](references/a100-optimization-guide.md) |
| T4 | sm_75 | [t4-optimization-guide.md](references/t4-optimization-guide.md) |
## When This Skill Applies
Use this skill when:
- **Benchmarking kernel performance** against baseline implementations
- Writing new CUDA kernels for diffusion models or LLMs
- Optimizing existing kernels for H100, A100, or T4 architecture
- Implementing custom attention, normalization, or activation layers
- Integrating kernels with **diffusers** pipelines (LTX-Video, Stable Diffusion, FLUX, DiT)
- Integrating kernels with **transformers** models (LLaMA, Mistral, Qwen)
- Debugging kernel performance issues on NVIDIA GPUs
## Working Example
Complete working examples ship with the kernels repo under `examples/kernels/` (also at [github.com/huggingface/kernels](https://github.com/huggingface/kernels/tree/main/examples/kernels)):
- `relu/` — the canonical minimal kernel: build.toml, flake.nix, `TORCH_LIBRARY_EXPAND` bindings, Python API, `layers/`, tests
- `relu-backprop-compile/` — backward pass + `torch.compile` support (fake-op registration)
- `silu-and-mul/` — activation kernel following the same layout
## Benchmarking Kernels
Use the benchmark script to measure kernel performance:
```bash
# Full benchmark with all options
python scripts/benchmark_example.py \
--use-optimized-kernels \
--compile \
--batch-size 1 \
--num-frames 161 \
--height 512 \
--width 768 \
--steps 50 \
--warmup-iterations 2
```
### Benchmark Script Options
| Option | Default | Description |
|--------|---------|-------------|
| `--use-optimized-kernels` | auto | Use custom H100 CUDA kernels |
| `--no-optimized-kernels` | - | Use baseline implementation |
| `--compile` | false | Enable torch.compile on transformer |
| `--batch-size` | 1 | Number of videos per prompt |
| `--num-frames` | 161 | Number of frames to generate |
| `--height` | 512 | Video height in pixels |
| `--width` | 768 | Video width in pixels |
| `--steps` | 50 | Denoising steps |
| `--warmup-iterations` | 2 | Warmup runs before benchmark |
### Example Benchmark Results
**End-to-End Video Generation (49 frames, 30 steps, H100 80GB):**
| Configuration | Time (s) | it/s | Speedup | Notes |
|:---|:---:|:---:|:---:|:---|
| Baseline (no compile) | 2.87 | 12.58 | 1.00x | Reference |
| **Optimized Kernels** | 2.70 | 13.52 | **1.06x** | 6% faster |
| Baseline + torch.compile | 2.14 | 19.05 | 1.34x | 34% faster |
**Important:** `--use-optimized-kernels` and `--compile` are currently mutually exclusive. Custom kernels require PyTorch custom op registration to work with torch.compile.
**Key metrics to capture:**
- **Device:** GPU model (e.g., NVIDIA H100 80GB HBM3)
- **Precision:** Data type used (e.g., bfloat16)
- **Resolution:** Width x Height (e.g., 768x512)
- **Frames:** Number of frames generated (e.g., 49, 161)
### RMSNorm Micro-benchmarks
The vectorized RMSNorm kernel achieves **2.67x average speedup** over PyTorch baseline:
| Shape | Custom (ms) | PyTorch (ms) | Speedup |
|:---|:---:|:---:|:---:|
| [1×1024×2048] | 0.019 | 0.065 | **3.37x** |
| [2×1024×2048] | 0.024 | 0.073 | **3.04x** |
| [4×1024×2048] | 0.036 | 0.093 | **2.58x** |
| [2×4096×3072] | 0.087 | 0.208 | **2.41x** |
| [4×4096×3072] | 0.157 | 0.392 | **2.49x** |
**Bandwidth efficiency:** 38% of H100's theoretical 3.35 TB/s
**Why end-to-end speedup is smaller:** RMSNorm accounts for ~5% of total compute in LTX-Video. The remaining time is spent in attention (Flash Attention/SDPA), linear projections, and VAE decode.
## Project Structure
```
.claude/skills/cuda-kernels/
├── scripts/
│ ├── benchmark_example.py # End-to-end video generation benchmark
│ ├── benchmark_rmsnorm.py # Isolated RMSNorm micro-benchmark
│ ├── ltx_kernel_injection_example.py # Minimal diffusers integration (~150 lines)
│ ├── transformers_injection_example.py # Minimal transformers integration (~120 lines)
│ └── huggingface_kernels_example.py # HuggingFace Kernels Hub integration
├── references/
│ ├── diffusers-integration.md # Complete diffusers integration guide
│ ├── transformers-integration.md # Complete transformers integration guide
│ ├── huggingface-kernels-integration.md # HuggingFace Kernels Hub (get_kernel) guide
│ ├── troubleshooting.md # Common issues and solutions
│ ├── kernel-templates.md # CUDA kernel templates (includes vectorized)
│ ├── h100-optimization-guide.md # H100 (Hopper) optimization deep dive
│ ├── a100-optimization-guide.md # A100 (Ampere) optimization deep dive
│ └── t4-optimization-guide.md # T4 (Turing) optimization deep dive
└── SKILL.md # This file
examples/kernels/relu/ # Canonical working example (kernels repo)
├── build.toml # kernel-builder build configuration
├── flake.nix # Nix build entry point
├── CARD.md # Kernel card template (becomes README.md)
├── relu_cuda/relu.cu # CUDA kernel source
├── torch-ext/
│ ├── torch_binding.h / .cpp # TORCH_LIBRARY_EXPAND bindings
│ └── relu/__init__.py # Python API (+ optional layers/)
└── tests/test_relu.py # Kernel tests (nix run .#ci-test)
```
## GPU Architecture Reference
### H100 (Hopper) - Primary Target
| Spec | Value | Optimization Impact |
|------|-------|---------------------|
| SMs | 132 | Grid sizing: aim for multiples of 132 |
| Threads/SM | 2048 | Max 16 blocks of 128 threads per SM |
| Shared Memory | 192 KB/SM | Large tiles possible |
| L2 Cache | 50 MB | Reuse across blocks |
| Memory BW | 3.35 TB/s | Coalesced access critical |
| Warp Size | 32 | All reductions use warp shuffles |
### Quick Comparison (H100 vs A100 vs T4)
| Spec | H100 | A100 | T4 |
|------|------|------|-----|
| SMs | 132 | 108 | 40 |
| Memory BW | 3.35 TB/s | 2.0 TB/s | 320 GB/s |
| Shared Mem/SM | 192 KB | 164 KB | 64 KB |
| BF16 Support | Yes | Yes | **No (FP16 only)** |
| Compute Cap | sm_90 | sm_80 | sm_75 |
> See detailed guides: [H100](references/h100-optimization-guide.md) | [A100](references/a100-optimization-guide.md) | [T4](references/t4-optimization-guide.md)
## Core Kernel Patterns
### Vectorized Memory Access (Critical for Performance)
**BFloat16 vectorization using `__nv_bfloat162`:**
```cuda
// Load 2 bfloat16 elements at once (32-bit load)
const __nv_bfloat162* vec_input = reinterpret_cast<const __nv_bfloat162*>(row_input);
#pragma unroll 4
for (int i = tid; i < vec_hidden; i += stride) {
__nv_bfloat162 v = vec_input[i];
float v0 = __bfloat162float(v.x);
float v1 = __bfloat162float(v.y);
sum_sq += v0 * v0 + v1 * v1;
}
```
**FP16 vectorization using `__half2`:**
```cuda
const __half2* vec_input = reinterpret_cast<const __half2*>(row_input);
__half2 v = vec_input[i];
float v0 = __half2float(v.x);
float v1 = __half2float(v.y);
```
**FP32 vectorization using `float4`:**
```cuda
const float4* vec_input = reinterpret_cast<const float4*>(row_input);
float4 v = vec_input[i];
sum_sq += v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w;
```
### Warp Shuffle Reductions
```cuda
template <typename T>
__device__ __forceinline__ T warp_reduce_sum(T val) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
val += __shfl_xor_sync(0xffffffff, val, offset);
}
return val;
}
```
### Block Sizes for Attention
- `BLOCK_SIZE_M = 128`, `BLOCK_SIZE_N = 64`, `BLOCK_SIZE_K = 64`
- `NUM_WARPS = 8`
### Thread Configuration
For element-wise ops (RoPE, GEGLU):
```cuda
constexpr int BLOCK_SIZE = 256;
int num_blocks = (total_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
```
For reduction ops (LayerNorm, RMSNorm) with vectorization:
```cuda
// Divide by 2 for bf16/fp16 vectorized access
int threads = min(hidden_size / 2, MAX_THREADS);
threads = max(threads, WARP_SIZE);
threads = (threads + 32 - 1) / 32 * 32; // Round to warp boundary
```
## Supported Data Types
All kernels support three precision modes:
- `__half` (FP16) - Default for inference
- `__nv_bfloat16` (BF16) - Preferred for training
- `float` (FP32) - Reference/debugging
## Building Kernels
### Scaffold a new kernel project
Start new kernels with `kernel-builder init` instead of creating files by hand — it generates the compliant layout in one shot:
```bash
kernel-builder init --name my-username/my-kernel
```
This creates `build.toml` (valid dash-separated name, license, `[general.hub] repo-id` already wired), `flake.nix`, `torch-ext/` with compilable `torch_binding.{h,cpp}` and the Python package, a `<name>_cuda/` kernel source dir, `tests/`, `benchmarks/`, `example.py`, and `CARD.md` — and it initializes a git repository (required for builds). Then replace the stub kernel with your own sources and update the `src` lists in `build.toml`.
### With Nix (Recommended)
```bash
nix run .#build-and-copy --max-jobs 2 --cores 8 -L
```
### Build and publish to the Hub in one go
```bash
kernel-builder build-and-upload
```
The target repo is set by `repo-id` under `[general.hub]` and `version` under `[general]` in `build.toml`. Uploads go to a **`kernel`-type** Hub repository (not a model repo); the owning user/org needs kernel-creation access ("Request Kernels Creation" at [huggingface.co/settings/account](https://huggingface.co/settings/account)).
### Editable install for local development
Never hand-write a `setup.py` (it leads to `torch.utils.cpp_extension`/pybind11, which cannot build under ABI3). Let kernel-builder generate the project files:
```bash
kernel-builder create-pyproject -f
pip install wheel
pip install --no-build-isolation -e .
```
### build.toml Configuration
```toml
[general]
# Name MUST be dash-separated lowercase (my-kernel), never underscores —
# `kernel-builder check-config` rejects underscores. The Python package
# lives at torch-ext/<name with dashes replaced by underscores>.
name = "ltx-kernels"
backends = ["cuda"]
version = 1
license = "Apache-2.0" # required field
[general.hub]
# Hub repo for `kernel-builder build-and-upload`; with `version` this
# selects the version branch (e.g. v1).
repo-id = "my-username/ltx-kernels"
[torch]
src = [
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding.h"
]
[kernel.your_kernel]
backend = "cuda"
src = ["kernel_src/your_kernel.cu"]
depends = ["torch"]
# Only constrain cuda-capabilities when the kernel truly requires it —
# do not over-specify.
```
The kernel directory **must be a git repository with files committed** (`git init && git add -A && git commit`) — Nix refuses to build non-git kernels ("Kernel is not in a git repository").
## Library Integration
### HuggingFace Kernels Hub (get_kernel)
> **See [huggingface-kernels-integration.md](references/huggingface-kernels-integration.md) for the complete guide.**
Load pre-compiled, optimized kernels directly from HuggingFace Hub without local compilation:
```python
from kernels import get_kernel, has_kernel
# Check availability and load — Hub loads REQUIRE version= (or revision=);
# a bare get_kernel(repo_id) raises ValueError.
if has_kernel("kernels-community/activation", version=1):
activation = get_kernel("kernels-community/activation", version=1)
# Use the kernel
x = torch.randn((4, 4), dtype=torch.float16, device="cuda")
y = torch.empty_like(x)
activation.gelu_fast(y, x)
```
**Key functions:**
- `get_kernel(repo_id, version=N)` - Download and load kernel from Hub; `version=` (major version) or `revision=` (branch/tag/commit) is **required**
- `has_kernel(repo_id, version=N)` - Check if compatible build exists
- `get_local_kernel(Path("path/to/kernel-project"))` - Load a local build (looks in `<path>` and `<path>/build`) — use during development
**Testing local builds through the `get_kernel()` code path:** set `LOCAL_KERNELS="org/name=/path/to/kernel-project"` and call `get_kernel("org/name")` unchanged — the override short-circuits the Hub entirely (no download, no version needed), so integration code can be tested verbatim against a local build.
**Popular community kernels:**
- `kernels-community/activation` - GELU, SiLU, etc.
- `kernels-community/flash-attn` - Flash Attention 2
- `kernels-community/triton-layer-norm` - LayerNorm, RMSNorm
### Diffusers Integration (Video/Image Generation)
> **See [diffusers-integration.md](references/diffusers-integration.md) for the complete guide.**
### Transformers Integration (LLMs)
> **See [transformers-integration.md](references/transformers-integration.md) for the complete guide.**
**Key differences from diffusers:**
- Transformers RMSNorm **always** has weights (no `elementwise_affine=False`)
- Use `'RMSNorm' in class_name` to match LlamaRMSNorm, MistralRMSNorm, etc.
- Check for `variance_epsilon` (LLaMA) or `eps` (others) for epsilon
- No `set_processor()` pattern - use Flash Attention 2 instead
**Minimal transformers pattern:**
```python
from transformers import AutoModelForCausalLM
from ltx_kernels import rmsnorm
def patch_rmsnorm(model):
for name, module in model.named_modules():
if 'RMSNorm' in type(module).__name__:
eps = getattr(module, 'variance_epsilon', None) or getattr(module, 'eps', 1e-6)
def make_forward(mod, epsilon):
def forward(x):
return rmsnorm(x, mod.weight, eps=epsilon)
return forward
module.forward = make_forward(module, eps)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16)
patch_rmsnorm(model)
```
### Diffusers Critical Pitfalls
#### 1. RMSNorm Weight May Be None
LTX-Video uses `elementwise_affine=False` for some RMSNorm modules:
```python
# Transformer blocks: NO WEIGHT
self.norm1 = RMSNorm(dim, elementwise_affine=False)
# Attention modules: HAS WEIGHT
self.norm_q = torch.nn.RMSNorm(..., elementwise_affine=True)
```
**Solution:** Handle both cases:
```python
has_weight = hasattr(module, 'weight') and module.weight is not None
if has_weight:
output = rmsnorm(x, module.weight, eps=eps)
else:
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
output = rmsnorm(x, weight, eps=eps)
```
#### 2. Diffusers RMSNorm != torch.nn.RMSNorm
```python
# WRONG - misses diffusers RMSNorm
if isinstance(module, torch.nn.RMSNorm):
# CORRECT - catches all RMSNorm variants
if type(module).__name__ == 'RMSNorm':
```
#### 3. LTX-Video Uses GELU, Not GEGLU
LTX-Video uses `activation_fn="gelu-approximate"`. Don't patch GEGLU for LTX-Video.
#### 4. Inject Kernels BEFORE CPU Offloading
```python
pipe = LTXPipeline.from_pretrained(...)
pipe.to("cuda")
inject_optimized_kernels(pipe) # BEFORE offloading
pipe.enable_model_cpu_offload() # Now safe
```
### Minimal Integration Pattern
```python
from diffusers import LTXPipeline
from ltx_kernels import rmsnorm
def patch_rmsnorm_modules(model):
"""Patch all RMSNorm modules to use custom kernel."""
for name, module in model.named_modules():
if type(module).__name__ == 'RMSNorm':
eps = getattr(module, 'eps', 1e-6)
has_weight = hasattr(module, 'weight') and module.weight is not None
if has_weight:
def make_forward(mod, epsilon):
def forward(x):
return rmsnorm(x, mod.weight, eps=epsilon)
return forward
module.forward = make_forward(module, eps)
else:
def make_forward(epsilon):
def forward(x):
w = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
return rmsnorm(x, w, eps=epsilon)
return forward
module.forward = make_forward(eps)
# Usage
pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
pipe.to("cuda")
patch_rmsnorm_modules(pipe.transformer)
pipe.enable_model_cpu_offload()
```
## Kernel-Specific Guidelines
### RMSNorm
- Input layout: `[..., hidden_size]`
- Epsilon default: 1e-6
- **Weight may be None** if `elementwise_affine=False`
- **Vectorization:** Use `__nv_bfloat162` for BF16, `__half2` for FP16, `float4` for FP32
- **Performance:** 2.67x faster than PyTorch with vectorized implementation
- **Bandwidth:** Achieves ~38% of H100's 3.35 TB/s theoretical bandwidth
### RoPE
- 1D: `[batch, seq, heads, head_dim]` - for text
- 3D: `[batch, t*h*w, heads, head_dim]` - for video
- LTX-Video computes its own RoPE via `LTXVideoRotaryPosEmbed`
### GEGLU vs GELU
- **GEGLU**: Input `[batch, seq, 2*hidden]` -> Output `[batch, seq, hidden]`
- **GELU**: Standard activation
- **LTX-Video uses GELU, NOT GEGLU**
### AdaLN
- Formula: `norm(x) * weight * (1 + scale) + shift`
- Used in DiT blocks for conditioning
## Performance Profiling
```bash
# NVIDIA Nsight Systems
nsys profile -o profile python your_script.py
# NVIDIA Nsight Compute
ncu --set full -o metrics python your_script.py
```
## Common Issues
> **See [troubleshooting.md](references/troubleshooting.md) for all common issues and solutions.**
Quick fixes:
- **"NoneType has no attribute contiguous"**: RMSNorm weight is None, create ones
- **isinstance() not matching**: Use `type(module).__name__` instead
- **GEGLU not called**: Model uses GELU, not GEGLU
- **Patching doesn't persist**: Inject before `enable_model_cpu_offload()`
- **torch.compile fails with custom kernels**: See below
### torch.compile Compatibility
Custom CUDA kernels and `torch.compile` are **mutually exclusive** unless you register the kernel as a PyTorch custom op.
**Error message:**
```
torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
```
**Workaround options:**
1. Use `--use-optimized-kernels` without `--compile` (6% speedup)
2. Use `--compile` without custom kernels (34% speedup)
3. Add a fake/meta implementation for the C++-registered op (see below)
**To make the op torch.compile-compatible:** ops registered via `TORCH_LIBRARY_EXPAND` in C++ are already proper custom ops — do NOT re-wrap them with `@torch.library.custom_op` in Python. Just register a fake (meta) implementation using the generated `_ops.py` helpers:
```python
import torch
from ._ops import ops, add_op_namespace_prefix
@torch.library.register_fake(add_op_namespace_prefix("rmsnorm_forward"))
def _(out, input, weight, eps):
return None # out-variant op: no shape changes
```
## See Also
### Scripts
- [benchmark_example.py](scripts/benchmark_example.py) - **Benchmarking script for comparing optimized vs baseline - START HERE**
- [ltx_kernel_injection_example.py](scripts/ltx_kernel_injection_example.py) - Minimal diffusers integration (~150 lines)
- [transformers_injection_example.py](scripts/transformers_injection_example.py) - Minimal transformers/LLM integration (~120 lines)
- [huggingface_kernels_example.py](scripts/huggingface_kernels_example.py) - HuggingFace Kernels Hub integration
### Integration Guides
- [huggingface-kernels-integration.md](references/huggingface-kernels-integration.md) - **HuggingFace Kernels Hub (get_kernel) - load pre-compiled kernels**
- [diffusers-integration.md](references/diffusers-integration.md) - Complete diffusers pipeline integration
- [transformers-integration.md](references/transformers-integration.md) - Complete transformers/LLM integration
### GPU Optimization Guides
- [h100-optimization-guide.md](references/h100-optimization-guide.md) - H100 (Hopper, sm_90) deep dive
- [a100-optimization-guide.md](references/a100-optimization-guide.md) - A100 (Ampere, sm_80) deep dive
- [t4-optimization-guide.md](references/t4-optimization-guide.md) - T4 (Turing, sm_75) deep dive
### Reference
- [troubleshooting.md](references/troubleshooting.md) - Common issues and solutions
- [kernel-templates.md](references/kernel-templates.md) - Complete kernel templates
- [examples/kernels/relu/](../../../examples/kernels/relu/) - Canonical working kernel example (bindings, layers, tests)
### External Resources
- [HuggingFace Kernels Documentation](https://huggingface.co/docs/kernels/en/index)
- [HuggingFace Kernels GitHub](https://github.com/huggingface/kernels)
- [Community Kernels on Hub](https://huggingface.co/kernels-community)