Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/retrieval/two_tower_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def train(
checkpoint_pg = dist.new_group(backend="gloo")
# Copy sharded state_dict to CPU.
cpu_state_dict = state_dict_to_device(
model.state_dict(), pg=checkpoint_pg, device=torch.device("cpu")
model.state_dict(), pg=checkpoint_pg, device=torch.device("cpu") # pyre-fixme[6]: Expected `ProcessGroup` but got `ProcessGroup | int | None`.
)

ebc_cpu = EmbeddingBagCollection(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def share_tensor_via_shm(
if dist.get_backend() == "gloo":
gloo_pg = dist.group.WORLD
else:
# pyre-fixme[9]: Expected `ProcessGroup | None` but got `ProcessGroup | int | None`.
gloo_pg = dist.new_group(backend="gloo")

torch.multiprocessing.set_sharing_strategy("file_system")
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,6 +1409,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
or optimizer_state_value.nelement() == 1 # single value state
)
# pyrefly: ignore [no-matching-overload]
# pyre-fixme[6]: Incompatible parameter type
optimizer_states_keys_by_table[table_config.name] = list(
optimizer_states.keys()
)
Expand Down
4 changes: 4 additions & 0 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,7 @@ def __init__(
)

self._output_tensors.append(output_tensor)
# pyre-fixme[6]: Expected `Work` but got `Work | Unknown | None`.
self._awaitables.append(awaitable)

def clear_inputs(self) -> int:
Expand Down Expand Up @@ -2133,6 +2134,7 @@ def __init__(
)
)

# pyre-fixme[8]: Expected `Work` but got `Work | Unknown | None`.
self._dist_values_req: dist.Work = dist.all_to_all_single(
self._dist_values,
jt.values(),
Expand Down Expand Up @@ -2203,6 +2205,7 @@ def __init__(
)

with record_function("## all2all_data:ids ##"):
# pyre-fixme[8]: Expected `Work` but got `Work | Unknown | None`.
self._values_awaitable: dist.Work = dist.all_to_all_single(
output=self._dist_values,
input=input,
Expand Down Expand Up @@ -2245,6 +2248,7 @@ def __init__(
)

with record_function("## all2all_data:ids splits ##"):
# pyre-fixme[8]: Expected `Work` but got `Work | Unknown | None`.
self._num_ids_awaitable: dist.Work = dist.all_to_all_single(
output=self._output_splits,
input=splits,
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def create_virtual_table_local_metadata(
else sum(weight_count_per_rank[:my_rank])
)
# pyrefly: ignore[no-matching-overload, missing-attribute]
# pyre-fixme[6]: Expected `Iterable[int]` but got `Size | int`.
local_metadata.shard_sizes = list(param.size())
# pyrefly: ignore[missing-attribute]
local_metadata.shard_offsets = [
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/embedding_tower_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
else:
yield from ()

# pyre-fixme[14]: `named_modules` overrides method defined in `Module` inconsistently.
def named_modules(
self,
memo: Optional[Set[nn.Module]] = None,
Expand Down Expand Up @@ -872,6 +873,7 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
)
)

# pyre-fixme[14]: `named_modules` overrides method defined in `Module` inconsistently.
def named_modules(
self,
memo: Optional[Set[nn.Module]] = None,
Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,6 +1752,7 @@ def input_dist(
"""
if isinstance(features, TensorDict):
# pyrefly: ignore [no-matching-overload]
# pyre-fixme[6]: Expected `Iterable` but got `_TensorDictKeysView`.
feature_keys = list(features.keys())
if len(self._features_order) > 0:
feature_keys = [feature_keys[i] for i in self._features_order]
Expand Down Expand Up @@ -2358,6 +2359,7 @@ def __init__(

# Get all fused optimizers and combine them.
optims = []
# pyre-fixme[6]: Expected `EmbeddingBag` but got `BaseEmbeddingLookup | Module | Unknown`.
for _, module in self._lookup.named_modules():
if isinstance(module, FusedOptimizerModule):
# modify param keys to match EmbeddingBag
Expand Down Expand Up @@ -2438,6 +2440,7 @@ def state_dict(
destination[new_key] = item
return destination

# pyre-fixme[14]: `named_modules` overrides method defined in `Module` inconsistently.
def named_modules(
self,
memo: Optional[Set[nn.Module]] = None,
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/planner/proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def load(
]
# `List[Tuple[int]]`.
# pyrefly: ignore[no-matching-overload]
# pyre-fixme[6]: Expected `list[list[int]]` but got `list[tuple[int, ...]]`.
self._proposals = list(itertools.product(*sharding_options_by_fqn_indices))

def _reset(self) -> None:
Expand Down
8 changes: 4 additions & 4 deletions torchrec/distributed/planner/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,16 +544,16 @@ def _log_max_perf_and_max_hbm(
f"# {'Estimated Sharding Distribution' : <{self._width-2}}#"
)
self._stats_table.append(
f"# {'Sparse only Max HBM: '+_generate_rank_hbm_stats(sparse_hbm, max) : <{self._width-3}}#"
f"# {'Sparse only Max HBM: '+_generate_rank_hbm_stats(sparse_hbm, max) : <{self._width-3}}#" # pyre-fixme[6]
)
self._stats_table.append(
f"# {'Sparse only Min HBM: '+_generate_rank_hbm_stats(sparse_hbm, min) : <{self._width-3}}#"
f"# {'Sparse only Min HBM: '+_generate_rank_hbm_stats(sparse_hbm, min) : <{self._width-3}}#" # pyre-fixme[6]
)
self._stats_table.append(
f"# {'Max HBM: '+_generate_rank_hbm_stats(used_hbm, max) : <{self._width-3}}#"
f"# {'Max HBM: '+_generate_rank_hbm_stats(used_hbm, max) : <{self._width-3}}#" # pyre-fixme[6]
)
self._stats_table.append(
f"# {'Min HBM: '+_generate_rank_hbm_stats(used_hbm, min) : <{self._width-3}}#"
f"# {'Min HBM: '+_generate_rank_hbm_stats(used_hbm, min) : <{self._width-3}}#" # pyre-fixme[6]
)
self._stats_table.append(
f"# {'Mean HBM: '+_generate_rank_hbm_stats(used_hbm, statistics.mean) : <{self._width-3}}#"
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/quant_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,12 +475,12 @@ def sharded_tbes_weights_spec(
), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection, ShardedQuantManagedCollisionEmbeddingCollection, ShardedQuantFeatureProcessedEmbeddingBagCollection and ShardedQuantManagedCollisionEmbeddingBagCollection are true"
tbes_configs: Dict[
IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig
] = module.tbes_configs()
] = module.tbes_configs() # pyre-fixme[29]: Expected a callable.
table_shardings: Dict[str, str] = {}

sharding_type_device_group_to_sharding_infos: Dict[
Tuple[str, str], List[EmbeddingShardingInfo]
] = module.sharding_type_device_group_to_sharding_infos()
] = module.sharding_type_device_group_to_sharding_infos() # pyre-fixme[29]: Expected a callable.

for (
(sharding_type, _),
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/train_pipeline/backward_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def __call__(

# Find the awaitable matching our sharding type, skipping DP (NoWait)
for w, st in zip( # pyrefly: ignore[no-matching-overload]
awaitables, sharding_types
awaitables, sharding_types # pyre-fixme[6]: Expected `Iterable[str]` but got `list[str] | Unknown | None`.
):
if isinstance(w, NoWait):
continue
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/train_pipeline/postproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def set_context(self, context: TrainPipelineContext) -> None:
def get_context(self) -> TrainPipelineContext:
return self._context

# pyre-fixme[14]: `named_modules` overrides method defined in `Module` inconsistently.
def named_modules(
self,
memo: Optional[Set[torch.nn.Module]] = None,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def _pipeline_detach_model(
for _, child_module in mod.named_modules():
if not hasattr(child_module, "_input_dists"):
continue
# pyre-fixme[16]: `Module` has no attribute `_input_dists`.
for input_dist in child_module._input_dists:
if hasattr(input_dist, "_dist"):
kjt_dists.append(input_dist._dist)
Expand Down Expand Up @@ -592,6 +593,7 @@ def _override_input_dist_forwards(
if not hasattr(child_module, "_input_dists"):
continue

# pyre-fixme[16]: `Module` has no attribute `_input_dists`.
for input_dist in child_module._input_dists:
if hasattr(input_dist, "_dist"):
assert isinstance(input_dist._dist, KJTAllToAll)
Expand Down
1 change: 1 addition & 0 deletions torchrec/fx/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def trace(
if isinstance(root, torch.nn.Module):
for prefix, module in root.named_modules():
# TODO(T140754678): Remove this workaround to _fx_path
# pyre-fixme[8]: Expected `Module | Tensor` but got `str | Unknown`.
module._fx_path = prefix

dmp = root
Expand Down
2 changes: 1 addition & 1 deletion torchrec/inference/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def _quantize_fp_module(
# handle the fp ebc separately
_quantize_fp_module(
model,
m,
m, # pyre-fixme[6]: Expected `FeatureProcessedEmbeddingBagCollection` but got `Module | Unknown`.
n,
weight_dtype=fp_weight_dtype,
# Pass in per_fp_table_weight_dtype if it is provided, perhaps
Expand Down
2 changes: 1 addition & 1 deletion torchrec/inference/state_dict_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def state_dict_to_device(
pg (ProcessGroup): Process Group used for comms
device (torch.device): device to put state_dict on
"""
ret = {}
ret: Dict[str, Union[torch.Tensor, ShardedTensor]] = {}
all_keys = state_dict_all_gather_keys(state_dict, pg)
for key in all_keys:
if key in state_dict:
Expand Down
4 changes: 4 additions & 0 deletions torchrec/metrics/cpu_offloaded_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __init__(
target=self._compute_loop, name=metric_compute_thread_name, daemon=True
)

# pyre-fixme[8]: Expected `ProcessGroup` but got `ProcessGroup | int | None`.
self.cpu_process_group: dist.ProcessGroup = dist.new_group(backend="gloo")
self.comms_module: CPUCommsRecMetricModule = CPUCommsRecMetricModule(
*args,
Expand Down Expand Up @@ -848,11 +849,14 @@ def state_dict(
Args are identical to torch.nn.Module.state_dict().
"""
# pyrefly: ignore[no-matching-overload]
# pyre-fixme[7]: Expected `dict[str, Any]` but got `dict[str, Any] | None`.
# pyre-fixme[6]: Expected `dict[str, Any]` but got `dict[str, Any] | None`.
return self.comms_module.state_dict(
*args, destination=destination, prefix=prefix, keep_vars=keep_vars
)

@override
# pyre-fixme[14]: `load_state_dict` overrides method defined in `RecMetricModule` inconsistently.
def load_state_dict(
self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
) -> None:
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,8 +610,10 @@ def load_pre_compute_states(
for task, metric_computation in zip(
# `Union[Module, Tensor]`.
# `Union[Module, Tensor]`.
# pyre-fixme[6]: Expected `Iterable` but got `Module | Tensor`.
metric._tasks,
# `Union[Module, Tensor]`.
# pyre-fixme[6]: Expected `Iterable` but got `Module | Tensor`.
metric._metrics_computations,
):
state = states[task.name]
Expand Down
1 change: 1 addition & 0 deletions torchrec/metrics/multi_label_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def _prepare_inputs(
# Expand weights to match
if weights is not None:
# pyrefly: ignore[no-matching-overload]
# pyre-fixme[6]: Expected `SymInt | int` but got `int | None`.
weights = weights.flatten().unsqueeze(-1).expand(-1, self._num_labels)
else:
weights = torch.ones_like(predictions, dtype=torch.float32)
Expand Down
5 changes: 4 additions & 1 deletion torchrec/metrics/rec_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def _fused_tasks_iter(self, compute_scope: str) -> ComputeIterType:
self._tasks,
metric_report.value,
(
self._metrics_computations[0].has_valid_update
self._metrics_computations[0].has_valid_update # pyre-fixme[6]: Expected `Iterable` but got `Module | Tensor | repeat[int]`.
if self._should_validate_update
else itertools.repeat(1)
), # has_valid_update > 0 means the update is valid
Expand Down Expand Up @@ -978,9 +978,12 @@ def state_dict(
# We need to flush the cached output to ensure checkpointing correctness.
self._check_fused_update(force=True)
# pyrefly: ignore[no-matching-overload]
# pyre-fixme[6]: Expected `dict[str, Any]` but got `dict[str, Tensor] | None`.
destination = super().state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
# pyre-fixme[7]: Expected `dict[str, Tensor]` but got `dict[str, Tensor] | None`.
# pyre-fixme[6]: Expected `dict[str, Any]` but got `dict[str, Tensor] | None`.
return self._metrics_computations.state_dict(
destination=destination,
prefix=f"{prefix}_metrics_computations.",
Expand Down
2 changes: 2 additions & 0 deletions torchrec/modules/lazy_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def init_weights(m):
# fmt: off
# `LazyModuleMixin` inconsistently.
# pyrefly: ignore[bad-override]
# pyre-fixme[14]: `_infer_parameters` overrides method defined in
# `LazyModuleMixin` inconsistently.
def _infer_parameters(self: _LazyExtensionProtocol, module, args, kwargs) -> None:
r"""Infers the size and initializes the parameters according to the provided input batch.

Expand Down
1 change: 1 addition & 0 deletions torchrec/sparse/tensor_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def maybe_td_to_kjt(
if isinstance(features, TensorDict):
if keys is None:
# pyrefly: ignore[no-matching-overload]
# pyre-fixme[6]: Expected `Iterable` but got `_TensorDictKeysView`.
keys = list(features.keys())
values = torch.cat([features[key]._values for key in keys], dim=0)
lengths = torch.cat(
Expand Down
Loading