Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/triton_dist/kernels/nvidia/gemm_perf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def get_tflops_approx(device: torch.dtype, num_ctas: int, num_warps: int, dtype:


def get_full_tflops_approx(dtype: torch.dtype, device: Optional[torch.device] = None):
device = torch.cuda.current_device() if device is None else device
prop = torch.cuda.get_device_properties(device)
return get_tflops_approx(device, prop.multi_processor_count, 4, dtype)

Expand All @@ -135,7 +136,9 @@ def get_tensorcore_dtype_support(device_id=0):
(8, 9): [torch.float16, torch.bfloat16, torch.float32, torch.int8, torch.float8_e4m3fn,
torch.float8_e5m2], # Ada L40S/RTX 40xx
# Hopper
(9, 0): [torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2, torch.int8]
(9, 0): [torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2, torch.int8],
# Blackwell
(10, 0): [torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2, torch.int8],
}
return DTYPE_MAP.get(cap, [torch.float16, torch.float32])

Expand Down Expand Up @@ -177,6 +180,8 @@ def get_tensorcore_tflops_by_device_name(dtype, device_id=0):
return 989 * (2 / dtype.itemsize)
if device_name == "NVIDIA H20":
return 148 * (2 / dtype.itemsize)
if device_name == "NVIDIA B200":
return 2250 * (2 / dtype.itemsize)

logging.warning(
f"device {device_name} not listed here. calculate tflops by estimation, or you can report it to developers.")
Expand Down Expand Up @@ -206,6 +211,7 @@ def get_dram_gbps_by_device_name(device_name: str):
"NVIDIA H100 SXM": 3958,
"NVIDIA H100 NVL": 3341,
"NVIDIA H800": 3350,
"NVIDIA B200": 8000,
}
return _DRAM_GBPS[device_name]

Expand Down