Fix stale seqlen cache in cascade reduce() (MPS crash, CUDA latent)#167
Open
Kyle-Wang0211 wants to merge 1 commit into
Open
Fix stale seqlen cache in cascade reduce() (MPS crash, CUDA latent)#167Kyle-Wang0211 wants to merge 1 commit into
Kyle-Wang0211 wants to merge 1 commit into
Conversation
VarLenTensor.reduce() in trellis2/modules/sparse/basic.py reads self.seqlen from a per-scale _spatial_cache slot that, in the cascade Shape-SLat sampler's CFG-rescale path (x_0_pos.std(...)), can hold a value left over from the parent tensor's scale. When that happens, sum(seqlen) != feats.shape[0], crashing the CPU-fallback segment_reduce on Apple Silicon (MPS) and likely silently miscomputing on CUDA. Two defensive layers added to reduce(): 1. Cache validation: if sum(lengths) != data.size(0), recompute from coords and evict stale seqlen/cum_seqlen/batch_boardcast_map/layout from the per-scale cache. 2. Explicit MPS -> CPU fallback for segment_reduce, bypassing PyTorch's noisier auto-fallback path (also ~5x faster per cascade step). Verified end-to-end on M3 Pro 18GB with --pipeline-type 1024_cascade --texture-size 2048 --steps 16: 24 min generation + 25s bake, output 21 MB GLB / 242 MB OBJ with 2K PBR. See PR description for full benchmark + cross-platform context. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
|
@Kyle-Wang0211 please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.
Contributor License AgreementContribution License AgreementThis Contribution License Agreement (“Agreement”) is agreed to by the party signing below (“You”),
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
VarLenTensor.reduce()intrellis2/modules/sparse/basic.pyreadsself.seqlenfrom a per-scale
_spatial_cacheslot that, in the cascade Shape-SLat sampler'sCFG-rescale path (
x_0_pos.std(...)), can hold a value left over from the parenttensor's scale. When that happens,
sum(seqlen) != feats.shape[0].torch.segment_reducehas no Metal kernel and PyTorch falls back to CPU. The CPU path strictly enforcessum(lengths) == data.size(0)and raisesRuntimeError. The cascade Shape-SLat pass aborts at ~step 3.std()in cascade mode would return numerically wrong values silently. Easy to verify by addingassert lengths.sum() == data.size(0)before the existing call. We do not have CUDA hardware to confirm.Closes the Apple-Silicon side of #74.
Repro (MPS)
Aborts ~step 3 of the cascade Shape-SLat pass with:
Root cause hypothesis
SparseTensor.seqlenis cached in_spatial_cache[str(self._scale)]['seqlen'](basic.py:543-548).SparseTensor.replace()(basic.py:671-676) constructs the new tensor withspatial_cache=self._spatial_cache— same dict reference, same scale key.__merge_sparse_cache()(basic.py:705-715) merges across two sparse operands.flow_euler.sample_oncederivesx_0_posfrom_pred_to_xstart(x_t, t, pred_pos), an elementwise chain overx_tand the model output. The result inheritsx_t's_spatial_cache.classifier_free_guidance_mixin.py:23) callsx_0_pos.std(dim=list(range(1, x_0_pos.ndim)), keepdim=True), which lands inVarLenTensor.reduce.self.seqlenreturns the inherited cached value, which no longer matchesx_0_pos.feats.shape[0]for this scale.The base Shape-SLat pass does not exercise the same scale interaction and stays valid.
Fix
Two defensive layers in
VarLenTensor.reduce():sum(lengths) != data.size(0), recomputelengthsfromcoords(the authoritative source) and evict the staleseqlen / cum_seqlen / batch_boardcast_map / layoutfrom the per-scale cache. On CUDA this branch is a no-op in the happy path; it only fires in the same stale-cache case that would silently miscompute today.segment_reduce— bypasses the auto fallback's noisier device-sync path. Also gives a ~5× speedup per cascade step (358s → ~50s) in our benchmark, independent of correctness.Benchmark (M3 Pro 18GB,
astronaut.jpeg,--pipeline-type 1024_cascade --texture-size 2048 --steps 16)Output: 21 MB GLB / 242 MB OBJ, 132K vert / 192K face, 2K PBR (base color + metallic + roughness + alpha).
Not verified
lengthsinstead of a stale one.1536_cascade— only1024_cascadewas end-to-end-verified. The cache logic is scale-agnostic but the 1536 path has separate reports (see related).Related cascade-pass reports
1536_cascadeabort, RTX 4090,spatial2channel.py:77OOB1024_cascade/1536_cascade, RTX 5090,max(): Expected reduction dim 0 to have non-zero sizesparse_structure_resolution=64edge casesegment_reducekernel (open)(感谢 TRELLIS.2 团队的开源——repro 在 macOS 上 100% 可复现,方便的话可以在 CUDA 上跑一下 cascade 输出比对,验证 silent miscompute 的影响范围。)