Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fbgemm_gpu/test/sparse/pack_segments_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def test_pack_segments_meta_backend(
return_presence_mask=True,
)

# pyre-fixme[6]: In call `tuple.__new__`, for 1st positional argument, expected `Iterable[int]` but got `Iterable[bool | float | int]`.
assert presence_mask.size() == torch.Size([lengths.numel(), max_length])

@unittest.skipIf(*gpu_unavailable)
Expand Down
12 changes: 12 additions & 0 deletions fbgemm_gpu/test/sparse/permute_indices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,14 @@ def test_permute_indices(
lengths = torch.cat(length_splits, dim=1)
else:
lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype)
# pyre-fixme[6]: For 1st param expected `list[int] | Size |
# tuple[int, ...]` but got `bool | float | int`.
weights = torch.rand(lengths.sum().item()).float() if has_weight else None
indices = torch.randint(
low=1,
high=int(1e5),
# pyre-fixme[6]: Expected `int | tuple[int, ...]` for 3rd
# param but got `tuple[float | int]`.
size=(lengths.sum().item(),),
).type(index_dtype)
if is_1D:
Expand Down Expand Up @@ -189,6 +193,8 @@ def test_permute_indices_non_contiguous(
indices = torch.randint(
low=1,
high=int(1e5),
# pyre-fixme[6]: Expected `int | tuple[int, ...]` for 3rd
# param but got `tuple[float | int]`.
size=(lengths.sum().item(),),
).type(index_dtype)

Expand Down Expand Up @@ -247,6 +253,8 @@ def test_permute_indices_scripted_with_none_weights(
indices = torch.randint(
low=1,
high=int(1e5),
# pyre-fixme[6]: Expected `int | tuple[int, ...]` for 3rd
# param but got `tuple[float | int]`.
size=(lengths.sum().item(),),
).type(index_dtype)
permute_list = list(range(1))
Expand Down Expand Up @@ -284,10 +292,14 @@ def test_permute_indices_with_repeats(
) -> None:
index_dtype = torch.int64 if long_index else torch.int32
lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype)
# pyre-fixme[6]: For 1st param expected `list[int] | Size |
# tuple[int, ...]` but got `bool | float | int`.
weights = torch.rand(lengths.sum().item()).float() if has_weight else None
indices = torch.randint(
low=1,
high=int(1e5),
# pyre-fixme[6]: Expected `int | tuple[int, ...]` for 3rd
# param but got `tuple[float | int]`.
size=(lengths.sum().item(),),
).type(index_dtype)
permute_list = list(range(T))
Expand Down
Loading