From da7e88052f2617a42e1560adf3379e490a60dbba Mon Sep 17 00:00:00 2001 From: Oliver Aobo Yang Date: Fri, 2 Jan 2026 12:32:09 -0800 Subject: [PATCH 1/4] Add util to viz mask segmentations in ImageMaskInput (#1752) Summary: add util to visualize how the image is segmented into features Differential Revision: D89953376 --- captum/attr/_utils/interpretable_input.py | 55 ++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index 6882df6c04..c6dfc826bd 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -1,6 +1,16 @@ # pyre-strict from abc import ABC, abstractmethod -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) import numpy as np import PIL.Image @@ -9,6 +19,9 @@ from captum._utils.typing import TokenizerLike from torch import Tensor +if TYPE_CHECKING: + from matplotlib.pyplot import Axes, Figure + def _scatter_itp_attr_by_mask( itp_attr: Tensor, @@ -711,3 +724,43 @@ def format_pixel_attr(self, itp_attr: Tensor) -> Tensor: formatted_mask.unsqueeze(0), ) return formatted_attr + + def plot_mask_overlay( + self, show: bool = False + ) -> Union[None, Tuple["Figure", "Axes"]]: + """ + util to help visualize the mask segementation + """ + + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + + ax.imshow(self.image) + fig.set_facecolor("white") + + # random colors for all interpretable features + opacity = 0.6 + colors = [ + np.array([*np.random.random(3), opacity]) + for _ in range(self.n_itp_features) + ] + + # Create mask_list from mask if not provided + mask_list = self.mask_list or [ + self.mask == mask_id for mask_id in self.mask_id_to_idx.keys() + ] + + for itp_idx, mask in enumerate(mask_list): + mask = mask.numpy() + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * colors[itp_idx].reshape(1, 1, -1) + ax.imshow(mask_image) + + ax.axis("off") + + if show: + plt.show() + return None + else: + return fig, ax From be31136aa6e1e0a2c607c935b4947e663164e01a Mon Sep 17 00:00:00 2001 From: Oliver Aobo Yang Date: Fri, 2 Jan 2026 12:32:09 -0800 Subject: [PATCH 2/4] Add border to mask overlay in MaskImageInput Summary: as title Differential Revision: D89954220 --- captum/attr/_utils/interpretable_input.py | 27 ++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index c6dfc826bd..38ef5a3543 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -726,14 +726,23 @@ def format_pixel_attr(self, itp_attr: Tensor) -> Tensor: return formatted_attr def plot_mask_overlay( - self, show: bool = False + self, show: bool = False, border_width: int = 0 ) -> Union[None, Tuple["Figure", "Axes"]]: """ util to help visualize the mask segementation + + Args: + show: If True, display the plot immediately using plt.show(). + If False, return the figure and axes objects. + border_width: Width of the border around each segment in pixels. + Set to 0 to disable borders. Default is 0. """ import matplotlib.pyplot as plt + # scipy is not within install_requires of captum + from scipy.ndimage import binary_erosion + fig, ax = plt.subplots() ax.imshow(self.image) @@ -751,12 +760,20 @@ def plot_mask_overlay( self.mask == mask_id for mask_id in self.mask_id_to_idx.keys() ] - for itp_idx, mask in enumerate(mask_list): - mask = mask.numpy() - h, w = mask.shape[-2:] - mask_image = mask.reshape(h, w, 1) * colors[itp_idx].reshape(1, 1, -1) + for mask, color in zip(mask_list, colors): + mask = mask.numpy().astype(bool) + h, w = mask.shape + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) + # Add border inside each segment using erosion + if border_width > 0: + eroded = binary_erosion(mask, iterations=border_width) + border = mask & ~eroded + border_color = np.array([*color[:3], 0.8]) + border_image = border.reshape(h, w, 1) * border_color.reshape(1, 1, -1) + ax.imshow(border_image) + ax.axis("off") if show: From 7cce150f572532935ff117c4b3d91bc95f174f18 Mon Sep 17 00:00:00 2001 From: Oliver Aobo Yang Date: Fri, 2 Jan 2026 12:32:09 -0800 Subject: [PATCH 3/4] Add legends to mask overlay in MaskImageInput (#1753) Summary: as title Reviewed By: styusuf Differential Revision: D89954477 --- captum/attr/_utils/interpretable_input.py | 35 ++++++++++++++++++----- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index 38ef5a3543..5461a1e9d1 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -726,7 +726,7 @@ def format_pixel_attr(self, itp_attr: Tensor) -> Tensor: return formatted_attr def plot_mask_overlay( - self, show: bool = False, border_width: int = 0 + self, show: bool = False, border_width: int = 0, show_legends: bool = True ) -> Union[None, Tuple["Figure", "Axes"]]: """ util to help visualize the mask segementation @@ -736,6 +736,8 @@ def plot_mask_overlay( If False, return the figure and axes objects. border_width: Width of the border around each segment in pixels. Set to 0 to disable borders. Default is 0. + show_legends: If True, display the mask id for each segment at its + centroid. Default is True. """ import matplotlib.pyplot as plt @@ -760,20 +762,39 @@ def plot_mask_overlay( self.mask == mask_id for mask_id in self.mask_id_to_idx.keys() ] - for mask, color in zip(mask_list, colors): - mask = mask.numpy().astype(bool) - h, w = mask.shape - mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + for mask, color, mid in zip(mask_list, colors, self.mask_id_to_idx.keys()): + mask_np = mask.numpy().astype(bool) + h, w = mask_np.shape + mask_image = mask_np.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) # Add border inside each segment using erosion if border_width > 0: - eroded = binary_erosion(mask, iterations=border_width) - border = mask & ~eroded + eroded = binary_erosion(mask_np, iterations=border_width) + border = mask_np & ~eroded border_color = np.array([*color[:3], 0.8]) border_image = border.reshape(h, w, 1) * border_color.reshape(1, 1, -1) ax.imshow(border_image) + # Calculate centroid and display mask id + if show_legends and mask_np.any(): + rows, cols = np.where(mask_np) + centroid_y, centroid_x = rows.mean(), cols.mean() + ax.text( + centroid_x, + centroid_y, + str(mid), + color="white", + fontsize=10, + ha="center", + va="center", + bbox={ + "boxstyle": "round,pad=0.2", + "facecolor": "black", + "alpha": 0.6, + }, + ) + ax.axis("off") if show: From 83055e66bd025f467e667286ec6c19c5c86f5b8a Mon Sep 17 00:00:00 2001 From: Oliver Aobo Yang Date: Fri, 2 Jan 2026 12:32:09 -0800 Subject: [PATCH 4/4] Add get_mask_list() helper method to ImageMaskInput Summary: Extract the mask list generation logic into a reusable method, centralizing the logic for creating binary masks from either mask_list or mask tensor Differential Revision: D90034703 --- captum/attr/_utils/interpretable_input.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index 5461a1e9d1..d7a057f197 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -679,6 +679,19 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any: return self.processor_fn(perturbed_image) + def get_mask_list(self) -> List[Tensor]: + """ + Get the list of binary masks for each interpretable feature. + If mask_list is provided, return it directly. Otherwise, create + a list of binary masks from the mask tensor. + + Returns: + List[Tensor]: list of binary masks, one for each interpretable feature + """ + return self.mask_list or [ + self.mask == mask_id for mask_id in self.mask_id_to_idx.keys() + ] + def format_attr(self, itp_attr: Tensor) -> Tensor: """ Attribution for interpretable image segments @@ -757,10 +770,7 @@ def plot_mask_overlay( for _ in range(self.n_itp_features) ] - # Create mask_list from mask if not provided - mask_list = self.mask_list or [ - self.mask == mask_id for mask_id in self.mask_id_to_idx.keys() - ] + mask_list = self.get_mask_list() for mask, color, mid in zip(mask_list, colors, self.mask_id_to_idx.keys()): mask_np = mask.numpy().astype(bool)