diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index 6882df6c04..d7a057f197 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -1,6 +1,16 @@ # pyre-strict from abc import ABC, abstractmethod -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) import numpy as np import PIL.Image @@ -9,6 +19,9 @@ from captum._utils.typing import TokenizerLike from torch import Tensor +if TYPE_CHECKING: + from matplotlib.pyplot import Axes, Figure + def _scatter_itp_attr_by_mask( itp_attr: Tensor, @@ -666,6 +679,19 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any: return self.processor_fn(perturbed_image) + def get_mask_list(self) -> List[Tensor]: + """ + Get the list of binary masks for each interpretable feature. + If mask_list is provided, return it directly. Otherwise, create + a list of binary masks from the mask tensor. + + Returns: + List[Tensor]: list of binary masks, one for each interpretable feature + """ + return self.mask_list or [ + self.mask == mask_id for mask_id in self.mask_id_to_idx.keys() + ] + def format_attr(self, itp_attr: Tensor) -> Tensor: """ Attribution for interpretable image segments @@ -711,3 +737,78 @@ def format_pixel_attr(self, itp_attr: Tensor) -> Tensor: formatted_mask.unsqueeze(0), ) return formatted_attr + + def plot_mask_overlay( + self, show: bool = False, border_width: int = 0, show_legends: bool = True + ) -> Union[None, Tuple["Figure", "Axes"]]: + """ + util to help visualize the mask segementation + + Args: + show: If True, display the plot immediately using plt.show(). + If False, return the figure and axes objects. + border_width: Width of the border around each segment in pixels. + Set to 0 to disable borders. Default is 0. + show_legends: If True, display the mask id for each segment at its + centroid. Default is True. + """ + + import matplotlib.pyplot as plt + + # scipy is not within install_requires of captum + from scipy.ndimage import binary_erosion + + fig, ax = plt.subplots() + + ax.imshow(self.image) + fig.set_facecolor("white") + + # random colors for all interpretable features + opacity = 0.6 + colors = [ + np.array([*np.random.random(3), opacity]) + for _ in range(self.n_itp_features) + ] + + mask_list = self.get_mask_list() + + for mask, color, mid in zip(mask_list, colors, self.mask_id_to_idx.keys()): + mask_np = mask.numpy().astype(bool) + h, w = mask_np.shape + mask_image = mask_np.reshape(h, w, 1) * color.reshape(1, 1, -1) + ax.imshow(mask_image) + + # Add border inside each segment using erosion + if border_width > 0: + eroded = binary_erosion(mask_np, iterations=border_width) + border = mask_np & ~eroded + border_color = np.array([*color[:3], 0.8]) + border_image = border.reshape(h, w, 1) * border_color.reshape(1, 1, -1) + ax.imshow(border_image) + + # Calculate centroid and display mask id + if show_legends and mask_np.any(): + rows, cols = np.where(mask_np) + centroid_y, centroid_x = rows.mean(), cols.mean() + ax.text( + centroid_x, + centroid_y, + str(mid), + color="white", + fontsize=10, + ha="center", + va="center", + bbox={ + "boxstyle": "round,pad=0.2", + "facecolor": "black", + "alpha": 0.6, + }, + ) + + ax.axis("off") + + if show: + plt.show() + return None + else: + return fig, ax