4343 TextTemplateInput ,
4444 TextTokenInput ,
4545)
46+ from captum .attr ._utils .visualization import draw_mask_border , draw_mask_legend
4647
4748if TYPE_CHECKING :
4849 from matplotlib .pyplot import Axes , Figure
@@ -326,6 +327,8 @@ def plot_image_heatmap(
326327 self ,
327328 show : bool = False ,
328329 target_token_pos : Union [int , tuple [int , int ], None ] = None ,
330+ border_width : int = 2 ,
331+ show_legends : bool = True ,
329332 ) -> Union [None , Tuple ["Figure" , "Axes" ]]:
330333 """
331334 Plot the image in the input with the overlay of salience based on
@@ -340,6 +343,12 @@ def plot_image_heatmap(
340343 the token at the given index. If tuple[int, int], like (m, n), use the
341344 summed token attribution of tokens from m to n (noninclusive)
342345 Default: None
346+ border_width (int): Width of the border around each mask segment in pixels.
347+ Set to 0 to disable borders. Only used when input has mask_list.
348+ Default: 2
349+ show_legends (bool): If True, display the mask id for each segment at its
350+ centroid. Only used when input has mask_list.
351+ Default: True
343352
344353
345354 Returns:
@@ -348,10 +357,9 @@ def plot_image_heatmap(
348357 customization.
349358 """
350359
351- if not isinstance (self .inp , ImageMaskInput ):
352- raise ValueError ("plot_image_heatmap is only available for ImageMaskInput" )
353-
354360 inp = self .inp
361+ if not isinstance (inp , ImageMaskInput ):
362+ raise ValueError ("plot_image_heatmap is only available for ImageMaskInput" )
355363
356364 import matplotlib .pyplot as plt
357365
@@ -370,27 +378,54 @@ def plot_image_heatmap(
370378 from_pos , to_pos = target_token_pos
371379 attr = self .token_attr [from_pos :to_pos ].sum (dim = 0 )
372380
381+ fig , ax = plt .subplots ()
382+ ax .imshow (inp .image )
383+
384+ # Get pixel-level attribution using format_pixel_attr
373385 pixel_attr = inp .format_pixel_attr (attr .unsqueeze (0 ))
374386 pixel_attr = pixel_attr .squeeze (0 ).cpu ().numpy ()
375387
376388 max_abs_attr_val = np .abs (pixel_attr ).max ()
377-
378- fig , ax = plt .subplots ()
379-
380- ax .imshow (inp .image )
381-
389+ alpha = 0.8
382390 heatmap = ax .imshow (
383391 pixel_attr ,
384392 vmax = max_abs_attr_val ,
385393 vmin = - max_abs_attr_val ,
386394 cmap = self ._get_plot_color_map (),
387- alpha = 0.7 ,
395+ alpha = alpha ,
388396 )
389397
390- fig .set_facecolor ("white" )
391398 cbar = fig .colorbar (heatmap , ax = ax )
392399 cbar .ax .set_ylabel ("Attribution" , rotation = - 90 , va = "bottom" )
393400
401+ # Draw borders and legends on top if mask_list is available
402+ if hasattr (inp , "get_mask_list" ):
403+ mask_list = [m .numpy ().astype (bool ) for m in inp .get_mask_list ()]
404+ mask_ids = list (inp .mask_id_to_idx .keys ())
405+ cmap = self ._get_plot_color_map ()
406+
407+ # Get attribution values for border color calculation
408+ attr_np = attr .cpu ().numpy ()
409+
410+ for i , mask in enumerate (mask_list ):
411+ if not mask .any ():
412+ continue
413+
414+ if border_width > 0 :
415+ # Calculate border color as a darker version of the salience color
416+ norm_val = (
417+ (attr_np [i ] / max_abs_attr_val + 1 ) / 2
418+ if max_abs_attr_val > 0
419+ else 0.5
420+ )
421+ rgba = np .array (cmap (norm_val ))
422+ # Create darker version by multiplying RGB by 0.6
423+ border_color = np .array ([* (rgba [:3 ] * 0.7 ), alpha ])
424+ draw_mask_border (ax , mask , border_width , border_color = border_color )
425+ if show_legends :
426+ draw_mask_legend (ax , mask , label = str (mask_ids [i ]))
427+
428+ fig .set_facecolor ("white" )
394429 ax .axis ("off" )
395430
396431 if show :
0 commit comments