Skip to content

Commit da7e880

Browse files
aobo-yfacebook-github-bot
authored andcommitted
Add util to viz mask segmentations in ImageMaskInput (meta-pytorch#1752)
Summary: add util to visualize how the image is segmented into features Differential Revision: D89953376
1 parent 4fade3a commit da7e880

1 file changed

Lines changed: 54 additions & 1 deletion

File tree

captum/attr/_utils/interpretable_input.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
# pyre-strict
22
from abc import ABC, abstractmethod
3-
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
3+
from typing import (
4+
Any,
5+
Callable,
6+
cast,
7+
Dict,
8+
List,
9+
Optional,
10+
Tuple,
11+
TYPE_CHECKING,
12+
Union,
13+
)
414

515
import numpy as np
616
import PIL.Image
@@ -9,6 +19,9 @@
919
from captum._utils.typing import TokenizerLike
1020
from torch import Tensor
1121

22+
if TYPE_CHECKING:
23+
from matplotlib.pyplot import Axes, Figure
24+
1225

1326
def _scatter_itp_attr_by_mask(
1427
itp_attr: Tensor,
@@ -711,3 +724,43 @@ def format_pixel_attr(self, itp_attr: Tensor) -> Tensor:
711724
formatted_mask.unsqueeze(0),
712725
)
713726
return formatted_attr
727+
728+
def plot_mask_overlay(
729+
self, show: bool = False
730+
) -> Union[None, Tuple["Figure", "Axes"]]:
731+
"""
732+
util to help visualize the mask segementation
733+
"""
734+
735+
import matplotlib.pyplot as plt
736+
737+
fig, ax = plt.subplots()
738+
739+
ax.imshow(self.image)
740+
fig.set_facecolor("white")
741+
742+
# random colors for all interpretable features
743+
opacity = 0.6
744+
colors = [
745+
np.array([*np.random.random(3), opacity])
746+
for _ in range(self.n_itp_features)
747+
]
748+
749+
# Create mask_list from mask if not provided
750+
mask_list = self.mask_list or [
751+
self.mask == mask_id for mask_id in self.mask_id_to_idx.keys()
752+
]
753+
754+
for itp_idx, mask in enumerate(mask_list):
755+
mask = mask.numpy()
756+
h, w = mask.shape[-2:]
757+
mask_image = mask.reshape(h, w, 1) * colors[itp_idx].reshape(1, 1, -1)
758+
ax.imshow(mask_image)
759+
760+
ax.axis("off")
761+
762+
if show:
763+
plt.show()
764+
return None
765+
else:
766+
return fig, ax

0 commit comments

Comments
 (0)