Skip to content

Commit aa3dfd5

Browse files
aobo-yfacebook-github-bot
authored andcommitted
Add mask-based visualization support to LLMAttributionResult.plot_image_heatmap (meta-pytorch#1757)
Summary: Enable `plot_image_heatmap` to leverage the new `draw_mask_border` and `draw_mask_legend` utilities when the input has mask segments. This provides better visualization for image attribution with discrete regions, displaying attribution colors per segment with borders and labels derived from the salience values. Differential Revision: D90038941
1 parent 7fc0bb9 commit aa3dfd5

1 file changed

Lines changed: 45 additions & 10 deletions

File tree

captum/attr/_core/llm_attr.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
TextTemplateInput,
4444
TextTokenInput,
4545
)
46+
from captum.attr._utils.visualization import draw_mask_border, draw_mask_legend
4647

4748
if TYPE_CHECKING:
4849
from matplotlib.pyplot import Axes, Figure
@@ -326,6 +327,8 @@ def plot_image_heatmap(
326327
self,
327328
show: bool = False,
328329
target_token_pos: Union[int, tuple[int, int], None] = None,
330+
border_width: int = 2,
331+
show_legends: bool = True,
329332
) -> Union[None, Tuple["Figure", "Axes"]]:
330333
"""
331334
Plot the image in the input with the overlay of salience based on
@@ -340,6 +343,12 @@ def plot_image_heatmap(
340343
the token at the given index. If tuple[int, int], like (m, n), use the
341344
summed token attribution of tokens from m to n (noninclusive)
342345
Default: None
346+
border_width (int): Width of the border around each mask segment in pixels.
347+
Set to 0 to disable borders. Only used when input has mask_list.
348+
Default: 2
349+
show_legends (bool): If True, display the mask id for each segment at its
350+
centroid. Only used when input has mask_list.
351+
Default: True
343352
344353
345354
Returns:
@@ -348,10 +357,9 @@ def plot_image_heatmap(
348357
customization.
349358
"""
350359

351-
if not isinstance(self.inp, ImageMaskInput):
352-
raise ValueError("plot_image_heatmap is only available for ImageMaskInput")
353-
354360
inp = self.inp
361+
if not isinstance(inp, ImageMaskInput):
362+
raise ValueError("plot_image_heatmap is only available for ImageMaskInput")
355363

356364
import matplotlib.pyplot as plt
357365

@@ -370,27 +378,54 @@ def plot_image_heatmap(
370378
from_pos, to_pos = target_token_pos
371379
attr = self.token_attr[from_pos:to_pos].sum(dim=0)
372380

381+
fig, ax = plt.subplots()
382+
ax.imshow(inp.image)
383+
384+
# Get pixel-level attribution using format_pixel_attr
373385
pixel_attr = inp.format_pixel_attr(attr.unsqueeze(0))
374386
pixel_attr = pixel_attr.squeeze(0).cpu().numpy()
375387

376388
max_abs_attr_val = np.abs(pixel_attr).max()
377-
378-
fig, ax = plt.subplots()
379-
380-
ax.imshow(inp.image)
381-
389+
alpha = 0.8
382390
heatmap = ax.imshow(
383391
pixel_attr,
384392
vmax=max_abs_attr_val,
385393
vmin=-max_abs_attr_val,
386394
cmap=self._get_plot_color_map(),
387-
alpha=0.7,
395+
alpha=alpha,
388396
)
389397

390-
fig.set_facecolor("white")
391398
cbar = fig.colorbar(heatmap, ax=ax)
392399
cbar.ax.set_ylabel("Attribution", rotation=-90, va="bottom")
393400

401+
# Draw borders and legends on top if mask_list is available
402+
if hasattr(inp, "get_mask_list"):
403+
mask_list = [m.numpy().astype(bool) for m in inp.get_mask_list()]
404+
mask_ids = list(inp.mask_id_to_idx.keys())
405+
cmap = self._get_plot_color_map()
406+
407+
# Get attribution values for border color calculation
408+
attr_np = attr.cpu().numpy()
409+
410+
for i, mask in enumerate(mask_list):
411+
if not mask.any():
412+
continue
413+
414+
if border_width > 0:
415+
# Calculate border color as a darker version of the salience color
416+
norm_val = (
417+
(attr_np[i] / max_abs_attr_val + 1) / 2
418+
if max_abs_attr_val > 0
419+
else 0.5
420+
)
421+
rgba = np.array(cmap(norm_val))
422+
# Create darker version by multiplying RGB by 0.6
423+
border_color = np.array([*(rgba[:3] * 0.7), alpha])
424+
draw_mask_border(ax, mask, border_width, border_color=border_color)
425+
if show_legends:
426+
draw_mask_legend(ax, mask, label=str(mask_ids[i]))
427+
428+
fig.set_facecolor("white")
394429
ax.axis("off")
395430

396431
if show:

0 commit comments

Comments
 (0)