diff --git a/general/gpu.py b/general/gpu.py index 64f1ea7..389996a 100644 --- a/general/gpu.py +++ b/general/gpu.py @@ -32,11 +32,14 @@ def is_jetson() -> bool: class CGPUInfo: """ This class is responsible for getting information from GPU (ONLY). + Supports both AMD (via rocm-smi) and NVIDIA (via pynvml) GPUs. """ cuda = False + rocm = False pynvmlLoaded = False jtopLoaded = False cudaAvailable = False + rocmAvailable = False torchDevice = 'cpu' cudaDevice = 'cpu' cudaDevicesFound = 0 @@ -47,6 +50,7 @@ class CGPUInfo: gpusUtilization = [] gpusVRAM = [] gpusTemperature = [] + gpuVendor = 'unknown' # 'nvidia', 'amd', or 'unknown' def __init__(self): if IS_JETSON: @@ -56,36 +60,64 @@ def __init__(self): self.jtopInstance = jtop() self.jtopInstance.start() self.jtopLoaded = True + self.gpuVendor = 'nvidia' logger.info('jtop initialized on Jetson device.') except ImportError as e: logger.error('jtop is not installed. ' + str(e)) except Exception as e: logger.error('Could not initialize jtop. ' + str(e)) else: - # Try to import pynvml for non-Jetson devices + # Try NVIDIA first try: import pynvml self.pynvml = pynvml self.pynvml.nvmlInit() - self.pynvmlLoaded = True - logger.info('pynvml (NVIDIA) initialized.') + # Check if NVIDIA GPUs are actually available + device_count = self.pynvml.nvmlDeviceGetCount() + if device_count and device_count > 0: + self.pynvmlLoaded = True + self.gpuVendor = 'nvidia' + logger.info('pynvml (NVIDIA) initialized.') + else: + logger.debug('No NVIDIA GPUs detected.') except ImportError as e: - logger.error('pynvml is not installed. ' + str(e)) + logger.debug('pynvml is not installed. ' + str(e)) except Exception as e: - logger.error('Could not init pynvml (NVIDIA). ' + str(e)) + logger.debug('Could not init pynvml (NVIDIA). ' + str(e)) - self.anygpuLoaded = self.pynvmlLoaded or self.jtopLoaded + # If NVIDIA not available, try to detect AMD GPUs using pyrsmi + if self.gpuVendor != 'nvidia': + try: + from pyrsmi import rocml + rocml.smi_initialize() + logger.debug('pyrsmi initialized') + device_count = rocml.smi_get_device_count() + logger.debug(f'pyrsmi device count: {device_count}') + if device_count and device_count > 0: + self.rocml = rocml + self.rocmAvailable = True + self.gpuVendor = 'amd' + logger.info(f'AMD ROCm detected via pyrsmi ({device_count} GPU(s)).') + else: + logger.debug(f'No AMD GPUs detected (device_count={device_count})') + rocml.smi_shutdown() + except ImportError as e: + logger.debug(f'pyrsmi is not installed: {e}') + except Exception as e: + logger.error(f'Could not detect AMD GPU via pyrsmi: {e}') + + self.anygpuLoaded = self.pynvmlLoaded or self.jtopLoaded or self.rocmAvailable try: self.torchDevice = comfy.model_management.get_torch_device_name(comfy.model_management.get_torch_device()) except Exception as e: logger.error('Could not pick default device. ' + str(e)) - if self.pynvmlLoaded and not self.jtopLoaded and not self.deviceGetCount(): + if (self.pynvmlLoaded or self.rocmAvailable) and not self.jtopLoaded and not self.deviceGetCount(): logger.warning('No GPU detected, disabling GPU monitoring.') self.anygpuLoaded = False self.pynvmlLoaded = False - self.jtopLoaded = False + self.rocmAvailable = False if self.anygpuLoaded: if self.deviceGetCount() > 0: @@ -110,18 +142,23 @@ def __init__(self): self.gpusVRAM.append(True) self.gpusTemperature.append(True) - self.cuda = True + self.cuda = self.pynvmlLoaded or self.jtopLoaded + self.rocm = self.rocmAvailable logger.info(self.systemGetDriverVersion()) else: - logger.warning('No GPU with CUDA detected.') + logger.warning('No GPU with CUDA/ROCm detected.') else: logger.warning('No GPU monitoring libraries available.') - self.cudaDevice = 'cpu' if self.torchDevice == 'cpu' else 'cuda' + self.cudaDevice = 'cpu' if self.torchDevice == 'cpu' else ('cuda' if self.cuda else 'rocm' if self.rocm else 'cpu') self.cudaAvailable = torch.cuda.is_available() + + # Check for ROCm availability (torch uses 'cuda' for both CUDA and ROCm) + if torch.cuda.is_available() and self.rocm: + self.rocmAvailable = True - if self.cuda and self.cudaAvailable and self.torchDevice == 'cpu': - logger.warning('CUDA is available, but torch is using CPU.') + if (self.cuda or self.rocm) and self.cudaAvailable and self.torchDevice == 'cpu': + logger.warning('GPU is available, but torch is using CPU.') def getInfo(self): logger.debug('Getting GPUs info...') @@ -149,7 +186,7 @@ def getStatus(self): else: gpuType = self.cudaDevice - if self.anygpuLoaded and self.cuda and self.cudaAvailable: + if self.anygpuLoaded and (self.cuda or self.rocm) and (self.cudaAvailable or self.rocmAvailable): for deviceIndex in range(self.cudaDevicesFound): deviceHandle = self.deviceGetHandleByIndex(deviceIndex) @@ -205,6 +242,8 @@ def getStatus(self): def deviceGetCount(self): if self.pynvmlLoaded: return self.pynvml.nvmlDeviceGetCount() + elif self.rocmAvailable: + return self.rocml.smi_get_device_count() elif self.jtopLoaded: # For Jetson devices, we assume there's one GPU return 1 @@ -214,6 +253,9 @@ def deviceGetCount(self): def deviceGetHandleByIndex(self, index): if self.pynvmlLoaded: return self.pynvml.nvmlDeviceGetHandleByIndex(index) + elif self.rocmAvailable: + # For AMD, the device index itself acts as the handle + return index elif self.jtopLoaded: return index # On Jetson, index acts as handle else: @@ -235,6 +277,12 @@ def deviceGetName(self, deviceHandle, deviceIndex): logger.error(f"UnicodeDecodeError: {e}") return gpuName + elif self.rocmAvailable: + try: + return self.rocml.smi_get_device_name(deviceIndex) + except Exception as e: + logger.debug('Could not get AMD GPU name. ' + str(e)) + return f'AMD GPU {deviceIndex}' elif self.jtopLoaded: # Access the GPU name from self.jtopInstance.gpu try: @@ -250,6 +298,9 @@ def deviceGetName(self, deviceHandle, deviceIndex): def systemGetDriverVersion(self): if self.pynvmlLoaded: return f'NVIDIA Driver: {self.pynvml.nvmlSystemGetDriverVersion()}' + elif self.rocmAvailable: + version = self.rocml.smi_get_kernel_version() + return f'AMD ROCm Driver: {version}' elif self.jtopLoaded: # No direct method to get driver version from jtop return 'NVIDIA Driver: unknown' @@ -259,6 +310,12 @@ def systemGetDriverVersion(self): def deviceGetUtilizationRates(self, deviceHandle): if self.pynvmlLoaded: return self.pynvml.nvmlDeviceGetUtilizationRates(deviceHandle).gpu + elif self.rocmAvailable: + try: + return self.rocml.smi_get_device_utilization(deviceHandle) + except Exception as e: + logger.debug('Could not get AMD GPU utilization. ' + str(e)) + return -1 elif self.jtopLoaded: # GPU utilization from jtop stats try: @@ -274,6 +331,14 @@ def deviceGetMemoryInfo(self, deviceHandle): if self.pynvmlLoaded: mem = self.pynvml.nvmlDeviceGetMemoryInfo(deviceHandle) return {'total': mem.total, 'used': mem.used} + elif self.rocmAvailable: + try: + total = self.rocml.smi_get_device_memory_total(deviceHandle) + used = self.rocml.smi_get_device_memory_used(deviceHandle) + return {'total': total, 'used': used} + except Exception as e: + logger.debug('Could not get AMD GPU memory info. ' + str(e)) + return {'total': 1, 'used': 1} elif self.jtopLoaded: mem_data = self.jtopInstance.memory['RAM'] total = mem_data['tot'] @@ -285,6 +350,9 @@ def deviceGetMemoryInfo(self, deviceHandle): def deviceGetTemperature(self, deviceHandle): if self.pynvmlLoaded: return self.pynvml.nvmlDeviceGetTemperature(deviceHandle, self.pynvml.NVML_TEMPERATURE_GPU) + elif self.rocmAvailable: + # AMD ROCm temperature reading is not available via pyrsmi, returning -1 to disable it + return -1 elif self.jtopLoaded: try: temperature = self.jtopInstance.stats.get('Temp gpu', -1) @@ -293,8 +361,10 @@ def deviceGetTemperature(self, deviceHandle): logger.error('Could not get GPU temperature. ' + str(e)) return -1 else: - return 0 + return -1 def close(self): if self.jtopLoaded and self.jtopInstance is not None: self.jtopInstance.close() + if self.rocmAvailable and hasattr(self, 'rocml'): + self.rocml.smi_shutdown() diff --git a/requirements.txt b/requirements.txt index cb86b30..5f6f5be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ torch numpy Pillow pynvml; platform_machine != 'aarch64' +pyrsmi; platform_machine != 'aarch64' py-cpuinfo piexif jetson-stats; platform_machine == 'aarch64'