Skip to content

Commit 2a775c8

Browse files
aobo-yfacebook-github-bot
authored andcommitted
Refactor mask visualization utils from ImageMaskInput to visualization module (meta-pytorch#1756)
Summary: Extract border and legend drawing logic from `plot_mask_overlay` into reusable utility functions `draw_mask_border` and `draw_mask_legend` in the visualization Differential Revision: D90036627
1 parent 2fc1ffd commit 2a775c8

2 files changed

Lines changed: 111 additions & 24 deletions

File tree

captum/attr/_utils/interpretable_input.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818

1919
from captum._utils.typing import TokenizerLike
20+
from captum.attr._utils.visualization import draw_mask_border, draw_mask_legend
2021
from torch import Tensor
2122

2223
if TYPE_CHECKING:
@@ -755,9 +756,6 @@ def plot_mask_overlay(
755756

756757
import matplotlib.pyplot as plt
757758

758-
# scipy is not within install_requires of captum
759-
from scipy.ndimage import binary_erosion
760-
761759
fig, ax = plt.subplots()
762760

763761
ax.imshow(self.image)
@@ -780,30 +778,12 @@ def plot_mask_overlay(
780778

781779
# Add border inside each segment using erosion
782780
if border_width > 0:
783-
eroded = binary_erosion(mask_np, iterations=border_width)
784-
border = mask_np & ~eroded
785781
border_color = np.array([*color[:3], 0.8])
786-
border_image = border.reshape(h, w, 1) * border_color.reshape(1, 1, -1)
787-
ax.imshow(border_image)
782+
draw_mask_border(ax, mask_np, border_width, border_color)
788783

789784
# Calculate centroid and display mask id
790-
if show_legends and mask_np.any():
791-
rows, cols = np.where(mask_np)
792-
centroid_y, centroid_x = rows.mean(), cols.mean()
793-
ax.text(
794-
centroid_x,
795-
centroid_y,
796-
str(mid),
797-
color="white",
798-
fontsize=10,
799-
ha="center",
800-
va="center",
801-
bbox={
802-
"boxstyle": "round,pad=0.2",
803-
"facecolor": "black",
804-
"alpha": 0.6,
805-
},
806-
)
785+
if show_legends:
786+
draw_mask_legend(ax, mask_np, label=str(mid))
807787

808788
ax.axis("off")
809789

captum/attr/_utils/visualization.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,113 @@
3838
HAS_IPYTHON = False
3939

4040

41+
def draw_mask_border(
42+
ax: Axes,
43+
mask: npt.NDArray[np.bool_],
44+
border_width: int = 1,
45+
border_color: Union[str, npt.NDArray[np.floating[Any]]] = "black",
46+
) -> None:
47+
"""
48+
Draw a border inside a mask region using binary erosion.
49+
50+
This function generates a border by eroding the mask and taking the difference
51+
between the original mask and the eroded version, then displays it on the axes.
52+
53+
Args:
54+
ax: Matplotlib axes object to draw on.
55+
mask: 2D boolean numpy array representing the mask region.
56+
Shape should be (height, width).
57+
border_width: Width of the border in pixels.
58+
Default: 1
59+
border_color: Color for the border. Can be a string color name (e.g.,
60+
"black", "red") or an RGBA array of shape (4,) with values
61+
typically in [0, 1].
62+
Default: "black"
63+
64+
Example::
65+
>>> mask = np.array([[True, True, True],
66+
... [True, True, True],
67+
... [True, True, True]])
68+
>>> fig, ax = plt.subplots()
69+
>>> draw_mask_border(ax, mask) # Uses default black border
70+
>>> draw_mask_border(ax, mask, border_width=2, border_color="red")
71+
"""
72+
if not mask.any():
73+
return
74+
75+
from scipy.ndimage import binary_erosion
76+
77+
# Convert string color to RGBA array
78+
if isinstance(border_color, str):
79+
rgba = colors.to_rgba(border_color)
80+
border_color_array = np.array(rgba)
81+
else:
82+
border_color_array = border_color
83+
84+
eroded = binary_erosion(mask, iterations=border_width)
85+
border = mask & ~eroded
86+
h, w = mask.shape
87+
border_image = border.reshape(h, w, 1) * border_color_array.reshape(1, 1, -1)
88+
ax.imshow(border_image)
89+
90+
91+
def draw_mask_legend(
92+
ax: Axes,
93+
mask: npt.NDArray[np.bool_],
94+
label: str,
95+
fontsize: int = 10,
96+
text_color: str = "white",
97+
bbox_facecolor: str = "black",
98+
bbox_alpha: float = 0.6,
99+
) -> None:
100+
"""
101+
Draw a label at the centroid of a mask region.
102+
103+
This function calculates the centroid (center of mass) of a boolean mask
104+
and places a text label at that position.
105+
106+
Args:
107+
ax: Matplotlib axes object to draw on.
108+
mask: 2D boolean numpy array representing the mask region.
109+
Shape should be (height, width).
110+
label: Text string to display at the centroid.
111+
fontsize: Font size for the label text.
112+
Default: 10
113+
text_color: Color of the label text.
114+
Default: "white"
115+
bbox_facecolor: Background color of the text bounding box.
116+
Default: "black"
117+
bbox_alpha: Transparency of the text bounding box.
118+
Default: 0.6
119+
120+
Example::
121+
>>> mask = np.array([[False, True, True],
122+
... [False, True, True],
123+
... [False, False, False]])
124+
>>> fig, ax = plt.subplots()
125+
>>> draw_mask_legend(ax, mask, label="1")
126+
"""
127+
if not mask.any():
128+
return
129+
130+
rows, cols = np.where(mask)
131+
centroid_y, centroid_x = rows.mean(), cols.mean()
132+
ax.text(
133+
centroid_x,
134+
centroid_y,
135+
label,
136+
color=text_color,
137+
fontsize=fontsize,
138+
ha="center",
139+
va="center",
140+
bbox={
141+
"boxstyle": "round,pad=0.2",
142+
"facecolor": bbox_facecolor,
143+
"alpha": bbox_alpha,
144+
},
145+
)
146+
147+
41148
class ImageVisualizationMethod(Enum):
42149
heat_map = 1
43150
blended_heat_map = 2

0 commit comments

Comments
 (0)