Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 58 additions & 5 deletions trellis2/modules/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def __getitem__(self, idx):
def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
if isinstance(dim, int):
dim = (dim,)

if op =='mean':
red = self.feats.mean(dim=dim, keepdim=keepdim)
elif op =='sum':
Expand All @@ -276,12 +276,65 @@ def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keep
red = self.feats.prod(dim=dim, keepdim=keepdim)
else:
raise ValueError(f"Unsupported reduce operation: {op}")

if dim is None or 0 in dim:
return red

red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen)
return red

# Defensive cache validation: in the cascade Shape-SLat sampler's
# CFG-rescale path (classifier_free_guidance_mixin.std_pos = ...),
# x_0_pos is derived from x_t via an elementwise chain that
# propagates _spatial_cache by reference (see replace() /
# __merge_sparse_cache below). If the inherited cached `seqlen`
# was computed at a different scale than x_0_pos.feats, the
# invariant sum(seqlen) == feats.shape[0] no longer holds.
# CUDA's segment_reduce kernel silently consumes the mismatch;
# the CPU path (MPS fallback) raises RuntimeError. Recompute
# lengths from coords (authoritative) and evict the stale
# cache entries so subsequent reads stay correct.
lengths = self.seqlen
n_data = red.shape[0]
if int(lengths.sum().item()) != n_data:
fresh = None
coords = getattr(self, 'coords', None)
if coords is not None and coords.shape[0] == n_data:
batch_size = int(coords[:, 0].max().item()) + 1 if coords.shape[0] > 0 else 1
fresh = torch.bincount(coords[:, 0].long(), minlength=batch_size).to(
dtype=torch.long, device=red.device
)
if fresh is None or int(fresh.sum().item()) != n_data:
fresh = torch.tensor(
[l.stop - l.start for l in self.layout],
dtype=torch.long, device=red.device,
)
if hasattr(self, '_spatial_cache') and hasattr(self, '_scale'):
try:
scale_key = str(self._scale)
slot = self._spatial_cache.get(scale_key, {})
for k in ('seqlen', 'cum_seqlen', 'batch_boardcast_map', 'layout'):
slot.pop(k, None)
except Exception:
pass
elif hasattr(self, '_cache'):
try:
for k in ('seqlen', 'cum_seqlen', 'batch_boardcast_map'):
self._cache.pop(k, None)
except Exception:
pass
lengths = fresh
if int(lengths.sum().item()) != n_data:
raise RuntimeError(
f"VarLenTensor.reduce: cannot reconcile seqlen "
f"sum({int(lengths.sum().item())}) with data.size(0)={n_data}. "
f"layout has {len(self.layout)} segments."
)

# torch.segment_reduce has no Metal kernel; PyTorch's auto fallback
# is racy on cascade-sized workloads. Run explicitly on CPU and
# copy back. No-op on CUDA / CPU.
if red.device.type == 'mps':
reduced = torch.segment_reduce(red.cpu(), reduce=op, lengths=lengths.cpu())
return reduced.to(red.device)
return torch.segment_reduce(red, reduce=op, lengths=lengths)

def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
return self.reduce(op='mean', dim=dim, keepdim=keepdim)
Expand Down