Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 82 additions & 31 deletions captum/attr/_utils/interpretable_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,13 @@ class ImageMaskInput(InterpretableInput):
and end with same attributions. When mask is None, the entire image is
considered as one interpretable feature.
Default: None
mask_list (List[Tensor], optional): a list of binary masks to group
the image pixels as interpretable segment features. Each mask indicates
one feature and must be in the same shape as the image size. Value True
means the pixel is included in the feature. Compared to mask, mask_list
allows multiple features to be defined on the same pixel. If mask_list
is not None, mask will be ignored.
Default: None
baseline (Tuple[int, int, int], optional): the baseline RGB value for
the “absent” image pixels.
Default: (255, 255, 255)
Expand Down Expand Up @@ -573,6 +580,7 @@ class ImageMaskInput(InterpretableInput):

image: PIL.Image.Image
mask: Tensor
mask_list: List[Tensor]
baseline: Tuple[int, int, int]
processor_fn: Callable[[PIL.Image.Image], Any]
n_itp_features: int
Expand All @@ -584,6 +592,7 @@ def __init__(
self,
image: PIL.Image.Image,
mask: Optional[Tensor] = None,
mask_list: Optional[List[Tensor]] = None,
baseline: Tuple[int, int, int] = (255, 255, 255),
processor_fn: Callable[[PIL.Image.Image], Any] = lambda x: x,
) -> None:
Expand All @@ -593,29 +602,46 @@ def __init__(
self.image = image
self.baseline = baseline

# Create a dummy mask if None is provided
if mask is None:
# Create a mask with all zeros (entire image as one segment)
image_shape = (image.size[1], image.size[0]) # (height, width)
mask = torch.zeros(image_shape, dtype=torch.int32)
image_shape = (image.size[1], image.size[0]) # (height, width)

if mask_list is not None:
# Validate that all masks in mask_list have the correct shape
for i, m in enumerate(mask_list):
assert m.shape == image_shape, (
f"mask_list[{i}] shape {m.shape} must match "
f"image shape {image_shape}"
)

mask_ids = list(range(len(mask_list)))
# Create a dummy mask for compatibility
mask = torch.empty(0)
else:
# Validate that mask size matches image size
image_shape = (image.size[1], image.size[0]) # (height, width)
assert (
mask.shape == image_shape
), f"mask shape {mask.shape} must match image shape {image_shape}"
# Create a dummy mask if None is provided
if mask is None:
# Create a mask with all zeros (entire image as one segment)
mask = torch.zeros(image_shape, dtype=torch.int32)
else:
# Validate that mask size matches image size
assert (
mask.shape == image_shape
), f"mask shape {mask.shape} must match image shape {image_shape}"

mask_ids = torch.unique(mask).cpu().tolist()

# dummy mask_list
mask_list = []

self.mask = mask
mask_ids = torch.unique(mask)
self.mask_list = mask_list
self.n_itp_features = len(mask_ids)
self.mask_id_to_idx = {int(mid): i for i, mid in enumerate(mask_ids)}

self.original_model_inputs = processor_fn(image)

# temporarily for compatibility with AttributionResult
# which use the values for plot legends
self.values = [f"image_feature_{mid}" for mid in mask_ids]

self.original_model_inputs = processor_fn(image)

def to_tensor(self) -> Tensor:
return torch.tensor([[1.0] * self.n_itp_features])

Expand All @@ -625,10 +651,16 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any:

img_array = np.array(self.image)

for mask_id, itp_idx in self.mask_id_to_idx.items():
if perturbed_tensor[0][itp_idx] == 0:
mask_positions = self.mask == mask_id
img_array[mask_positions] = self.baseline
if self.mask_list:
for itp_idx, feature_mask in enumerate(self.mask_list):
if perturbed_tensor[0][itp_idx] == 0:
mask_positions = feature_mask.bool()
img_array[mask_positions] = self.baseline
else:
for mask_id, itp_idx in self.mask_id_to_idx.items():
if perturbed_tensor[0][itp_idx] == 0:
mask_positions = self.mask == mask_id
img_array[mask_positions] = self.baseline

perturbed_image = PIL.Image.fromarray(img_array.astype("uint8"))

Expand All @@ -643,20 +675,39 @@ def format_attr(self, itp_attr: Tensor) -> Tensor:

def format_pixel_attr(self, itp_attr: Tensor) -> Tensor:
"""
Attribution for image pixels
Attribution for image pixels.
When mask_list is used, attributions from overlapping features are summed.
Returns: 3D tensor of shape (1, height, width)
"""
device = itp_attr.device

# Map mask IDs to continuous indices
image_shape = self.mask.shape
formatted_mask = torch.zeros_like(self.mask, device=device)
for mask_id, itp_idx in self.mask_id_to_idx.items():
formatted_mask[self.mask == mask_id] = itp_idx

formatted_attr = _scatter_itp_attr_by_mask(
itp_attr,
(1, *image_shape),
formatted_mask.unsqueeze(0),
)
return formatted_attr
image_shape = (self.image.size[1], self.image.size[0]) # (height, width)

if self.mask_list:
# For mask_list, sum attributions from overlapping features
output_dims = itp_attr.shape[:-1]
attr_shape = (*output_dims, *image_shape)
formatted_attr = torch.zeros(attr_shape, device=device)

for itp_idx, feature_mask in enumerate(self.mask_list):
# Get attribution value for this feature
# itp_attr shape: (*output_dims, n_itp_features)
feature_attr = itp_attr[..., itp_idx]
# Expand to match image shape and add where mask is True
mask_bool = feature_mask.bool().to(device)
# Expand feature_attr to broadcast with mask
expanded_attr = feature_attr.unsqueeze(-1).unsqueeze(-1)
formatted_attr = formatted_attr + expanded_attr * mask_bool

return formatted_attr
else:
# Map mask IDs to continuous indices
formatted_mask = torch.zeros_like(self.mask, device=device)
for mask_id, itp_idx in self.mask_id_to_idx.items():
formatted_mask[self.mask == mask_id] = itp_idx

formatted_attr = _scatter_itp_attr_by_mask(
itp_attr,
(1, *image_shape),
formatted_mask.unsqueeze(0),
)
return formatted_attr
159 changes: 159 additions & 0 deletions tests/attr/test_interpretable_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,3 +526,162 @@ def test_format_pixel_attr_with_non_continuous_mask(self) -> None:
self.assertTrue(torch.all(result[0, :, :5] == segment_0_value))
self.assertTrue(torch.all(result[0, :, 5:10] == segment_10_value))
self.assertTrue(torch.all(result[0, :, 10:] == segment_20_value))

# Tests for mask_list functionality

def test_init_mask_list_ignores_mask(self) -> None:
# Setup: provide both mask and mask_list
image = self._create_test_image(width=10, height=10)
# mask has 3 segments
mask = torch.zeros((10, 10), dtype=torch.int32)
mask[:, 3:7] = 1
mask[:, 7:] = 2
# mask_list has 2 masks
mask1 = torch.zeros((10, 10), dtype=torch.bool)
mask1[:, :5] = True
mask2 = torch.zeros((10, 10), dtype=torch.bool)
mask2[:, 5:] = True

# Execute: create ImageMaskInput with both mask and mask_list
mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
mask=mask,
mask_list=[mask1, mask2],
)

# Assert: mask_list takes precedence, so n_itp_features should be 2
self.assertEqual(mm_input.n_itp_features, 2)
self.assertEqual(len(mm_input.mask_list), 2)

def test_to_tensor_with_mask_list(self) -> None:
# Setup: create ImageMaskInput with 3 masks
image = self._create_test_image(width=15, height=10)
mask1 = torch.zeros((10, 15), dtype=torch.bool)
mask1[:, :5] = True
mask2 = torch.zeros((10, 15), dtype=torch.bool)
mask2[:, 5:10] = True
mask3 = torch.zeros((10, 15), dtype=torch.bool)
mask3[:, 10:] = True

mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
mask_list=[mask1, mask2, mask3],
)

# Execute: convert to tensor
result = mm_input.to_tensor()

# Assert: verify tensor has correct number of features
expected = torch.tensor([[1.0, 1.0, 1.0]])
assertTensorAlmostEqual(self, result, expected)

def test_to_model_input_with_mask_list(self) -> None:
# Setup: create image with 2 halves (left red, right green)
img_array = np.zeros((10, 10, 3), dtype=np.uint8)
img_array[:, :5] = [255, 0, 0] # Left half red
img_array[:, 5:] = [0, 255, 0] # Right half green
image = PIL.Image.fromarray(img_array)

mask1 = torch.zeros((10, 10), dtype=torch.bool)
mask1[:, :5] = True # Left half
mask2 = torch.zeros((10, 10), dtype=torch.bool)
mask2[:, 5:] = True # Right half

mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
mask_list=[mask1, mask2],
baseline=(255, 255, 255),
)

# Execute: keep left half (0), remove right half (1)
perturbed_tensor = torch.tensor([[1.0, 0.0]])
result = mm_input.to_model_input(perturbed_tensor)

# Assert: left half should be red, right half should be white
img_array = result["pixel_values"].numpy().astype(np.uint8)
# Left half should be red
self.assertTrue(np.all(img_array[:, :5, 0] == 255))
self.assertTrue(np.all(img_array[:, :5, 1] == 0))
# Right half should be white (baseline)
self.assertTrue(np.all(img_array[:, 5:] == 255))

def test_to_model_input_with_mask_list_overlapping(self) -> None:
# Setup: create red image with overlapping masks
image = self._create_test_image(color=(255, 0, 0))
mask1 = torch.zeros((10, 10), dtype=torch.bool)
mask1[:, :7] = True # Left 7 columns
mask2 = torch.zeros((10, 10), dtype=torch.bool)
mask2[:, 3:] = True # Right 7 columns (overlap at columns 3-6)

mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
mask_list=[mask1, mask2],
baseline=(255, 255, 255),
)

# Execute: remove first feature (mask1), keep second (mask2)
perturbed_tensor = torch.tensor([[0.0, 1.0]])
result = mm_input.to_model_input(perturbed_tensor)

# Assert: left 7 columns (covered by mask1) should be white
# even though columns 3-6 are also in mask2 (but mask1 sets them to baseline)
img_array = result["pixel_values"].numpy().astype(np.uint8)
self.assertTrue(np.all(img_array[:, :7] == 255))

def test_format_pixel_attr_with_mask_list(self) -> None:
# Setup: create ImageMaskInput with 2 non-overlapping masks
image = self._create_test_image(width=10, height=5)
mask1 = torch.zeros((5, 10), dtype=torch.bool)
mask1[:, :5] = True # Left half
mask2 = torch.zeros((5, 10), dtype=torch.bool)
mask2[:, 5:] = True # Right half

mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
mask_list=[mask1, mask2],
)

# Execute: format attribution
attr = torch.tensor([[0.3, 0.7]])
result = mm_input.format_pixel_attr(attr)

# Assert: left half should have 0.3, right half should have 0.7
self.assertEqual(result.shape, (1, 5, 10))
assertTensorAlmostEqual(
self, result[0, :, :5], torch.full((5, 5), 0.3)
) # Left half
assertTensorAlmostEqual(
self, result[0, :, 5:], torch.full((5, 5), 0.7)
) # Right half

def test_format_pixel_attr_with_mask_list_overlapping(self) -> None:
# Setup: create ImageMaskInput with overlapping masks
image = self._create_test_image(width=10, height=5)
mask1 = torch.zeros((5, 10), dtype=torch.bool)
mask1[:, :7] = True # Left 7 columns
mask2 = torch.zeros((5, 10), dtype=torch.bool)
mask2[:, 3:] = True # Right 7 columns (overlap at columns 3-6)

mm_input = ImageMaskInput(
processor_fn=self._simple_processor,
image=image,
mask_list=[mask1, mask2],
)

# Execute: format attribution with values 0.3 for mask1, 0.5 for mask2
attr = torch.tensor([[0.3, 0.5]])
result = mm_input.format_pixel_attr(attr)

# Assert: overlapping region should have summed attribution
self.assertEqual(result.shape, (1, 5, 10))
# Columns 0-2: only mask1 (0.3)
assertTensorAlmostEqual(self, result[0, :, :3], torch.full((5, 3), 0.3))
# Columns 3-6: both masks (0.3 + 0.5 = 0.8)
assertTensorAlmostEqual(self, result[0, :, 3:7], torch.full((5, 4), 0.8))
# Columns 7-9: only mask2 (0.5)
assertTensorAlmostEqual(self, result[0, :, 7:], torch.full((5, 3), 0.5))