From 51f2a010ca279010433fa52d7e6ce63857e7073c Mon Sep 17 00:00:00 2001 From: Oliver Aobo Yang Date: Fri, 2 Jan 2026 15:23:42 -0800 Subject: [PATCH 1/6] Add util to viz mask segmentations in ImageMaskInput (#1752) Summary: add util to visualize how the image is segmented into features Differential Revision: D89953376 --- captum/attr/_utils/interpretable_input.py | 55 ++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index 6882df6c04..c6dfc826bd 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -1,6 +1,16 @@ # pyre-strict from abc import ABC, abstractmethod -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) import numpy as np import PIL.Image @@ -9,6 +19,9 @@ from captum._utils.typing import TokenizerLike from torch import Tensor +if TYPE_CHECKING: + from matplotlib.pyplot import Axes, Figure + def _scatter_itp_attr_by_mask( itp_attr: Tensor, @@ -711,3 +724,43 @@ def format_pixel_attr(self, itp_attr: Tensor) -> Tensor: formatted_mask.unsqueeze(0), ) return formatted_attr + + def plot_mask_overlay( + self, show: bool = False + ) -> Union[None, Tuple["Figure", "Axes"]]: + """ + util to help visualize the mask segementation + """ + + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + + ax.imshow(self.image) + fig.set_facecolor("white") + + # random colors for all interpretable features + opacity = 0.6 + colors = [ + np.array([*np.random.random(3), opacity]) + for _ in range(self.n_itp_features) + ] + + # Create mask_list from mask if not provided + mask_list = self.mask_list or [ + self.mask == mask_id for mask_id in self.mask_id_to_idx.keys() + ] + + for itp_idx, mask in enumerate(mask_list): + mask = mask.numpy() + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * colors[itp_idx].reshape(1, 1, -1) + ax.imshow(mask_image) + + ax.axis("off") + + if show: + plt.show() + return None + else: + return fig, ax From 4d3f564d8966cdc2f32c724ee66a6c5c036e8020 Mon Sep 17 00:00:00 2001 From: Oliver Aobo Yang Date: Fri, 2 Jan 2026 15:23:42 -0800 Subject: [PATCH 2/6] Add border to mask overlay in MaskImageInput Summary: as title Differential Revision: D89954220 --- captum/attr/_utils/interpretable_input.py | 27 ++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index c6dfc826bd..38ef5a3543 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -726,14 +726,23 @@ def format_pixel_attr(self, itp_attr: Tensor) -> Tensor: return formatted_attr def plot_mask_overlay( - self, show: bool = False + self, show: bool = False, border_width: int = 0 ) -> Union[None, Tuple["Figure", "Axes"]]: """ util to help visualize the mask segementation + + Args: + show: If True, display the plot immediately using plt.show(). + If False, return the figure and axes objects. + border_width: Width of the border around each segment in pixels. + Set to 0 to disable borders. Default is 0. """ import matplotlib.pyplot as plt + # scipy is not within install_requires of captum + from scipy.ndimage import binary_erosion + fig, ax = plt.subplots() ax.imshow(self.image) @@ -751,12 +760,20 @@ def plot_mask_overlay( self.mask == mask_id for mask_id in self.mask_id_to_idx.keys() ] - for itp_idx, mask in enumerate(mask_list): - mask = mask.numpy() - h, w = mask.shape[-2:] - mask_image = mask.reshape(h, w, 1) * colors[itp_idx].reshape(1, 1, -1) + for mask, color in zip(mask_list, colors): + mask = mask.numpy().astype(bool) + h, w = mask.shape + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) + # Add border inside each segment using erosion + if border_width > 0: + eroded = binary_erosion(mask, iterations=border_width) + border = mask & ~eroded + border_color = np.array([*color[:3], 0.8]) + border_image = border.reshape(h, w, 1) * border_color.reshape(1, 1, -1) + ax.imshow(border_image) + ax.axis("off") if show: From 77542e9763f1a436f59fa44fb5078bd94cc9c651 Mon Sep 17 00:00:00 2001 From: Oliver Aobo Yang Date: Fri, 2 Jan 2026 15:23:42 -0800 Subject: [PATCH 3/6] Add legends to mask overlay in MaskImageInput (#1753) Summary: as title Reviewed By: styusuf Differential Revision: D89954477 --- captum/attr/_utils/interpretable_input.py | 35 ++++++++++++++++++----- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index 38ef5a3543..5461a1e9d1 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -726,7 +726,7 @@ def format_pixel_attr(self, itp_attr: Tensor) -> Tensor: return formatted_attr def plot_mask_overlay( - self, show: bool = False, border_width: int = 0 + self, show: bool = False, border_width: int = 0, show_legends: bool = True ) -> Union[None, Tuple["Figure", "Axes"]]: """ util to help visualize the mask segementation @@ -736,6 +736,8 @@ def plot_mask_overlay( If False, return the figure and axes objects. border_width: Width of the border around each segment in pixels. Set to 0 to disable borders. Default is 0. + show_legends: If True, display the mask id for each segment at its + centroid. Default is True. """ import matplotlib.pyplot as plt @@ -760,20 +762,39 @@ def plot_mask_overlay( self.mask == mask_id for mask_id in self.mask_id_to_idx.keys() ] - for mask, color in zip(mask_list, colors): - mask = mask.numpy().astype(bool) - h, w = mask.shape - mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + for mask, color, mid in zip(mask_list, colors, self.mask_id_to_idx.keys()): + mask_np = mask.numpy().astype(bool) + h, w = mask_np.shape + mask_image = mask_np.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) # Add border inside each segment using erosion if border_width > 0: - eroded = binary_erosion(mask, iterations=border_width) - border = mask & ~eroded + eroded = binary_erosion(mask_np, iterations=border_width) + border = mask_np & ~eroded border_color = np.array([*color[:3], 0.8]) border_image = border.reshape(h, w, 1) * border_color.reshape(1, 1, -1) ax.imshow(border_image) + # Calculate centroid and display mask id + if show_legends and mask_np.any(): + rows, cols = np.where(mask_np) + centroid_y, centroid_x = rows.mean(), cols.mean() + ax.text( + centroid_x, + centroid_y, + str(mid), + color="white", + fontsize=10, + ha="center", + va="center", + bbox={ + "boxstyle": "round,pad=0.2", + "facecolor": "black", + "alpha": 0.6, + }, + ) + ax.axis("off") if show: From 2fc1ffddd25ceb50c1576f970d572e81c32a630e Mon Sep 17 00:00:00 2001 From: Oliver Aobo Yang Date: Fri, 2 Jan 2026 15:23:42 -0800 Subject: [PATCH 4/6] Add get_mask_list() helper method to ImageMaskInput (#1755) Summary: Extract the mask list generation logic into a reusable method, centralizing the logic for creating binary masks from either mask_list or mask tensor Reviewed By: sarahtranfb Differential Revision: D90034703 --- captum/attr/_utils/interpretable_input.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index 5461a1e9d1..d7a057f197 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -679,6 +679,19 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any: return self.processor_fn(perturbed_image) + def get_mask_list(self) -> List[Tensor]: + """ + Get the list of binary masks for each interpretable feature. + If mask_list is provided, return it directly. Otherwise, create + a list of binary masks from the mask tensor. + + Returns: + List[Tensor]: list of binary masks, one for each interpretable feature + """ + return self.mask_list or [ + self.mask == mask_id for mask_id in self.mask_id_to_idx.keys() + ] + def format_attr(self, itp_attr: Tensor) -> Tensor: """ Attribution for interpretable image segments @@ -757,10 +770,7 @@ def plot_mask_overlay( for _ in range(self.n_itp_features) ] - # Create mask_list from mask if not provided - mask_list = self.mask_list or [ - self.mask == mask_id for mask_id in self.mask_id_to_idx.keys() - ] + mask_list = self.get_mask_list() for mask, color, mid in zip(mask_list, colors, self.mask_id_to_idx.keys()): mask_np = mask.numpy().astype(bool) From 2a775c8cd010c3c854bb29222f195dea261b19a4 Mon Sep 17 00:00:00 2001 From: Oliver Aobo Yang Date: Fri, 2 Jan 2026 15:23:42 -0800 Subject: [PATCH 5/6] Refactor mask visualization utils from ImageMaskInput to visualization module (#1756) Summary: Extract border and legend drawing logic from `plot_mask_overlay` into reusable utility functions `draw_mask_border` and `draw_mask_legend` in the visualization Differential Revision: D90036627 --- captum/attr/_utils/interpretable_input.py | 28 +----- captum/attr/_utils/visualization.py | 107 ++++++++++++++++++++++ 2 files changed, 111 insertions(+), 24 deletions(-) diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index d7a057f197..d1280b0d3d 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -17,6 +17,7 @@ import torch from captum._utils.typing import TokenizerLike +from captum.attr._utils.visualization import draw_mask_border, draw_mask_legend from torch import Tensor if TYPE_CHECKING: @@ -755,9 +756,6 @@ def plot_mask_overlay( import matplotlib.pyplot as plt - # scipy is not within install_requires of captum - from scipy.ndimage import binary_erosion - fig, ax = plt.subplots() ax.imshow(self.image) @@ -780,30 +778,12 @@ def plot_mask_overlay( # Add border inside each segment using erosion if border_width > 0: - eroded = binary_erosion(mask_np, iterations=border_width) - border = mask_np & ~eroded border_color = np.array([*color[:3], 0.8]) - border_image = border.reshape(h, w, 1) * border_color.reshape(1, 1, -1) - ax.imshow(border_image) + draw_mask_border(ax, mask_np, border_width, border_color) # Calculate centroid and display mask id - if show_legends and mask_np.any(): - rows, cols = np.where(mask_np) - centroid_y, centroid_x = rows.mean(), cols.mean() - ax.text( - centroid_x, - centroid_y, - str(mid), - color="white", - fontsize=10, - ha="center", - va="center", - bbox={ - "boxstyle": "round,pad=0.2", - "facecolor": "black", - "alpha": 0.6, - }, - ) + if show_legends: + draw_mask_legend(ax, mask_np, label=str(mid)) ax.axis("off") diff --git a/captum/attr/_utils/visualization.py b/captum/attr/_utils/visualization.py index 5f8e89bdb3..ad4d24cf2f 100644 --- a/captum/attr/_utils/visualization.py +++ b/captum/attr/_utils/visualization.py @@ -38,6 +38,113 @@ HAS_IPYTHON = False +def draw_mask_border( + ax: Axes, + mask: npt.NDArray[np.bool_], + border_width: int = 1, + border_color: Union[str, npt.NDArray[np.floating[Any]]] = "black", +) -> None: + """ + Draw a border inside a mask region using binary erosion. + + This function generates a border by eroding the mask and taking the difference + between the original mask and the eroded version, then displays it on the axes. + + Args: + ax: Matplotlib axes object to draw on. + mask: 2D boolean numpy array representing the mask region. + Shape should be (height, width). + border_width: Width of the border in pixels. + Default: 1 + border_color: Color for the border. Can be a string color name (e.g., + "black", "red") or an RGBA array of shape (4,) with values + typically in [0, 1]. + Default: "black" + + Example:: + >>> mask = np.array([[True, True, True], + ... [True, True, True], + ... [True, True, True]]) + >>> fig, ax = plt.subplots() + >>> draw_mask_border(ax, mask) # Uses default black border + >>> draw_mask_border(ax, mask, border_width=2, border_color="red") + """ + if not mask.any(): + return + + from scipy.ndimage import binary_erosion + + # Convert string color to RGBA array + if isinstance(border_color, str): + rgba = colors.to_rgba(border_color) + border_color_array = np.array(rgba) + else: + border_color_array = border_color + + eroded = binary_erosion(mask, iterations=border_width) + border = mask & ~eroded + h, w = mask.shape + border_image = border.reshape(h, w, 1) * border_color_array.reshape(1, 1, -1) + ax.imshow(border_image) + + +def draw_mask_legend( + ax: Axes, + mask: npt.NDArray[np.bool_], + label: str, + fontsize: int = 10, + text_color: str = "white", + bbox_facecolor: str = "black", + bbox_alpha: float = 0.6, +) -> None: + """ + Draw a label at the centroid of a mask region. + + This function calculates the centroid (center of mass) of a boolean mask + and places a text label at that position. + + Args: + ax: Matplotlib axes object to draw on. + mask: 2D boolean numpy array representing the mask region. + Shape should be (height, width). + label: Text string to display at the centroid. + fontsize: Font size for the label text. + Default: 10 + text_color: Color of the label text. + Default: "white" + bbox_facecolor: Background color of the text bounding box. + Default: "black" + bbox_alpha: Transparency of the text bounding box. + Default: 0.6 + + Example:: + >>> mask = np.array([[False, True, True], + ... [False, True, True], + ... [False, False, False]]) + >>> fig, ax = plt.subplots() + >>> draw_mask_legend(ax, mask, label="1") + """ + if not mask.any(): + return + + rows, cols = np.where(mask) + centroid_y, centroid_x = rows.mean(), cols.mean() + ax.text( + centroid_x, + centroid_y, + label, + color=text_color, + fontsize=fontsize, + ha="center", + va="center", + bbox={ + "boxstyle": "round,pad=0.2", + "facecolor": bbox_facecolor, + "alpha": bbox_alpha, + }, + ) + + class ImageVisualizationMethod(Enum): heat_map = 1 blended_heat_map = 2 From 4c16deaf22e07d1c1f3194cf07332a3a40c13bb5 Mon Sep 17 00:00:00 2001 From: Oliver Aobo Yang Date: Fri, 2 Jan 2026 15:23:42 -0800 Subject: [PATCH 6/6] Add mask-based visualization support to LLMAttributionResult.plot_image_heatmap (#1757) Summary: Enable `plot_image_heatmap` to leverage the new `draw_mask_border` and `draw_mask_legend` utilities when the input has mask segments. This provides better visualization for image attribution with discrete regions, displaying attribution colors per segment with borders and labels derived from the salience values. Differential Revision: D90038941 --- captum/attr/_core/llm_attr.py | 55 ++++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 33a4bde39c..0208940898 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -43,6 +43,7 @@ TextTemplateInput, TextTokenInput, ) +from captum.attr._utils.visualization import draw_mask_border, draw_mask_legend if TYPE_CHECKING: from matplotlib.pyplot import Axes, Figure @@ -326,6 +327,8 @@ def plot_image_heatmap( self, show: bool = False, target_token_pos: Union[int, tuple[int, int], None] = None, + border_width: int = 2, + show_legends: bool = True, ) -> Union[None, Tuple["Figure", "Axes"]]: """ Plot the image in the input with the overlay of salience based on @@ -340,6 +343,12 @@ def plot_image_heatmap( the token at the given index. If tuple[int, int], like (m, n), use the summed token attribution of tokens from m to n (noninclusive) Default: None + border_width (int): Width of the border around each mask segment in pixels. + Set to 0 to disable borders. Only used when input has mask_list. + Default: 2 + show_legends (bool): If True, display the mask id for each segment at its + centroid. Only used when input has mask_list. + Default: True Returns: @@ -348,10 +357,9 @@ def plot_image_heatmap( customization. """ - if not isinstance(self.inp, ImageMaskInput): - raise ValueError("plot_image_heatmap is only available for ImageMaskInput") - inp = self.inp + if not isinstance(inp, ImageMaskInput): + raise ValueError("plot_image_heatmap is only available for ImageMaskInput") import matplotlib.pyplot as plt @@ -370,27 +378,54 @@ def plot_image_heatmap( from_pos, to_pos = target_token_pos attr = self.token_attr[from_pos:to_pos].sum(dim=0) + fig, ax = plt.subplots() + ax.imshow(inp.image) + + # Get pixel-level attribution using format_pixel_attr pixel_attr = inp.format_pixel_attr(attr.unsqueeze(0)) pixel_attr = pixel_attr.squeeze(0).cpu().numpy() max_abs_attr_val = np.abs(pixel_attr).max() - - fig, ax = plt.subplots() - - ax.imshow(inp.image) - + alpha = 0.8 heatmap = ax.imshow( pixel_attr, vmax=max_abs_attr_val, vmin=-max_abs_attr_val, cmap=self._get_plot_color_map(), - alpha=0.7, + alpha=alpha, ) - fig.set_facecolor("white") cbar = fig.colorbar(heatmap, ax=ax) cbar.ax.set_ylabel("Attribution", rotation=-90, va="bottom") + # Draw borders and legends on top if mask_list is available + if hasattr(inp, "get_mask_list"): + mask_list = [m.numpy().astype(bool) for m in inp.get_mask_list()] + mask_ids = list(inp.mask_id_to_idx.keys()) + cmap = self._get_plot_color_map() + + # Get attribution values for border color calculation + attr_np = attr.cpu().numpy() + + for i, mask in enumerate(mask_list): + if not mask.any(): + continue + + if border_width > 0: + # Calculate border color as a darker version of the salience color + norm_val = ( + (attr_np[i] / max_abs_attr_val + 1) / 2 + if max_abs_attr_val > 0 + else 0.5 + ) + rgba = np.array(cmap(norm_val)) + # Create darker version by multiplying RGB by 0.6 + border_color = np.array([*(rgba[:3] * 0.7), alpha]) + draw_mask_border(ax, mask, border_width, border_color=border_color) + if show_legends: + draw_mask_legend(ax, mask, label=str(mask_ids[i])) + + fig.set_facecolor("white") ax.axis("off") if show: