diff --git a/openfold3/core/metrics/aggregate_confidence_ranking.py b/openfold3/core/metrics/aggregate_confidence_ranking.py index e868e2006..8cdaa4690 100644 --- a/openfold3/core/metrics/aggregate_confidence_ranking.py +++ b/openfold3/core/metrics/aggregate_confidence_ranking.py @@ -31,6 +31,15 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> dict: + # Under inference offload, aux_heads returns pde_logits, pae_logits, and + # distogram_logits on CPU. Move each one onto the compute device only while + # it's being consumed and drop the local reference afterwards (or just use a + # temporary in the first place), so that PDE and PAE are never both + # device-resident at the same time. Use atom_positions_predicted as the device + # anchor: it's always produced on the compute device by the diffusion + # sampler. + compute_device = outputs["atom_positions_predicted"].device + confidence_scores = {} confidence_scores["plddt"] = ( probs_to_expected_error( @@ -39,7 +48,7 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di * 100.0 ) - pde_probs = torch.softmax(outputs["pde_logits"], dim=-1) + pde_probs = torch.softmax(outputs["pde_logits"].to(device=compute_device), dim=-1) confidence_scores["pde"] = probs_to_expected_error( pde_probs, **config.confidence.pde ) @@ -50,7 +59,7 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di confidence_scores["gpde"], contact_probs = compute_global_predicted_distance_error( pde=confidence_scores["pde"], - logits=outputs["distogram_logits"], + logits=outputs["distogram_logits"].to(device=compute_device), **config.confidence.distogram, ) if config.confidence.distogram.return_contact_probs: @@ -59,7 +68,8 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di del contact_probs if config.architecture.heads.pae.enabled: - pae_probs = torch.softmax(outputs["pae_logits"], dim=-1) + pae_logits_on_device = outputs["pae_logits"].to(device=compute_device) + pae_probs = torch.softmax(pae_logits_on_device, dim=-1) confidence_scores["pae"] = probs_to_expected_error( pae_probs, **config.confidence.pae ) @@ -76,10 +86,16 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di valid_frame_mask = valid_frame_mask.bool() + # Patch outputs locally so downstream sample-ranking sees the + # device-resident pae_logits without us having to thread it through + # every callee signature. + outputs_for_ranking = dict(outputs) + outputs_for_ranking["pae_logits"] = pae_logits_on_device + confidence_scores.update( full_complex_sample_ranking_metric( batch=batch, - output=outputs, + output=outputs_for_ranking, has_frame=valid_frame_mask, **config.confidence.sample_ranking.full_complex, **config.confidence.ptm, @@ -90,7 +106,7 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di confidence_scores.update( compute_chain_pair_iptm( batch=batch, - logits=outputs["pae_logits"], + logits=pae_logits_on_device, has_frame=valid_frame_mask, **config.confidence.ptm, ) @@ -100,12 +116,15 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di confidence_scores.update( compute_chain_ptm( batch=batch, - outputs=outputs, + outputs=outputs_for_ranking, has_frame=valid_frame_mask, **config.confidence.ptm, ) ) + del outputs_for_ranking + del pae_logits_on_device + return confidence_scores diff --git a/openfold3/core/model/feature_embedders/template_embedders.py b/openfold3/core/model/feature_embedders/template_embedders.py index 32fe06991..0298e9b7c 100644 --- a/openfold3/core/model/feature_embedders/template_embedders.py +++ b/openfold3/core/model/feature_embedders/template_embedders.py @@ -23,6 +23,7 @@ import openfold3.core.config.default_linear_init_config as lin_init from openfold3.core.model.primitives import LayerNorm, Linear +from openfold3.core.utils.tensor_utils import add class TemplatePairEmbedderAllAtom(nn.Module): @@ -67,7 +68,7 @@ def __init__( self.layer_norm_z = LayerNorm(c_in) self.linear_z = Linear(c_in, c_out, **linear_init_params.linear_z) - def _embed_feats(self, batch: dict): + def _embed_feats(self, batch: dict, inplace_safe: bool = False): dtype = batch["template_unit_vector"].dtype # [*, N_token, N_token] @@ -103,17 +104,29 @@ def _embed_feats(self, batch: dict): ) a = self.dgram_linear(template_distogram) - a = a + self.pseudo_beta_mask_linear(pseudo_beta_pair_mask) - a = a + self.aatype_linear_1(template_restype_ti.to(dtype=dtype)) - a = a + self.aatype_linear_2(template_restype_tj.to(dtype=dtype)) - a = a + self.x_linear(x[..., None]) - a = a + self.y_linear(y[..., None]) - a = a + self.z_linear(z[..., None]) - a = a + self.backbone_mask_linear(backbone_frame_pair_mask) + a = add( + a, self.pseudo_beta_mask_linear(pseudo_beta_pair_mask), inplace=inplace_safe + ) + a = add( + a, + self.aatype_linear_1(template_restype_ti.to(dtype=dtype)), + inplace=inplace_safe, + ) + a = add( + a, + self.aatype_linear_2(template_restype_tj.to(dtype=dtype)), + inplace=inplace_safe, + ) + a = add(a, self.x_linear(x[..., None]), inplace=inplace_safe) + a = add(a, self.y_linear(y[..., None]), inplace=inplace_safe) + a = add(a, self.z_linear(z[..., None]), inplace=inplace_safe) + a = add( + a, self.backbone_mask_linear(backbone_frame_pair_mask), inplace=inplace_safe + ) return a - def forward(self, batch, z): + def forward(self, batch, z, inplace_safe: bool = False): """ Args: batch: @@ -123,7 +136,7 @@ def forward(self, batch, z): Returns: # [*, N_templ, N_token, N_token, C_out] Template pair feature embedding """ - a = self._embed_feats(batch=batch) + a = self._embed_feats(batch=batch, inplace_safe=inplace_safe) # [*, N_templ, N_token, N_token, C_out] z = self.linear_z(self.layer_norm_z(z)) diff --git a/openfold3/core/model/heads/head_modules.py b/openfold3/core/model/heads/head_modules.py index 46a132ac4..44b32518d 100644 --- a/openfold3/core/model/heads/head_modules.py +++ b/openfold3/core/model/heads/head_modules.py @@ -167,6 +167,12 @@ def forward( # Distogram head: Main loop (Algorithm 1), line 17 distogram_logits = self.distogram(z=zij) + # Under offload_inference, move distogram off GPU now; downstream + # confidence scoring consumes it once at the very end (gpde) and + # can pull it back on demand. Saves ~S*N^2*C_out*4 bytes of GPU + # peak during all of the per-pair-head compute that follows. + if offload_inference: + distogram_logits = distogram_logits.to(device="cpu") aux_out["distogram_logits"] = distogram_logits # Stop grad @@ -239,23 +245,24 @@ def forward( ) aux_out["experimentally_resolved_logits"] = experimentally_resolved_logits - # zij is moved back to GPU after the single rep confidence heads - # because building the max_atom_per_token_mask uses a lot of memory - zij = zij.to(device=out_device) - - pde_logits = self.pde(zij, apply_per_sample=apply_per_sample) + # We leave zij on CPU here and let the PDE/PAE heads pull what they + # need. This enables moving only a single sample onto the GPU at a time + # if running with apply_per_sample. + aux_out["pde_logits"] = self.pde( + zij, + apply_per_sample=apply_per_sample, + compute_device=out_device, + ) if self.config.pae.enabled: - # Offload pde logits to not keep all three pairwise tensors - # in GPU memory at once - offload_device = "cpu" if offload_inference else out_device - pde_logits = pde_logits.to(device=offload_device) - aux_out["pae_logits"] = self.pae(zij, apply_per_sample=apply_per_sample) + aux_out["pae_logits"] = self.pae( + zij, + apply_per_sample=apply_per_sample, + compute_device=out_device, + ) del zij - aux_out["pde_logits"] = pde_logits.to(device=out_device) - aux_out = {k: v.to(dtype=out_dtype) for k, v in aux_out.items()} return aux_out diff --git a/openfold3/core/model/heads/prediction_heads.py b/openfold3/core/model/heads/prediction_heads.py index c446f2807..ba1b92735 100644 --- a/openfold3/core/model/heads/prediction_heads.py +++ b/openfold3/core/model/heads/prediction_heads.py @@ -399,19 +399,28 @@ def _compute_logits(self, zij: torch.Tensor): def _chunk( self, zij: torch.Tensor, + compute_device: torch.device | None = None, ) -> torch.Tensor: + # ``zij`` will be moved in slices to ``compute_device`` for the layer + # norm + linear, and the output logits will be moved afterwards to + # the original ``zij.device`` zij_out = torch.zeros( (*zij.shape[:-1], self.c_out), device=zij.device, dtype=zij.dtype ) no_samples = zij.shape[-4] for i in range(no_samples): - zij_out[..., i : i + 1, :, :, :] = self._compute_logits( - zij[..., i : i + 1, :, :, :] - ) + slice_in = zij[..., i : i + 1, :, :, :].to(device=compute_device) + slice_out = self._compute_logits(slice_in).to(device=zij.device) + zij_out[..., i : i + 1, :, :, :] = slice_out return zij_out - def forward(self, zij, apply_per_sample: bool = False): + def forward( + self, + zij, + apply_per_sample: bool = False, + compute_device: torch.device | None = None, + ): """ Args: zij: @@ -421,14 +430,22 @@ def forward(self, zij, apply_per_sample: bool = False): This is a memory optimization which is only used during validation/inference and will depend on the number of samples in the full rollout. + compute_device: + Device on which to run computation. zij will be moved here + before doing any computation. When apply_per_sample is true, + each per-sample slice of ``zij`` is moved onto this device + separately for the computation and the output is moved to + ``zij.device`` before processing the next slice. Returns: logits: [*, N, N, C_out] Logits """ if apply_per_sample: - logits = self._chunk(zij=zij) + logits = self._chunk(zij=zij, compute_device=compute_device) else: - logits = self._compute_logits(zij=zij) + logits = self._compute_logits(zij=zij.to(device=compute_device)).to( + device=zij.device + ) return logits @@ -471,19 +488,25 @@ def _compute_logits(self, zij: torch.Tensor): def _chunk( self, zij: torch.Tensor, + compute_device: torch.device | None = None, ) -> torch.Tensor: zij_out = torch.zeros( (*zij.shape[:-1], self.c_out), device=zij.device, dtype=zij.dtype ) no_samples = zij.shape[-4] for i in range(no_samples): - zij_out[..., i : i + 1, :, :, :] = self._compute_logits( - zij[..., i : i + 1, :, :, :] - ) + slice_in = zij[..., i : i + 1, :, :, :].to(device=compute_device) + slice_out = self._compute_logits(slice_in).to(device=zij.device) + zij_out[..., i : i + 1, :, :, :] = slice_out return zij_out - def forward(self, zij, apply_per_sample: bool = False): + def forward( + self, + zij, + apply_per_sample: bool = False, + compute_device: torch.device | None = None, + ): """ Args: zij: @@ -493,14 +516,22 @@ def forward(self, zij, apply_per_sample: bool = False): This is a memory optimization which is only used during validation/inference and will depend on the number of samples in the full rollout. + compute_device: + Device on which to run computation. zij will be moved here + before doing any computation. When apply_per_sample is true, + each per-sample slice of ``zij`` is moved onto this device + separately for the computation and the output is moved to + ``zij.device`` before processing the next slice. Returns: logits: [*, N, N, C_out] Logits """ if apply_per_sample: - logits = self._chunk(zij=zij) + logits = self._chunk(zij=zij, compute_device=compute_device) else: - logits = self._compute_logits(zij=zij) + logits = self._compute_logits(zij=zij.to(device=compute_device)).to( + device=zij.device + ) return logits diff --git a/openfold3/core/model/latent/template_module.py b/openfold3/core/model/latent/template_module.py index d627da83d..485d068d0 100644 --- a/openfold3/core/model/latent/template_module.py +++ b/openfold3/core/model/latent/template_module.py @@ -565,6 +565,7 @@ def _forward_offload( t = self.template_pair_embedder( batch=batch_templ, z=z, + inplace_safe=inplace_safe, ) # [*, N_templ, N_token, N_token, C_z] @@ -608,7 +609,7 @@ def _forward( inplace_safe: bool = False, ) -> torch.Tensor: # [*, N_templ, N_token, N_token, C_t] - t = self.template_pair_embedder(batch, z) + t = self.template_pair_embedder(batch, z, inplace_safe=inplace_safe) # [*, 1, N_token, N_token] pair_mask = pair_mask[..., None, :, :].to(dtype=z.dtype) diff --git a/openfold3/core/model/layers/diffusion_conditioning.py b/openfold3/core/model/layers/diffusion_conditioning.py index 2059c62a0..6686736fa 100644 --- a/openfold3/core/model/layers/diffusion_conditioning.py +++ b/openfold3/core/model/layers/diffusion_conditioning.py @@ -21,7 +21,7 @@ from openfold3.core.model.layers.transition import SwiGLUTransition from openfold3.core.model.primitives.linear import Linear from openfold3.core.model.primitives.normalization import LayerNorm -from openfold3.core.utils.chunk_utils import ChunkSizeTuner +from openfold3.core.utils.chunk_utils import ChunkSizeTuner, chunk_layer from openfold3.core.utils.relpos import relpos_complex @@ -137,16 +137,28 @@ def _embed_trunk_inputs( si_input: torch.Tensor, si_trunk: torch.Tensor, zij_trunk: torch.Tensor, + chunk_size: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - # Pair conditioning relpos_zij = relpos_complex( batch=batch, max_relative_idx=self.max_relative_idx, max_relative_chain=self.max_relative_chain, ).to(dtype=zij_trunk.dtype) - zij = torch.cat([zij_trunk, relpos_zij], dim=-1) - zij = self.linear_z(self.layer_norm_z(zij)) + def _proj_zij(zij_trunk_in, relpos_in): + return self.linear_z( + self.layer_norm_z(torch.cat([zij_trunk_in, relpos_in], dim=-1)) + ) + + if chunk_size is not None: + zij = chunk_layer( + layer=_proj_zij, + inputs={"zij_trunk_in": zij_trunk, "relpos_in": relpos_zij}, + chunk_size=chunk_size, + no_batch_dims=zij_trunk.dim() - 2, + ) + else: + zij = _proj_zij(zij_trunk, relpos_zij) # Single conditioning si = torch.cat([si_trunk, si_input], dim=-1) @@ -246,7 +258,12 @@ def forward( zij_trunk = zij_trunk * 0 si, zij = self._embed_trunk_inputs( - batch=batch, t=t, si_input=si_input, si_trunk=si_trunk, zij_trunk=zij_trunk + batch=batch, + t=t, + si_input=si_input, + si_trunk=si_trunk, + zij_trunk=zij_trunk, + chunk_size=chunk_size, ) if chunk_size is not None: diff --git a/openfold3/core/utils/atomize_utils.py b/openfold3/core/utils/atomize_utils.py index 794369817..077923add 100644 --- a/openfold3/core/utils/atomize_utils.py +++ b/openfold3/core/utils/atomize_utils.py @@ -762,23 +762,23 @@ def get_token_frame_atoms( valid_frame_mask: [*, N_token] Mask denoting valid frames """ - # Create pairwise atom mask - pair_mask = atom_mask[..., None] * atom_mask[..., None, :] + # Pairwise atom mask, kept as bool throughout to avoid materializing + # large fp32 [N_atom, N_atom] intermediates. + am_bool = atom_mask.bool() + pair_mask = am_bool[..., None] & am_bool[..., None, :] - # Update pairwise atom mask # Restrict to atoms within the same chain atom_asym_id = broadcast_token_feat_to_atoms( token_mask=batch["token_mask"], num_atoms_per_token=batch["num_atoms_per_token"], token_feat=batch["asym_id"], ) - atom_asym_id_mask = atom_asym_id[..., None] == atom_asym_id[..., None, :] - pair_mask = pair_mask * atom_asym_id_mask + pair_mask &= atom_asym_id[..., None] == atom_asym_id[..., None, :] - # Compute distance matrix + # Compute distance matrix. Use cdist to avoid materializing N*N*3 intermediate # [*, N_atom, N_atom] - d = torch.sum(eps + (x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1) ** 0.5 - d = d * pair_mask + inf * (1 - pair_mask) + d = torch.cdist(x, x) + d.masked_fill_(~pair_mask, inf) # Find indices of two closest atoms for start atoms # [*, N_token] diff --git a/openfold3/tests/test_of3_model.py b/openfold3/tests/test_of3_model.py index 1dbf07a1e..d1be18f8a 100644 --- a/openfold3/tests/test_of3_model.py +++ b/openfold3/tests/test_of3_model.py @@ -206,6 +206,7 @@ def test_shape_large_eval(self, dtype): use_triton_triangle_kernels=is_rocm, ) + @pytest.mark.slow @compare_utils.skip_unless_triton_installed() @compare_utils.skip_unless_cuda_available() def test_shape_large_bf16_train(self):