Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 85 additions & 15 deletions general/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ def is_jetson() -> bool:
class CGPUInfo:
"""
This class is responsible for getting information from GPU (ONLY).
Supports both AMD (via rocm-smi) and NVIDIA (via pynvml) GPUs.
"""
cuda = False
rocm = False
pynvmlLoaded = False
jtopLoaded = False
cudaAvailable = False
rocmAvailable = False
torchDevice = 'cpu'
cudaDevice = 'cpu'
cudaDevicesFound = 0
Expand All @@ -47,6 +50,7 @@ class CGPUInfo:
gpusUtilization = []
gpusVRAM = []
gpusTemperature = []
gpuVendor = 'unknown' # 'nvidia', 'amd', or 'unknown'

def __init__(self):
if IS_JETSON:
Expand All @@ -56,36 +60,64 @@ def __init__(self):
self.jtopInstance = jtop()
self.jtopInstance.start()
self.jtopLoaded = True
self.gpuVendor = 'nvidia'
logger.info('jtop initialized on Jetson device.')
except ImportError as e:
logger.error('jtop is not installed. ' + str(e))
except Exception as e:
logger.error('Could not initialize jtop. ' + str(e))
else:
# Try to import pynvml for non-Jetson devices
# Try NVIDIA first
try:
import pynvml
self.pynvml = pynvml
self.pynvml.nvmlInit()
self.pynvmlLoaded = True
logger.info('pynvml (NVIDIA) initialized.')
# Check if NVIDIA GPUs are actually available
device_count = self.pynvml.nvmlDeviceGetCount()
if device_count and device_count > 0:
self.pynvmlLoaded = True
self.gpuVendor = 'nvidia'
logger.info('pynvml (NVIDIA) initialized.')
else:
logger.debug('No NVIDIA GPUs detected.')
except ImportError as e:
logger.error('pynvml is not installed. ' + str(e))
logger.debug('pynvml is not installed. ' + str(e))
except Exception as e:
logger.error('Could not init pynvml (NVIDIA). ' + str(e))
logger.debug('Could not init pynvml (NVIDIA). ' + str(e))

self.anygpuLoaded = self.pynvmlLoaded or self.jtopLoaded
# If NVIDIA not available, try to detect AMD GPUs using pyrsmi
if self.gpuVendor != 'nvidia':
try:
from pyrsmi import rocml
rocml.smi_initialize()
logger.debug('pyrsmi initialized')
device_count = rocml.smi_get_device_count()
logger.debug(f'pyrsmi device count: {device_count}')
if device_count and device_count > 0:
self.rocml = rocml
self.rocmAvailable = True
self.gpuVendor = 'amd'
logger.info(f'AMD ROCm detected via pyrsmi ({device_count} GPU(s)).')
else:
logger.debug(f'No AMD GPUs detected (device_count={device_count})')
rocml.smi_shutdown()
except ImportError as e:
logger.debug(f'pyrsmi is not installed: {e}')
except Exception as e:
logger.error(f'Could not detect AMD GPU via pyrsmi: {e}')

self.anygpuLoaded = self.pynvmlLoaded or self.jtopLoaded or self.rocmAvailable

try:
self.torchDevice = comfy.model_management.get_torch_device_name(comfy.model_management.get_torch_device())
except Exception as e:
logger.error('Could not pick default device. ' + str(e))

if self.pynvmlLoaded and not self.jtopLoaded and not self.deviceGetCount():
if (self.pynvmlLoaded or self.rocmAvailable) and not self.jtopLoaded and not self.deviceGetCount():
logger.warning('No GPU detected, disabling GPU monitoring.')
self.anygpuLoaded = False
self.pynvmlLoaded = False
self.jtopLoaded = False
self.rocmAvailable = False

if self.anygpuLoaded:
if self.deviceGetCount() > 0:
Expand All @@ -110,18 +142,23 @@ def __init__(self):
self.gpusVRAM.append(True)
self.gpusTemperature.append(True)

self.cuda = True
self.cuda = self.pynvmlLoaded or self.jtopLoaded
self.rocm = self.rocmAvailable
logger.info(self.systemGetDriverVersion())
else:
logger.warning('No GPU with CUDA detected.')
logger.warning('No GPU with CUDA/ROCm detected.')
else:
logger.warning('No GPU monitoring libraries available.')

self.cudaDevice = 'cpu' if self.torchDevice == 'cpu' else 'cuda'
self.cudaDevice = 'cpu' if self.torchDevice == 'cpu' else ('cuda' if self.cuda else 'rocm' if self.rocm else 'cpu')
self.cudaAvailable = torch.cuda.is_available()

# Check for ROCm availability (torch uses 'cuda' for both CUDA and ROCm)
if torch.cuda.is_available() and self.rocm:
self.rocmAvailable = True

if self.cuda and self.cudaAvailable and self.torchDevice == 'cpu':
logger.warning('CUDA is available, but torch is using CPU.')
if (self.cuda or self.rocm) and self.cudaAvailable and self.torchDevice == 'cpu':
logger.warning('GPU is available, but torch is using CPU.')

def getInfo(self):
logger.debug('Getting GPUs info...')
Expand Down Expand Up @@ -149,7 +186,7 @@ def getStatus(self):
else:
gpuType = self.cudaDevice

if self.anygpuLoaded and self.cuda and self.cudaAvailable:
if self.anygpuLoaded and (self.cuda or self.rocm) and (self.cudaAvailable or self.rocmAvailable):
for deviceIndex in range(self.cudaDevicesFound):
deviceHandle = self.deviceGetHandleByIndex(deviceIndex)

Expand Down Expand Up @@ -205,6 +242,8 @@ def getStatus(self):
def deviceGetCount(self):
if self.pynvmlLoaded:
return self.pynvml.nvmlDeviceGetCount()
elif self.rocmAvailable:
return self.rocml.smi_get_device_count()
elif self.jtopLoaded:
# For Jetson devices, we assume there's one GPU
return 1
Expand All @@ -214,6 +253,9 @@ def deviceGetCount(self):
def deviceGetHandleByIndex(self, index):
if self.pynvmlLoaded:
return self.pynvml.nvmlDeviceGetHandleByIndex(index)
elif self.rocmAvailable:
# For AMD, the device index itself acts as the handle
return index
elif self.jtopLoaded:
return index # On Jetson, index acts as handle
else:
Expand All @@ -235,6 +277,12 @@ def deviceGetName(self, deviceHandle, deviceIndex):
logger.error(f"UnicodeDecodeError: {e}")

return gpuName
elif self.rocmAvailable:
try:
return self.rocml.smi_get_device_name(deviceIndex)
except Exception as e:
logger.debug('Could not get AMD GPU name. ' + str(e))
return f'AMD GPU {deviceIndex}'
elif self.jtopLoaded:
# Access the GPU name from self.jtopInstance.gpu
try:
Expand All @@ -250,6 +298,9 @@ def deviceGetName(self, deviceHandle, deviceIndex):
def systemGetDriverVersion(self):
if self.pynvmlLoaded:
return f'NVIDIA Driver: {self.pynvml.nvmlSystemGetDriverVersion()}'
elif self.rocmAvailable:
version = self.rocml.smi_get_kernel_version()
return f'AMD ROCm Driver: {version}'
elif self.jtopLoaded:
# No direct method to get driver version from jtop
return 'NVIDIA Driver: unknown'
Expand All @@ -259,6 +310,12 @@ def systemGetDriverVersion(self):
def deviceGetUtilizationRates(self, deviceHandle):
if self.pynvmlLoaded:
return self.pynvml.nvmlDeviceGetUtilizationRates(deviceHandle).gpu
elif self.rocmAvailable:
try:
return self.rocml.smi_get_device_utilization(deviceHandle)
except Exception as e:
logger.debug('Could not get AMD GPU utilization. ' + str(e))
return -1
elif self.jtopLoaded:
# GPU utilization from jtop stats
try:
Expand All @@ -274,6 +331,14 @@ def deviceGetMemoryInfo(self, deviceHandle):
if self.pynvmlLoaded:
mem = self.pynvml.nvmlDeviceGetMemoryInfo(deviceHandle)
return {'total': mem.total, 'used': mem.used}
elif self.rocmAvailable:
try:
total = self.rocml.smi_get_device_memory_total(deviceHandle)
used = self.rocml.smi_get_device_memory_used(deviceHandle)
return {'total': total, 'used': used}
except Exception as e:
logger.debug('Could not get AMD GPU memory info. ' + str(e))
return {'total': 1, 'used': 1}
elif self.jtopLoaded:
mem_data = self.jtopInstance.memory['RAM']
total = mem_data['tot']
Expand All @@ -285,6 +350,9 @@ def deviceGetMemoryInfo(self, deviceHandle):
def deviceGetTemperature(self, deviceHandle):
if self.pynvmlLoaded:
return self.pynvml.nvmlDeviceGetTemperature(deviceHandle, self.pynvml.NVML_TEMPERATURE_GPU)
elif self.rocmAvailable:
# AMD ROCm temperature reading is not available via pyrsmi, returning -1 to disable it
return -1
elif self.jtopLoaded:
try:
temperature = self.jtopInstance.stats.get('Temp gpu', -1)
Expand All @@ -293,8 +361,10 @@ def deviceGetTemperature(self, deviceHandle):
logger.error('Could not get GPU temperature. ' + str(e))
return -1
else:
return 0
return -1

def close(self):
if self.jtopLoaded and self.jtopInstance is not None:
self.jtopInstance.close()
if self.rocmAvailable and hasattr(self, 'rocml'):
self.rocml.smi_shutdown()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ torch
numpy
Pillow
pynvml; platform_machine != 'aarch64'
pyrsmi; platform_machine != 'aarch64'
py-cpuinfo
piexif
jetson-stats; platform_machine == 'aarch64'