diff --git a/python/triton_dist/kernels/nvidia/gemm_perf_model.py b/python/triton_dist/kernels/nvidia/gemm_perf_model.py index 62c4f2a92..ee23d0223 100644 --- a/python/triton_dist/kernels/nvidia/gemm_perf_model.py +++ b/python/triton_dist/kernels/nvidia/gemm_perf_model.py @@ -120,6 +120,7 @@ def get_tflops_approx(device: torch.dtype, num_ctas: int, num_warps: int, dtype: def get_full_tflops_approx(dtype: torch.dtype, device: Optional[torch.device] = None): + device = torch.cuda.current_device() if device is None else device prop = torch.cuda.get_device_properties(device) return get_tflops_approx(device, prop.multi_processor_count, 4, dtype) @@ -135,7 +136,9 @@ def get_tensorcore_dtype_support(device_id=0): (8, 9): [torch.float16, torch.bfloat16, torch.float32, torch.int8, torch.float8_e4m3fn, torch.float8_e5m2], # Ada L40S/RTX 40xx # Hopper - (9, 0): [torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2, torch.int8] + (9, 0): [torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2, torch.int8], + # Blackwell + (10, 0): [torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2, torch.int8], } return DTYPE_MAP.get(cap, [torch.float16, torch.float32]) @@ -177,6 +180,8 @@ def get_tensorcore_tflops_by_device_name(dtype, device_id=0): return 989 * (2 / dtype.itemsize) if device_name == "NVIDIA H20": return 148 * (2 / dtype.itemsize) + if device_name == "NVIDIA B200": + return 2250 * (2 / dtype.itemsize) logging.warning( f"device {device_name} not listed here. calculate tflops by estimation, or you can report it to developers.") @@ -206,6 +211,7 @@ def get_dram_gbps_by_device_name(device_name: str): "NVIDIA H100 SXM": 3958, "NVIDIA H100 NVL": 3341, "NVIDIA H800": 3350, + "NVIDIA B200": 8000, } return _DRAM_GBPS[device_name]