diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 169bbee0a9..33a4bde39c 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -206,25 +206,12 @@ def plot_token_attr( fig.set_size_inches( max(data.shape[1] * 1.3, 6.4), max(data.shape[0] / 2.5, 4.8) ) - colors = [ - "#93003a", - "#d0365b", - "#f57789", - "#ffbdc3", - "#ffffff", - "#a4d6e1", - "#73a3ca", - "#4772b3", - "#00429d", - ] im = ax.imshow( data, vmax=max_abs_attr_val, vmin=-max_abs_attr_val, - cmap=mcolors.LinearSegmentedColormap.from_list( - name="colors", colors=colors - ), + cmap=self._get_plot_color_map(), aspect="auto", ) fig.set_facecolor("white") @@ -392,25 +379,11 @@ def plot_image_heatmap( ax.imshow(inp.image) - colors = [ - "#93003a", - "#d0365b", - "#f57789", - "#ffbdc3", - "#ffffff", - "#a4d6e1", - "#73a3ca", - "#4772b3", - "#00429d", - ] - heatmap = ax.imshow( pixel_attr, vmax=max_abs_attr_val, vmin=-max_abs_attr_val, - cmap=mcolors.LinearSegmentedColormap.from_list( - name="colors", colors=colors - ), + cmap=self._get_plot_color_map(), alpha=0.7, ) @@ -426,6 +399,22 @@ def plot_image_heatmap( else: return fig, ax + def _get_plot_color_map(self) -> mcolors.LinearSegmentedColormap: + return mcolors.LinearSegmentedColormap.from_list( + name="colors", + colors=[ + "#93003a", + "#d0365b", + "#f57789", + "#ffbdc3", + "#ffffff", + "#a4d6e1", + "#73a3ca", + "#4772b3", + "#00429d", + ], + ) + def _clean_up_pretty_token(token: str) -> str: """Remove newlines and leading/trailing whitespace from token.""" diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index dd34b5499a..6882df6c04 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -517,6 +517,13 @@ class ImageMaskInput(InterpretableInput): and end with same attributions. When mask is None, the entire image is considered as one interpretable feature. Default: None + mask_list (List[Tensor], optional): a list of binary masks to group + the image pixels as interpretable segment features. Each mask indicates + one feature and must be in the same shape as the image size. Value True + means the pixel is included in the feature. Compared to mask, mask_list + allows multiple features to be defined on the same pixel. If mask_list + is not None, mask will be ignored. + Default: None baseline (Tuple[int, int, int], optional): the baseline RGB value for the “absent” image pixels. Default: (255, 255, 255) @@ -573,6 +580,7 @@ class ImageMaskInput(InterpretableInput): image: PIL.Image.Image mask: Tensor + mask_list: List[Tensor] baseline: Tuple[int, int, int] processor_fn: Callable[[PIL.Image.Image], Any] n_itp_features: int @@ -584,6 +592,7 @@ def __init__( self, image: PIL.Image.Image, mask: Optional[Tensor] = None, + mask_list: Optional[List[Tensor]] = None, baseline: Tuple[int, int, int] = (255, 255, 255), processor_fn: Callable[[PIL.Image.Image], Any] = lambda x: x, ) -> None: @@ -593,29 +602,46 @@ def __init__( self.image = image self.baseline = baseline - # Create a dummy mask if None is provided - if mask is None: - # Create a mask with all zeros (entire image as one segment) - image_shape = (image.size[1], image.size[0]) # (height, width) - mask = torch.zeros(image_shape, dtype=torch.int32) + image_shape = (image.size[1], image.size[0]) # (height, width) + + if mask_list is not None: + # Validate that all masks in mask_list have the correct shape + for i, m in enumerate(mask_list): + assert m.shape == image_shape, ( + f"mask_list[{i}] shape {m.shape} must match " + f"image shape {image_shape}" + ) + + mask_ids = list(range(len(mask_list))) + # Create a dummy mask for compatibility + mask = torch.empty(0) else: - # Validate that mask size matches image size - image_shape = (image.size[1], image.size[0]) # (height, width) - assert ( - mask.shape == image_shape - ), f"mask shape {mask.shape} must match image shape {image_shape}" + # Create a dummy mask if None is provided + if mask is None: + # Create a mask with all zeros (entire image as one segment) + mask = torch.zeros(image_shape, dtype=torch.int32) + else: + # Validate that mask size matches image size + assert ( + mask.shape == image_shape + ), f"mask shape {mask.shape} must match image shape {image_shape}" + + mask_ids = torch.unique(mask).cpu().tolist() + + # dummy mask_list + mask_list = [] self.mask = mask - mask_ids = torch.unique(mask) + self.mask_list = mask_list self.n_itp_features = len(mask_ids) self.mask_id_to_idx = {int(mid): i for i, mid in enumerate(mask_ids)} - self.original_model_inputs = processor_fn(image) - # temporarily for compatibility with AttributionResult # which use the values for plot legends self.values = [f"image_feature_{mid}" for mid in mask_ids] + self.original_model_inputs = processor_fn(image) + def to_tensor(self) -> Tensor: return torch.tensor([[1.0] * self.n_itp_features]) @@ -625,10 +651,16 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any: img_array = np.array(self.image) - for mask_id, itp_idx in self.mask_id_to_idx.items(): - if perturbed_tensor[0][itp_idx] == 0: - mask_positions = self.mask == mask_id - img_array[mask_positions] = self.baseline + if self.mask_list: + for itp_idx, feature_mask in enumerate(self.mask_list): + if perturbed_tensor[0][itp_idx] == 0: + mask_positions = feature_mask.bool() + img_array[mask_positions] = self.baseline + else: + for mask_id, itp_idx in self.mask_id_to_idx.items(): + if perturbed_tensor[0][itp_idx] == 0: + mask_positions = self.mask == mask_id + img_array[mask_positions] = self.baseline perturbed_image = PIL.Image.fromarray(img_array.astype("uint8")) @@ -643,20 +675,39 @@ def format_attr(self, itp_attr: Tensor) -> Tensor: def format_pixel_attr(self, itp_attr: Tensor) -> Tensor: """ - Attribution for image pixels + Attribution for image pixels. + When mask_list is used, attributions from overlapping features are summed. Returns: 3D tensor of shape (1, height, width) """ device = itp_attr.device - - # Map mask IDs to continuous indices - image_shape = self.mask.shape - formatted_mask = torch.zeros_like(self.mask, device=device) - for mask_id, itp_idx in self.mask_id_to_idx.items(): - formatted_mask[self.mask == mask_id] = itp_idx - - formatted_attr = _scatter_itp_attr_by_mask( - itp_attr, - (1, *image_shape), - formatted_mask.unsqueeze(0), - ) - return formatted_attr + image_shape = (self.image.size[1], self.image.size[0]) # (height, width) + + if self.mask_list: + # For mask_list, sum attributions from overlapping features + output_dims = itp_attr.shape[:-1] + attr_shape = (*output_dims, *image_shape) + formatted_attr = torch.zeros(attr_shape, device=device) + + for itp_idx, feature_mask in enumerate(self.mask_list): + # Get attribution value for this feature + # itp_attr shape: (*output_dims, n_itp_features) + feature_attr = itp_attr[..., itp_idx] + # Expand to match image shape and add where mask is True + mask_bool = feature_mask.bool().to(device) + # Expand feature_attr to broadcast with mask + expanded_attr = feature_attr.unsqueeze(-1).unsqueeze(-1) + formatted_attr = formatted_attr + expanded_attr * mask_bool + + return formatted_attr + else: + # Map mask IDs to continuous indices + formatted_mask = torch.zeros_like(self.mask, device=device) + for mask_id, itp_idx in self.mask_id_to_idx.items(): + formatted_mask[self.mask == mask_id] = itp_idx + + formatted_attr = _scatter_itp_attr_by_mask( + itp_attr, + (1, *image_shape), + formatted_mask.unsqueeze(0), + ) + return formatted_attr diff --git a/tests/attr/test_interpretable_input.py b/tests/attr/test_interpretable_input.py index 03b975faa9..781d590f2b 100644 --- a/tests/attr/test_interpretable_input.py +++ b/tests/attr/test_interpretable_input.py @@ -526,3 +526,162 @@ def test_format_pixel_attr_with_non_continuous_mask(self) -> None: self.assertTrue(torch.all(result[0, :, :5] == segment_0_value)) self.assertTrue(torch.all(result[0, :, 5:10] == segment_10_value)) self.assertTrue(torch.all(result[0, :, 10:] == segment_20_value)) + + # Tests for mask_list functionality + + def test_init_mask_list_ignores_mask(self) -> None: + # Setup: provide both mask and mask_list + image = self._create_test_image(width=10, height=10) + # mask has 3 segments + mask = torch.zeros((10, 10), dtype=torch.int32) + mask[:, 3:7] = 1 + mask[:, 7:] = 2 + # mask_list has 2 masks + mask1 = torch.zeros((10, 10), dtype=torch.bool) + mask1[:, :5] = True + mask2 = torch.zeros((10, 10), dtype=torch.bool) + mask2[:, 5:] = True + + # Execute: create ImageMaskInput with both mask and mask_list + mm_input = ImageMaskInput( + processor_fn=self._simple_processor, + image=image, + mask=mask, + mask_list=[mask1, mask2], + ) + + # Assert: mask_list takes precedence, so n_itp_features should be 2 + self.assertEqual(mm_input.n_itp_features, 2) + self.assertEqual(len(mm_input.mask_list), 2) + + def test_to_tensor_with_mask_list(self) -> None: + # Setup: create ImageMaskInput with 3 masks + image = self._create_test_image(width=15, height=10) + mask1 = torch.zeros((10, 15), dtype=torch.bool) + mask1[:, :5] = True + mask2 = torch.zeros((10, 15), dtype=torch.bool) + mask2[:, 5:10] = True + mask3 = torch.zeros((10, 15), dtype=torch.bool) + mask3[:, 10:] = True + + mm_input = ImageMaskInput( + processor_fn=self._simple_processor, + image=image, + mask_list=[mask1, mask2, mask3], + ) + + # Execute: convert to tensor + result = mm_input.to_tensor() + + # Assert: verify tensor has correct number of features + expected = torch.tensor([[1.0, 1.0, 1.0]]) + assertTensorAlmostEqual(self, result, expected) + + def test_to_model_input_with_mask_list(self) -> None: + # Setup: create image with 2 halves (left red, right green) + img_array = np.zeros((10, 10, 3), dtype=np.uint8) + img_array[:, :5] = [255, 0, 0] # Left half red + img_array[:, 5:] = [0, 255, 0] # Right half green + image = PIL.Image.fromarray(img_array) + + mask1 = torch.zeros((10, 10), dtype=torch.bool) + mask1[:, :5] = True # Left half + mask2 = torch.zeros((10, 10), dtype=torch.bool) + mask2[:, 5:] = True # Right half + + mm_input = ImageMaskInput( + processor_fn=self._simple_processor, + image=image, + mask_list=[mask1, mask2], + baseline=(255, 255, 255), + ) + + # Execute: keep left half (0), remove right half (1) + perturbed_tensor = torch.tensor([[1.0, 0.0]]) + result = mm_input.to_model_input(perturbed_tensor) + + # Assert: left half should be red, right half should be white + img_array = result["pixel_values"].numpy().astype(np.uint8) + # Left half should be red + self.assertTrue(np.all(img_array[:, :5, 0] == 255)) + self.assertTrue(np.all(img_array[:, :5, 1] == 0)) + # Right half should be white (baseline) + self.assertTrue(np.all(img_array[:, 5:] == 255)) + + def test_to_model_input_with_mask_list_overlapping(self) -> None: + # Setup: create red image with overlapping masks + image = self._create_test_image(color=(255, 0, 0)) + mask1 = torch.zeros((10, 10), dtype=torch.bool) + mask1[:, :7] = True # Left 7 columns + mask2 = torch.zeros((10, 10), dtype=torch.bool) + mask2[:, 3:] = True # Right 7 columns (overlap at columns 3-6) + + mm_input = ImageMaskInput( + processor_fn=self._simple_processor, + image=image, + mask_list=[mask1, mask2], + baseline=(255, 255, 255), + ) + + # Execute: remove first feature (mask1), keep second (mask2) + perturbed_tensor = torch.tensor([[0.0, 1.0]]) + result = mm_input.to_model_input(perturbed_tensor) + + # Assert: left 7 columns (covered by mask1) should be white + # even though columns 3-6 are also in mask2 (but mask1 sets them to baseline) + img_array = result["pixel_values"].numpy().astype(np.uint8) + self.assertTrue(np.all(img_array[:, :7] == 255)) + + def test_format_pixel_attr_with_mask_list(self) -> None: + # Setup: create ImageMaskInput with 2 non-overlapping masks + image = self._create_test_image(width=10, height=5) + mask1 = torch.zeros((5, 10), dtype=torch.bool) + mask1[:, :5] = True # Left half + mask2 = torch.zeros((5, 10), dtype=torch.bool) + mask2[:, 5:] = True # Right half + + mm_input = ImageMaskInput( + processor_fn=self._simple_processor, + image=image, + mask_list=[mask1, mask2], + ) + + # Execute: format attribution + attr = torch.tensor([[0.3, 0.7]]) + result = mm_input.format_pixel_attr(attr) + + # Assert: left half should have 0.3, right half should have 0.7 + self.assertEqual(result.shape, (1, 5, 10)) + assertTensorAlmostEqual( + self, result[0, :, :5], torch.full((5, 5), 0.3) + ) # Left half + assertTensorAlmostEqual( + self, result[0, :, 5:], torch.full((5, 5), 0.7) + ) # Right half + + def test_format_pixel_attr_with_mask_list_overlapping(self) -> None: + # Setup: create ImageMaskInput with overlapping masks + image = self._create_test_image(width=10, height=5) + mask1 = torch.zeros((5, 10), dtype=torch.bool) + mask1[:, :7] = True # Left 7 columns + mask2 = torch.zeros((5, 10), dtype=torch.bool) + mask2[:, 3:] = True # Right 7 columns (overlap at columns 3-6) + + mm_input = ImageMaskInput( + processor_fn=self._simple_processor, + image=image, + mask_list=[mask1, mask2], + ) + + # Execute: format attribution with values 0.3 for mask1, 0.5 for mask2 + attr = torch.tensor([[0.3, 0.5]]) + result = mm_input.format_pixel_attr(attr) + + # Assert: overlapping region should have summed attribution + self.assertEqual(result.shape, (1, 5, 10)) + # Columns 0-2: only mask1 (0.3) + assertTensorAlmostEqual(self, result[0, :, :3], torch.full((5, 3), 0.3)) + # Columns 3-6: both masks (0.3 + 0.5 = 0.8) + assertTensorAlmostEqual(self, result[0, :, 3:7], torch.full((5, 4), 0.8)) + # Columns 7-9: only mask2 (0.5) + assertTensorAlmostEqual(self, result[0, :, 7:], torch.full((5, 3), 0.5))