Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 18 additions & 29 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,25 +206,12 @@ def plot_token_attr(
fig.set_size_inches(
max(data.shape[1] * 1.3, 6.4), max(data.shape[0] / 2.5, 4.8)
)
colors = [
"#93003a",
"#d0365b",
"#f57789",
"#ffbdc3",
"#ffffff",
"#a4d6e1",
"#73a3ca",
"#4772b3",
"#00429d",
]

im = ax.imshow(
data,
vmax=max_abs_attr_val,
vmin=-max_abs_attr_val,
cmap=mcolors.LinearSegmentedColormap.from_list(
name="colors", colors=colors
),
cmap=self._get_plot_color_map(),
aspect="auto",
)
fig.set_facecolor("white")
Expand Down Expand Up @@ -392,25 +379,11 @@ def plot_image_heatmap(

ax.imshow(inp.image)

colors = [
"#93003a",
"#d0365b",
"#f57789",
"#ffbdc3",
"#ffffff",
"#a4d6e1",
"#73a3ca",
"#4772b3",
"#00429d",
]

heatmap = ax.imshow(
pixel_attr,
vmax=max_abs_attr_val,
vmin=-max_abs_attr_val,
cmap=mcolors.LinearSegmentedColormap.from_list(
name="colors", colors=colors
),
cmap=self._get_plot_color_map(),
alpha=0.7,
)

Expand All @@ -426,6 +399,22 @@ def plot_image_heatmap(
else:
return fig, ax

def _get_plot_color_map(self) -> mcolors.LinearSegmentedColormap:
return mcolors.LinearSegmentedColormap.from_list(
name="colors",
colors=[
"#93003a",
"#d0365b",
"#f57789",
"#ffbdc3",
"#ffffff",
"#a4d6e1",
"#73a3ca",
"#4772b3",
"#00429d",
],
)


def _clean_up_pretty_token(token: str) -> str:
"""Remove newlines and leading/trailing whitespace from token."""
Expand Down
113 changes: 82 additions & 31 deletions captum/attr/_utils/interpretable_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,13 @@ class ImageMaskInput(InterpretableInput):
and end with same attributions. When mask is None, the entire image is
considered as one interpretable feature.
Default: None
mask_list (List[Tensor], optional): a list of binary masks to group
the image pixels as interpretable segment features. Each mask indicates
one feature and must be in the same shape as the image size. Value True
means the pixel is included in the feature. Compared to mask, mask_list
allows multiple features to be defined on the same pixel. If mask_list
is not None, mask will be ignored.
Default: None
baseline (Tuple[int, int, int], optional): the baseline RGB value for
the “absent” image pixels.
Default: (255, 255, 255)
Expand Down Expand Up @@ -573,6 +580,7 @@ class ImageMaskInput(InterpretableInput):

image: PIL.Image.Image
mask: Tensor
mask_list: List[Tensor]
baseline: Tuple[int, int, int]
processor_fn: Callable[[PIL.Image.Image], Any]
n_itp_features: int
Expand All @@ -584,6 +592,7 @@ def __init__(
self,
image: PIL.Image.Image,
mask: Optional[Tensor] = None,
mask_list: Optional[List[Tensor]] = None,
baseline: Tuple[int, int, int] = (255, 255, 255),
processor_fn: Callable[[PIL.Image.Image], Any] = lambda x: x,
) -> None:
Expand All @@ -593,29 +602,46 @@ def __init__(
self.image = image
self.baseline = baseline

# Create a dummy mask if None is provided
if mask is None:
# Create a mask with all zeros (entire image as one segment)
image_shape = (image.size[1], image.size[0]) # (height, width)
mask = torch.zeros(image_shape, dtype=torch.int32)
image_shape = (image.size[1], image.size[0]) # (height, width)

if mask_list is not None:
# Validate that all masks in mask_list have the correct shape
for i, m in enumerate(mask_list):
assert m.shape == image_shape, (
f"mask_list[{i}] shape {m.shape} must match "
f"image shape {image_shape}"
)

mask_ids = list(range(len(mask_list)))
# Create a dummy mask for compatibility
mask = torch.empty(0)
else:
# Validate that mask size matches image size
image_shape = (image.size[1], image.size[0]) # (height, width)
assert (
mask.shape == image_shape
), f"mask shape {mask.shape} must match image shape {image_shape}"
# Create a dummy mask if None is provided
if mask is None:
# Create a mask with all zeros (entire image as one segment)
mask = torch.zeros(image_shape, dtype=torch.int32)
else:
# Validate that mask size matches image size
assert (
mask.shape == image_shape
), f"mask shape {mask.shape} must match image shape {image_shape}"

mask_ids = torch.unique(mask).cpu().tolist()

# dummy mask_list
mask_list = []

self.mask = mask
mask_ids = torch.unique(mask)
self.mask_list = mask_list
self.n_itp_features = len(mask_ids)
self.mask_id_to_idx = {int(mid): i for i, mid in enumerate(mask_ids)}

self.original_model_inputs = processor_fn(image)

# temporarily for compatibility with AttributionResult
# which use the values for plot legends
self.values = [f"image_feature_{mid}" for mid in mask_ids]

self.original_model_inputs = processor_fn(image)

def to_tensor(self) -> Tensor:
return torch.tensor([[1.0] * self.n_itp_features])

Expand All @@ -625,10 +651,16 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any:

img_array = np.array(self.image)

for mask_id, itp_idx in self.mask_id_to_idx.items():
if perturbed_tensor[0][itp_idx] == 0:
mask_positions = self.mask == mask_id
img_array[mask_positions] = self.baseline
if self.mask_list:
for itp_idx, feature_mask in enumerate(self.mask_list):
if perturbed_tensor[0][itp_idx] == 0:
mask_positions = feature_mask.bool()
img_array[mask_positions] = self.baseline
else:
for mask_id, itp_idx in self.mask_id_to_idx.items():
if perturbed_tensor[0][itp_idx] == 0:
mask_positions = self.mask == mask_id
img_array[mask_positions] = self.baseline

perturbed_image = PIL.Image.fromarray(img_array.astype("uint8"))

Expand All @@ -643,20 +675,39 @@ def format_attr(self, itp_attr: Tensor) -> Tensor:

def format_pixel_attr(self, itp_attr: Tensor) -> Tensor:
"""
Attribution for image pixels
Attribution for image pixels.
When mask_list is used, attributions from overlapping features are summed.
Returns: 3D tensor of shape (1, height, width)
"""
device = itp_attr.device

# Map mask IDs to continuous indices
image_shape = self.mask.shape
formatted_mask = torch.zeros_like(self.mask, device=device)
for mask_id, itp_idx in self.mask_id_to_idx.items():
formatted_mask[self.mask == mask_id] = itp_idx

formatted_attr = _scatter_itp_attr_by_mask(
itp_attr,
(1, *image_shape),
formatted_mask.unsqueeze(0),
)
return formatted_attr
image_shape = (self.image.size[1], self.image.size[0]) # (height, width)

if self.mask_list:
# For mask_list, sum attributions from overlapping features
output_dims = itp_attr.shape[:-1]
attr_shape = (*output_dims, *image_shape)
formatted_attr = torch.zeros(attr_shape, device=device)

for itp_idx, feature_mask in enumerate(self.mask_list):
# Get attribution value for this feature
# itp_attr shape: (*output_dims, n_itp_features)
feature_attr = itp_attr[..., itp_idx]
# Expand to match image shape and add where mask is True
mask_bool = feature_mask.bool().to(device)
# Expand feature_attr to broadcast with mask
expanded_attr = feature_attr.unsqueeze(-1).unsqueeze(-1)
formatted_attr = formatted_attr + expanded_attr * mask_bool

return formatted_attr
else:
# Map mask IDs to continuous indices
formatted_mask = torch.zeros_like(self.mask, device=device)
for mask_id, itp_idx in self.mask_id_to_idx.items():
formatted_mask[self.mask == mask_id] = itp_idx

formatted_attr = _scatter_itp_attr_by_mask(
itp_attr,
(1, *image_shape),
formatted_mask.unsqueeze(0),
)
return formatted_attr
Loading