[AMD] Enable CUDA graph capture with rocshmem_hipmodule_init API#158
[AMD] Enable CUDA graph capture with rocshmem_hipmodule_init API#158drprajap wants to merge 1 commit into
Conversation
This commit integrates and tests the rocshmem_hipmodule_init API for CUDA/HIP graph capture compatibility. It removes manual management of rocshmem device context via get/set APIs and now initializes rocshmem module at JIT compilation, making compatible with torch.cuda.graph() capture. Signed-off-by: Dimple Prajapati <dimple.prajapati@amd.com>
|
good job! finally we get rid of the rocshmem_set_ctx. rocshmem_set_ctx force developer to add a ctx argument for each entry kernel. I have to write 2 implementations for a single kernel, one as device function without ctx argument, one as global function with ctx argument. now it's all the same. thanks a lot! |
Agree with the inconsistency.. good that we can unify the implementation now. It took a while to get this right, the HIP apis have specific requirements to perform gpu operations during graph capture so the change became a bit tricky. Let’s merge this once rocshmem PR is in. Thanks for review. |
This commit integrates and tests the rocshmem_hipmodule_init API for CUDA/HIP graph capture compatibility.
It removes manual management of rocshmem device context via get/set APIs and now initializes rocshmem module at JIT compilation, making it compatible with torch.cuda.graph() capture.
This change depends on this rocSHMEM PR ROCm/rocm-systems#3165. rocshmem submodule needs to be updated properly post merging this rocSHMEM PR.