Skip to content

Commit e57ab6b

Browse files
aobo-yfacebook-github-bot
authored andcommitted
Add tests for mask_list in ImageMaskInput (meta-pytorch#1750)
Summary: as title Differential Revision: D89952326
1 parent 4650188 commit e57ab6b

1 file changed

Lines changed: 159 additions & 0 deletions

File tree

tests/attr/test_interpretable_input.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,3 +526,162 @@ def test_format_pixel_attr_with_non_continuous_mask(self) -> None:
526526
self.assertTrue(torch.all(result[0, :, :5] == segment_0_value))
527527
self.assertTrue(torch.all(result[0, :, 5:10] == segment_10_value))
528528
self.assertTrue(torch.all(result[0, :, 10:] == segment_20_value))
529+
530+
# Tests for mask_list functionality
531+
532+
def test_init_mask_list_ignores_mask(self) -> None:
533+
# Setup: provide both mask and mask_list
534+
image = self._create_test_image(width=10, height=10)
535+
# mask has 3 segments
536+
mask = torch.zeros((10, 10), dtype=torch.int32)
537+
mask[:, 3:7] = 1
538+
mask[:, 7:] = 2
539+
# mask_list has 2 masks
540+
mask1 = torch.zeros((10, 10), dtype=torch.bool)
541+
mask1[:, :5] = True
542+
mask2 = torch.zeros((10, 10), dtype=torch.bool)
543+
mask2[:, 5:] = True
544+
545+
# Execute: create ImageMaskInput with both mask and mask_list
546+
mm_input = ImageMaskInput(
547+
processor_fn=self._simple_processor,
548+
image=image,
549+
mask=mask,
550+
mask_list=[mask1, mask2],
551+
)
552+
553+
# Assert: mask_list takes precedence, so n_itp_features should be 2
554+
self.assertEqual(mm_input.n_itp_features, 2)
555+
self.assertEqual(len(mm_input.mask_list), 2)
556+
557+
def test_to_tensor_with_mask_list(self) -> None:
558+
# Setup: create ImageMaskInput with 3 masks
559+
image = self._create_test_image(width=15, height=10)
560+
mask1 = torch.zeros((10, 15), dtype=torch.bool)
561+
mask1[:, :5] = True
562+
mask2 = torch.zeros((10, 15), dtype=torch.bool)
563+
mask2[:, 5:10] = True
564+
mask3 = torch.zeros((10, 15), dtype=torch.bool)
565+
mask3[:, 10:] = True
566+
567+
mm_input = ImageMaskInput(
568+
processor_fn=self._simple_processor,
569+
image=image,
570+
mask_list=[mask1, mask2, mask3],
571+
)
572+
573+
# Execute: convert to tensor
574+
result = mm_input.to_tensor()
575+
576+
# Assert: verify tensor has correct number of features
577+
expected = torch.tensor([[1.0, 1.0, 1.0]])
578+
assertTensorAlmostEqual(self, result, expected)
579+
580+
def test_to_model_input_with_mask_list(self) -> None:
581+
# Setup: create image with 2 halves (left red, right green)
582+
img_array = np.zeros((10, 10, 3), dtype=np.uint8)
583+
img_array[:, :5] = [255, 0, 0] # Left half red
584+
img_array[:, 5:] = [0, 255, 0] # Right half green
585+
image = PIL.Image.fromarray(img_array)
586+
587+
mask1 = torch.zeros((10, 10), dtype=torch.bool)
588+
mask1[:, :5] = True # Left half
589+
mask2 = torch.zeros((10, 10), dtype=torch.bool)
590+
mask2[:, 5:] = True # Right half
591+
592+
mm_input = ImageMaskInput(
593+
processor_fn=self._simple_processor,
594+
image=image,
595+
mask_list=[mask1, mask2],
596+
baseline=(255, 255, 255),
597+
)
598+
599+
# Execute: keep left half (0), remove right half (1)
600+
perturbed_tensor = torch.tensor([[1.0, 0.0]])
601+
result = mm_input.to_model_input(perturbed_tensor)
602+
603+
# Assert: left half should be red, right half should be white
604+
img_array = result["pixel_values"].numpy().astype(np.uint8)
605+
# Left half should be red
606+
self.assertTrue(np.all(img_array[:, :5, 0] == 255))
607+
self.assertTrue(np.all(img_array[:, :5, 1] == 0))
608+
# Right half should be white (baseline)
609+
self.assertTrue(np.all(img_array[:, 5:] == 255))
610+
611+
def test_to_model_input_with_mask_list_overlapping(self) -> None:
612+
# Setup: create red image with overlapping masks
613+
image = self._create_test_image(color=(255, 0, 0))
614+
mask1 = torch.zeros((10, 10), dtype=torch.bool)
615+
mask1[:, :7] = True # Left 7 columns
616+
mask2 = torch.zeros((10, 10), dtype=torch.bool)
617+
mask2[:, 3:] = True # Right 7 columns (overlap at columns 3-6)
618+
619+
mm_input = ImageMaskInput(
620+
processor_fn=self._simple_processor,
621+
image=image,
622+
mask_list=[mask1, mask2],
623+
baseline=(255, 255, 255),
624+
)
625+
626+
# Execute: remove first feature (mask1), keep second (mask2)
627+
perturbed_tensor = torch.tensor([[0.0, 1.0]])
628+
result = mm_input.to_model_input(perturbed_tensor)
629+
630+
# Assert: left 7 columns (covered by mask1) should be white
631+
# even though columns 3-6 are also in mask2 (but mask1 sets them to baseline)
632+
img_array = result["pixel_values"].numpy().astype(np.uint8)
633+
self.assertTrue(np.all(img_array[:, :7] == 255))
634+
635+
def test_format_pixel_attr_with_mask_list(self) -> None:
636+
# Setup: create ImageMaskInput with 2 non-overlapping masks
637+
image = self._create_test_image(width=10, height=5)
638+
mask1 = torch.zeros((5, 10), dtype=torch.bool)
639+
mask1[:, :5] = True # Left half
640+
mask2 = torch.zeros((5, 10), dtype=torch.bool)
641+
mask2[:, 5:] = True # Right half
642+
643+
mm_input = ImageMaskInput(
644+
processor_fn=self._simple_processor,
645+
image=image,
646+
mask_list=[mask1, mask2],
647+
)
648+
649+
# Execute: format attribution
650+
attr = torch.tensor([[0.3, 0.7]])
651+
result = mm_input.format_pixel_attr(attr)
652+
653+
# Assert: left half should have 0.3, right half should have 0.7
654+
self.assertEqual(result.shape, (1, 5, 10))
655+
assertTensorAlmostEqual(
656+
self, result[0, :, :5], torch.full((5, 5), 0.3)
657+
) # Left half
658+
assertTensorAlmostEqual(
659+
self, result[0, :, 5:], torch.full((5, 5), 0.7)
660+
) # Right half
661+
662+
def test_format_pixel_attr_with_mask_list_overlapping(self) -> None:
663+
# Setup: create ImageMaskInput with overlapping masks
664+
image = self._create_test_image(width=10, height=5)
665+
mask1 = torch.zeros((5, 10), dtype=torch.bool)
666+
mask1[:, :7] = True # Left 7 columns
667+
mask2 = torch.zeros((5, 10), dtype=torch.bool)
668+
mask2[:, 3:] = True # Right 7 columns (overlap at columns 3-6)
669+
670+
mm_input = ImageMaskInput(
671+
processor_fn=self._simple_processor,
672+
image=image,
673+
mask_list=[mask1, mask2],
674+
)
675+
676+
# Execute: format attribution with values 0.3 for mask1, 0.5 for mask2
677+
attr = torch.tensor([[0.3, 0.5]])
678+
result = mm_input.format_pixel_attr(attr)
679+
680+
# Assert: overlapping region should have summed attribution
681+
self.assertEqual(result.shape, (1, 5, 10))
682+
# Columns 0-2: only mask1 (0.3)
683+
assertTensorAlmostEqual(self, result[0, :, :3], torch.full((5, 3), 0.3))
684+
# Columns 3-6: both masks (0.3 + 0.5 = 0.8)
685+
assertTensorAlmostEqual(self, result[0, :, 3:7], torch.full((5, 4), 0.8))
686+
# Columns 7-9: only mask2 (0.5)
687+
assertTensorAlmostEqual(self, result[0, :, 7:], torch.full((5, 3), 0.5))

0 commit comments

Comments
 (0)