Add Intel XPU (XCCL) support across torchft #327
Conversation
This change extends torchft to run on Intel XPU accelerators alongside
CUDA. CUDA / NCCL behavior is left unchanged; ProcessGroupXCCL mirrors
the ProcessGroupNCCL structure. Test files are intentionally excluded
from this commit.
Files changed:
torchft/__init__.py
- Export ProcessGroupXCCL and ProcessGroupBabyXCCL alongside the
existing NCCL variants.
torchft/process_group.py
- ProcessGroupGloo: additionally register the backend on
torch.device("xpu") when XPU is available. CUDA registration
unchanged.
- _WorkAcceleratorTimeout.wait: replace torch.cuda.synchronize() with
torch.accelerator.synchronize() so the timeout fires correctly on
both NCCL and XCCL hosts (previously crashed on XPU with
"Torch not compiled with CUDA enabled" when a recv_checkpoint
timed out).
- ProcessGroupXCCL: mirror ProcessGroupNCCL's structure.
* errored() routes through synchronize() (the existing helper)
instead of torch.xpu.current_stream().synchronize() directly.
* abort() uses errored=errored keyword (same call shape as NCCL).
- ProcessGroupBabyXCCL: pass the required Options() argument to
BaseProcessGroupXCCL (matches the torch.distributed signature).
torchft/local_sgd.py
- _StreamingDiLoCoFragment._stream / _stop_event are now torch.Stream /
torch.Event constructed against the active accelerator. Behavior on
CUDA is unchanged (torch.accelerator.* delegates to torch.cuda.*).
- Replace inline torch.cuda.stream / nullcontext branching with
get_stream_context(self._stream).
torchft/checkpointing/http_transport.py
- _stream is built from torch.accelerator.current_accelerator().
- Use get_stream_context for the staging block.
- _to_cpu now copies both 'cuda' and 'xpu' tensors to host
non-blocking; previously only 'cuda' tensors hit that path.
torchft/checkpointing/pg_transport_bench.py
- Branch between ProcessGroupBabyNCCL and ProcessGroupBabyXCCL based
on the resolved device type.
torchft/collectives.py
- Sync stream is created via torch.Stream(device=device) (the generic
base class). get_stream_context already dispatches to the right
torch.cuda.stream / torch.xpu.stream context manager based on
stream.device.type, so device-specific Stream subclasses are not
needed at construction.
- Use torch.accelerator.current_stream() and get_stream_context
instead of torch.cuda.stream context managers.
torchft/quantization.py
- Drop top-level "import torch.cuda as cuda" and centralize FP8
capability detection in _supports_native_fp8(). On CUDA the check
is unchanged (compute capability >= (9, 0)). XPU falls back to the
int8 path until a stable XPU FP8 capability check is available.
train_ddp.py
- No source change required; origin already supports XPU via
ProcessGroupXCCL on the cuda/xpu/cpu branch.
train_diloco.py
- Add an XPU branch that calls torch.xpu.set_device(REPLICA_GROUP_ID
% torch.xpu.device_count()) (no XPU env-var equivalent of
CUDA_VISIBLE_DEVICES, so without this multiple replica groups on
one host would all collide on xpu:0). CUDA / Gloo branch and the
USE_NCCL toggle are unchanged.
End-to-end runs verified on an 8-XPU host:
- train_diloco.py with 2 replica groups: both replicas join quorum,
DiLoCo fragment sync runs at every step, loss decreases.
- train_ddp.py with 2 replica groups: both replicas join quorum and
progress through training steps with allreduce participating.
Regression tests:
- process_group_test.py: 26 passed / 27 skipped (no XPU regressions).
- manager_integ_test + local_sgd_integ_test: 35 passed / 3 skipped.
Known limitations (not addressed here):
- Upstream PyTorch lacks a _share_xpu_ IPC primitive equivalent to
_share_cuda_, so ProcessGroupBabyXCCL cannot pass XPU tensors
across the worker pipe today.
- oneCCL does not support multiple ranks within a single process
(KVS server collision); thread-based rank pools must use the
Baby* variants which run each rank in its own subprocess.
d4l3k
left a comment
There was a problem hiding this comment.
Generally this seems pretty reasonable, will wait on CI
| return "torchft-baby-nccl" | ||
|
|
||
|
|
||
| class ProcessGroupBabyXCCL(ProcessGroupBaby): |
There was a problem hiding this comment.
What's the story for XCCL? Does it support fault tolerance semantics? Do we need BabyXCCL?
There was a problem hiding this comment.
Yes. ProcessGroupXCCL supports the same fault-tolerance semantics as ProcessGroupNCCL, and ProcessGroupBabyXCCL is needed for the same reasons ProcessGroupBabyNCCL is.
XCCL follows the same flow as NCCL.
We have future Intel GPUs in the pipeline and want full collective + FT coverage from day zero, so keeping ProcessGroupXCCL and ProcessGroupBabyXCCL structurally symmetric with the NCCL to avoid future rework.
I've verified train_diloco.py and train_ddp.py on Intel PVC/BMG devices.
Thanks for the review. The lint issues from the previous CI run are fixed .
Could you please re-approve?
|
@d4l3k has imported this pull request. If you are a Meta employee, you can view this in D106871759. (Because this pull request was imported automatically, there will not be any future comments.) |
Summary
Continuation of PR-260 as per RFC-257
Extends torchft to run on Intel XPU accelerators alongside CUDA, using PyTorch's
device-agnostic accelerator API (
torch.accelerator.*) so the same code pathcovers both backends.
Validation
Validated on Intel(R) Arc(TM) Pro B60 Graphics
Test plan
Test script changes will be done as part of PR-326
CC: @tushar00jain @d4l3k @rbabukv