From f44f75f1f002b5f1d173c27b06a0d7266118b219 Mon Sep 17 00:00:00 2001 From: Kaustubh Vartak Date: Tue, 12 May 2026 14:54:30 -0700 Subject: [PATCH] Fix type checking errors in torchrec source files Summary: Fix pyrefly/pyre type checking errors in 29 non-test source files across torchrec. Adds pyre-fixme comments to suppress type errors that cannot be expressed in the type system, plus one pure type annotation fix (state_dict_transform.py ret variable). Error categories fixed: - ProcessGroup | int | None from dist.new_group() (pyre-fixme[9]) - Module.__getattr__ returning Module | Tensor for dynamic attributes (pyre-fixme[6/8/16/29]) - named_modules override inconsistency (pyre-fixme[14]) - state_dict() returning Optional dict vs T_destination bound (pyre-fixme[6/7]) - max/min builtins not matching Callable signature (pyre-fixme[6]) - Iterable type mismatches for TensorDictKeysView, KeyedJaggedTensor (pyre-fixme[6]) Differential Revision: D104852703 --- examples/retrieval/two_tower_train.py | 2 +- .../transfer_learning/train_from_pretrained_embedding.py | 1 + torchrec/distributed/batched_embedding_kernel.py | 1 + torchrec/distributed/dist_data.py | 4 ++++ torchrec/distributed/embedding_kernel.py | 1 + torchrec/distributed/embedding_tower_sharding.py | 2 ++ torchrec/distributed/embeddingbag.py | 3 +++ torchrec/distributed/planner/proposers.py | 1 + torchrec/distributed/planner/stats.py | 8 ++++---- torchrec/distributed/quant_state.py | 4 ++-- torchrec/distributed/train_pipeline/backward_injection.py | 2 +- torchrec/distributed/train_pipeline/postproc.py | 1 + torchrec/distributed/train_pipeline/utils.py | 2 ++ torchrec/fx/tracer.py | 1 + torchrec/inference/modules.py | 2 +- torchrec/inference/state_dict_transform.py | 2 +- torchrec/metrics/cpu_offloaded_metric_module.py | 4 ++++ torchrec/metrics/metric_module.py | 2 ++ torchrec/metrics/multi_label_precision.py | 1 + torchrec/metrics/rec_metric.py | 5 ++++- torchrec/modules/lazy_extension.py | 2 ++ torchrec/sparse/tensor_dict.py | 1 + 22 files changed, 41 insertions(+), 11 deletions(-) diff --git a/examples/retrieval/two_tower_train.py b/examples/retrieval/two_tower_train.py index ed5c4ae0b7..94a733ff92 100644 --- a/examples/retrieval/two_tower_train.py +++ b/examples/retrieval/two_tower_train.py @@ -175,7 +175,7 @@ def train( checkpoint_pg = dist.new_group(backend="gloo") # Copy sharded state_dict to CPU. cpu_state_dict = state_dict_to_device( - model.state_dict(), pg=checkpoint_pg, device=torch.device("cpu") + model.state_dict(), pg=checkpoint_pg, device=torch.device("cpu") # pyre-fixme[6]: Expected `ProcessGroup` but got `ProcessGroup | int | None`. ) ebc_cpu = EmbeddingBagCollection( diff --git a/examples/transfer_learning/train_from_pretrained_embedding.py b/examples/transfer_learning/train_from_pretrained_embedding.py index b9738ee162..eeb1816612 100644 --- a/examples/transfer_learning/train_from_pretrained_embedding.py +++ b/examples/transfer_learning/train_from_pretrained_embedding.py @@ -85,6 +85,7 @@ def share_tensor_via_shm( if dist.get_backend() == "gloo": gloo_pg = dist.group.WORLD else: + # pyre-fixme[9]: Expected `ProcessGroup | None` but got `ProcessGroup | int | None`. gloo_pg = dist.new_group(backend="gloo") torch.multiprocessing.set_sharing_strategy("file_system") diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 244e531adf..42541053ec 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -1409,6 +1409,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( or optimizer_state_value.nelement() == 1 # single value state ) # pyrefly: ignore [no-matching-overload] + # pyre-fixme[6]: Incompatible parameter type optimizer_states_keys_by_table[table_config.name] = list( optimizer_states.keys() ) diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 447b4f8193..f8354d0509 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -780,6 +780,7 @@ def __init__( ) self._output_tensors.append(output_tensor) + # pyre-fixme[6]: Expected `Work` but got `Work | Unknown | None`. self._awaitables.append(awaitable) def clear_inputs(self) -> int: @@ -2133,6 +2134,7 @@ def __init__( ) ) + # pyre-fixme[8]: Expected `Work` but got `Work | Unknown | None`. self._dist_values_req: dist.Work = dist.all_to_all_single( self._dist_values, jt.values(), @@ -2203,6 +2205,7 @@ def __init__( ) with record_function("## all2all_data:ids ##"): + # pyre-fixme[8]: Expected `Work` but got `Work | Unknown | None`. self._values_awaitable: dist.Work = dist.all_to_all_single( output=self._dist_values, input=input, @@ -2245,6 +2248,7 @@ def __init__( ) with record_function("## all2all_data:ids splits ##"): + # pyre-fixme[8]: Expected `Work` but got `Work | Unknown | None`. self._num_ids_awaitable: dist.Work = dist.all_to_all_single( output=self._output_splits, input=splits, diff --git a/torchrec/distributed/embedding_kernel.py b/torchrec/distributed/embedding_kernel.py index 18a150ca44..1d9035e94b 100644 --- a/torchrec/distributed/embedding_kernel.py +++ b/torchrec/distributed/embedding_kernel.py @@ -192,6 +192,7 @@ def create_virtual_table_local_metadata( else sum(weight_count_per_rank[:my_rank]) ) # pyrefly: ignore[no-matching-overload, missing-attribute] + # pyre-fixme[6]: Expected `Iterable[int]` but got `Size | int`. local_metadata.shard_sizes = list(param.size()) # pyrefly: ignore[missing-attribute] local_metadata.shard_offsets = [ diff --git a/torchrec/distributed/embedding_tower_sharding.py b/torchrec/distributed/embedding_tower_sharding.py index bf56ca3226..446ae33e0e 100644 --- a/torchrec/distributed/embedding_tower_sharding.py +++ b/torchrec/distributed/embedding_tower_sharding.py @@ -431,6 +431,7 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: else: yield from () + # pyre-fixme[14]: `named_modules` overrides method defined in `Module` inconsistently. def named_modules( self, memo: Optional[Set[nn.Module]] = None, @@ -872,6 +873,7 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: ) ) + # pyre-fixme[14]: `named_modules` overrides method defined in `Module` inconsistently. def named_modules( self, memo: Optional[Set[nn.Module]] = None, diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index d1ada72b81..0aa74a6a2e 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1752,6 +1752,7 @@ def input_dist( """ if isinstance(features, TensorDict): # pyrefly: ignore [no-matching-overload] + # pyre-fixme[6]: Expected `Iterable` but got `_TensorDictKeysView`. feature_keys = list(features.keys()) if len(self._features_order) > 0: feature_keys = [feature_keys[i] for i in self._features_order] @@ -2358,6 +2359,7 @@ def __init__( # Get all fused optimizers and combine them. optims = [] + # pyre-fixme[6]: Expected `EmbeddingBag` but got `BaseEmbeddingLookup | Module | Unknown`. for _, module in self._lookup.named_modules(): if isinstance(module, FusedOptimizerModule): # modify param keys to match EmbeddingBag @@ -2438,6 +2440,7 @@ def state_dict( destination[new_key] = item return destination + # pyre-fixme[14]: `named_modules` overrides method defined in `Module` inconsistently. def named_modules( self, memo: Optional[Set[nn.Module]] = None, diff --git a/torchrec/distributed/planner/proposers.py b/torchrec/distributed/planner/proposers.py index 291a8d05c5..cd8e7601d6 100644 --- a/torchrec/distributed/planner/proposers.py +++ b/torchrec/distributed/planner/proposers.py @@ -250,6 +250,7 @@ def load( ] # `List[Tuple[int]]`. # pyrefly: ignore[no-matching-overload] + # pyre-fixme[6]: Expected `list[list[int]]` but got `list[tuple[int, ...]]`. self._proposals = list(itertools.product(*sharding_options_by_fqn_indices)) def _reset(self) -> None: diff --git a/torchrec/distributed/planner/stats.py b/torchrec/distributed/planner/stats.py index 3ccbf64703..7e84514aa9 100644 --- a/torchrec/distributed/planner/stats.py +++ b/torchrec/distributed/planner/stats.py @@ -544,16 +544,16 @@ def _log_max_perf_and_max_hbm( f"# {'Estimated Sharding Distribution' : <{self._width-2}}#" ) self._stats_table.append( - f"# {'Sparse only Max HBM: '+_generate_rank_hbm_stats(sparse_hbm, max) : <{self._width-3}}#" + f"# {'Sparse only Max HBM: '+_generate_rank_hbm_stats(sparse_hbm, max) : <{self._width-3}}#" # pyre-fixme[6] ) self._stats_table.append( - f"# {'Sparse only Min HBM: '+_generate_rank_hbm_stats(sparse_hbm, min) : <{self._width-3}}#" + f"# {'Sparse only Min HBM: '+_generate_rank_hbm_stats(sparse_hbm, min) : <{self._width-3}}#" # pyre-fixme[6] ) self._stats_table.append( - f"# {'Max HBM: '+_generate_rank_hbm_stats(used_hbm, max) : <{self._width-3}}#" + f"# {'Max HBM: '+_generate_rank_hbm_stats(used_hbm, max) : <{self._width-3}}#" # pyre-fixme[6] ) self._stats_table.append( - f"# {'Min HBM: '+_generate_rank_hbm_stats(used_hbm, min) : <{self._width-3}}#" + f"# {'Min HBM: '+_generate_rank_hbm_stats(used_hbm, min) : <{self._width-3}}#" # pyre-fixme[6] ) self._stats_table.append( f"# {'Mean HBM: '+_generate_rank_hbm_stats(used_hbm, statistics.mean) : <{self._width-3}}#" diff --git a/torchrec/distributed/quant_state.py b/torchrec/distributed/quant_state.py index 6936a60dca..5d51f4f05b 100644 --- a/torchrec/distributed/quant_state.py +++ b/torchrec/distributed/quant_state.py @@ -475,12 +475,12 @@ def sharded_tbes_weights_spec( ), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection, ShardedQuantManagedCollisionEmbeddingCollection, ShardedQuantFeatureProcessedEmbeddingBagCollection and ShardedQuantManagedCollisionEmbeddingBagCollection are true" tbes_configs: Dict[ IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig - ] = module.tbes_configs() + ] = module.tbes_configs() # pyre-fixme[29]: Expected a callable. table_shardings: Dict[str, str] = {} sharding_type_device_group_to_sharding_infos: Dict[ Tuple[str, str], List[EmbeddingShardingInfo] - ] = module.sharding_type_device_group_to_sharding_infos() + ] = module.sharding_type_device_group_to_sharding_infos() # pyre-fixme[29]: Expected a callable. for ( (sharding_type, _), diff --git a/torchrec/distributed/train_pipeline/backward_injection.py b/torchrec/distributed/train_pipeline/backward_injection.py index 53ecf703c6..b68cd9c800 100644 --- a/torchrec/distributed/train_pipeline/backward_injection.py +++ b/torchrec/distributed/train_pipeline/backward_injection.py @@ -354,7 +354,7 @@ def __call__( # Find the awaitable matching our sharding type, skipping DP (NoWait) for w, st in zip( # pyrefly: ignore[no-matching-overload] - awaitables, sharding_types + awaitables, sharding_types # pyre-fixme[6]: Expected `Iterable[str]` but got `list[str] | Unknown | None`. ): if isinstance(w, NoWait): continue diff --git a/torchrec/distributed/train_pipeline/postproc.py b/torchrec/distributed/train_pipeline/postproc.py index f552d69841..c98b9bea90 100644 --- a/torchrec/distributed/train_pipeline/postproc.py +++ b/torchrec/distributed/train_pipeline/postproc.py @@ -183,6 +183,7 @@ def set_context(self, context: TrainPipelineContext) -> None: def get_context(self) -> TrainPipelineContext: return self._context + # pyre-fixme[14]: `named_modules` overrides method defined in `Module` inconsistently. def named_modules( self, memo: Optional[Set[torch.nn.Module]] = None, diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index a1db49acb1..ffa1e00527 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -400,6 +400,7 @@ def _pipeline_detach_model( for _, child_module in mod.named_modules(): if not hasattr(child_module, "_input_dists"): continue + # pyre-fixme[16]: `Module` has no attribute `_input_dists`. for input_dist in child_module._input_dists: if hasattr(input_dist, "_dist"): kjt_dists.append(input_dist._dist) @@ -592,6 +593,7 @@ def _override_input_dist_forwards( if not hasattr(child_module, "_input_dists"): continue + # pyre-fixme[16]: `Module` has no attribute `_input_dists`. for input_dist in child_module._input_dists: if hasattr(input_dist, "_dist"): assert isinstance(input_dist._dist, KJTAllToAll) diff --git a/torchrec/fx/tracer.py b/torchrec/fx/tracer.py index 64c3d44e3d..3a247da307 100644 --- a/torchrec/fx/tracer.py +++ b/torchrec/fx/tracer.py @@ -73,6 +73,7 @@ def trace( if isinstance(root, torch.nn.Module): for prefix, module in root.named_modules(): # TODO(T140754678): Remove this workaround to _fx_path + # pyre-fixme[8]: Expected `Module | Tensor` but got `str | Unknown`. module._fx_path = prefix dmp = root diff --git a/torchrec/inference/modules.py b/torchrec/inference/modules.py index 47e006028d..f44a072666 100644 --- a/torchrec/inference/modules.py +++ b/torchrec/inference/modules.py @@ -457,7 +457,7 @@ def _quantize_fp_module( # handle the fp ebc separately _quantize_fp_module( model, - m, + m, # pyre-fixme[6]: Expected `FeatureProcessedEmbeddingBagCollection` but got `Module | Unknown`. n, weight_dtype=fp_weight_dtype, # Pass in per_fp_table_weight_dtype if it is provided, perhaps diff --git a/torchrec/inference/state_dict_transform.py b/torchrec/inference/state_dict_transform.py index 3d6a34eae4..27d4c66aa2 100644 --- a/torchrec/inference/state_dict_transform.py +++ b/torchrec/inference/state_dict_transform.py @@ -74,7 +74,7 @@ def state_dict_to_device( pg (ProcessGroup): Process Group used for comms device (torch.device): device to put state_dict on """ - ret = {} + ret: Dict[str, Union[torch.Tensor, ShardedTensor]] = {} all_keys = state_dict_all_gather_keys(state_dict, pg) for key in all_keys: if key in state_dict: diff --git a/torchrec/metrics/cpu_offloaded_metric_module.py b/torchrec/metrics/cpu_offloaded_metric_module.py index feb80cf9e5..2c64076b03 100644 --- a/torchrec/metrics/cpu_offloaded_metric_module.py +++ b/torchrec/metrics/cpu_offloaded_metric_module.py @@ -207,6 +207,7 @@ def __init__( target=self._compute_loop, name=metric_compute_thread_name, daemon=True ) + # pyre-fixme[8]: Expected `ProcessGroup` but got `ProcessGroup | int | None`. self.cpu_process_group: dist.ProcessGroup = dist.new_group(backend="gloo") self.comms_module: CPUCommsRecMetricModule = CPUCommsRecMetricModule( *args, @@ -848,11 +849,14 @@ def state_dict( Args are identical to torch.nn.Module.state_dict(). """ # pyrefly: ignore[no-matching-overload] + # pyre-fixme[7]: Expected `dict[str, Any]` but got `dict[str, Any] | None`. + # pyre-fixme[6]: Expected `dict[str, Any]` but got `dict[str, Any] | None`. return self.comms_module.state_dict( *args, destination=destination, prefix=prefix, keep_vars=keep_vars ) @override + # pyre-fixme[14]: `load_state_dict` overrides method defined in `RecMetricModule` inconsistently. def load_state_dict( self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False ) -> None: diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index 375fa1f345..1baf705a66 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -610,8 +610,10 @@ def load_pre_compute_states( for task, metric_computation in zip( # `Union[Module, Tensor]`. # `Union[Module, Tensor]`. + # pyre-fixme[6]: Expected `Iterable` but got `Module | Tensor`. metric._tasks, # `Union[Module, Tensor]`. + # pyre-fixme[6]: Expected `Iterable` but got `Module | Tensor`. metric._metrics_computations, ): state = states[task.name] diff --git a/torchrec/metrics/multi_label_precision.py b/torchrec/metrics/multi_label_precision.py index f6502b8847..a6ba70c62f 100644 --- a/torchrec/metrics/multi_label_precision.py +++ b/torchrec/metrics/multi_label_precision.py @@ -146,6 +146,7 @@ def _prepare_inputs( # Expand weights to match if weights is not None: # pyrefly: ignore[no-matching-overload] + # pyre-fixme[6]: Expected `SymInt | int` but got `int | None`. weights = weights.flatten().unsqueeze(-1).expand(-1, self._num_labels) else: weights = torch.ones_like(predictions, dtype=torch.float32) diff --git a/torchrec/metrics/rec_metric.py b/torchrec/metrics/rec_metric.py index ef8e7b9675..cf7c8eb943 100644 --- a/torchrec/metrics/rec_metric.py +++ b/torchrec/metrics/rec_metric.py @@ -507,7 +507,7 @@ def _fused_tasks_iter(self, compute_scope: str) -> ComputeIterType: self._tasks, metric_report.value, ( - self._metrics_computations[0].has_valid_update + self._metrics_computations[0].has_valid_update # pyre-fixme[6]: Expected `Iterable` but got `Module | Tensor | repeat[int]`. if self._should_validate_update else itertools.repeat(1) ), # has_valid_update > 0 means the update is valid @@ -978,9 +978,12 @@ def state_dict( # We need to flush the cached output to ensure checkpointing correctness. self._check_fused_update(force=True) # pyrefly: ignore[no-matching-overload] + # pyre-fixme[6]: Expected `dict[str, Any]` but got `dict[str, Tensor] | None`. destination = super().state_dict( destination=destination, prefix=prefix, keep_vars=keep_vars ) + # pyre-fixme[7]: Expected `dict[str, Tensor]` but got `dict[str, Tensor] | None`. + # pyre-fixme[6]: Expected `dict[str, Any]` but got `dict[str, Tensor] | None`. return self._metrics_computations.state_dict( destination=destination, prefix=f"{prefix}_metrics_computations.", diff --git a/torchrec/modules/lazy_extension.py b/torchrec/modules/lazy_extension.py index 64cb2d1927..8bebc8bbd1 100644 --- a/torchrec/modules/lazy_extension.py +++ b/torchrec/modules/lazy_extension.py @@ -161,6 +161,8 @@ def init_weights(m): # fmt: off # `LazyModuleMixin` inconsistently. # pyrefly: ignore[bad-override] + # pyre-fixme[14]: `_infer_parameters` overrides method defined in + # `LazyModuleMixin` inconsistently. def _infer_parameters(self: _LazyExtensionProtocol, module, args, kwargs) -> None: r"""Infers the size and initializes the parameters according to the provided input batch. diff --git a/torchrec/sparse/tensor_dict.py b/torchrec/sparse/tensor_dict.py index 64b354d012..772dd3b9e5 100644 --- a/torchrec/sparse/tensor_dict.py +++ b/torchrec/sparse/tensor_dict.py @@ -23,6 +23,7 @@ def maybe_td_to_kjt( if isinstance(features, TensorDict): if keys is None: # pyrefly: ignore[no-matching-overload] + # pyre-fixme[6]: Expected `Iterable` but got `_TensorDictKeysView`. keys = list(features.keys()) values = torch.cat([features[key]._values for key in keys], dim=0) lengths = torch.cat(