Native Windows build of vLLM 0.21.0 — no WSL, no Docker, no Linux VM.
Latest build (cu128 / Python 3.13 / Blackwell): now built for RTX 50-series (Blackwell, sm_120) in addition to 30-/40-series —
TORCH_CUDA_ARCH_LIST=8.6;8.9;12.0, Python 3.13, PyTorch 2.11.0+cu128, CUDA 12.8. This release also fixes the OpenAI API server on Windows (vllm servenow starts — previously only the in-processLLM()API worked). See What's new.
Ships with 10 KV cache compression methods: the 6 Multi-TurboQuant
methods (isoquant/planarquant/turboquant25/35) plus the 4 new
upstream TurboQuant variants that landed in v0.19.2rc0 (turboquant_k8v4,
turboquant_4bit_nc, turboquant_k3v4_nc, turboquant_3bit_nc).
vLLM is the most popular open-source LLM serving engine, but it officially only supports Linux. This repo provides a pre-built wheel (just download and install) plus a complete patchset for compiling vLLM v0.21.0 natively on Windows with full CUDA acceleration, Triton support, and Multi-TurboQuant integration.
| Release | vLLM | PyTorch | Triton | KV compression | Download |
|---|---|---|---|---|---|
| v0.21.0-win-cu128 (latest) | 0.21.0 | 2.11.0+cu128 | 3.6.0 | Multi-TurboQuant (6) + upstream TurboQuant (4) + fp8 — Python 3.13, Blackwell sm_120 | Download |
| v0.21.0-win | 0.21.0 | 2.11.0+cu126 | 3.6.0 | Multi-TurboQuant (6) + upstream TurboQuant (4) + fp8 (Python 3.10) | Download |
| v0.19.1-win | 0.19.1 | 2.10.0+cu126 | 3.6.0 | Multi-TurboQuant (6 methods) + fp8 | Download |
| v0.19.0-win | 0.19.0 | 2.10.0+cu126 | 3.6.0 | Multi-TurboQuant (6 methods) + fp8 | Download |
| v0.17.1-win | 0.17.1 | 2.10.0+cu126 | 3.6.0 | TurboQuant (2 recipes) | Download |
| v0.14.2-win | 0.14.2 | 2.9.1+cu126 | n/a | fp8 only | Download |
This is a rebuild of the same vLLM 0.21.0 source for RTX 50-series (Blackwell) plus a set of Windows API-server fixes. Thanks to @Dhrhciebcy for the report that surfaced both the Blackwell gap and the API-server bug.
- Blackwell (sm_120) support — built with
TORCH_CUDA_ARCH_LIST=8.6;8.9;12.0on CUDA 12.8 + PyTorch 2.11.0+cu128 + Python 3.13, so the wheel carries sm_86 / sm_89 / sm_120 kernels (verified withcuobjdump). The olderv0.21.0-winwheel (cu126, sm_86 only) fails on a 5090 withno kernel image is available for execution on the device— that's a compute-capability gap, not a Python-version problem. - The OpenAI API server now works on Windows. Previously only the
in-process
LLM()path worked;vllm serve/api_servercrashed. Four Windows-only bugs fixed: (1) bareimport uvloop(Unix-only) in five entrypoints → falls back toasyncio; (2)wait_for_engine_startup()registered process sentinels (Windows HANDLEs, not sockets) with azmq.Poller→not a socket, now skipped on win32 with exit-code liveness checks; (3) pyzmq needsloop.add_reader, absent from the Windows Proactor loop → setWindowsSelectorEventLoopPolicy(no tornado); (4)loop.add_signal_handlerisNotImplementedErroron Windows → falls back tosignal.signal. winloop is no longer needed. - Two Blackwell-only kernels are skipped on Windows (they don't compile
under MSVC and aren't usable here anyway): QuTLASS (NVFP4/MXFP4
microscaling quant — uses GCC inline-PTX
asm) and the MiniMax multi-GPU all-reduce RMS fusion (needs real multi-GPU comm; Windows usesFakeProcessGroup). Their vLLM callers arehasattr-guarded, so FP4 and MiniMax just degrade gracefully. Everything mainstream — FP16/BF16, AWQ, GPTQ/Marlin, FP8, and all 10 KV-cache compression methods — is unaffected. - Dependency note: vLLM gates
llguidanceandxgrammaronplatform_machine == "x86_64", but Windows reportsAMD64, so pip silently skips them and vLLM then fails to import.install.batinstalls them explicitly; if installing manually, runpip install "llguidance>=1.3.0,<1.4.0" "xgrammar>=0.2.0,<1.0.0".
- vLLM v0.21.0 base — 1,157 upstream commits since v0.19.1, including the new native TurboQuant attention backend (PR #38479), DeepGEMM extension, fastsafetensors prefetch helpers, and v1 engine maturity.
- PyTorch 2.11.0 + CUDA 12.6 (was 2.10.0). New compiler flags needed
for MSVC:
/Usmallto dodge therpcndr.hmacro that collides with PyTorch's newbool smallparameter name, and/Zc:__cplusplusso CUTLASS'sis_unsigned_v(C++17) actually sees the standard__cplusplusvalue. - Upstream TurboQuant coexists with Multi-TurboQuant — the patch
registers our 6 method names alongside upstream's 4 in
CacheDType. Backend dispatch invllm/platforms/cuda.pyroutesturboquant_*to the newTurboQuantBackend; ours stay on the existingTritonAttentionbackend with the dispatch hooks from the v4 patch. - CUTLASS 4.4.2 (vendored + vllm-flash-attn submodule) is now patched
inline — two MSVC fixes (
memsetDevicehost/device mismatch, fourstatic constexpr dim3 get_block_shape()violations). The patches ship ascutlass-windows.patchandvllm-flash-attn-cutlass-windows.patchinsidevllm-source/;CMakeLists.txtapplies them automatically afterFetchContent_MakeAvailable, so no manual intervention. - flashinfer is now silently skipped on Windows — upstream defaults
VLLM_USE_FLASHINFER_SAMPLER=True, which then unconditionallyimport flashinfer(no Windows wheel). The patch flips the default toFalseonwin32so the Triton fallback is used transparently. - Smoke-tested end-to-end on RTX 3090, Qwen3-14B-AWQ-4bit with both
kv_cache_dtype=auto(9.7 tok/s) andturboquant35(0.73 tok/s, consistent with v0.19.x).
- Multi-TurboQuant integration: 6 KV cache compression methods
(
isoquant3,isoquant4,planarquant3,planarquant4,turboquant25,turboquant35) with real uint8 packed storage — 2× more KV cache tokens at the samegpu_memory_utilization. - Custom Windows safetensors reader: numpy memory-mapping + chunked GPU streaming. Loads a 14B model in seconds and works on systems with the Windows pagefile disabled.
- All 140 CUDA targets compile clean with MSVC 2022 + CUDA 12.6 + Ninja. 36 source files patched + 3 new files (the TQ dispatch helper and the two CUTLASS patches).
- Tests included: end-to-end validation suite that proves each TQ method actually compresses (not a placebo) and each one produces unique output from FP16.
Single 24 GB RTX 3090, Qwen3-14B AWQ-4bit, gpu_memory_utilization=0.5:
| KV dtype | Cache tokens | Concurrency @ 512 | vs FP16 |
|---|---|---|---|
auto (fp16) |
16,336 | 31.91× | 1.00× |
isoquant3/4, planarquant3/4, turboquant25/35 |
32,672 | 63.94× | 2.00× |
Full benchmarks → docs/benchmarks.md
Download vllm-0.21.0+cu128-cp313-cp313-win_amd64.whl from the Releases page, then:
:: Create a Python 3.13 venv
py -3.13 -m venv venv
venv\Scripts\activate
:: Install PyTorch 2.11.0 with CUDA 12.8 (cu128 = Blackwell support)
pip install torch==2.11.0 ^
--index-url https://download.pytorch.org/whl/cu128
:: Install Triton for Windows
pip install triton-windows==3.6.0.post26
:: Install the pre-built vLLM wheel
pip install vllm-0.21.0+cu128-cp313-cp313-win_amd64.whl
:: Structured-output backends vLLM gates on x86_64 (Windows = AMD64, so pip
:: skips them and vLLM won't import without these)
pip install "llguidance>=1.3.0,<1.4.0" "xgrammar>=0.2.0,<1.0.0"
:: Install Multi-TurboQuant for the 6 KV cache compression methods
pip install git+https://github.com/aivrar/multi-turboquant.gitOr just run install.bat for a fully self-contained, one-click portable
Python install — it downloads Python 3.13, PyTorch cu128, and the vLLM wheel
itself (no manual download or folder creation needed). If you already have the
.whl locally, drop it in dist-v5\ next to install.bat and the script uses
that instead of downloading.
Requires Visual Studio 2022 (Community is fine), CUDA 12.8, and a Python 3.13
venv. Building all three arches (8.6;8.9;12.0) takes ~3-4 h at MAX_JOBS=2
(the CUDA compile dominates; see notes below). Use MAX_JOBS=2 and do not
enable sccache — both cause intermittent MSVC cl.exe crashes (0xC000001D)
on the heavy multi-arch CUDA kernels.
git clone https://github.com/vllm-project/vllm.git vllm-source
cd vllm-source && git checkout v0.21.0 && cd ..
git apply vllm-windows-v5.patch --directory vllm-source
build.batThe patch also drops cutlass-windows.patch and
vllm-flash-attn-cutlass-windows.patch into vllm-source/. The build's
CMakeLists.txt applies them automatically to the FetchContent-managed
.deps/cutlass-src/ and .deps/vllm-flash-attn-src/csrc/cutlass after
the first configure, so you don't need a separate step.
Full instructions, including all the env vars and prerequisites: → docs/install.md
import os
# Required on Windows
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# CUDA + torch DLL search paths
os.add_dll_directory(r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\bin")
os.add_dll_directory(r"C:\path\to\venv\Lib\site-packages\torch\lib")
# Both uvloop and flashinfer fallbacks are baked into the wheel.
# Multi-GPU host? Don't forget CUDA_DEVICE_ORDER + CUDA_VISIBLE_DEVICES
# so vLLM lands on the GPU you actually want.
from vllm import LLM, SamplingParams
llm = LLM(
model=r"E:\models\Qwen3-14B-AWQ-4bit",
dtype="float16",
kv_cache_dtype="isoquant4", # 2× KV cache capacity, near-FP16 quality
max_model_len=2048,
gpu_memory_utilization=0.85,
enforce_eager=True,
trust_remote_code=True,
)
outputs = llm.generate(
["Explain CUDA streams in three sentences:"],
SamplingParams(temperature=0.7, max_tokens=200),
)
print(outputs[0].outputs[0].text)For OpenAI-compatible HTTP serving and more usage patterns: → docs/usage.md
vLLM v0.21.0 on Windows ships with integrated support for ten KV cache
compression dtypes. The four turboquant_* entries are the new upstream
TurboQuant attention backend (PR #38479, landed in v0.19.2rc0); the six
others come from our Multi-TurboQuant
library and run on the patched TritonAttention backend.
| Method | Bits | Family | Calibration | Use case |
|---|---|---|---|---|
turboquant_k8v4 |
8.25 / 4.25 | upstream | none | Mixed-precision K/V |
turboquant_4bit_nc |
4.25 | upstream | none | Upstream default |
turboquant_k3v4_nc |
3.25 / 4.25 | upstream | none | More aggressive K |
turboquant_3bit_nc |
3.25 | upstream | none | Most aggressive upstream |
isoquant4 |
4.25 | quaternion 4D rotation | none | Recommended default (ours) |
planarquant4 |
4.25 | Givens 2D rotation | none | Same memory, simpler transform |
isoquant3 |
3.25 | quaternion 4D rotation | none | More aggressive |
planarquant3 |
3.25 | Givens 2D rotation | none | More aggressive |
turboquant35 |
3.25 | WHT + MSE codebook + QJL | runtime | Calibrated outliers |
turboquant25 |
2.25 | WHT + MSE codebook + QJL | runtime | Most compression |
Just pass the method name as kv_cache_dtype when constructing an
LLM (or --kv-cache-dtype to vllm serve). Upstream turboquant_*
names are routed by vllm/platforms/cuda.py to the new
TurboQuantBackend (separate cache layout + Triton encode/decode);
ours stay on TritonAttention with the dispatch hooks from the v4
patch.
Trade-off (ours): throughput drops ~30-300× with our 6 methods enabled because the encode/decode runs in PyTorch (no fused Triton kernel yet). Memory savings are real, throughput cost is the price. Best for offline / long-context / batch workloads. The upstream variants use fused Triton kernels and don't pay this cost. See docs/turboquant.md for the full picture.
vllm-windows-v5.patch is a unified diff against vllm-project/vllm
at tag v0.21.0. It touches 36 files + 3 new files (the TQ
dispatch helper plus two CUTLASS-vendor patches):
- Build system (4): CMakeLists, cmake/utils, setup.py, requirements/cuda.txt
(with
/Usmall+/Zc:__cplusplusfor MSVC,fastsafetensorsandflashinfercommented out, auto-apply of cutlass-windows patches) - CUDA kernels (17): MSVC compatibility for keyword operators,
designated initializers,
__builtin_clz, variable templates with attributes, nested constexpr lambdas, deeply nestedelse if,__attribute__((aligned)),std::isinf,__int128_t, the newpersistent_topk.cuh__forceinlineswap,fused_silu_mul_block_quant.cuquant_type_max_v<T>()call-syntax, and thetopk_softplus_sqrt_kernels.cupreprocessor-in-macro-arg refactor - Runtime Python (9):
fcntl→msvcrt, ZMQ IPC → TCP, fork → spawn, NCCL → FakeProcessGroup, custom safetensors reader for small pagefile systems,uvloopfallback,VLLM_USE_FLASHINFER_SAMPLERdefault-False on Windows - Multi-TurboQuant integration (4 + 1 new): 6 new
CacheDTypeliterals, dtype mapping, attention backend dispatch, plus the newvllm/v1/attention/ops/multi_turboquant_kv.py(295 lines) - CUTLASS patches (2 new files):
cutlass-windows.patch(5 files in CUTLASS 4.4.2:cuda_host_adapter.hpp+ 4 SM100/SM103 headers withstatic constexpr dim3violations) andvllm-flash-attn-cutlass-windows.patch(5 files in the vendored CUTLASS submodule under vllm-flash-attn).
Full per-file breakdown → PATCHES.md
All changes are guarded by #ifdef _MSC_VER, sys.platform == "win32",
if(MSVC ...), or similar conditionals. Zero impact on Linux builds.
| Page | Topic |
|---|---|
| docs/install.md | Install the wheel or build from source |
| docs/usage.md | Python embedding + HTTP server |
| docs/turboquant.md | Multi-TurboQuant deep dive |
| docs/benchmarks.md | Real numbers, all 6 methods |
| docs/build.md | Patch internals + iterating on builds |
| docs/architecture.md | How the integration works |
| docs/troubleshooting.md | Common errors + fixes |
| tests/README.md | End-to-end test scripts |
| Component | Minimum | Recommended |
|---|---|---|
| OS | Windows 10 21H2 (x64) | Windows 10 22H2 / Windows 11 |
| GPU | NVIDIA SM 8.0+ (RTX 30/40/50, A100, H100) | RTX 3090 / 4090 / A6000 |
| VRAM | 12 GB | 24 GB |
| RAM | 16 GB | 32+ GB |
| CUDA driver | R570+ (Blackwell needs R570+) | latest |
| Python | 3.13.x | 3.13.11 |
| Compiler (build only) | VS 2022 Community + Win 10 SDK | Same |
| CUDA Toolkit (build only) | 12.8 (first toolkit with sm_120) | 12.8 |
For build-from-source, you also need a Windows pagefile (system managed is fine). Without it, large allocations during compilation can fail. See docs/troubleshooting.md → OSError 1455.
- RTX 3090 (24 GB, SM 8.6, driver 596.36) — build + smoke test (generation + api_server)
- Qwen2.5-0.5B-Instruct (smoke test), Qwen3-14B-abliterated-AWQ-4bit
- Qwen3.5-9B-abliterated-GPTQ-4bit (text-only)
- Windows 10 Pro 22H2
- Visual Studio 2022 Community 17.13 (MSVC 14.43)
- CUDA Toolkit 12.8
- Python 3.13.11
- RTX 50-series (Blackwell sm_120): kernels compiled & verified via
cuobjdump; runtime confirmation pending community hardware
kv_cache_dtype=auto (FlashAttention 2): 20 tokens in 2.06 s,
9.7 tok/s with max_model_len=512, gpu_memory_utilization=0.92.
First model load completes in ~24 s after the safetensors cache warms.
kv_cache_dtype=turboquant35 (Triton attention + Multi-TurboQuant
PyTorch-fallback encode/decode): 20 tokens in 27.39 s, 0.73 tok/s —
in line with the v0.19.x figure (0.92 tok/s for 5 tokens). All other
Multi-TurboQuant methods (isoquant3/4, planarquant3/4,
turboquant25) should behave the same as in v0.19.x; rerun
tests/test_tq_real.py for a full sweep.
Older Multi-TurboQuant timings on the same hardware (5 decoded tokens,
gpu_memory_utilization=0.5):
| Method | Preset | Time (5 tok) | Output tok/s | Status |
|---|---|---|---|---|
isoquant3 |
no_calibration_symmetric | 41.5s | 0.12 | PASS |
isoquant4 |
no_calibration_quality | 53.0s | 0.09 | PASS |
planarquant3 |
k_only_planar | 40.5s | 0.12 | PASS |
planarquant4 |
k_only_planar | 53.0s | 0.09 | PASS |
turboquant25 |
max_compression | 6.7s | 0.74 | PASS |
turboquant35 |
speed | 5.4s | 0.92 | PASS |
turboquant25/35 are ~8× faster than the iso/planar family on the
PyTorch-fallback path. Reproduce with:
set TQ_METHOD=isoquant3
%VLLM_PYTHON% tests\test_tq_diag.py- Single GPU only. NCCL doesn't ship with PyTorch on Windows; the
patch wires up
FakeProcessGroupfor single-rank operation. Multi-GPU needs separate vLLM instances + external load balancing. - No FlashInfer. No Windows wheel. The patch defaults
VLLM_USE_FLASHINFER_SAMPLER=Falseonwin32so vLLM falls back to the Triton sampler transparently. - No FlashAttention 3, no FlashAttention 4 (CuteDSL). FA3 has
MSVC-incompatible PTX macros, FA4 needs
nvidia-cutlass-dsl(no Windows wheel). FlashAttention 2 works fine. - No fastsafetensors. Linux-only (
io_uring). The patchedweight_utils.pykeeps the in-tree numpy-mmap + chunked-GPU-stream reader from v0.19.x for the safetensors path. - No DeepGEMM, no Quack, no Tilelang, no TokenSpeed-MLA, no NIXL. None ship Windows wheels; CMake skips DeepGEMM automatically when the target arch is below SM 9.0+.
- Our 6 Multi-TurboQuant methods are still on the PyTorch-fallback
encode/decode. Memory savings real, throughput cost real
(
turboquant35≈ 0.73 tok/s on Qwen3-14B). The upstreamturboquant_*variants don't pay this cost — they use the fused Triton store/decode kernels that landed in PR #38479. - Triton JIT cold-start latency. First inference with Triton kernels (e.g. Qwen3.5 GDN layers) takes ~1-2 minutes for compilation.
| vLLM | The original engine |
| PyTorch | Tensor library + CUDA bindings |
| CUDA Toolkit | NVIDIA |
| FlashAttention | FA2 kernels |
| triton-windows | Triton compiler ported to Windows |
| Multi-TurboQuant | KV cache compression methods (ours) |
| Upstream TurboQuant | TurboQuant attention backend (vLLM PR #38479) |
| CUTLASS | GEMM kernels (CUTLASS 4.4.2 with MSVC patches) |
| TurboQuant paper | Walsh-Hadamard quantization |
Built with the help of Claude.
MIT. See LICENSE.