@@ -526,3 +526,162 @@ def test_format_pixel_attr_with_non_continuous_mask(self) -> None:
526526 self .assertTrue (torch .all (result [0 , :, :5 ] == segment_0_value ))
527527 self .assertTrue (torch .all (result [0 , :, 5 :10 ] == segment_10_value ))
528528 self .assertTrue (torch .all (result [0 , :, 10 :] == segment_20_value ))
529+
530+ # Tests for mask_list functionality
531+
532+ def test_init_mask_list_ignores_mask (self ) -> None :
533+ # Setup: provide both mask and mask_list
534+ image = self ._create_test_image (width = 10 , height = 10 )
535+ # mask has 3 segments
536+ mask = torch .zeros ((10 , 10 ), dtype = torch .int32 )
537+ mask [:, 3 :7 ] = 1
538+ mask [:, 7 :] = 2
539+ # mask_list has 2 masks
540+ mask1 = torch .zeros ((10 , 10 ), dtype = torch .bool )
541+ mask1 [:, :5 ] = True
542+ mask2 = torch .zeros ((10 , 10 ), dtype = torch .bool )
543+ mask2 [:, 5 :] = True
544+
545+ # Execute: create ImageMaskInput with both mask and mask_list
546+ mm_input = ImageMaskInput (
547+ processor_fn = self ._simple_processor ,
548+ image = image ,
549+ mask = mask ,
550+ mask_list = [mask1 , mask2 ],
551+ )
552+
553+ # Assert: mask_list takes precedence, so n_itp_features should be 2
554+ self .assertEqual (mm_input .n_itp_features , 2 )
555+ self .assertEqual (len (mm_input .mask_list ), 2 )
556+
557+ def test_to_tensor_with_mask_list (self ) -> None :
558+ # Setup: create ImageMaskInput with 3 masks
559+ image = self ._create_test_image (width = 15 , height = 10 )
560+ mask1 = torch .zeros ((10 , 15 ), dtype = torch .bool )
561+ mask1 [:, :5 ] = True
562+ mask2 = torch .zeros ((10 , 15 ), dtype = torch .bool )
563+ mask2 [:, 5 :10 ] = True
564+ mask3 = torch .zeros ((10 , 15 ), dtype = torch .bool )
565+ mask3 [:, 10 :] = True
566+
567+ mm_input = ImageMaskInput (
568+ processor_fn = self ._simple_processor ,
569+ image = image ,
570+ mask_list = [mask1 , mask2 , mask3 ],
571+ )
572+
573+ # Execute: convert to tensor
574+ result = mm_input .to_tensor ()
575+
576+ # Assert: verify tensor has correct number of features
577+ expected = torch .tensor ([[1.0 , 1.0 , 1.0 ]])
578+ assertTensorAlmostEqual (self , result , expected )
579+
580+ def test_to_model_input_with_mask_list (self ) -> None :
581+ # Setup: create image with 2 halves (left red, right green)
582+ img_array = np .zeros ((10 , 10 , 3 ), dtype = np .uint8 )
583+ img_array [:, :5 ] = [255 , 0 , 0 ] # Left half red
584+ img_array [:, 5 :] = [0 , 255 , 0 ] # Right half green
585+ image = PIL .Image .fromarray (img_array )
586+
587+ mask1 = torch .zeros ((10 , 10 ), dtype = torch .bool )
588+ mask1 [:, :5 ] = True # Left half
589+ mask2 = torch .zeros ((10 , 10 ), dtype = torch .bool )
590+ mask2 [:, 5 :] = True # Right half
591+
592+ mm_input = ImageMaskInput (
593+ processor_fn = self ._simple_processor ,
594+ image = image ,
595+ mask_list = [mask1 , mask2 ],
596+ baseline = (255 , 255 , 255 ),
597+ )
598+
599+ # Execute: keep left half (0), remove right half (1)
600+ perturbed_tensor = torch .tensor ([[1.0 , 0.0 ]])
601+ result = mm_input .to_model_input (perturbed_tensor )
602+
603+ # Assert: left half should be red, right half should be white
604+ img_array = result ["pixel_values" ].numpy ().astype (np .uint8 )
605+ # Left half should be red
606+ self .assertTrue (np .all (img_array [:, :5 , 0 ] == 255 ))
607+ self .assertTrue (np .all (img_array [:, :5 , 1 ] == 0 ))
608+ # Right half should be white (baseline)
609+ self .assertTrue (np .all (img_array [:, 5 :] == 255 ))
610+
611+ def test_to_model_input_with_mask_list_overlapping (self ) -> None :
612+ # Setup: create red image with overlapping masks
613+ image = self ._create_test_image (color = (255 , 0 , 0 ))
614+ mask1 = torch .zeros ((10 , 10 ), dtype = torch .bool )
615+ mask1 [:, :7 ] = True # Left 7 columns
616+ mask2 = torch .zeros ((10 , 10 ), dtype = torch .bool )
617+ mask2 [:, 3 :] = True # Right 7 columns (overlap at columns 3-6)
618+
619+ mm_input = ImageMaskInput (
620+ processor_fn = self ._simple_processor ,
621+ image = image ,
622+ mask_list = [mask1 , mask2 ],
623+ baseline = (255 , 255 , 255 ),
624+ )
625+
626+ # Execute: remove first feature (mask1), keep second (mask2)
627+ perturbed_tensor = torch .tensor ([[0.0 , 1.0 ]])
628+ result = mm_input .to_model_input (perturbed_tensor )
629+
630+ # Assert: left 7 columns (covered by mask1) should be white
631+ # even though columns 3-6 are also in mask2 (but mask1 sets them to baseline)
632+ img_array = result ["pixel_values" ].numpy ().astype (np .uint8 )
633+ self .assertTrue (np .all (img_array [:, :7 ] == 255 ))
634+
635+ def test_format_pixel_attr_with_mask_list (self ) -> None :
636+ # Setup: create ImageMaskInput with 2 non-overlapping masks
637+ image = self ._create_test_image (width = 10 , height = 5 )
638+ mask1 = torch .zeros ((5 , 10 ), dtype = torch .bool )
639+ mask1 [:, :5 ] = True # Left half
640+ mask2 = torch .zeros ((5 , 10 ), dtype = torch .bool )
641+ mask2 [:, 5 :] = True # Right half
642+
643+ mm_input = ImageMaskInput (
644+ processor_fn = self ._simple_processor ,
645+ image = image ,
646+ mask_list = [mask1 , mask2 ],
647+ )
648+
649+ # Execute: format attribution
650+ attr = torch .tensor ([[0.3 , 0.7 ]])
651+ result = mm_input .format_pixel_attr (attr )
652+
653+ # Assert: left half should have 0.3, right half should have 0.7
654+ self .assertEqual (result .shape , (1 , 5 , 10 ))
655+ assertTensorAlmostEqual (
656+ self , result [0 , :, :5 ], torch .full ((5 , 5 ), 0.3 )
657+ ) # Left half
658+ assertTensorAlmostEqual (
659+ self , result [0 , :, 5 :], torch .full ((5 , 5 ), 0.7 )
660+ ) # Right half
661+
662+ def test_format_pixel_attr_with_mask_list_overlapping (self ) -> None :
663+ # Setup: create ImageMaskInput with overlapping masks
664+ image = self ._create_test_image (width = 10 , height = 5 )
665+ mask1 = torch .zeros ((5 , 10 ), dtype = torch .bool )
666+ mask1 [:, :7 ] = True # Left 7 columns
667+ mask2 = torch .zeros ((5 , 10 ), dtype = torch .bool )
668+ mask2 [:, 3 :] = True # Right 7 columns (overlap at columns 3-6)
669+
670+ mm_input = ImageMaskInput (
671+ processor_fn = self ._simple_processor ,
672+ image = image ,
673+ mask_list = [mask1 , mask2 ],
674+ )
675+
676+ # Execute: format attribution with values 0.3 for mask1, 0.5 for mask2
677+ attr = torch .tensor ([[0.3 , 0.5 ]])
678+ result = mm_input .format_pixel_attr (attr )
679+
680+ # Assert: overlapping region should have summed attribution
681+ self .assertEqual (result .shape , (1 , 5 , 10 ))
682+ # Columns 0-2: only mask1 (0.3)
683+ assertTensorAlmostEqual (self , result [0 , :, :3 ], torch .full ((5 , 3 ), 0.3 ))
684+ # Columns 3-6: both masks (0.3 + 0.5 = 0.8)
685+ assertTensorAlmostEqual (self , result [0 , :, 3 :7 ], torch .full ((5 , 4 ), 0.8 ))
686+ # Columns 7-9: only mask2 (0.5)
687+ assertTensorAlmostEqual (self , result [0 , :, 7 :], torch .full ((5 , 3 ), 0.5 ))
0 commit comments