diff --git a/trellis2/modules/sparse/basic.py b/trellis2/modules/sparse/basic.py index 880973b8..03873f41 100644 --- a/trellis2/modules/sparse/basic.py +++ b/trellis2/modules/sparse/basic.py @@ -267,7 +267,7 @@ def __getitem__(self, idx): def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: if isinstance(dim, int): dim = (dim,) - + if op =='mean': red = self.feats.mean(dim=dim, keepdim=keepdim) elif op =='sum': @@ -276,12 +276,65 @@ def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keep red = self.feats.prod(dim=dim, keepdim=keepdim) else: raise ValueError(f"Unsupported reduce operation: {op}") - + if dim is None or 0 in dim: return red - - red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen) - return red + + # Defensive cache validation: in the cascade Shape-SLat sampler's + # CFG-rescale path (classifier_free_guidance_mixin.std_pos = ...), + # x_0_pos is derived from x_t via an elementwise chain that + # propagates _spatial_cache by reference (see replace() / + # __merge_sparse_cache below). If the inherited cached `seqlen` + # was computed at a different scale than x_0_pos.feats, the + # invariant sum(seqlen) == feats.shape[0] no longer holds. + # CUDA's segment_reduce kernel silently consumes the mismatch; + # the CPU path (MPS fallback) raises RuntimeError. Recompute + # lengths from coords (authoritative) and evict the stale + # cache entries so subsequent reads stay correct. + lengths = self.seqlen + n_data = red.shape[0] + if int(lengths.sum().item()) != n_data: + fresh = None + coords = getattr(self, 'coords', None) + if coords is not None and coords.shape[0] == n_data: + batch_size = int(coords[:, 0].max().item()) + 1 if coords.shape[0] > 0 else 1 + fresh = torch.bincount(coords[:, 0].long(), minlength=batch_size).to( + dtype=torch.long, device=red.device + ) + if fresh is None or int(fresh.sum().item()) != n_data: + fresh = torch.tensor( + [l.stop - l.start for l in self.layout], + dtype=torch.long, device=red.device, + ) + if hasattr(self, '_spatial_cache') and hasattr(self, '_scale'): + try: + scale_key = str(self._scale) + slot = self._spatial_cache.get(scale_key, {}) + for k in ('seqlen', 'cum_seqlen', 'batch_boardcast_map', 'layout'): + slot.pop(k, None) + except Exception: + pass + elif hasattr(self, '_cache'): + try: + for k in ('seqlen', 'cum_seqlen', 'batch_boardcast_map'): + self._cache.pop(k, None) + except Exception: + pass + lengths = fresh + if int(lengths.sum().item()) != n_data: + raise RuntimeError( + f"VarLenTensor.reduce: cannot reconcile seqlen " + f"sum({int(lengths.sum().item())}) with data.size(0)={n_data}. " + f"layout has {len(self.layout)} segments." + ) + + # torch.segment_reduce has no Metal kernel; PyTorch's auto fallback + # is racy on cascade-sized workloads. Run explicitly on CPU and + # copy back. No-op on CUDA / CPU. + if red.device.type == 'mps': + reduced = torch.segment_reduce(red.cpu(), reduce=op, lengths=lengths.cpu()) + return reduced.to(red.device) + return torch.segment_reduce(red, reduce=op, lengths=lengths) def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: return self.reduce(op='mean', dim=dim, keepdim=keepdim)