Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace uses of np.ndarray with npt.NDArray (#1387) #1389

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions captum/attr/_utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import matplotlib

import numpy as np
import numpy.typing as npt
from matplotlib import cm, colors, pyplot as plt
from matplotlib.axes import Axes
from matplotlib.collections import LineCollection
Expand Down Expand Up @@ -47,11 +48,11 @@ class VisualizeSign(Enum):
all = 4


def _prepare_image(attr_visual: ndarray) -> ndarray:
def _prepare_image(attr_visual: npt.NDArray) -> npt.NDArray:
return np.clip(attr_visual.astype(int), 0, 255)


def _normalize_scale(attr: ndarray, scale_factor: float) -> ndarray:
def _normalize_scale(attr: npt.NDArray, scale_factor: float) -> npt.NDArray:
assert scale_factor != 0, "Cannot normalize by scale factor = 0"
if abs(scale_factor) < 1e-5:
warnings.warn(
Expand All @@ -64,7 +65,9 @@ def _normalize_scale(attr: ndarray, scale_factor: float) -> ndarray:
return np.clip(attr_norm, -1, 1)


def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]) -> float:
def _cumulative_sum_threshold(
values: npt.NDArray, percentile: Union[int, float]
) -> float:
# given values should be non-negative
assert percentile >= 0 and percentile <= 100, (
"Percentile for thresholding must be " "between 0 and 100 inclusive."
Expand All @@ -76,11 +79,11 @@ def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]) ->


def _normalize_attr(
attr: ndarray,
attr: npt.NDArray,
sign: str,
outlier_perc: Union[int, float] = 2,
reduction_axis: Optional[int] = None,
) -> ndarray:
) -> npt.NDArray:
attr_combined = attr
if reduction_axis is not None:
attr_combined = np.sum(attr, axis=reduction_axis)
Expand Down Expand Up @@ -130,7 +133,7 @@ def _initialize_cmap_and_vmin_vmax(

def _visualize_original_image(
plt_axis: Axes,
original_image: Optional[ndarray],
original_image: Optional[npt.NDArray],
**kwargs: Any,
) -> None:
assert (
Expand All @@ -143,7 +146,7 @@ def _visualize_original_image(

def _visualize_heat_map(
plt_axis: Axes,
norm_attr: ndarray,
norm_attr: npt.NDArray,
cmap: Union[str, Colormap],
vmin: float,
vmax: float,
Expand All @@ -155,8 +158,8 @@ def _visualize_heat_map(

def _visualize_blended_heat_map(
plt_axis: Axes,
original_image: ndarray,
norm_attr: ndarray,
original_image: npt.NDArray,
norm_attr: npt.NDArray,
cmap: Union[str, Colormap],
vmin: float,
vmax: float,
Expand All @@ -176,8 +179,8 @@ def _visualize_blended_heat_map(
def _visualize_masked_image(
plt_axis: Axes,
sign: str,
original_image: ndarray,
norm_attr: ndarray,
original_image: npt.NDArray,
norm_attr: npt.NDArray,
**kwargs: Any,
) -> None:
assert VisualizeSign[sign].value != VisualizeSign.all.value, (
Expand All @@ -190,8 +193,8 @@ def _visualize_masked_image(
def _visualize_alpha_scaling(
plt_axis: Axes,
sign: str,
original_image: ndarray,
norm_attr: ndarray,
original_image: npt.NDArray,
norm_attr: npt.NDArray,
**kwargs: Any,
) -> None:
assert VisualizeSign[sign].value != VisualizeSign.all.value, (
Expand All @@ -210,8 +213,8 @@ def _visualize_alpha_scaling(


def visualize_image_attr(
attr: ndarray,
original_image: Optional[ndarray] = None,
attr: npt.NDArray,
original_image: Optional[npt.NDArray] = None,
method: str = "heat_map",
sign: str = "absolute_value",
plt_fig_axis: Optional[Tuple[Figure, Axes]] = None,
Expand Down Expand Up @@ -417,8 +420,8 @@ def visualize_image_attr(


def visualize_image_attr_multiple(
attr: ndarray,
original_image: Union[None, ndarray],
attr: npt.NDArray,
original_image: Union[None, npt.NDArray],
methods: List[str],
signs: List[str],
titles: Optional[List[str]] = None,
Expand Down Expand Up @@ -526,9 +529,9 @@ def visualize_image_attr_multiple(


def visualize_timeseries_attr(
attr: ndarray,
data: ndarray,
x_values: Optional[ndarray] = None,
attr: npt.NDArray,
data: npt.NDArray,
x_values: Optional[npt.NDArray] = None,
method: str = "overlay_individual",
sign: str = "absolute_value",
channel_labels: Optional[List[str]] = None,
Expand Down
4 changes: 2 additions & 2 deletions tests/attr/test_gradient_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from typing import cast, Tuple

import numpy as np
import numpy.typing as npt
import torch
from captum._utils.typing import Tensor
from captum.attr._core.gradient_shap import GradientShap
from captum.attr._core.integrated_gradients import IntegratedGradients
from numpy import ndarray
from tests.attr.helpers.attribution_delta_util import (
assert_attribution_delta,
assert_delta,
Expand Down Expand Up @@ -132,7 +132,7 @@ def generate_baselines_with_inputs(inputs: Tensor) -> Tensor:
inp_shape = cast(Tuple[int, ...], inputs.shape)
return torch.arange(0.0, inp_shape[1] * 2.0).reshape(2, inp_shape[1])

def generate_baselines_returns_array() -> ndarray:
def generate_baselines_returns_array() -> npt.NDArray:
return np.arange(0.0, num_in * 4.0).reshape(4, num_in)

# 10-class classification model
Expand Down
3 changes: 2 additions & 1 deletion tests/utils/models/linear_models/_test_linear_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import captum._utils.models.linear_model.model as pytorch_model_module
import numpy as np
import numpy.typing as npt
import sklearn.datasets as datasets
import torch
from tests.helpers.evaluate_linear_model import evaluate
Expand Down Expand Up @@ -107,7 +108,7 @@ def compare_to_sk_learn(
o_sklearn["l1_reg"] = alpha * sklearn_h.norm(p=1, dim=-1)

rel_diff = cast(
np.ndarray,
npt.NDArray,
# pyre-fixme[6]: For 1st argument expected `int` but got `Union[int, Tensor]`.
(sum(o_sklearn.values()) - sum(o_pytorch.values())),
) / abs(sum(o_sklearn.values()))
Expand Down
Loading