Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
TextTemplateInput,
TextTokenInput,
)
from captum.attr._utils.visualization import draw_mask_border, draw_mask_legend

if TYPE_CHECKING:
from matplotlib.pyplot import Axes, Figure
Expand Down Expand Up @@ -326,6 +327,8 @@ def plot_image_heatmap(
self,
show: bool = False,
target_token_pos: Union[int, tuple[int, int], None] = None,
border_width: int = 2,
show_legends: bool = True,
) -> Union[None, Tuple["Figure", "Axes"]]:
"""
Plot the image in the input with the overlay of salience based on
Expand All @@ -340,6 +343,12 @@ def plot_image_heatmap(
the token at the given index. If tuple[int, int], like (m, n), use the
summed token attribution of tokens from m to n (noninclusive)
Default: None
border_width (int): Width of the border around each mask segment in pixels.
Set to 0 to disable borders. Only used when input has mask_list.
Default: 2
show_legends (bool): If True, display the mask id for each segment at its
centroid. Only used when input has mask_list.
Default: True


Returns:
Expand All @@ -348,10 +357,9 @@ def plot_image_heatmap(
customization.
"""

if not isinstance(self.inp, ImageMaskInput):
raise ValueError("plot_image_heatmap is only available for ImageMaskInput")

inp = self.inp
if not isinstance(inp, ImageMaskInput):
raise ValueError("plot_image_heatmap is only available for ImageMaskInput")

import matplotlib.pyplot as plt

Expand All @@ -370,27 +378,54 @@ def plot_image_heatmap(
from_pos, to_pos = target_token_pos
attr = self.token_attr[from_pos:to_pos].sum(dim=0)

fig, ax = plt.subplots()
ax.imshow(inp.image)

# Get pixel-level attribution using format_pixel_attr
pixel_attr = inp.format_pixel_attr(attr.unsqueeze(0))
pixel_attr = pixel_attr.squeeze(0).cpu().numpy()

max_abs_attr_val = np.abs(pixel_attr).max()

fig, ax = plt.subplots()

ax.imshow(inp.image)

alpha = 0.8
heatmap = ax.imshow(
pixel_attr,
vmax=max_abs_attr_val,
vmin=-max_abs_attr_val,
cmap=self._get_plot_color_map(),
alpha=0.7,
alpha=alpha,
)

fig.set_facecolor("white")
cbar = fig.colorbar(heatmap, ax=ax)
cbar.ax.set_ylabel("Attribution", rotation=-90, va="bottom")

# Draw borders and legends on top if mask_list is available
if hasattr(inp, "get_mask_list"):
mask_list = [m.numpy().astype(bool) for m in inp.get_mask_list()]
mask_ids = list(inp.mask_id_to_idx.keys())
cmap = self._get_plot_color_map()

# Get attribution values for border color calculation
attr_np = attr.cpu().numpy()

for i, mask in enumerate(mask_list):
if not mask.any():
continue

if border_width > 0:
# Calculate border color as a darker version of the salience color
norm_val = (
(attr_np[i] / max_abs_attr_val + 1) / 2
if max_abs_attr_val > 0
else 0.5
)
rgba = np.array(cmap(norm_val))
# Create darker version by multiplying RGB by 0.6
border_color = np.array([*(rgba[:3] * 0.7), alpha])
draw_mask_border(ax, mask, border_width, border_color=border_color)
if show_legends:
draw_mask_legend(ax, mask, label=str(mask_ids[i]))

fig.set_facecolor("white")
ax.axis("off")

if show:
Expand Down
83 changes: 82 additions & 1 deletion captum/attr/_utils/interpretable_input.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
# pyre-strict
from abc import ABC, abstractmethod
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
Tuple,
TYPE_CHECKING,
Union,
)

import numpy as np
import PIL.Image
import torch

from captum._utils.typing import TokenizerLike
from captum.attr._utils.visualization import draw_mask_border, draw_mask_legend
from torch import Tensor

if TYPE_CHECKING:
from matplotlib.pyplot import Axes, Figure


def _scatter_itp_attr_by_mask(
itp_attr: Tensor,
Expand Down Expand Up @@ -666,6 +680,19 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any:

return self.processor_fn(perturbed_image)

def get_mask_list(self) -> List[Tensor]:
"""
Get the list of binary masks for each interpretable feature.
If mask_list is provided, return it directly. Otherwise, create
a list of binary masks from the mask tensor.

Returns:
List[Tensor]: list of binary masks, one for each interpretable feature
"""
return self.mask_list or [
self.mask == mask_id for mask_id in self.mask_id_to_idx.keys()
]

def format_attr(self, itp_attr: Tensor) -> Tensor:
"""
Attribution for interpretable image segments
Expand Down Expand Up @@ -711,3 +738,57 @@ def format_pixel_attr(self, itp_attr: Tensor) -> Tensor:
formatted_mask.unsqueeze(0),
)
return formatted_attr

def plot_mask_overlay(
self, show: bool = False, border_width: int = 0, show_legends: bool = True
) -> Union[None, Tuple["Figure", "Axes"]]:
"""
util to help visualize the mask segementation

Args:
show: If True, display the plot immediately using plt.show().
If False, return the figure and axes objects.
border_width: Width of the border around each segment in pixels.
Set to 0 to disable borders. Default is 0.
show_legends: If True, display the mask id for each segment at its
centroid. Default is True.
"""

import matplotlib.pyplot as plt

fig, ax = plt.subplots()

ax.imshow(self.image)
fig.set_facecolor("white")

# random colors for all interpretable features
opacity = 0.6
colors = [
np.array([*np.random.random(3), opacity])
for _ in range(self.n_itp_features)
]

mask_list = self.get_mask_list()

for mask, color, mid in zip(mask_list, colors, self.mask_id_to_idx.keys()):
mask_np = mask.numpy().astype(bool)
h, w = mask_np.shape
mask_image = mask_np.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)

# Add border inside each segment using erosion
if border_width > 0:
border_color = np.array([*color[:3], 0.8])
draw_mask_border(ax, mask_np, border_width, border_color)

# Calculate centroid and display mask id
if show_legends:
draw_mask_legend(ax, mask_np, label=str(mid))

ax.axis("off")

if show:
plt.show()
return None
else:
return fig, ax
107 changes: 107 additions & 0 deletions captum/attr/_utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,113 @@
HAS_IPYTHON = False


def draw_mask_border(
ax: Axes,
mask: npt.NDArray[np.bool_],
border_width: int = 1,
border_color: Union[str, npt.NDArray[np.floating[Any]]] = "black",
) -> None:
"""
Draw a border inside a mask region using binary erosion.

This function generates a border by eroding the mask and taking the difference
between the original mask and the eroded version, then displays it on the axes.

Args:
ax: Matplotlib axes object to draw on.
mask: 2D boolean numpy array representing the mask region.
Shape should be (height, width).
border_width: Width of the border in pixels.
Default: 1
border_color: Color for the border. Can be a string color name (e.g.,
"black", "red") or an RGBA array of shape (4,) with values
typically in [0, 1].
Default: "black"

Example::
>>> mask = np.array([[True, True, True],
... [True, True, True],
... [True, True, True]])
>>> fig, ax = plt.subplots()
>>> draw_mask_border(ax, mask) # Uses default black border
>>> draw_mask_border(ax, mask, border_width=2, border_color="red")
"""
if not mask.any():
return

from scipy.ndimage import binary_erosion

# Convert string color to RGBA array
if isinstance(border_color, str):
rgba = colors.to_rgba(border_color)
border_color_array = np.array(rgba)
else:
border_color_array = border_color

eroded = binary_erosion(mask, iterations=border_width)
border = mask & ~eroded
h, w = mask.shape
border_image = border.reshape(h, w, 1) * border_color_array.reshape(1, 1, -1)
ax.imshow(border_image)


def draw_mask_legend(
ax: Axes,
mask: npt.NDArray[np.bool_],
label: str,
fontsize: int = 10,
text_color: str = "white",
bbox_facecolor: str = "black",
bbox_alpha: float = 0.6,
) -> None:
"""
Draw a label at the centroid of a mask region.

This function calculates the centroid (center of mass) of a boolean mask
and places a text label at that position.

Args:
ax: Matplotlib axes object to draw on.
mask: 2D boolean numpy array representing the mask region.
Shape should be (height, width).
label: Text string to display at the centroid.
fontsize: Font size for the label text.
Default: 10
text_color: Color of the label text.
Default: "white"
bbox_facecolor: Background color of the text bounding box.
Default: "black"
bbox_alpha: Transparency of the text bounding box.
Default: 0.6

Example::
>>> mask = np.array([[False, True, True],
... [False, True, True],
... [False, False, False]])
>>> fig, ax = plt.subplots()
>>> draw_mask_legend(ax, mask, label="1")
"""
if not mask.any():
return

rows, cols = np.where(mask)
centroid_y, centroid_x = rows.mean(), cols.mean()
ax.text(
centroid_x,
centroid_y,
label,
color=text_color,
fontsize=fontsize,
ha="center",
va="center",
bbox={
"boxstyle": "round,pad=0.2",
"facecolor": bbox_facecolor,
"alpha": bbox_alpha,
},
)


class ImageVisualizationMethod(Enum):
heat_map = 1
blended_heat_map = 2
Expand Down