-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup.py
More file actions
49 lines (44 loc) · 1.45 KB
/
Copy pathsetup.py
File metadata and controls
49 lines (44 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
Setup script to build the custom CUDA extension for normalization.
Build with:
python setup.py build_ext --inplace
Or to skip CUDA version check (if you know what you're doing):
TORCH_CUDA_ARCH_LIST="9.0" python setup.py build_ext --inplace
"""
import os
import warnings
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
# Set CUDA_HOME if not already set
if 'CUDA_HOME' not in os.environ:
# Try to find CUDA installation
for cuda_path in ['/usr/local/cuda-12.6', '/usr/local/cuda-13.0', '/usr/local/cuda']:
if os.path.exists(cuda_path):
os.environ['CUDA_HOME'] = cuda_path
print(f"Using CUDA_HOME: {cuda_path}")
break
# Set explicit CUDA architecture
# GB200 is compute capability 10.0, H100 is 9.0
if 'TORCH_CUDA_ARCH_LIST' not in os.environ:
os.environ['TORCH_CUDA_ARCH_LIST'] = '9.0'
print(f"Using TORCH_CUDA_ARCH_LIST: {os.environ['TORCH_CUDA_ARCH_LIST']}")
setup(
name='normalize_cuda',
ext_modules=[
CUDAExtension(
'normalize_cuda',
[
'src/normalize_cuda.cpp',
'src/normalize_cuda_kernel.cu',
'src/normalize_cuda_kernel_optimized.cu',
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '--use_fast_math']
}
)
],
cmdclass={
'build_ext': BuildExtension
}
)