diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 33a4bde39c..0208940898 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -43,6 +43,7 @@ TextTemplateInput, TextTokenInput, ) +from captum.attr._utils.visualization import draw_mask_border, draw_mask_legend if TYPE_CHECKING: from matplotlib.pyplot import Axes, Figure @@ -326,6 +327,8 @@ def plot_image_heatmap( self, show: bool = False, target_token_pos: Union[int, tuple[int, int], None] = None, + border_width: int = 2, + show_legends: bool = True, ) -> Union[None, Tuple["Figure", "Axes"]]: """ Plot the image in the input with the overlay of salience based on @@ -340,6 +343,12 @@ def plot_image_heatmap( the token at the given index. If tuple[int, int], like (m, n), use the summed token attribution of tokens from m to n (noninclusive) Default: None + border_width (int): Width of the border around each mask segment in pixels. + Set to 0 to disable borders. Only used when input has mask_list. + Default: 2 + show_legends (bool): If True, display the mask id for each segment at its + centroid. Only used when input has mask_list. + Default: True Returns: @@ -348,10 +357,9 @@ def plot_image_heatmap( customization. """ - if not isinstance(self.inp, ImageMaskInput): - raise ValueError("plot_image_heatmap is only available for ImageMaskInput") - inp = self.inp + if not isinstance(inp, ImageMaskInput): + raise ValueError("plot_image_heatmap is only available for ImageMaskInput") import matplotlib.pyplot as plt @@ -370,27 +378,54 @@ def plot_image_heatmap( from_pos, to_pos = target_token_pos attr = self.token_attr[from_pos:to_pos].sum(dim=0) + fig, ax = plt.subplots() + ax.imshow(inp.image) + + # Get pixel-level attribution using format_pixel_attr pixel_attr = inp.format_pixel_attr(attr.unsqueeze(0)) pixel_attr = pixel_attr.squeeze(0).cpu().numpy() max_abs_attr_val = np.abs(pixel_attr).max() - - fig, ax = plt.subplots() - - ax.imshow(inp.image) - + alpha = 0.8 heatmap = ax.imshow( pixel_attr, vmax=max_abs_attr_val, vmin=-max_abs_attr_val, cmap=self._get_plot_color_map(), - alpha=0.7, + alpha=alpha, ) - fig.set_facecolor("white") cbar = fig.colorbar(heatmap, ax=ax) cbar.ax.set_ylabel("Attribution", rotation=-90, va="bottom") + # Draw borders and legends on top if mask_list is available + if hasattr(inp, "get_mask_list"): + mask_list = [m.numpy().astype(bool) for m in inp.get_mask_list()] + mask_ids = list(inp.mask_id_to_idx.keys()) + cmap = self._get_plot_color_map() + + # Get attribution values for border color calculation + attr_np = attr.cpu().numpy() + + for i, mask in enumerate(mask_list): + if not mask.any(): + continue + + if border_width > 0: + # Calculate border color as a darker version of the salience color + norm_val = ( + (attr_np[i] / max_abs_attr_val + 1) / 2 + if max_abs_attr_val > 0 + else 0.5 + ) + rgba = np.array(cmap(norm_val)) + # Create darker version by multiplying RGB by 0.6 + border_color = np.array([*(rgba[:3] * 0.7), alpha]) + draw_mask_border(ax, mask, border_width, border_color=border_color) + if show_legends: + draw_mask_legend(ax, mask, label=str(mask_ids[i])) + + fig.set_facecolor("white") ax.axis("off") if show: diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index 6882df6c04..d1280b0d3d 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -1,14 +1,28 @@ # pyre-strict from abc import ABC, abstractmethod -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) import numpy as np import PIL.Image import torch from captum._utils.typing import TokenizerLike +from captum.attr._utils.visualization import draw_mask_border, draw_mask_legend from torch import Tensor +if TYPE_CHECKING: + from matplotlib.pyplot import Axes, Figure + def _scatter_itp_attr_by_mask( itp_attr: Tensor, @@ -666,6 +680,19 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any: return self.processor_fn(perturbed_image) + def get_mask_list(self) -> List[Tensor]: + """ + Get the list of binary masks for each interpretable feature. + If mask_list is provided, return it directly. Otherwise, create + a list of binary masks from the mask tensor. + + Returns: + List[Tensor]: list of binary masks, one for each interpretable feature + """ + return self.mask_list or [ + self.mask == mask_id for mask_id in self.mask_id_to_idx.keys() + ] + def format_attr(self, itp_attr: Tensor) -> Tensor: """ Attribution for interpretable image segments @@ -711,3 +738,57 @@ def format_pixel_attr(self, itp_attr: Tensor) -> Tensor: formatted_mask.unsqueeze(0), ) return formatted_attr + + def plot_mask_overlay( + self, show: bool = False, border_width: int = 0, show_legends: bool = True + ) -> Union[None, Tuple["Figure", "Axes"]]: + """ + util to help visualize the mask segementation + + Args: + show: If True, display the plot immediately using plt.show(). + If False, return the figure and axes objects. + border_width: Width of the border around each segment in pixels. + Set to 0 to disable borders. Default is 0. + show_legends: If True, display the mask id for each segment at its + centroid. Default is True. + """ + + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + + ax.imshow(self.image) + fig.set_facecolor("white") + + # random colors for all interpretable features + opacity = 0.6 + colors = [ + np.array([*np.random.random(3), opacity]) + for _ in range(self.n_itp_features) + ] + + mask_list = self.get_mask_list() + + for mask, color, mid in zip(mask_list, colors, self.mask_id_to_idx.keys()): + mask_np = mask.numpy().astype(bool) + h, w = mask_np.shape + mask_image = mask_np.reshape(h, w, 1) * color.reshape(1, 1, -1) + ax.imshow(mask_image) + + # Add border inside each segment using erosion + if border_width > 0: + border_color = np.array([*color[:3], 0.8]) + draw_mask_border(ax, mask_np, border_width, border_color) + + # Calculate centroid and display mask id + if show_legends: + draw_mask_legend(ax, mask_np, label=str(mid)) + + ax.axis("off") + + if show: + plt.show() + return None + else: + return fig, ax diff --git a/captum/attr/_utils/visualization.py b/captum/attr/_utils/visualization.py index 5f8e89bdb3..ad4d24cf2f 100644 --- a/captum/attr/_utils/visualization.py +++ b/captum/attr/_utils/visualization.py @@ -38,6 +38,113 @@ HAS_IPYTHON = False +def draw_mask_border( + ax: Axes, + mask: npt.NDArray[np.bool_], + border_width: int = 1, + border_color: Union[str, npt.NDArray[np.floating[Any]]] = "black", +) -> None: + """ + Draw a border inside a mask region using binary erosion. + + This function generates a border by eroding the mask and taking the difference + between the original mask and the eroded version, then displays it on the axes. + + Args: + ax: Matplotlib axes object to draw on. + mask: 2D boolean numpy array representing the mask region. + Shape should be (height, width). + border_width: Width of the border in pixels. + Default: 1 + border_color: Color for the border. Can be a string color name (e.g., + "black", "red") or an RGBA array of shape (4,) with values + typically in [0, 1]. + Default: "black" + + Example:: + >>> mask = np.array([[True, True, True], + ... [True, True, True], + ... [True, True, True]]) + >>> fig, ax = plt.subplots() + >>> draw_mask_border(ax, mask) # Uses default black border + >>> draw_mask_border(ax, mask, border_width=2, border_color="red") + """ + if not mask.any(): + return + + from scipy.ndimage import binary_erosion + + # Convert string color to RGBA array + if isinstance(border_color, str): + rgba = colors.to_rgba(border_color) + border_color_array = np.array(rgba) + else: + border_color_array = border_color + + eroded = binary_erosion(mask, iterations=border_width) + border = mask & ~eroded + h, w = mask.shape + border_image = border.reshape(h, w, 1) * border_color_array.reshape(1, 1, -1) + ax.imshow(border_image) + + +def draw_mask_legend( + ax: Axes, + mask: npt.NDArray[np.bool_], + label: str, + fontsize: int = 10, + text_color: str = "white", + bbox_facecolor: str = "black", + bbox_alpha: float = 0.6, +) -> None: + """ + Draw a label at the centroid of a mask region. + + This function calculates the centroid (center of mass) of a boolean mask + and places a text label at that position. + + Args: + ax: Matplotlib axes object to draw on. + mask: 2D boolean numpy array representing the mask region. + Shape should be (height, width). + label: Text string to display at the centroid. + fontsize: Font size for the label text. + Default: 10 + text_color: Color of the label text. + Default: "white" + bbox_facecolor: Background color of the text bounding box. + Default: "black" + bbox_alpha: Transparency of the text bounding box. + Default: 0.6 + + Example:: + >>> mask = np.array([[False, True, True], + ... [False, True, True], + ... [False, False, False]]) + >>> fig, ax = plt.subplots() + >>> draw_mask_legend(ax, mask, label="1") + """ + if not mask.any(): + return + + rows, cols = np.where(mask) + centroid_y, centroid_x = rows.mean(), cols.mean() + ax.text( + centroid_x, + centroid_y, + label, + color=text_color, + fontsize=fontsize, + ha="center", + va="center", + bbox={ + "boxstyle": "round,pad=0.2", + "facecolor": bbox_facecolor, + "alpha": bbox_alpha, + }, + ) + + class ImageVisualizationMethod(Enum): heat_map = 1 blended_heat_map = 2