From 586795505a12bb864897305f96456774bc854289 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Wed, 20 May 2026 14:50:39 -0700 Subject: [PATCH 1/8] Use `torch.cdist` in `get_token_frame_atoms` This is the memory peak for large complexes (around 5k tokens). The pairwise distance computation here materializes an 3*N_atom^2 intermediate (I'm seeing 42GB of peak memory from this call alone). `torch.cdist` uses the identity `||a - b||^2 = ||a||^2 + ||b||^2 - 2 a.b` with more than 25 points (or always if you set `use_mm_for_euclid_dist`). In floating point math this has some negative consequences for precision as it ends up subtracting large numbers to get a small result, but given that we then feed it into topk, this shouldn't have that big an effect. It's also about 4x faster (just for the distance computation), though I don't think that really matters here. I tested across 14 monomer (49-4563 residues, 366-36,356 atoms), so we have realistic coordinates. But because distances are only used for frame atoms when is_atomized is true, I artificially flipped all of them to set them as atomized (would've been better to use actual atomized inputs, but this is what I had on hand). `cdist` changes the output of `phi` for four atoms total in the dataset (per input mean 0.05% of atoms, max single input 0.5% of atoms) and reduces the memory usage of this computation for a 42k atoms from 42GB to 13GB. There are some other locations that also do this distance calculation. I think these aren't run during inference, so I haven't hit them, but it might be worth investigating them as well. They don't get fed straight into a topk though, so there may be more precision and differentiability complications. --- openfold3/core/utils/atomize_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openfold3/core/utils/atomize_utils.py b/openfold3/core/utils/atomize_utils.py index 794369817..850aa029b 100644 --- a/openfold3/core/utils/atomize_utils.py +++ b/openfold3/core/utils/atomize_utils.py @@ -775,9 +775,9 @@ def get_token_frame_atoms( atom_asym_id_mask = atom_asym_id[..., None] == atom_asym_id[..., None, :] pair_mask = pair_mask * atom_asym_id_mask - # Compute distance matrix + # Compute distance matrix. Use cdist to avoid materializing N*N*3 intermediate # [*, N_atom, N_atom] - d = torch.sum(eps + (x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1) ** 0.5 + d = torch.cdist(x, x) d = d * pair_mask + inf * (1 - pair_mask) # Find indices of two closest atoms for start atoms From c81781385839e564759f366f7bd6d2cec05dae56 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Wed, 20 May 2026 14:59:02 -0700 Subject: [PATCH 2/8] Use inplace bool masks in `get_token_frame_atoms` The masking line `d = d * pair_mask + inf * (1 - pair_mask)` momentarily holds 4-5 simultaneous fp32 [N_atom, N_atom] tensors (the existing pair_mask and d, plus three new intermediates). At N~42k this expression alone accounts for ~27 GB of the function's ~33 GB peak after applying the cdist optimization. Switching to: - pair_mask kept as bool (1.7 GB at N=41863 vs 6.7 GB as fp32) - in-place AND for the chain-restriction step - inplace masked_fill_ instead of multiply-and-add drops the function's peak to ~10GB at N~42K. This is bit-exact when pair_mask is 0/1-valued (which it is here, since atom_mask is 0/1-float and atom_asym_id_mask is bool): d * 1 + 0 == d and 0 + inf * 1 == inf in IEEE 754. Verified empirically 14 monomer inputs that `phi` and `valid_frame_mask` are bit-identical to the previous masking form. --- openfold3/core/utils/atomize_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/openfold3/core/utils/atomize_utils.py b/openfold3/core/utils/atomize_utils.py index 850aa029b..077923add 100644 --- a/openfold3/core/utils/atomize_utils.py +++ b/openfold3/core/utils/atomize_utils.py @@ -762,23 +762,23 @@ def get_token_frame_atoms( valid_frame_mask: [*, N_token] Mask denoting valid frames """ - # Create pairwise atom mask - pair_mask = atom_mask[..., None] * atom_mask[..., None, :] + # Pairwise atom mask, kept as bool throughout to avoid materializing + # large fp32 [N_atom, N_atom] intermediates. + am_bool = atom_mask.bool() + pair_mask = am_bool[..., None] & am_bool[..., None, :] - # Update pairwise atom mask # Restrict to atoms within the same chain atom_asym_id = broadcast_token_feat_to_atoms( token_mask=batch["token_mask"], num_atoms_per_token=batch["num_atoms_per_token"], token_feat=batch["asym_id"], ) - atom_asym_id_mask = atom_asym_id[..., None] == atom_asym_id[..., None, :] - pair_mask = pair_mask * atom_asym_id_mask + pair_mask &= atom_asym_id[..., None] == atom_asym_id[..., None, :] # Compute distance matrix. Use cdist to avoid materializing N*N*3 intermediate # [*, N_atom, N_atom] d = torch.cdist(x, x) - d = d * pair_mask + inf * (1 - pair_mask) + d.masked_fill_(~pair_mask, inf) # Find indices of two closest atoms for start atoms # [*, N_token] From 042e14d6f91b6341a3145e1f239570455f1e684b Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Tue, 26 May 2026 12:29:19 -0700 Subject: [PATCH 3/8] Only move zij to GPU in slices for PAE/PDE computation Under offload_inference, pairformer_embedding returns zij on CPU. Even if `apply_per_sample` was set, the previous code moved the whole multi-sample pair representation back to GPU (S*N_tok^2*C_z) before running the per-pair heads (PDE/PAE), which is the single largest allocation in the run around 5k tokens and 5 samples. Instead, plumb a compute_device kwarg through PDE/PAE's forward into `_chunk`. The output buffer is still allocated on zij.device (CPU under offload), but each iteration copies one sample's slice to compute_device, runs the linear+layernorm, copies the result slice back to zij.device, and writes it into the output buffer. The multi-sample pair-rep tensor never lives on GPU. Validated on 4967 residue 8uq5_A monomer on MI300X: - peak max_memory_allocated: 165 GB -> 132 GB (-20%) - runtime: 1840s -> 1824s (noise) - outputs match exactly when torch is run in deterministic mode Note that to keep the API contract sensible, this adds a copy back to zij.device even in the non `apply_per_sample` case (otherwise, the return device would differ based on `apply_per_sample`, which would be a bit weird). This means that if `apply_per_sample` is False and `offload_inference` is True, whichever of PDE/PAE is called last will unnecessarily copy the full zij to the CPU and then immediately back to the GPU in the following dict comprehension (if it's not last then the copy back is actually useful as it frees GPU memory for the subsequent head). I played with having an output device parameter, but it made things excessively complex. I think this is not really a sensible usage scenario (I can't think of a scenario where offload threshold should be lower than per-sample threshold) and the overhead of the copy is not really very large (see benchmark numbers above), so I just left the simpler thing. --- openfold3/core/model/heads/head_modules.py | 29 +++++----- .../core/model/heads/prediction_heads.py | 55 +++++++++++++++---- 2 files changed, 59 insertions(+), 25 deletions(-) diff --git a/openfold3/core/model/heads/head_modules.py b/openfold3/core/model/heads/head_modules.py index 46a132ac4..8185b30b0 100644 --- a/openfold3/core/model/heads/head_modules.py +++ b/openfold3/core/model/heads/head_modules.py @@ -239,23 +239,26 @@ def forward( ) aux_out["experimentally_resolved_logits"] = experimentally_resolved_logits - # zij is moved back to GPU after the single rep confidence heads - # because building the max_atom_per_token_mask uses a lot of memory - zij = zij.to(device=out_device) - - pde_logits = self.pde(zij, apply_per_sample=apply_per_sample) + # We leave zij on CPU here and let the PDE/PAE heads pull what they + # need. This enables moving only a single sample onto the GPU at a time + # if running with apply_per_sample. + aux_out["pde_logits"] = self.pde( + zij, + apply_per_sample=apply_per_sample, + compute_device=out_device, + ) if self.config.pae.enabled: - # Offload pde logits to not keep all three pairwise tensors - # in GPU memory at once - offload_device = "cpu" if offload_inference else out_device - pde_logits = pde_logits.to(device=offload_device) - aux_out["pae_logits"] = self.pae(zij, apply_per_sample=apply_per_sample) + aux_out["pae_logits"] = self.pae( + zij, + apply_per_sample=apply_per_sample, + compute_device=out_device, + ) del zij - aux_out["pde_logits"] = pde_logits.to(device=out_device) - - aux_out = {k: v.to(dtype=out_dtype) for k, v in aux_out.items()} + aux_out = { + k: v.to(device=out_device, dtype=out_dtype) for k, v in aux_out.items() + } return aux_out diff --git a/openfold3/core/model/heads/prediction_heads.py b/openfold3/core/model/heads/prediction_heads.py index c446f2807..ba1b92735 100644 --- a/openfold3/core/model/heads/prediction_heads.py +++ b/openfold3/core/model/heads/prediction_heads.py @@ -399,19 +399,28 @@ def _compute_logits(self, zij: torch.Tensor): def _chunk( self, zij: torch.Tensor, + compute_device: torch.device | None = None, ) -> torch.Tensor: + # ``zij`` will be moved in slices to ``compute_device`` for the layer + # norm + linear, and the output logits will be moved afterwards to + # the original ``zij.device`` zij_out = torch.zeros( (*zij.shape[:-1], self.c_out), device=zij.device, dtype=zij.dtype ) no_samples = zij.shape[-4] for i in range(no_samples): - zij_out[..., i : i + 1, :, :, :] = self._compute_logits( - zij[..., i : i + 1, :, :, :] - ) + slice_in = zij[..., i : i + 1, :, :, :].to(device=compute_device) + slice_out = self._compute_logits(slice_in).to(device=zij.device) + zij_out[..., i : i + 1, :, :, :] = slice_out return zij_out - def forward(self, zij, apply_per_sample: bool = False): + def forward( + self, + zij, + apply_per_sample: bool = False, + compute_device: torch.device | None = None, + ): """ Args: zij: @@ -421,14 +430,22 @@ def forward(self, zij, apply_per_sample: bool = False): This is a memory optimization which is only used during validation/inference and will depend on the number of samples in the full rollout. + compute_device: + Device on which to run computation. zij will be moved here + before doing any computation. When apply_per_sample is true, + each per-sample slice of ``zij`` is moved onto this device + separately for the computation and the output is moved to + ``zij.device`` before processing the next slice. Returns: logits: [*, N, N, C_out] Logits """ if apply_per_sample: - logits = self._chunk(zij=zij) + logits = self._chunk(zij=zij, compute_device=compute_device) else: - logits = self._compute_logits(zij=zij) + logits = self._compute_logits(zij=zij.to(device=compute_device)).to( + device=zij.device + ) return logits @@ -471,19 +488,25 @@ def _compute_logits(self, zij: torch.Tensor): def _chunk( self, zij: torch.Tensor, + compute_device: torch.device | None = None, ) -> torch.Tensor: zij_out = torch.zeros( (*zij.shape[:-1], self.c_out), device=zij.device, dtype=zij.dtype ) no_samples = zij.shape[-4] for i in range(no_samples): - zij_out[..., i : i + 1, :, :, :] = self._compute_logits( - zij[..., i : i + 1, :, :, :] - ) + slice_in = zij[..., i : i + 1, :, :, :].to(device=compute_device) + slice_out = self._compute_logits(slice_in).to(device=zij.device) + zij_out[..., i : i + 1, :, :, :] = slice_out return zij_out - def forward(self, zij, apply_per_sample: bool = False): + def forward( + self, + zij, + apply_per_sample: bool = False, + compute_device: torch.device | None = None, + ): """ Args: zij: @@ -493,14 +516,22 @@ def forward(self, zij, apply_per_sample: bool = False): This is a memory optimization which is only used during validation/inference and will depend on the number of samples in the full rollout. + compute_device: + Device on which to run computation. zij will be moved here + before doing any computation. When apply_per_sample is true, + each per-sample slice of ``zij`` is moved onto this device + separately for the computation and the output is moved to + ``zij.device`` before processing the next slice. Returns: logits: [*, N, N, C_out] Logits """ if apply_per_sample: - logits = self._chunk(zij=zij) + logits = self._chunk(zij=zij, compute_device=compute_device) else: - logits = self._compute_logits(zij=zij) + logits = self._compute_logits(zij=zij.to(device=compute_device)).to( + device=zij.device + ) return logits From f545f2064e70d8685e142100620b9c5c7e94931a Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Wed, 27 May 2026 11:59:55 -0700 Subject: [PATCH 4/8] Defer PDE/PAE/distogram move-to-GPU until consumed Under offload_inference, aux_heads previously moved PDE, PAE, and distogram logits back to GPU in its final dict comp, simultaneously putting both per-pair head outputs on GPU `~2*S*N_tok^2*C_out*4` bytes. This is ~64 GB at N_tok=5k, S=5, C_out=64, which is a binding memory peak (after my previous memory reduction changes). Instead, we leave them on CPU at `aux_heads` exit and have `_get_confidence_scores` move each one onto GPU only while it is being consumed (PDE for softmax->pde->gpde, PAE for softmax->pae and the sample-ranking compute, distogram for gpde). Then drop the per-call GPU reference. PDE and PAE no longer coexist on GPU. This doesn't always lower the program peak by itself (there are other points that hit the same memory peak this is removing), but unblocks subsequent optimizations. Verified outputs match exactly in torch deterministic mode and runtime is within noise. --- .../metrics/aggregate_confidence_ranking.py | 31 +++++++++++++++---- openfold3/core/model/heads/head_modules.py | 10 ++++-- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/openfold3/core/metrics/aggregate_confidence_ranking.py b/openfold3/core/metrics/aggregate_confidence_ranking.py index e868e2006..8cdaa4690 100644 --- a/openfold3/core/metrics/aggregate_confidence_ranking.py +++ b/openfold3/core/metrics/aggregate_confidence_ranking.py @@ -31,6 +31,15 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> dict: + # Under inference offload, aux_heads returns pde_logits, pae_logits, and + # distogram_logits on CPU. Move each one onto the compute device only while + # it's being consumed and drop the local reference afterwards (or just use a + # temporary in the first place), so that PDE and PAE are never both + # device-resident at the same time. Use atom_positions_predicted as the device + # anchor: it's always produced on the compute device by the diffusion + # sampler. + compute_device = outputs["atom_positions_predicted"].device + confidence_scores = {} confidence_scores["plddt"] = ( probs_to_expected_error( @@ -39,7 +48,7 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di * 100.0 ) - pde_probs = torch.softmax(outputs["pde_logits"], dim=-1) + pde_probs = torch.softmax(outputs["pde_logits"].to(device=compute_device), dim=-1) confidence_scores["pde"] = probs_to_expected_error( pde_probs, **config.confidence.pde ) @@ -50,7 +59,7 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di confidence_scores["gpde"], contact_probs = compute_global_predicted_distance_error( pde=confidence_scores["pde"], - logits=outputs["distogram_logits"], + logits=outputs["distogram_logits"].to(device=compute_device), **config.confidence.distogram, ) if config.confidence.distogram.return_contact_probs: @@ -59,7 +68,8 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di del contact_probs if config.architecture.heads.pae.enabled: - pae_probs = torch.softmax(outputs["pae_logits"], dim=-1) + pae_logits_on_device = outputs["pae_logits"].to(device=compute_device) + pae_probs = torch.softmax(pae_logits_on_device, dim=-1) confidence_scores["pae"] = probs_to_expected_error( pae_probs, **config.confidence.pae ) @@ -76,10 +86,16 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di valid_frame_mask = valid_frame_mask.bool() + # Patch outputs locally so downstream sample-ranking sees the + # device-resident pae_logits without us having to thread it through + # every callee signature. + outputs_for_ranking = dict(outputs) + outputs_for_ranking["pae_logits"] = pae_logits_on_device + confidence_scores.update( full_complex_sample_ranking_metric( batch=batch, - output=outputs, + output=outputs_for_ranking, has_frame=valid_frame_mask, **config.confidence.sample_ranking.full_complex, **config.confidence.ptm, @@ -90,7 +106,7 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di confidence_scores.update( compute_chain_pair_iptm( batch=batch, - logits=outputs["pae_logits"], + logits=pae_logits_on_device, has_frame=valid_frame_mask, **config.confidence.ptm, ) @@ -100,12 +116,15 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di confidence_scores.update( compute_chain_ptm( batch=batch, - outputs=outputs, + outputs=outputs_for_ranking, has_frame=valid_frame_mask, **config.confidence.ptm, ) ) + del outputs_for_ranking + del pae_logits_on_device + return confidence_scores diff --git a/openfold3/core/model/heads/head_modules.py b/openfold3/core/model/heads/head_modules.py index 8185b30b0..44b32518d 100644 --- a/openfold3/core/model/heads/head_modules.py +++ b/openfold3/core/model/heads/head_modules.py @@ -167,6 +167,12 @@ def forward( # Distogram head: Main loop (Algorithm 1), line 17 distogram_logits = self.distogram(z=zij) + # Under offload_inference, move distogram off GPU now; downstream + # confidence scoring consumes it once at the very end (gpde) and + # can pull it back on demand. Saves ~S*N^2*C_out*4 bytes of GPU + # peak during all of the per-pair-head compute that follows. + if offload_inference: + distogram_logits = distogram_logits.to(device="cpu") aux_out["distogram_logits"] = distogram_logits # Stop grad @@ -257,8 +263,6 @@ def forward( del zij - aux_out = { - k: v.to(device=out_device, dtype=out_dtype) for k, v in aux_out.items() - } + aux_out = {k: v.to(dtype=out_dtype) for k, v in aux_out.items()} return aux_out From 0e74aad7ffc10b32e324025bef70fc724ed8cebe Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Wed, 27 May 2026 12:00:07 -0700 Subject: [PATCH 5/8] template_embedders: fold linear projections in-place `_embed_feats` has a chain of 8 'a = a + self._linear(x)' out-of-place additions. Each step has the accumulator a, the new linear output, and a+linear(x) all live during the add, so the peak per step is 3x `[*, N_templ, N, N, C]` tensors (~32 GB for my example N_tok=5570 monomer). `add_` folds the linear output into `a` in-place, dropping the 3x to 2x and freeing the temporary linear output immediately after fold-in. Confirmed output is bit-identical in torch deterministic mode. --- .../feature_embedders/template_embedders.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/openfold3/core/model/feature_embedders/template_embedders.py b/openfold3/core/model/feature_embedders/template_embedders.py index 32fe06991..70003d7fa 100644 --- a/openfold3/core/model/feature_embedders/template_embedders.py +++ b/openfold3/core/model/feature_embedders/template_embedders.py @@ -102,14 +102,20 @@ def _embed_feats(self, batch: dict): *template_restype.shape[:-2], n_token, -1, -1 ) + # Build the accumulator from the first projection, then fold each + # subsequent linear in-place. Out of place adds (a = a + linear(x)) keep + # a, linear(x), and a+linear(x) all live during the addition; that's + # three [*, N_templ, N, N, C] tensors co-resident. add_ frees the new + # sum's separate tensor (it's written into a) so peak per step drops + # from 3x slice to ~2x slice. a = self.dgram_linear(template_distogram) - a = a + self.pseudo_beta_mask_linear(pseudo_beta_pair_mask) - a = a + self.aatype_linear_1(template_restype_ti.to(dtype=dtype)) - a = a + self.aatype_linear_2(template_restype_tj.to(dtype=dtype)) - a = a + self.x_linear(x[..., None]) - a = a + self.y_linear(y[..., None]) - a = a + self.z_linear(z[..., None]) - a = a + self.backbone_mask_linear(backbone_frame_pair_mask) + a.add_(self.pseudo_beta_mask_linear(pseudo_beta_pair_mask)) + a.add_(self.aatype_linear_1(template_restype_ti.to(dtype=dtype))) + a.add_(self.aatype_linear_2(template_restype_tj.to(dtype=dtype))) + a.add_(self.x_linear(x[..., None])) + a.add_(self.y_linear(y[..., None])) + a.add_(self.z_linear(z[..., None])) + a.add_(self.backbone_mask_linear(backbone_frame_pair_mask)) return a From dabfb380182b25eedea0a25ad6324119b3a8d496 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Tue, 2 Jun 2026 16:59:30 -0700 Subject: [PATCH 6/8] Chunk diffusion_conditioning embedding Unchunked, this cat is now the memory bottleneck. At ~5k tokens, this drops peak memory by ~26GB. --- .../model/layers/diffusion_conditioning.py | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/openfold3/core/model/layers/diffusion_conditioning.py b/openfold3/core/model/layers/diffusion_conditioning.py index 2059c62a0..6686736fa 100644 --- a/openfold3/core/model/layers/diffusion_conditioning.py +++ b/openfold3/core/model/layers/diffusion_conditioning.py @@ -21,7 +21,7 @@ from openfold3.core.model.layers.transition import SwiGLUTransition from openfold3.core.model.primitives.linear import Linear from openfold3.core.model.primitives.normalization import LayerNorm -from openfold3.core.utils.chunk_utils import ChunkSizeTuner +from openfold3.core.utils.chunk_utils import ChunkSizeTuner, chunk_layer from openfold3.core.utils.relpos import relpos_complex @@ -137,16 +137,28 @@ def _embed_trunk_inputs( si_input: torch.Tensor, si_trunk: torch.Tensor, zij_trunk: torch.Tensor, + chunk_size: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - # Pair conditioning relpos_zij = relpos_complex( batch=batch, max_relative_idx=self.max_relative_idx, max_relative_chain=self.max_relative_chain, ).to(dtype=zij_trunk.dtype) - zij = torch.cat([zij_trunk, relpos_zij], dim=-1) - zij = self.linear_z(self.layer_norm_z(zij)) + def _proj_zij(zij_trunk_in, relpos_in): + return self.linear_z( + self.layer_norm_z(torch.cat([zij_trunk_in, relpos_in], dim=-1)) + ) + + if chunk_size is not None: + zij = chunk_layer( + layer=_proj_zij, + inputs={"zij_trunk_in": zij_trunk, "relpos_in": relpos_zij}, + chunk_size=chunk_size, + no_batch_dims=zij_trunk.dim() - 2, + ) + else: + zij = _proj_zij(zij_trunk, relpos_zij) # Single conditioning si = torch.cat([si_trunk, si_input], dim=-1) @@ -246,7 +258,12 @@ def forward( zij_trunk = zij_trunk * 0 si, zij = self._embed_trunk_inputs( - batch=batch, t=t, si_input=si_input, si_trunk=si_trunk, zij_trunk=zij_trunk + batch=batch, + t=t, + si_input=si_input, + si_trunk=si_trunk, + zij_trunk=zij_trunk, + chunk_size=chunk_size, ) if chunk_size is not None: From 85c30fcb0864a2291222bd54a24db7b1d2bca575 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Tue, 9 Jun 2026 11:26:25 -0700 Subject: [PATCH 7/8] Guard inplace adds with inplace_safe --- .../feature_embedders/template_embedders.py | 39 +++++++++++-------- .../core/model/latent/template_module.py | 3 +- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/openfold3/core/model/feature_embedders/template_embedders.py b/openfold3/core/model/feature_embedders/template_embedders.py index 70003d7fa..0298e9b7c 100644 --- a/openfold3/core/model/feature_embedders/template_embedders.py +++ b/openfold3/core/model/feature_embedders/template_embedders.py @@ -23,6 +23,7 @@ import openfold3.core.config.default_linear_init_config as lin_init from openfold3.core.model.primitives import LayerNorm, Linear +from openfold3.core.utils.tensor_utils import add class TemplatePairEmbedderAllAtom(nn.Module): @@ -67,7 +68,7 @@ def __init__( self.layer_norm_z = LayerNorm(c_in) self.linear_z = Linear(c_in, c_out, **linear_init_params.linear_z) - def _embed_feats(self, batch: dict): + def _embed_feats(self, batch: dict, inplace_safe: bool = False): dtype = batch["template_unit_vector"].dtype # [*, N_token, N_token] @@ -102,24 +103,30 @@ def _embed_feats(self, batch: dict): *template_restype.shape[:-2], n_token, -1, -1 ) - # Build the accumulator from the first projection, then fold each - # subsequent linear in-place. Out of place adds (a = a + linear(x)) keep - # a, linear(x), and a+linear(x) all live during the addition; that's - # three [*, N_templ, N, N, C] tensors co-resident. add_ frees the new - # sum's separate tensor (it's written into a) so peak per step drops - # from 3x slice to ~2x slice. a = self.dgram_linear(template_distogram) - a.add_(self.pseudo_beta_mask_linear(pseudo_beta_pair_mask)) - a.add_(self.aatype_linear_1(template_restype_ti.to(dtype=dtype))) - a.add_(self.aatype_linear_2(template_restype_tj.to(dtype=dtype))) - a.add_(self.x_linear(x[..., None])) - a.add_(self.y_linear(y[..., None])) - a.add_(self.z_linear(z[..., None])) - a.add_(self.backbone_mask_linear(backbone_frame_pair_mask)) + a = add( + a, self.pseudo_beta_mask_linear(pseudo_beta_pair_mask), inplace=inplace_safe + ) + a = add( + a, + self.aatype_linear_1(template_restype_ti.to(dtype=dtype)), + inplace=inplace_safe, + ) + a = add( + a, + self.aatype_linear_2(template_restype_tj.to(dtype=dtype)), + inplace=inplace_safe, + ) + a = add(a, self.x_linear(x[..., None]), inplace=inplace_safe) + a = add(a, self.y_linear(y[..., None]), inplace=inplace_safe) + a = add(a, self.z_linear(z[..., None]), inplace=inplace_safe) + a = add( + a, self.backbone_mask_linear(backbone_frame_pair_mask), inplace=inplace_safe + ) return a - def forward(self, batch, z): + def forward(self, batch, z, inplace_safe: bool = False): """ Args: batch: @@ -129,7 +136,7 @@ def forward(self, batch, z): Returns: # [*, N_templ, N_token, N_token, C_out] Template pair feature embedding """ - a = self._embed_feats(batch=batch) + a = self._embed_feats(batch=batch, inplace_safe=inplace_safe) # [*, N_templ, N_token, N_token, C_out] z = self.linear_z(self.layer_norm_z(z)) diff --git a/openfold3/core/model/latent/template_module.py b/openfold3/core/model/latent/template_module.py index d627da83d..485d068d0 100644 --- a/openfold3/core/model/latent/template_module.py +++ b/openfold3/core/model/latent/template_module.py @@ -565,6 +565,7 @@ def _forward_offload( t = self.template_pair_embedder( batch=batch_templ, z=z, + inplace_safe=inplace_safe, ) # [*, N_templ, N_token, N_token, C_z] @@ -608,7 +609,7 @@ def _forward( inplace_safe: bool = False, ) -> torch.Tensor: # [*, N_templ, N_token, N_token, C_t] - t = self.template_pair_embedder(batch, z) + t = self.template_pair_embedder(batch, z, inplace_safe=inplace_safe) # [*, 1, N_token, N_token] pair_mask = pair_mask[..., None, :, :].to(dtype=z.dtype) From 137fa73ab7c231f8e54a7d3145cfa060dd8f86ad Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Tue, 9 Jun 2026 11:37:25 -0700 Subject: [PATCH 8/8] Mark large model test as slow --- openfold3/tests/test_of3_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/openfold3/tests/test_of3_model.py b/openfold3/tests/test_of3_model.py index 1dbf07a1e..d1be18f8a 100644 --- a/openfold3/tests/test_of3_model.py +++ b/openfold3/tests/test_of3_model.py @@ -206,6 +206,7 @@ def test_shape_large_eval(self, dtype): use_triton_triangle_kernels=is_rocm, ) + @pytest.mark.slow @compare_utils.skip_unless_triton_installed() @compare_utils.skip_unless_cuda_available() def test_shape_large_bf16_train(self):