├── anatome
├── __init__.py
├── fourier.py
├── utils.py
├── landscape.py
└── distance.py
├── setup.py
├── tests
├── test_utils.py
├── test_fourier.py
├── test_landscape.py
└── test_distance.py
├── .github
└── workflows
│ └── action.yml
├── LICENSE
├── README.md
├── assets
├── landscape2d.svg
└── fourier.svg
└── examples.ipynb
/anatome/__init__.py:
--------------------------------------------------------------------------------
1 | from .distance import Distance
2 | from .fourier import fourier_map
3 | from .landscape import landscape1d, landscape2d
4 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | install_requires = [
4 | 'torch>=1.10.0',
5 | 'torchvision>=0.11.1',
6 | ]
7 |
8 | setup(
9 | name='anatome',
10 | version='0.0.6',
11 | description='Ἀνατομή is a PyTorch library to analyze representation of neural networks',
12 | author='Ryuichiro Hataya',
13 | install_requires=install_requires,
14 | packages=find_packages()
15 | )
16 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from anatome import utils
4 |
5 |
6 | def test_normalize_denormalize():
7 | input = torch.randn(4, 3, 5, 5)
8 | mean = torch.as_tensor([0.5, 0.5, 0.5])
9 | std = torch.as_tensor([0.5, 0.5, 0.5])
10 | output = utils._denormalize(utils._normalize(input, mean, std),
11 | mean, std)
12 | assert torch.allclose(input, output, atol=1e-4)
13 |
--------------------------------------------------------------------------------
/tests/test_fourier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from anatome import fourier
6 |
7 |
8 | def test_landscape():
9 | model = nn.Sequential(nn.Conv2d(3, 4, 3),
10 | nn.ReLU(),
11 | nn.AdaptiveAvgPool2d(1),
12 | nn.Flatten(),
13 | nn.Linear(4, 3))
14 | data = (torch.randn(10, 3, 16, 16),
15 | torch.randint(2, (10,)))
16 | fourier.fourier_map(model, data, F.cross_entropy, 4)
17 | fourier.fourier_map(model, data, F.cross_entropy, 4, (2, 2))
18 |
--------------------------------------------------------------------------------
/tests/test_landscape.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from anatome import landscape
6 |
7 |
8 | def test_landscape():
9 | model = nn.Sequential(nn.Linear(4, 3),
10 | nn.ReLU(),
11 | nn.Linear(3, 2))
12 | data = (torch.randn(10, 4),
13 | torch.randint(2, (10,)))
14 | x_coord, y = landscape.landscape1d(model, data, F.cross_entropy, (-1, 1), 0.5)
15 | assert x_coord.shape == y.shape
16 |
17 | x_coord, y_coord, z = landscape.landscape2d(model, data, F.cross_entropy, (-1, 1), (-1, 1), (0.5, 0.5))
18 | assert x_coord.shape == y_coord.shape
19 | assert x_coord.shape == z.shape
20 |
--------------------------------------------------------------------------------
/.github/workflows/action.yml:
--------------------------------------------------------------------------------
1 | name: pytest
2 |
3 | on: [ push, pull_request ]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 | if: "!contains(github.event.head_commit.message, 'skip test')"
10 |
11 | strategy:
12 | matrix:
13 | python: [ '3.9' ]
14 | torch: [ 'torch==1.10.0+cpu torchvision==0.11.1+cpu -f https://download.pytorch.org/whl/torch_stable.html' ]
15 |
16 | steps:
17 | - uses: actions/checkout@v2
18 | - uses: actions/setup-python@v2
19 | with:
20 | python-version: ${{ matrix.python }}
21 |
22 | - name: install dependencies
23 | run: |
24 | python -m venv venv
25 | . venv/bin/activate
26 | pip install ${{ matrix.torch }}
27 | pip install -U pytest
28 | pip install -U .
29 |
30 | - name: run test
31 | run: |
32 | . venv/bin/activate
33 | pytest
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Ryuichiro Hataya
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # anatome 
2 |
3 | Ἀνατομή is a PyTorch library to analyze internal representation of neural networks
4 |
5 | This project is under active development and the codebase is subject to change.
6 |
7 | **v0.0.5 introduces significant changes to `distance`.**
8 |
9 | ## Installation
10 |
11 | `anatome` requires
12 |
13 | ```
14 | Python>=3.9.0
15 | PyTorch>=1.10
16 | torchvision>=0.11
17 | ```
18 |
19 | After the installation of PyTorch, install `anatome` as follows:
20 |
21 | ```
22 | pip install -U git+https://github.com/moskomule/anatome
23 | ```
24 |
25 | ## Available Tools
26 |
27 | ### Representation Similarity
28 |
29 | To measure the similarity of learned representation, `anatome.SimilarityHook` is a useful tool. Currently, the following
30 | methods are implemented.
31 |
32 | - [Raghu et al. NIPS2017 SVCCA](https://papers.nips.cc/paper/7188-svcca-singular-vector-canonical-correlation-analysis-for-deep-learning-dynamics-and-interpretability)
33 | - [Marcos et al. NeurIPS2018 PWCCA](https://papers.nips.cc/paper/7815-insights-on-representational-similarity-in-neural-networks-with-canonical-correlation)
34 | - [Kornblith et al. ICML2019 Linear CKA](http://proceedings.mlr.press/v97/kornblith19a.html)
35 | - [Ding et al. arXiv Orthogonal Procrustes distance](https://arxiv.org/abs/2108.01661)
36 |
37 | ```python
38 | import torch
39 | from torchvision.models import resnet18
40 | from anatome import Distance
41 |
42 | random_model = resnet18()
43 | learned_model = resnet18(pretrained=True)
44 | distance = Distance(random_model, learned_model, method='pwcca')
45 | with torch.no_grad():
46 | distance.forward(torch.randn(256, 3, 224, 224))
47 |
48 | # resize if necessary by specifying `size`
49 | distance.between("layer3.0.conv1", "layer3.0.conv1", size=8)
50 | ```
51 |
52 | ### Loss Landscape Visualization
53 |
54 | - [Li et al. NeurIPS2018](https://papers.nips.cc/paper/7875-visualizing-the-loss-landscape-of-neural-nets)
55 |
56 | ```python
57 | from torch.nn import functional as F
58 | from torchvision.models import resnet18
59 | from anatome import landscape2d
60 |
61 | x, y, z = landscape2d(resnet18(),
62 | data,
63 | F.cross_entropy,
64 | x_range=(-1, 1),
65 | y_range=(-1, 1),
66 | step_size=0.1)
67 | imshow(z)
68 | ```
69 |
70 | 
71 | 
72 |
73 | ### Fourier Analysis
74 |
75 | - Yin et al. NeurIPS 2019 etc.,
76 |
77 | ```python
78 | from torch.nn import functional as F
79 | from torchvision.models import resnet18
80 | from anatome import fourier_map
81 |
82 | map = fourier_map(resnet18(),
83 | data,
84 | F.cross_entropy,
85 | norm=4)
86 | imshow(map)
87 | ```
88 |
89 | 
90 |
91 | ## Citation
92 |
93 | If you use this implementation in your research, please cite as:
94 |
95 | ```
96 | @software{hataya2020anatome,
97 | author={Ryuichiro Hataya},
98 | title={anatome, a PyTorch library to analyze internal representation of neural networks},
99 | url={https://github.com/moskomule/anatome},
100 | year={2020}
101 | }
102 | ```
--------------------------------------------------------------------------------
/anatome/fourier.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Callable
4 |
5 | import torch
6 | from torch import Tensor, nn
7 | from torch.nn import functional as F
8 |
9 | try:
10 | from tqdm import tqdm
11 | except ImportError:
12 | tqdm = lambda x, ncols=None: x
13 |
14 | from .utils import _denormalize, _evaluate, _normalize, ifft_shift, _irfft
15 |
16 |
17 | def add_fourier_noise(idx: tuple[int, int],
18 | images: Tensor,
19 | norm: float,
20 | size: tuple[int, int] = None,
21 | ) -> Tensor:
22 | """ Add Fourier noise
23 |
24 | Args:
25 | idx: index to be used
26 | images: original images
27 | norm: norm of additive noise
28 | size:
29 |
30 | Returns: images with Fourier noise
31 |
32 | """
33 |
34 | images = images.clone()
35 |
36 | if size is None:
37 | _, _, h, w = images.size()
38 | else:
39 | h, w = size
40 |
41 | noise = images.new_zeros(1, h, w, 2)
42 | noise[:, idx[0], idx[1]] = 1
43 | noise[:, h - 1 - idx[0], w - 1 - idx[1]] = 1
44 | recon = _irfft(ifft_shift(noise), 2, normalized=True, onesided=False).unsqueeze(0)
45 | recon.div_(recon.norm(p=2)).mul_(norm)
46 | if size is not None:
47 | recon = F.interpolate(recon, images.shape[2:])
48 | images.add_(recon).clamp_(0, 1)
49 | return images
50 |
51 |
52 | @torch.no_grad()
53 | def fourier_map(model: nn.Module,
54 | data: tuple[Tensor, Tensor],
55 | criterion: Callable[[Tensor, Tensor], Tensor],
56 | norm: float,
57 | fourier_map_size: Optional[tuple[int, int]] = None,
58 | mean: Optional[list[float] or Tensor] = None,
59 | std: Optional[list[float] or Tensor] = None,
60 | auto_cast: bool = False
61 | ) -> Tensor:
62 | """
63 |
64 | Args:
65 | model: Trained model
66 | data: Pairs of [input, target] to compute criterion
67 | criterion: Criterion of (input, target) -> scalar value
68 | norm: Intensity of fourier noise
69 | fourier_map_size: Size of map, (H, W). Note that the computational time is dominated by HW.
70 | mean: If the range of input is [-1, 1], specify mean and std.
71 | std: If the range of input is [-1, 1], specify mean and std.
72 |
73 | Returns:
74 |
75 | """
76 | input, target = data
77 | if fourier_map_size is None:
78 | _, _, h, w = input.size()
79 | else:
80 | h, w = fourier_map_size
81 | if mean is not None:
82 | _mean = torch.as_tensor(mean, device=input.device, dtype=torch.float)
83 | _std = torch.as_tensor(std, device=input.device, dtype=torch.float)
84 | input = _denormalize(input, _mean, _std) # [0, 1]
85 | map = torch.zeros(h, w)
86 | for u_i in tqdm(torch.triu_indices(h, w).t(), ncols=80):
87 | l_i = h - 1 - u_i[0], w - 1 - u_i[1]
88 | noisy_input = add_fourier_noise(u_i, input, norm, fourier_map_size)
89 | if mean is not None:
90 | noisy_input = _normalize(noisy_input, _mean, _std) # to [-1, 1]
91 | loss = _evaluate(model, (noisy_input, target), criterion, auto_cast)
92 | map[u_i[0], u_i[1]] = loss
93 | map[l_i[0], l_i[1]] = loss
94 | return map
95 |
--------------------------------------------------------------------------------
/anatome/utils.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | from typing import Callable, Optional
3 |
4 | import torch
5 | from torch import Tensor, nn
6 |
7 |
8 | def _svd(input: torch.Tensor
9 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
10 | # torch.svd style
11 | U, S, Vh = torch.linalg.svd(input, full_matrices=False)
12 | V = Vh.transpose(-2, -1)
13 | return U, S, V
14 |
15 |
16 | @torch.no_grad()
17 | def _evaluate(model: nn.Module,
18 | data: tuple[Tensor, Tensor],
19 | criterion: Callable[[Tensor, Tensor], Tensor],
20 | auto_cast: bool
21 | ) -> float:
22 | # evaluate model with given data points using the criterion
23 | with (torch.cuda.amp.autocast() if auto_cast and torch.cuda.is_available() else contextlib.nullcontext()):
24 | input, target = data
25 | return criterion(model(input), target).item()
26 |
27 |
28 | def _normalize(input: Tensor,
29 | mean: Tensor,
30 | std: Tensor
31 | ) -> Tensor:
32 | # normalize tensor in [0, 1] to [-1, 1]
33 | input = input.clone()
34 | input.add_(-mean[:, None, None]).div_(std[:, None, None])
35 | return input
36 |
37 |
38 | def _denormalize(input: Tensor,
39 | mean: Tensor,
40 | std: Tensor
41 | ) -> Tensor:
42 | # denormalize tensor in [-1, 1] to [0, 1]
43 | input = input.clone()
44 | input.mul_(std[:, None, None]).add_(mean[:, None, None])
45 | return input
46 |
47 |
48 | def fft_shift(input: torch.Tensor,
49 | dims: Optional[tuple[int, ...]] = None
50 | ) -> torch.Tensor:
51 | """ PyTorch version of np.fftshift
52 |
53 | Args:
54 | input: rFFTed Tensor of size [Bx]CxHxWx2
55 | dims:
56 |
57 | Returns: shifted tensor
58 |
59 | """
60 |
61 | return torch.fft.fftshift(input, dims)
62 |
63 |
64 | def ifft_shift(input: torch.Tensor,
65 | dims: Optional[tuple[int, ...]] = None
66 | ) -> torch.Tensor:
67 | """ PyTorch version of np.ifftshift
68 |
69 | Args:
70 | input: rFFTed Tensor of size [Bx]CxHxWx2
71 | dims:
72 |
73 | Returns: shifted tensor
74 |
75 | """
76 |
77 | return torch.fft.ifftshift(input, dims)
78 |
79 |
80 | def _rfft(self: Tensor,
81 | signal_ndim: int,
82 | normalized: bool = False,
83 | onesided: bool = True
84 | ) -> Tensor:
85 | # old-day's torch.rfft
86 |
87 | if signal_ndim > 4:
88 | raise RuntimeError("signal_ndim is expected to be 1, 2, 3.")
89 |
90 | m = torch.fft.rfftn if onesided else torch.fft.fftn
91 | dim = [-3, -2, -1][3 - signal_ndim:]
92 | return torch.view_as_real(m(self, dim=dim, norm="ortho" if normalized else None))
93 |
94 |
95 | def _irfft(self: Tensor,
96 | signal_ndim: int,
97 | normalized: bool = False,
98 | onesided: bool = True,
99 | ) -> Tensor:
100 | # old-day's torch.irfft
101 |
102 | if signal_ndim > 4:
103 | raise RuntimeError("signal_ndim is expected to be 1, 2, 3.")
104 | if not torch.is_complex(self):
105 | self = torch.view_as_complex(self)
106 |
107 | m = torch.fft.irfftn if onesided else torch.fft.ifftn
108 | dim = [-3, -2, -1][3 - signal_ndim:]
109 | out = m(self, dim=dim, norm="ortho" if normalized else None)
110 | return out.real if torch.is_complex(out) else out
111 |
--------------------------------------------------------------------------------
/tests/test_distance.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch import nn
4 |
5 | from anatome import distance
6 |
7 |
8 | @pytest.fixture
9 | def matrices():
10 | return torch.randn(10, 5), torch.randn(10, 8), torch.randn(10, 20), torch.randn(8, 5)
11 |
12 |
13 | @pytest.mark.parametrize('mat_size', ([10, 5], [100, 50], [500, 200]))
14 | def test_cca_consistency(mat_size):
15 | def gen():
16 | x, y = torch.randn(*mat_size, dtype=torch.float64), torch.randn(*mat_size, dtype=torch.float64)
17 | x = distance._zero_mean(x, 0)
18 | y = distance._zero_mean(y, 0)
19 | return x, y
20 |
21 | x, y = gen()
22 | cca_svd = distance.cca_by_svd(x, y)
23 | cca_qr = distance.cca_by_qr(x, y)
24 | torch.testing.assert_close(cca_svd[0].abs(), cca_qr[0].abs())
25 | torch.testing.assert_close(cca_svd[1].abs(), cca_qr[1].abs())
26 | torch.testing.assert_close(cca_svd[2], cca_qr[2])
27 |
28 |
29 | @pytest.mark.parametrize('method', ['svd', 'qr'])
30 | def test_cca_shape(matrices, method):
31 | i1, i2, i3, i4 = matrices
32 | distance.cca(i1, i1, method)
33 | a, b, diag = distance.cca(i1, i2, method)
34 | assert list(a.size()) == [i1.size(1), diag.size(0)]
35 | assert list(b.size()) == [i2.size(1), diag.size(0)]
36 | with pytest.raises(ValueError):
37 | # needs more batch size
38 | distance.cca(i1, i3, method)
39 | with pytest.raises(ValueError):
40 | distance.cca(i1, i4, method)
41 | with pytest.raises(ValueError):
42 | distance.cca(i1, i2, 'wrong')
43 |
44 |
45 | def test_cka_shape(matrices):
46 | i1, i2, i3, i4 = matrices
47 | distance.linear_cka_distance(i1, i2, True)
48 | distance.linear_cka_distance(i1, i3, True)
49 | distance.linear_cka_distance(i1, i2, False)
50 | with pytest.raises(ValueError):
51 | distance.linear_cka_distance(i1, i4, True)
52 |
53 |
54 | def test_opd(matrices):
55 | i1, i2, i3, i4 = matrices
56 | distance.orthogonal_procrustes_distance(i1, i1)
57 | distance.orthogonal_procrustes_distance(i1, i2)
58 | distance.orthogonal_procrustes_distance(i1, i3)
59 | with pytest.raises(ValueError):
60 | distance.orthogonal_procrustes_distance(i1, i4)
61 |
62 |
63 | @pytest.mark.parametrize('method', ['pwcca', 'svcca', 'lincka', 'opd'])
64 | def test_similarity_hook_linear(method):
65 | model1 = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 4))
66 | model2 = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 4))
67 | with pytest.raises(RuntimeError):
68 | distance.Distance(model1, model2, method=method, model1_names=['3'])
69 |
70 | dist = distance.Distance(model1, model2, method=method)
71 |
72 | assert dist.convert_names(model1, None, None, False) == ['0', '1']
73 | with torch.no_grad():
74 | dist.forward(torch.randn(13, 3))
75 |
76 | dist.between("1", "1")
77 |
78 |
79 | @pytest.mark.parametrize('resize_by', ['avg_pool', 'dft'])
80 | def test_similarity_hook_conv2d(resize_by):
81 | model1 = nn.Sequential(nn.Conv2d(3, 3, kernel_size=3), nn.Conv2d(3, 4, kernel_size=3))
82 | model2 = nn.Sequential(nn.Conv2d(3, 3, kernel_size=3), nn.Conv2d(3, 4, kernel_size=3))
83 |
84 | dist = distance.Distance(model1, model2, model1_names=['0', '1'], model2_names=['0', '1'], method='lincka')
85 |
86 | with torch.no_grad():
87 | dist.forward(torch.randn(13, 3, 11, 11))
88 |
89 | dist.between('1', '1', size=5)
90 | dist.between('1', '1', size=7)
91 | with pytest.raises(RuntimeError):
92 | dist.between('1', '1', size=8)
93 |
94 | dist.between('0', '1')
95 |
--------------------------------------------------------------------------------
/anatome/landscape.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from copy import deepcopy
4 | from typing import Callable
5 |
6 | import torch
7 | from torch import nn, Tensor
8 |
9 | try:
10 | from tqdm import tqdm
11 | except ImportError:
12 | tqdm = lambda x, ncols=None: x
13 |
14 | from .utils import _evaluate
15 |
16 | EPS = 1e-8
17 |
18 |
19 | def _filter_normed_random_direction(model: nn.Module
20 | ) -> list[Tensor]:
21 | # applies filter normalization proposed in Li+2018
22 | def _filter_norm(dirs: Tensor,
23 | params: Tensor
24 | ) -> Tensor:
25 | d_norm = dirs.view(dirs.size(0), -1).norm(dim=-1)
26 | p_norm = params.view(params.size(0), -1).norm(dim=-1)
27 | ones = [1 for _ in range(dirs.dim() - 1)]
28 | return dirs.mul_(p_norm.view(-1, *ones) / (d_norm.view(-1, *ones) + EPS))
29 |
30 | directions = [(params, torch.randn_like(params)) for params in model.parameters()]
31 | directions = [_filter_norm(dirs, params) for params, dirs in directions]
32 | return directions
33 |
34 |
35 | def _get_perturbed_model(model: nn.Module,
36 | direction: list[Tensor] or tuple[list[Tensor], list[Tensor]],
37 | step_size: float or tuple[float, float]
38 | ) -> nn.Module:
39 | # perturb the weight of model along direction with step size
40 | new_model = deepcopy(model)
41 | if len(direction) == 2:
42 | # 2d
43 | perturbation = [d_0 * step_size[0] + d_1 * step_size[1] for d_0, d_1 in zip(*direction)]
44 | else:
45 | # 1d
46 | perturbation = [d_0 * step_size for d_0 in direction]
47 |
48 | for param, pert in zip(new_model.parameters(), perturbation):
49 | if param.data.dim() <= 1:
50 | # ignore biasbn in the original code
51 | continue
52 | param.data.add_(pert)
53 |
54 | return new_model
55 |
56 |
57 | @torch.no_grad()
58 | def landscape1d(model: nn.Module,
59 | data: tuple[Tensor, Tensor],
60 | criterion: Callable[[Tensor, Tensor], Tensor],
61 | x_range: tuple[float, float],
62 | step_size: float,
63 | auto_cast: bool = False
64 | ) -> tuple[Tensor, Tensor]:
65 | """ Compute loss landscape along a random direction X. The landscape is
66 |
67 | [{criterion(input, target) at Θ+iX} for i in range(x_min, x_max, α)]
68 |
69 | Args:
70 | model: Trained model, parameterized by Θ
71 | data: Pairs of [input, target] to compute criterion
72 | criterion: Criterion of (input, target) -> scalar value
73 | x_range: (x_min, x_max)
74 | step_size: α
75 |
76 | Returns: x-coordinates, landscape
77 |
78 | """
79 |
80 | x_coord = torch.arange(x_range[0], x_range[1] + step_size, step_size, dtype=torch.float)
81 | x_direction = _filter_normed_random_direction(model)
82 | loss_values = torch.zeros_like(x_coord, device=torch.device('cpu'))
83 | for i, x in enumerate(tqdm(x_coord.tolist(), ncols=80)):
84 | new_model = _get_perturbed_model(model, x_direction, x)
85 | loss_values[i] = _evaluate(new_model, data, criterion, auto_cast)
86 | return x_coord, loss_values
87 |
88 |
89 | @torch.no_grad()
90 | def landscape2d(model: nn.Module,
91 | data: tuple[Tensor, Tensor],
92 | criterion: Callable[[Tensor, Tensor], Tensor],
93 | x_range: tuple[float, float],
94 | y_range: tuple[float, float],
95 | step_size: float or tuple[float, float],
96 | auto_cast: bool = False
97 | ) -> tuple[Tensor, Tensor, Tensor]:
98 | """ Compute loss landscape along two random directions X and Y. The landscape is
99 |
100 | [
101 | [{criterion(input, target) at Θ+iX+jY}
102 | for i in range(x_min, x_max, α)]
103 | for j in range(y_min, y_max, β)]
104 | ]
105 |
106 | Args:
107 | model: Trained model, parameterized by Θ
108 | data: Pairs of [input, target] to compute criterion
109 | criterion: Criterion of (input, target) -> scalar value
110 | x_range: (x_min, x_max)
111 | y_range: (y_min, y_max)
112 | step_size: α, β
113 |
114 | Returns: x-coordinates, y-coordinates, landscape
115 |
116 | """
117 | if isinstance(step_size, float):
118 | step_size = (step_size, step_size)
119 | x_coord = torch.arange(x_range[0], x_range[1] + step_size[0], step_size[0], dtype=torch.float)
120 | y_coord = torch.arange(y_range[0], y_range[1] + step_size[1], step_size[1], dtype=torch.float)
121 | x_coord, y_coord = torch.meshgrid(x_coord, y_coord, indexing='ij')
122 | shape = x_coord.shape
123 | x_coord, y_coord = x_coord.flatten(), y_coord.flatten()
124 | x_direction = _filter_normed_random_direction(model)
125 | y_direction = _filter_normed_random_direction(model)
126 | loss_values = torch.zeros_like(x_coord, device=torch.device('cpu'))
127 | # To enable tqdm
128 | for i, (x, y) in enumerate(zip(
129 | tqdm(x_coord.tolist(), ncols=80),
130 | y_coord.tolist())
131 | ):
132 | new_model = _get_perturbed_model(model, (x_direction, y_direction), (x, y))
133 | loss_values[i] = _evaluate(new_model, data, criterion, auto_cast)
134 | return x_coord.view(shape), y_coord.view(shape), loss_values.view(shape)
135 |
--------------------------------------------------------------------------------
/assets/landscape2d.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
63 |
--------------------------------------------------------------------------------
/anatome/distance.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from functools import partial
4 | from typing import Callable, Literal
5 |
6 | import torch
7 | from torch import Tensor, nn
8 | from torch.nn import functional as F
9 | from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor
10 |
11 | from .utils import _irfft, _rfft, _svd
12 |
13 |
14 | def _zero_mean(input: Tensor,
15 | dim: int
16 | ) -> Tensor:
17 | return input - input.mean(dim=dim, keepdim=True)
18 |
19 |
20 | def _check_shape_equal(x: Tensor,
21 | y: Tensor,
22 | dim: int
23 | ):
24 | if x.size(dim) != y.size(dim):
25 | raise ValueError(f'x.size({dim}) == y.size({dim}) is expected, but got {x.size(dim)=}, {y.size(dim)=} instead.')
26 |
27 |
28 | def cca_by_svd(x: Tensor,
29 | y: Tensor
30 | ) -> tuple[Tensor, Tensor, Tensor]:
31 | """ CCA using only SVD.
32 | For more details, check Press 2011 "Canonical Correlation Clarified by Singular Value Decomposition"
33 |
34 | Args:
35 | x: input tensor of Shape DxH
36 | y: input tensor of shape DxW
37 |
38 | Returns: x-side coefficients, y-side coefficients, diagonal
39 |
40 | """
41 |
42 | # torch.svd(x)[1] is vector
43 | u_1, s_1, v_1 = _svd(x)
44 | u_2, s_2, v_2 = _svd(y)
45 | uu = u_1.t() @ u_2
46 | u, diag, v = _svd(uu)
47 | # a @ (1 / s_1).diag() @ u, without creating s_1.diag()
48 | a = v_1 @ (1 / s_1[:, None] * u)
49 | b = v_2 @ (1 / s_2[:, None] * v)
50 | return a, b, diag
51 |
52 |
53 | def cca_by_qr(x: Tensor,
54 | y: Tensor
55 | ) -> tuple[Tensor, Tensor, Tensor]:
56 | """ CCA using QR and SVD.
57 | For more details, check Press 2011 "Canonical Correlation Clarified by Singular Value Decomposition"
58 |
59 | Args:
60 | x: input tensor of Shape DxH
61 | y: input tensor of shape DxW
62 |
63 | Returns: x-side coefficients, y-side coefficients, diagonal
64 |
65 | """
66 |
67 | q_1, r_1 = torch.linalg.qr(x)
68 | q_2, r_2 = torch.linalg.qr(y)
69 | qq = q_1.t() @ q_2
70 | u, diag, v = _svd(qq)
71 | # a = r_1.inverse() @ u, but it is faster and more numerically stable
72 | a = torch.linalg.solve(r_1, u)
73 | b = torch.linalg.solve(r_2, v)
74 | return a, b, diag
75 |
76 |
77 | def cca(x: Tensor,
78 | y: Tensor,
79 | backend: str
80 | ) -> tuple[Tensor, Tensor, Tensor]:
81 | """ Compute CCA, Canonical Correlation Analysis
82 |
83 | Args:
84 | x: input tensor of Shape DxH
85 | y: input tensor of Shape DxW
86 | backend: svd or qr
87 |
88 | Returns: x-side coefficients, y-side coefficients, diagonal
89 |
90 | """
91 |
92 | _check_shape_equal(x, y, 0)
93 |
94 | if x.size(0) < x.size(1):
95 | raise ValueError(f'x.size(0) >= x.size(1) is expected, but got {x.size()=}.')
96 |
97 | if y.size(0) < y.size(1):
98 | raise ValueError(f'y.size(0) >= y.size(1) is expected, but got {y.size()=}.')
99 |
100 | if backend not in ('svd', 'qr'):
101 | raise ValueError(f'backend is svd or qr, but got {backend}')
102 |
103 | x = _zero_mean(x, dim=0)
104 | y = _zero_mean(y, dim=0)
105 | return cca_by_svd(x, y) if backend == 'svd' else cca_by_qr(x, y)
106 |
107 |
108 | def _svd_reduction(input: Tensor,
109 | accept_rate: float
110 | ) -> Tensor:
111 | left, diag, right = _svd(input)
112 | full = diag.abs().sum()
113 | ratio = diag.abs().cumsum(dim=0) / full
114 | num = torch.where(ratio < accept_rate,
115 | input.new_ones(1, dtype=torch.long),
116 | input.new_zeros(1, dtype=torch.long)
117 | ).sum()
118 | return input @ right[:, : num]
119 |
120 |
121 | def svcca_distance(x: Tensor,
122 | y: Tensor,
123 | accept_rate: float,
124 | backend: str
125 | ) -> Tensor:
126 | """ Singular Vector CCA proposed in Raghu et al. 2017.
127 |
128 | Args:
129 | x: input tensor of Shape DxH, where D>H
130 | y: input tensor of Shape DxW, where D>H
131 | accept_rate: 0.99
132 | backend: svd or qr
133 |
134 | Returns:
135 |
136 | """
137 |
138 | x = _svd_reduction(x, accept_rate)
139 | y = _svd_reduction(y, accept_rate)
140 | div = min(x.size(1), y.size(1))
141 | a, b, diag = cca(x, y, backend)
142 | return 1 - diag.sum() / div
143 |
144 |
145 | def pwcca_distance(x: Tensor,
146 | y: Tensor,
147 | backend: str
148 | ) -> Tensor:
149 | """ Projection Weighted CCA proposed in Marcos et al. 2018.
150 |
151 | Args:
152 | x: input tensor of Shape DxH, where D>H
153 | y: input tensor of Shape DxW, where D>H
154 | backend: svd or qr
155 |
156 | Returns:
157 |
158 | """
159 |
160 | a, b, diag = cca(x, y, backend)
161 | a, _ = torch.linalg.qr(a) # reorthonormalize
162 | alpha = (x @ a).abs_().sum(dim=0)
163 | alpha /= alpha.sum()
164 | return 1 - alpha @ diag
165 |
166 |
167 | def _debiased_dot_product_similarity(z: Tensor,
168 | sum_row_x: Tensor,
169 | sum_row_y: Tensor,
170 | sq_norm_x: Tensor,
171 | sq_norm_y: Tensor,
172 | size: int
173 | ) -> Tensor:
174 | return (z
175 | - size / (size - 2) * (sum_row_x @ sum_row_y)
176 | + sq_norm_x * sq_norm_y / ((size - 1) * (size - 2)))
177 |
178 |
179 | def linear_cka_distance(x: Tensor,
180 | y: Tensor,
181 | reduce_bias: bool
182 | ) -> Tensor:
183 | """ Linear CKA used in Kornblith et al. 19
184 |
185 | Args:
186 | x: input tensor of Shape DxH
187 | y: input tensor of Shape DxW
188 | reduce_bias: debias CKA estimator, which might be helpful when D is limited
189 |
190 | Returns:
191 |
192 | """
193 |
194 | _check_shape_equal(x, y, 0)
195 |
196 | x = _zero_mean(x, dim=0)
197 | y = _zero_mean(y, dim=0)
198 | dot_prod = (y.t() @ x).norm('fro').pow(2)
199 | norm_x = (x.t() @ x).norm('fro')
200 | norm_y = (y.t() @ y).norm('fro')
201 |
202 | if reduce_bias:
203 | size = x.size(0)
204 | # (x @ x.t()).diag()
205 | sum_row_x = torch.einsum('ij,ij->i', x, x)
206 | sum_row_y = torch.einsum('ij,ij->i', y, y)
207 | sq_norm_x = sum_row_x.sum()
208 | sq_norm_y = sum_row_y.sum()
209 | dot_prod = _debiased_dot_product_similarity(dot_prod, sum_row_x, sum_row_y, sq_norm_x, sq_norm_y, size)
210 | norm_x = _debiased_dot_product_similarity(norm_x.pow(2), sum_row_x, sum_row_x, sq_norm_x, sq_norm_x, size
211 | ).sqrt()
212 | norm_y = _debiased_dot_product_similarity(norm_y.pow(2), sum_row_y, sum_row_y, sq_norm_y, sq_norm_y, size
213 | ).sqrt()
214 | return 1 - dot_prod / (norm_x * norm_y)
215 |
216 |
217 | def orthogonal_procrustes_distance(x: Tensor,
218 | y: Tensor,
219 | ) -> Tensor:
220 | """ Orthogonal Procrustes distance used in Ding+21
221 |
222 | Args:
223 | x: input tensor of Shape DxH
224 | y: input tensor of Shape DxW
225 |
226 | Returns:
227 |
228 | """
229 | _check_shape_equal(x, y, 0)
230 |
231 | frobenius_norm = partial(torch.linalg.norm, ord="fro")
232 | nuclear_norm = partial(torch.linalg.norm, ord="nuc")
233 |
234 | x = _zero_mean(x, dim=0)
235 | x /= frobenius_norm(x)
236 | y = _zero_mean(y, dim=0)
237 | y /= frobenius_norm(y)
238 | # frobenius_norm(x) = 1, frobenius_norm(y) = 1
239 | # 0.5*d_proc(x, y)
240 | return 1 - nuclear_norm(x.t() @ y)
241 |
242 |
243 | class Distance(object):
244 | """ Module to measure distance between `model1` and `model2`
245 |
246 | Args:
247 | method: Method to compute distance. 'pwcca' by default.
248 | model1_names: Names of modules of `model1` to be used. If None (default), all names are used.
249 | model2_names: Names of modules of `model2` to be used. If None (default), all names are used.
250 | model1_leaf_modules: Modules of model1 to be considered as single nodes (see https://pytorch.org/blog/FX-feature-extraction-torchvision/).
251 | model2_leaf_modules: Modules of model2 to be considered as single nodes (see https://pytorch.org/blog/FX-feature-extraction-torchvision/).
252 | train_mode: If True, models' `train_model` is used, otherwise `eval_mode`. False by default.
253 | """
254 |
255 | _supported_dims = (2, 4)
256 | _default_backends = {'pwcca': partial(pwcca_distance, backend='svd'),
257 | 'svcca': partial(svcca_distance, accept_rate=0.99, backend='svd'),
258 | 'lincka': partial(linear_cka_distance, reduce_bias=False),
259 | 'opd': orthogonal_procrustes_distance}
260 |
261 | def __init__(self,
262 | model1: nn.Module,
263 | model2: nn.Module,
264 | method: str | Callable = 'pwcca',
265 | model1_names: str | list[str] = None,
266 | model2_names: str | list[str] = None,
267 | model1_leaf_modules: list[nn.Module] = None,
268 | model2_leaf_modules: list[nn.Module] = None,
269 | train_mode: bool = False
270 | ):
271 |
272 | dp_ddp = (nn.DataParallel, nn.parallel.DistributedDataParallel)
273 | if isinstance(model1, dp_ddp) or isinstance(model2, dp_ddp):
274 | raise RuntimeWarning('model is nn.DataParallel or nn.DistributedDataParallel. '
275 | 'SimilarityHook may causes unexpected behavior.')
276 | if isinstance(method, str):
277 | method = self._default_backends[method]
278 | self.distance_func = method
279 | self.model1 = model1
280 | self.model2 = model2
281 | self.extractor1 = create_feature_extractor(model1, self.convert_names(model1, model1_names,
282 | model1_leaf_modules, train_mode))
283 | self.extractor2 = create_feature_extractor(model2, self.convert_names(model2, model2_names,
284 | model2_leaf_modules, train_mode))
285 | self._model1_tensors: dict[str, torch.Tensor] = None
286 | self._model2_tensors: dict[str, torch.Tensor] = None
287 |
288 | def available_names(self,
289 | model1_leaf_modules: list[nn.Module] = None,
290 | model2_leaf_modules: list[nn.Module] = None,
291 | train_mode: bool = False
292 | ):
293 | return {'model1': self.convert_names(self.model1, None, model1_leaf_modules, train_mode),
294 | 'model2': self.convert_names(self.model2, None, model2_leaf_modules, train_mode)}
295 |
296 | @staticmethod
297 | def convert_names(model: nn.Module,
298 | names: str | list[str],
299 | leaf_modules: list[nn.Module],
300 | train_mode: bool
301 | ) -> list[str]:
302 | # a helper function
303 | if isinstance(names, str):
304 | names = [names]
305 | tracer_kwargs = {}
306 | if leaf_modules is not None:
307 | tracer_kwargs['leaf_modules'] = leaf_modules
308 |
309 | _names = get_graph_node_names(model, tracer_kwargs=tracer_kwargs)
310 | _names = _names[0] if train_mode else _names[1]
311 | _names = _names[1:] # because the first element is input
312 |
313 | if names is None:
314 | names = _names
315 | else:
316 | if not (set(names) <= set(_names)):
317 | diff = set(names) - set(_names)
318 | raise RuntimeError(f'Unknown names: {list(diff)}')
319 |
320 | return names
321 |
322 | def forward(self,
323 | data
324 | ) -> None:
325 | """ Forward pass of models. Used to store intermediate features.
326 |
327 | Args:
328 | data: input data to models
329 |
330 | """
331 | self._model1_tensors = self.extractor1(data)
332 | self._model2_tensors = self.extractor2(data)
333 |
334 | def between(self,
335 | name1: str,
336 | name2: str,
337 | size: int | tuple[int, int] = None,
338 | downsample_method: Literal['avg_pool', 'dft'] = 'avg_pool'
339 | ) -> torch.Tensor:
340 | """ Compute distance between modules corresponding to name1 and name2.
341 |
342 | Args:
343 | name1: Name of a module of `model1`
344 | name2: Name of a module of `model2`
345 | size: Size for downsampling if necessary. If size's type is int, both features of name1 and name2 are
346 | reshaped to (size, size). If size's type is tuple[int, int], features are reshaped to (size[0], size[0]) and
347 | (size[1], size[1]). If size is None (default), no downsampling is applied.
348 | downsample_method: Downsampling method: 'avg_pool' for average pooling and 'dft' for discrete
349 | Fourier transform
350 |
351 | Returns: Distance in tensor.
352 |
353 | """
354 | tensor1 = self._model1_tensors[name1]
355 | tensor2 = self._model2_tensors[name2]
356 | if tensor1.dim() not in self._supported_dims:
357 | raise RuntimeError(f'Supported dimensions are ={self._supported_dims}, but got {tensor1.dim()}')
358 | if tensor2.dim() not in self._supported_dims:
359 | raise RuntimeError(f'Supported dimensions are ={self._supported_dims}, but got {tensor2.dim()}')
360 |
361 | if size is not None:
362 | if isinstance(size, int):
363 | size = (size, size)
364 |
365 | def downsample_if_necessary(input, s):
366 | if input.dim() == 4:
367 | input = self._downsample_4d(input, s, downsample_method)
368 | return input
369 |
370 | tensor1 = downsample_if_necessary(tensor1, size[0])
371 | tensor2 = downsample_if_necessary(tensor2, size[1])
372 |
373 | def reshape_if_4d(input):
374 | if input.dim() == 4:
375 | # see https://arxiv.org/abs/1706.05806's P5.
376 | if name1 == name2: # same layer comparisons -> Cx(BHW)
377 | input = input.permute(1, 0, 2, 3).flatten(1)
378 | else: # different layer comparisons -> Bx(CHW)
379 | input = input.flatten(1)
380 | return input
381 |
382 | tensor1 = reshape_if_4d(tensor1)
383 | tensor2 = reshape_if_4d(tensor2)
384 |
385 | return self.distance_func(tensor1, tensor2)
386 |
387 | @staticmethod
388 | def _downsample_4d(input: Tensor,
389 | size: int,
390 | backend: Literal['avg_pool', 'dft']
391 | ) -> Tensor:
392 | if input.dim() != 4:
393 | raise RuntimeError(f'input is expected to be 4D tensor, but got {input.dim()=}.')
394 |
395 | # todo: what if channel-last?
396 | b, c, h, w = input.size()
397 |
398 | if (size, size) == (h, w):
399 | return input
400 |
401 | if (size, size) > (h, w):
402 | raise RuntimeError(f'size ({size}) is expected to be smaller than h or w, but got {h=}, {w=}.')
403 |
404 | if backend not in ('avg_pool', 'dft'):
405 | raise RuntimeError(f'backend is expected to be avg_pool or dft, but got {backend=}.')
406 |
407 | if backend == 'avg_pool':
408 | return F.adaptive_avg_pool2d(input, (size, size))
409 |
410 | # almost PyTorch implant of
411 | # https://github.com/google/svcca/blob/master/dft_ccas.py
412 | if input.size(2) != input.size(3):
413 | raise RuntimeError('width and height of input needs to be equal')
414 | h = input.size(2)
415 | input_fft = _rfft(input, 2, normalized=True, onesided=False)
416 | freqs = torch.fft.fftfreq(h, 1 / h, device=input.device)
417 | idx = (freqs >= -size / 2) & (freqs < size / 2)
418 | # BxCxHxWx2 -> BxCxhxwx2
419 | input_fft = input_fft[..., idx, :][..., idx, :, :]
420 | input = _irfft(input_fft, 2, normalized=True, onesided=False)
421 | return input
422 |
--------------------------------------------------------------------------------
/assets/fourier.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
63 |
--------------------------------------------------------------------------------
/examples.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "tags": []
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import torch\n",
12 | "from torch.nn import functional as F\n",
13 | "from torchvision.models import resnet18\n",
14 | "from torchvision import transforms\n",
15 | "from torchvision.datasets import ImageFolder\n",
16 | "from torch.utils.data import DataLoader\n",
17 | "\n",
18 | "import matplotlib.pyplot as plt\n",
19 | "\n",
20 | "batch_size = 128\n",
21 | "\n",
22 | "model = resnet18(pretrained=True)\n",
23 | "imagenet = ImageFolder('~/.torch/data/imagenet/val', \n",
24 | " transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),\n",
25 | " transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]))\n",
26 | "data = next(iter(DataLoader(imagenet, batch_size=batch_size, num_workers=8)))"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 2,
32 | "metadata": {},
33 | "outputs": [],
34 | "source": [
35 | "model.eval()\n",
36 | "model.cuda()\n",
37 | "data = data[0].cuda(), data[1].cuda()"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "## CCA\n",
45 | "\n",
46 | "Measure CCA distance between trained and not trained models."
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 3,
52 | "metadata": {},
53 | "outputs": [
54 | {
55 | "output_type": "execute_result",
56 | "data": {
57 | "text/plain": "tensor(0.3290, device='cuda:0')"
58 | },
59 | "metadata": {},
60 | "execution_count": 3
61 | }
62 | ],
63 | "source": [
64 | "from anatome import CCAHook\n",
65 | "random_model = resnet18().cuda()\n",
66 | "\n",
67 | "hook1 = CCAHook(model, \"layer1.0.conv1\")\n",
68 | "hook2 = CCAHook(random_model, \"layer1.0.conv1\")\n",
69 | "\n",
70 | "with torch.no_grad():\n",
71 | " model(data[0])\n",
72 | " random_model(data[0])\n",
73 | "hook1.distance(hook2, size=8)"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 4,
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "# clear hooked tensors and caches\n",
83 | "hook1.clear()\n",
84 | "hook2.clear()\n",
85 | "torch.cuda.empty_cache()"
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {},
91 | "source": [
92 | "## Loss Landscape\n",
93 | "\n",
94 | "Show loss landscape in 1D space."
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 5,
100 | "metadata": {
101 | "tags": []
102 | },
103 | "outputs": [
104 | {
105 | "output_type": "stream",
106 | "name": "stderr",
107 | "text": "100%|███████████████████████████████████████████| 21/21 [00:01<00:00, 13.02it/s]\n"
108 | }
109 | ],
110 | "source": [
111 | "from anatome import landscape1d\n",
112 | "x, y = landscape1d(model, data, F.cross_entropy, x_range=(-1, 1), step_size=0.1)"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": 6,
118 | "metadata": {},
119 | "outputs": [
120 | {
121 | "output_type": "execute_result",
122 | "data": {
123 | "text/plain": "[]"
124 | },
125 | "metadata": {},
126 | "execution_count": 6
127 | },
128 | {
129 | "output_type": "display_data",
130 | "data": {
131 | "text/plain": "",
132 | "image/svg+xml": "\n\n\n\n",
133 | "image/png": "\n"
134 | },
135 | "metadata": {
136 | "needs_background": "light"
137 | }
138 | }
139 | ],
140 | "source": [
141 | "plt.plot(x, y)"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": 7,
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "torch.cuda.empty_cache()"
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "source": [
156 | "## Fourier Analysis\n",
157 | "\n",
158 | "Show a model's robustness."
159 | ],
160 | "metadata": {
161 | "collapsed": false
162 | }
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": null,
167 | "outputs": [],
168 | "source": [
169 | "from anatome import fourier_map\n",
170 | "map = fourier_map(model, data, F.cross_entropy, 20, (6, 6), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))"
171 | ],
172 | "metadata": {
173 | "collapsed": false,
174 | "pycharm": {
175 | "name": "#%%\n"
176 | }
177 | }
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": null,
182 | "outputs": [],
183 | "source": [
184 | "plt.imshow(map, interpolation='nearest')"
185 | ],
186 | "metadata": {
187 | "collapsed": false,
188 | "pycharm": {
189 | "name": "#%%\n"
190 | }
191 | }
192 | },
193 | {
194 | "cell_type": "code",
195 | "execution_count": null,
196 | "outputs": [],
197 | "source": [
198 | "torch.cuda.empty_cache()"
199 | ],
200 | "metadata": {
201 | "collapsed": false,
202 | "pycharm": {
203 | "name": "#%%\n"
204 | }
205 | }
206 | },
207 | {
208 | "cell_type": "markdown",
209 | "metadata": {},
210 | "source": []
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 8,
215 | "metadata": {
216 | "tags": []
217 | },
218 | "outputs": [
219 | {
220 | "output_type": "stream",
221 | "name": "stderr",
222 | "text": "100%|███████████████████████████████████████████| 21/21 [00:03<00:00, 6.08it/s]\n"
223 | }
224 | ],
225 | "source": [
226 | "from anatome import fourier_map\n",
227 | "map = fourier_map(model, data, F.cross_entropy, 20, (6, 6), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 9,
233 | "metadata": {},
234 | "outputs": [
235 | {
236 | "output_type": "execute_result",
237 | "data": {
238 | "text/plain": ""
239 | },
240 | "metadata": {},
241 | "execution_count": 9
242 | },
243 | {
244 | "output_type": "display_data",
245 | "data": {
246 | "text/plain": "",
247 | "image/svg+xml": "\n\n\n\n",
248 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAKsUlEQVR4nO3dbWhdhR3H8d/PpNquPlRdV2tTlvqAILLpCH1TmKyzoz5Mh2+moC+G0iETqhtI3Ttf+sb5ZmCDyja0FkEFcW6uoMUVfErb6KzREZzT0I50dlVrpV3T/17ktiQmac69OSfn8N/3A8Gk93L9Kf325N70nuOIEIA8Tqt7AIByETWQDFEDyRA1kAxRA8l0V/Ggi849I865cHEVD92Rnu7DdU9ovOEjZ9c9YZKvDp9R94Qpepfsr3vCSf8a+a8+OzDm6W6rJOpzLlys27f8sIqH7siDywbrntB4Nw+vq3vCJIO7L657whSbf/xI3RNO+vmNn8x4G99+A8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kUitr2etsf2B62vanqUQA6N2vUtrsk/VbStZIul3Sr7curHgagM0WO1KslDUfEhxFxVNJWSTdVOwtAp4pEvULSxHdkj7R+bRLbG2wP2B44/J8jZe0D0KYiUU93ypQpVwCIiP6I6IuIvm+c27xT0QD/L4pEPSJp5YSveyTtrWYOgLkqEvVbki61vcr26ZJukfR8tbMAdGrWEw9GxDHbd0t6SVKXpMcjYk/lywB0pNDZRCPiRUkvVrwFQAn4G2VAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kU+gNHe3q6T6sB5cNVvHQHVn305/VPaHxDl6ysO4Jk5xXye/Mubnzy7vqnnDSyIHfzHgbR2ogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIJlZo7b9uO1R2+/OxyAAc1PkSP07Sesr3gGgJLNGHRGvSjowD1sAlKC059S2N9gesD2w/9Oxsh4WQJtKizoi+iOiLyL6lp7fVdbDAmgTr34DyRA1kEyRH2k9Jek1SZfZHrF9R/WzAHRq1rMrR8St8zEEQDn49htIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkZn1DB6px2l931z1hkqX7VtU9ofGOd19Q94STTjt6itvmbwaA+UDUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJFLpC30vYrtods77G9cT6GAehMkfdTH5P0q4jYZfssSTttb4uI9yreBqADsx6pI2JfROxqff6FpCFJK6oeBqAzbT2ntt0r6SpJb0xz2wbbA7YH9n86Vs46AG0rHLXtMyU9I+meiPj867dHRH9E9EVE39Lzu8rcCKANhaK2vUDjQT8ZEc9WOwnAXBR59duSHpM0FBEPVT8JwFwUOVKvkXS7pLW2B1sf11W8C0CHZv2RVkTskOR52AKgBPyNMiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIpco6ytg0fOVs3D6+r4qE7cvCShXVPmGLpvlV1T5hkbPgfdU+YpPui3ronTLHp3i11Tzjp168emPE2jtRAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJFPkqpcLbb9p+23be2w/MB/DAHSmyPupj0haGxGHWtep3mH7TxHxesXbAHSgyFUvQ9Kh1pcLWh9R5SgAnSv0nNp2l+1BSaOStkXEG9PcZ4PtAdsDRw9+VfZOAAUVijoixiLiSkk9klbbvmKa+/RHRF9E9J2+ZFHZOwEU1Nar3xFxUNJ2SesrWQNgzoq8+r3U9pLW54skXSPp/aqHAehMkVe/l0v6ve0ujf8h8HREvFDtLACdKvLq9zuSrpqHLQBKwN8oA5IhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIJki79Jq21eHz9Dg7oureOiOnFfJf2Uu3Rf11j1hkqMrzq17whSbP7667gkn7T86MuNtHKmBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSKZw1K0Lz++2zcXxgAZr50i9UdJQVUMAlKNQ1LZ7JF0v6dFq5wCYq6JH6ocl3Sfp+Ex3sL3B9oDtgbFDX5YyDkD7Zo3a9g2SRiNi56nuFxH9EdEXEX1dZy4ubSCA9hQ5Uq+RdKPtjyRtlbTW9hOVrgLQsVmjjoj7I6InInol3SLp5Yi4rfJlADrCz6mBZNo6eW5EbJe0vZIlAErBkRpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIpq13aRXVu2S/Nv/4kSoeuiN3fnlX3ROmON59Qd0TJtl075a6J0yy+eOr654wRd/5H9c94aSh7qMz3saRGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkCr31snVt6i8kjUk6FhF9VY4C0Ll23k/9g4j4d2VLAJSCb7+BZIpGHZL+Ynun7Q3T3cH2BtsDtgc++3SsvIUA2lL02+81EbHX9rckbbP9fkS8OvEOEdEvqV+SLvvOwih5J4CCCh2pI2Jv65+jkp6TtLrKUQA6N2vUthfbPuvE55J+JOndqocB6EyRb7+XSXrO9on7b4mIP1e6CkDHZo06Ij6U9N152AKgBPxIC0iGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWQcUf75DGzvl/TPEh7qm5KadF409pxa0/ZIzdtU1p5vR8TS6W6oJOqy2B5o0plL2XNqTdsjNW/TfOzh228gGaIGkml61P11D/ga9pxa0/ZIzdtU+Z5GP6cG0L6mH6kBtImogWQaGbXt9bY/sD1se1MD9jxue9R2I06NbHul7VdsD9neY3tjzXsW2n7T9tutPQ/UuecE2122d9t+oe4t0viFJm3/zfag7YHK/j1Ne05tu0vS3yWtkzQi6S1Jt0bEezVu+r6kQ5L+EBFX1LVjwp7lkpZHxK7WOdl3SvpJXf+PPH7+6MURccj2Akk7JG2MiNfr2DNh1y8l9Uk6OyJuqHNLa89HkvqqvtBkE4/UqyUNR8SHEXFU0lZJN9U5qHWJoQN1bpgoIvZFxK7W519IGpK0osY9ERGHWl8uaH3UerSw3SPpekmP1rmjDk2MeoWkTyZ8PaIaf8M2ne1eSVdJeqPmHV22ByWNStoWEbXukfSwpPskHa95x0SzXmiyDE2M2tP8WrOeIzSE7TMlPSPpnoj4vM4tETEWEVdK6pG02nZtT1Ns3yBpNCJ21rVhBmsi4nuSrpX0i9bTutI1MeoRSSsnfN0jaW9NWxqr9dz1GUlPRsSzde85ISIOStouaX2NM9ZIurH1HHarpLW2n6hxj6T5u9BkE6N+S9KltlfZPl3SLZKer3lTo7RemHpM0lBEPNSAPUttL2l9vkjSNZLer2tPRNwfET0R0avx3z8vR8Rtde2R5vdCk42LOiKOSbpb0ksafwHo6YjYU+cm209Jek3SZbZHbN9R5x6NH4lu1/gRaLD1cV2Ne5ZLesX2Oxr/Q3lbRDTix0gNskzSDttvS3pT0h+rutBk436kBWBuGnekBjA3RA0kQ9RAMkQNJEPUQDJEDSRD1EAy/wO97YMCPlrXMQAAAABJRU5ErkJggg==\n"
249 | },
250 | "metadata": {
251 | "needs_background": "light"
252 | }
253 | }
254 | ],
255 | "source": [
256 | "plt.imshow(map, interpolation='nearest')"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": 10,
262 | "metadata": {},
263 | "outputs": [],
264 | "source": [
265 | "torch.cuda.empty_cache()"
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": null,
271 | "metadata": {},
272 | "outputs": [],
273 | "source": []
274 | }
275 | ],
276 | "metadata": {
277 | "kernelspec": {
278 | "display_name": "Python 3.8.3 64-bit ('jupyter': conda)",
279 | "language": "python",
280 | "name": "python38364bitjupyterconda644e0e65b1f74b2f90dbbacb7386a62c"
281 | },
282 | "language_info": {
283 | "codemirror_mode": {
284 | "name": "ipython",
285 | "version": 2
286 | },
287 | "file_extension": ".py",
288 | "mimetype": "text/x-python",
289 | "name": "python",
290 | "nbconvert_exporter": "python",
291 | "pygments_lexer": "ipython2",
292 | "version": "3.8.3-final"
293 | }
294 | },
295 | "nbformat": 4,
296 | "nbformat_minor": 0
297 | }
--------------------------------------------------------------------------------