Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 102 additions & 1 deletion captum/attr/_utils/interpretable_input.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
# pyre-strict
from abc import ABC, abstractmethod
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
Tuple,
TYPE_CHECKING,
Union,
)

import numpy as np
import PIL.Image
Expand All @@ -9,6 +19,9 @@
from captum._utils.typing import TokenizerLike
from torch import Tensor

if TYPE_CHECKING:
from matplotlib.pyplot import Axes, Figure


def _scatter_itp_attr_by_mask(
itp_attr: Tensor,
Expand Down Expand Up @@ -666,6 +679,19 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any:

return self.processor_fn(perturbed_image)

def get_mask_list(self) -> List[Tensor]:
"""
Get the list of binary masks for each interpretable feature.
If mask_list is provided, return it directly. Otherwise, create
a list of binary masks from the mask tensor.

Returns:
List[Tensor]: list of binary masks, one for each interpretable feature
"""
return self.mask_list or [
self.mask == mask_id for mask_id in self.mask_id_to_idx.keys()
]

def format_attr(self, itp_attr: Tensor) -> Tensor:
"""
Attribution for interpretable image segments
Expand Down Expand Up @@ -711,3 +737,78 @@ def format_pixel_attr(self, itp_attr: Tensor) -> Tensor:
formatted_mask.unsqueeze(0),
)
return formatted_attr

def plot_mask_overlay(
self, show: bool = False, border_width: int = 0, show_legends: bool = True
) -> Union[None, Tuple["Figure", "Axes"]]:
"""
util to help visualize the mask segementation

Args:
show: If True, display the plot immediately using plt.show().
If False, return the figure and axes objects.
border_width: Width of the border around each segment in pixels.
Set to 0 to disable borders. Default is 0.
show_legends: If True, display the mask id for each segment at its
centroid. Default is True.
"""

import matplotlib.pyplot as plt

# scipy is not within install_requires of captum
from scipy.ndimage import binary_erosion

fig, ax = plt.subplots()

ax.imshow(self.image)
fig.set_facecolor("white")

# random colors for all interpretable features
opacity = 0.6
colors = [
np.array([*np.random.random(3), opacity])
for _ in range(self.n_itp_features)
]

mask_list = self.get_mask_list()

for mask, color, mid in zip(mask_list, colors, self.mask_id_to_idx.keys()):
mask_np = mask.numpy().astype(bool)
h, w = mask_np.shape
mask_image = mask_np.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)

# Add border inside each segment using erosion
if border_width > 0:
eroded = binary_erosion(mask_np, iterations=border_width)
border = mask_np & ~eroded
border_color = np.array([*color[:3], 0.8])
border_image = border.reshape(h, w, 1) * border_color.reshape(1, 1, -1)
ax.imshow(border_image)

# Calculate centroid and display mask id
if show_legends and mask_np.any():
rows, cols = np.where(mask_np)
centroid_y, centroid_x = rows.mean(), cols.mean()
ax.text(
centroid_x,
centroid_y,
str(mid),
color="white",
fontsize=10,
ha="center",
va="center",
bbox={
"boxstyle": "round,pad=0.2",
"facecolor": "black",
"alpha": 0.6,
},
)

ax.axis("off")

if show:
plt.show()
return None
else:
return fig, ax