|
1 | 1 | # pyre-strict |
2 | 2 | from abc import ABC, abstractmethod |
3 | | -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union |
| 3 | +from typing import ( |
| 4 | + Any, |
| 5 | + Callable, |
| 6 | + cast, |
| 7 | + Dict, |
| 8 | + List, |
| 9 | + Optional, |
| 10 | + Tuple, |
| 11 | + TYPE_CHECKING, |
| 12 | + Union, |
| 13 | +) |
4 | 14 |
|
5 | 15 | import numpy as np |
6 | 16 | import PIL.Image |
|
9 | 19 | from captum._utils.typing import TokenizerLike |
10 | 20 | from torch import Tensor |
11 | 21 |
|
| 22 | +if TYPE_CHECKING: |
| 23 | + from matplotlib.pyplot import Axes, Figure |
| 24 | + |
12 | 25 |
|
13 | 26 | def _scatter_itp_attr_by_mask( |
14 | 27 | itp_attr: Tensor, |
@@ -711,3 +724,43 @@ def format_pixel_attr(self, itp_attr: Tensor) -> Tensor: |
711 | 724 | formatted_mask.unsqueeze(0), |
712 | 725 | ) |
713 | 726 | return formatted_attr |
| 727 | + |
| 728 | + def plot_mask_overlay( |
| 729 | + self, show: bool = False |
| 730 | + ) -> Union[None, Tuple["Figure", "Axes"]]: |
| 731 | + """ |
| 732 | + util to help visualize the mask segementation |
| 733 | + """ |
| 734 | + |
| 735 | + import matplotlib.pyplot as plt |
| 736 | + |
| 737 | + fig, ax = plt.subplots() |
| 738 | + |
| 739 | + ax.imshow(self.image) |
| 740 | + fig.set_facecolor("white") |
| 741 | + |
| 742 | + # random colors for all interpretable features |
| 743 | + opacity = 0.6 |
| 744 | + colors = [ |
| 745 | + np.array([*np.random.random(3), opacity]) |
| 746 | + for _ in range(self.n_itp_features) |
| 747 | + ] |
| 748 | + |
| 749 | + # Create mask_list from mask if not provided |
| 750 | + mask_list = self.mask_list or [ |
| 751 | + self.mask == mask_id for mask_id in self.mask_id_to_idx.keys() |
| 752 | + ] |
| 753 | + |
| 754 | + for itp_idx, mask in enumerate(mask_list): |
| 755 | + mask = mask.numpy() |
| 756 | + h, w = mask.shape[-2:] |
| 757 | + mask_image = mask.reshape(h, w, 1) * colors[itp_idx].reshape(1, 1, -1) |
| 758 | + ax.imshow(mask_image) |
| 759 | + |
| 760 | + ax.axis("off") |
| 761 | + |
| 762 | + if show: |
| 763 | + plt.show() |
| 764 | + return None |
| 765 | + else: |
| 766 | + return fig, ax |
0 commit comments