├── anatome ├── __init__.py ├── fourier.py ├── utils.py ├── landscape.py └── distance.py ├── setup.py ├── tests ├── test_utils.py ├── test_fourier.py ├── test_landscape.py └── test_distance.py ├── .github └── workflows │ └── action.yml ├── LICENSE ├── README.md ├── assets ├── landscape2d.svg └── fourier.svg └── examples.ipynb /anatome/__init__.py: -------------------------------------------------------------------------------- 1 | from .distance import Distance 2 | from .fourier import fourier_map 3 | from .landscape import landscape1d, landscape2d 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | install_requires = [ 4 | 'torch>=1.10.0', 5 | 'torchvision>=0.11.1', 6 | ] 7 | 8 | setup( 9 | name='anatome', 10 | version='0.0.6', 11 | description='Ἀνατομή is a PyTorch library to analyze representation of neural networks', 12 | author='Ryuichiro Hataya', 13 | install_requires=install_requires, 14 | packages=find_packages() 15 | ) 16 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from anatome import utils 4 | 5 | 6 | def test_normalize_denormalize(): 7 | input = torch.randn(4, 3, 5, 5) 8 | mean = torch.as_tensor([0.5, 0.5, 0.5]) 9 | std = torch.as_tensor([0.5, 0.5, 0.5]) 10 | output = utils._denormalize(utils._normalize(input, mean, std), 11 | mean, std) 12 | assert torch.allclose(input, output, atol=1e-4) 13 | -------------------------------------------------------------------------------- /tests/test_fourier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from anatome import fourier 6 | 7 | 8 | def test_landscape(): 9 | model = nn.Sequential(nn.Conv2d(3, 4, 3), 10 | nn.ReLU(), 11 | nn.AdaptiveAvgPool2d(1), 12 | nn.Flatten(), 13 | nn.Linear(4, 3)) 14 | data = (torch.randn(10, 3, 16, 16), 15 | torch.randint(2, (10,))) 16 | fourier.fourier_map(model, data, F.cross_entropy, 4) 17 | fourier.fourier_map(model, data, F.cross_entropy, 4, (2, 2)) 18 | -------------------------------------------------------------------------------- /tests/test_landscape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from anatome import landscape 6 | 7 | 8 | def test_landscape(): 9 | model = nn.Sequential(nn.Linear(4, 3), 10 | nn.ReLU(), 11 | nn.Linear(3, 2)) 12 | data = (torch.randn(10, 4), 13 | torch.randint(2, (10,))) 14 | x_coord, y = landscape.landscape1d(model, data, F.cross_entropy, (-1, 1), 0.5) 15 | assert x_coord.shape == y.shape 16 | 17 | x_coord, y_coord, z = landscape.landscape2d(model, data, F.cross_entropy, (-1, 1), (-1, 1), (0.5, 0.5)) 18 | assert x_coord.shape == y_coord.shape 19 | assert x_coord.shape == z.shape 20 | -------------------------------------------------------------------------------- /.github/workflows/action.yml: -------------------------------------------------------------------------------- 1 | name: pytest 2 | 3 | on: [ push, pull_request ] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | if: "!contains(github.event.head_commit.message, 'skip test')" 10 | 11 | strategy: 12 | matrix: 13 | python: [ '3.9' ] 14 | torch: [ 'torch==1.10.0+cpu torchvision==0.11.1+cpu -f https://download.pytorch.org/whl/torch_stable.html' ] 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - uses: actions/setup-python@v2 19 | with: 20 | python-version: ${{ matrix.python }} 21 | 22 | - name: install dependencies 23 | run: | 24 | python -m venv venv 25 | . venv/bin/activate 26 | pip install ${{ matrix.torch }} 27 | pip install -U pytest 28 | pip install -U . 29 | 30 | - name: run test 31 | run: | 32 | . venv/bin/activate 33 | pytest -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Ryuichiro Hataya 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # anatome ![](https://github.com/moskomule/anatome/workflows/pytest/badge.svg) 2 | 3 | Ἀνατομή is a PyTorch library to analyze internal representation of neural networks 4 | 5 | This project is under active development and the codebase is subject to change. 6 | 7 | **v0.0.5 introduces significant changes to `distance`.** 8 | 9 | ## Installation 10 | 11 | `anatome` requires 12 | 13 | ``` 14 | Python>=3.9.0 15 | PyTorch>=1.10 16 | torchvision>=0.11 17 | ``` 18 | 19 | After the installation of PyTorch, install `anatome` as follows: 20 | 21 | ``` 22 | pip install -U git+https://github.com/moskomule/anatome 23 | ``` 24 | 25 | ## Available Tools 26 | 27 | ### Representation Similarity 28 | 29 | To measure the similarity of learned representation, `anatome.SimilarityHook` is a useful tool. Currently, the following 30 | methods are implemented. 31 | 32 | - [Raghu et al. NIPS2017 SVCCA](https://papers.nips.cc/paper/7188-svcca-singular-vector-canonical-correlation-analysis-for-deep-learning-dynamics-and-interpretability) 33 | - [Marcos et al. NeurIPS2018 PWCCA](https://papers.nips.cc/paper/7815-insights-on-representational-similarity-in-neural-networks-with-canonical-correlation) 34 | - [Kornblith et al. ICML2019 Linear CKA](http://proceedings.mlr.press/v97/kornblith19a.html) 35 | - [Ding et al. arXiv Orthogonal Procrustes distance](https://arxiv.org/abs/2108.01661) 36 | 37 | ```python 38 | import torch 39 | from torchvision.models import resnet18 40 | from anatome import Distance 41 | 42 | random_model = resnet18() 43 | learned_model = resnet18(pretrained=True) 44 | distance = Distance(random_model, learned_model, method='pwcca') 45 | with torch.no_grad(): 46 | distance.forward(torch.randn(256, 3, 224, 224)) 47 | 48 | # resize if necessary by specifying `size` 49 | distance.between("layer3.0.conv1", "layer3.0.conv1", size=8) 50 | ``` 51 | 52 | ### Loss Landscape Visualization 53 | 54 | - [Li et al. NeurIPS2018](https://papers.nips.cc/paper/7875-visualizing-the-loss-landscape-of-neural-nets) 55 | 56 | ```python 57 | from torch.nn import functional as F 58 | from torchvision.models import resnet18 59 | from anatome import landscape2d 60 | 61 | x, y, z = landscape2d(resnet18(), 62 | data, 63 | F.cross_entropy, 64 | x_range=(-1, 1), 65 | y_range=(-1, 1), 66 | step_size=0.1) 67 | imshow(z) 68 | ``` 69 | 70 | ![](assets/landscape2d.svg) 71 | ![](assets/landscape3d.svg) 72 | 73 | ### Fourier Analysis 74 | 75 | - Yin et al. NeurIPS 2019 etc., 76 | 77 | ```python 78 | from torch.nn import functional as F 79 | from torchvision.models import resnet18 80 | from anatome import fourier_map 81 | 82 | map = fourier_map(resnet18(), 83 | data, 84 | F.cross_entropy, 85 | norm=4) 86 | imshow(map) 87 | ``` 88 | 89 | ![](assets/fourier.svg) 90 | 91 | ## Citation 92 | 93 | If you use this implementation in your research, please cite as: 94 | 95 | ``` 96 | @software{hataya2020anatome, 97 | author={Ryuichiro Hataya}, 98 | title={anatome, a PyTorch library to analyze internal representation of neural networks}, 99 | url={https://github.com/moskomule/anatome}, 100 | year={2020} 101 | } 102 | ``` -------------------------------------------------------------------------------- /anatome/fourier.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Callable 4 | 5 | import torch 6 | from torch import Tensor, nn 7 | from torch.nn import functional as F 8 | 9 | try: 10 | from tqdm import tqdm 11 | except ImportError: 12 | tqdm = lambda x, ncols=None: x 13 | 14 | from .utils import _denormalize, _evaluate, _normalize, ifft_shift, _irfft 15 | 16 | 17 | def add_fourier_noise(idx: tuple[int, int], 18 | images: Tensor, 19 | norm: float, 20 | size: tuple[int, int] = None, 21 | ) -> Tensor: 22 | """ Add Fourier noise 23 | 24 | Args: 25 | idx: index to be used 26 | images: original images 27 | norm: norm of additive noise 28 | size: 29 | 30 | Returns: images with Fourier noise 31 | 32 | """ 33 | 34 | images = images.clone() 35 | 36 | if size is None: 37 | _, _, h, w = images.size() 38 | else: 39 | h, w = size 40 | 41 | noise = images.new_zeros(1, h, w, 2) 42 | noise[:, idx[0], idx[1]] = 1 43 | noise[:, h - 1 - idx[0], w - 1 - idx[1]] = 1 44 | recon = _irfft(ifft_shift(noise), 2, normalized=True, onesided=False).unsqueeze(0) 45 | recon.div_(recon.norm(p=2)).mul_(norm) 46 | if size is not None: 47 | recon = F.interpolate(recon, images.shape[2:]) 48 | images.add_(recon).clamp_(0, 1) 49 | return images 50 | 51 | 52 | @torch.no_grad() 53 | def fourier_map(model: nn.Module, 54 | data: tuple[Tensor, Tensor], 55 | criterion: Callable[[Tensor, Tensor], Tensor], 56 | norm: float, 57 | fourier_map_size: Optional[tuple[int, int]] = None, 58 | mean: Optional[list[float] or Tensor] = None, 59 | std: Optional[list[float] or Tensor] = None, 60 | auto_cast: bool = False 61 | ) -> Tensor: 62 | """ 63 | 64 | Args: 65 | model: Trained model 66 | data: Pairs of [input, target] to compute criterion 67 | criterion: Criterion of (input, target) -> scalar value 68 | norm: Intensity of fourier noise 69 | fourier_map_size: Size of map, (H, W). Note that the computational time is dominated by HW. 70 | mean: If the range of input is [-1, 1], specify mean and std. 71 | std: If the range of input is [-1, 1], specify mean and std. 72 | 73 | Returns: 74 | 75 | """ 76 | input, target = data 77 | if fourier_map_size is None: 78 | _, _, h, w = input.size() 79 | else: 80 | h, w = fourier_map_size 81 | if mean is not None: 82 | _mean = torch.as_tensor(mean, device=input.device, dtype=torch.float) 83 | _std = torch.as_tensor(std, device=input.device, dtype=torch.float) 84 | input = _denormalize(input, _mean, _std) # [0, 1] 85 | map = torch.zeros(h, w) 86 | for u_i in tqdm(torch.triu_indices(h, w).t(), ncols=80): 87 | l_i = h - 1 - u_i[0], w - 1 - u_i[1] 88 | noisy_input = add_fourier_noise(u_i, input, norm, fourier_map_size) 89 | if mean is not None: 90 | noisy_input = _normalize(noisy_input, _mean, _std) # to [-1, 1] 91 | loss = _evaluate(model, (noisy_input, target), criterion, auto_cast) 92 | map[u_i[0], u_i[1]] = loss 93 | map[l_i[0], l_i[1]] = loss 94 | return map 95 | -------------------------------------------------------------------------------- /anatome/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from typing import Callable, Optional 3 | 4 | import torch 5 | from torch import Tensor, nn 6 | 7 | 8 | def _svd(input: torch.Tensor 9 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 10 | # torch.svd style 11 | U, S, Vh = torch.linalg.svd(input, full_matrices=False) 12 | V = Vh.transpose(-2, -1) 13 | return U, S, V 14 | 15 | 16 | @torch.no_grad() 17 | def _evaluate(model: nn.Module, 18 | data: tuple[Tensor, Tensor], 19 | criterion: Callable[[Tensor, Tensor], Tensor], 20 | auto_cast: bool 21 | ) -> float: 22 | # evaluate model with given data points using the criterion 23 | with (torch.cuda.amp.autocast() if auto_cast and torch.cuda.is_available() else contextlib.nullcontext()): 24 | input, target = data 25 | return criterion(model(input), target).item() 26 | 27 | 28 | def _normalize(input: Tensor, 29 | mean: Tensor, 30 | std: Tensor 31 | ) -> Tensor: 32 | # normalize tensor in [0, 1] to [-1, 1] 33 | input = input.clone() 34 | input.add_(-mean[:, None, None]).div_(std[:, None, None]) 35 | return input 36 | 37 | 38 | def _denormalize(input: Tensor, 39 | mean: Tensor, 40 | std: Tensor 41 | ) -> Tensor: 42 | # denormalize tensor in [-1, 1] to [0, 1] 43 | input = input.clone() 44 | input.mul_(std[:, None, None]).add_(mean[:, None, None]) 45 | return input 46 | 47 | 48 | def fft_shift(input: torch.Tensor, 49 | dims: Optional[tuple[int, ...]] = None 50 | ) -> torch.Tensor: 51 | """ PyTorch version of np.fftshift 52 | 53 | Args: 54 | input: rFFTed Tensor of size [Bx]CxHxWx2 55 | dims: 56 | 57 | Returns: shifted tensor 58 | 59 | """ 60 | 61 | return torch.fft.fftshift(input, dims) 62 | 63 | 64 | def ifft_shift(input: torch.Tensor, 65 | dims: Optional[tuple[int, ...]] = None 66 | ) -> torch.Tensor: 67 | """ PyTorch version of np.ifftshift 68 | 69 | Args: 70 | input: rFFTed Tensor of size [Bx]CxHxWx2 71 | dims: 72 | 73 | Returns: shifted tensor 74 | 75 | """ 76 | 77 | return torch.fft.ifftshift(input, dims) 78 | 79 | 80 | def _rfft(self: Tensor, 81 | signal_ndim: int, 82 | normalized: bool = False, 83 | onesided: bool = True 84 | ) -> Tensor: 85 | # old-day's torch.rfft 86 | 87 | if signal_ndim > 4: 88 | raise RuntimeError("signal_ndim is expected to be 1, 2, 3.") 89 | 90 | m = torch.fft.rfftn if onesided else torch.fft.fftn 91 | dim = [-3, -2, -1][3 - signal_ndim:] 92 | return torch.view_as_real(m(self, dim=dim, norm="ortho" if normalized else None)) 93 | 94 | 95 | def _irfft(self: Tensor, 96 | signal_ndim: int, 97 | normalized: bool = False, 98 | onesided: bool = True, 99 | ) -> Tensor: 100 | # old-day's torch.irfft 101 | 102 | if signal_ndim > 4: 103 | raise RuntimeError("signal_ndim is expected to be 1, 2, 3.") 104 | if not torch.is_complex(self): 105 | self = torch.view_as_complex(self) 106 | 107 | m = torch.fft.irfftn if onesided else torch.fft.ifftn 108 | dim = [-3, -2, -1][3 - signal_ndim:] 109 | out = m(self, dim=dim, norm="ortho" if normalized else None) 110 | return out.real if torch.is_complex(out) else out 111 | -------------------------------------------------------------------------------- /tests/test_distance.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | from anatome import distance 6 | 7 | 8 | @pytest.fixture 9 | def matrices(): 10 | return torch.randn(10, 5), torch.randn(10, 8), torch.randn(10, 20), torch.randn(8, 5) 11 | 12 | 13 | @pytest.mark.parametrize('mat_size', ([10, 5], [100, 50], [500, 200])) 14 | def test_cca_consistency(mat_size): 15 | def gen(): 16 | x, y = torch.randn(*mat_size, dtype=torch.float64), torch.randn(*mat_size, dtype=torch.float64) 17 | x = distance._zero_mean(x, 0) 18 | y = distance._zero_mean(y, 0) 19 | return x, y 20 | 21 | x, y = gen() 22 | cca_svd = distance.cca_by_svd(x, y) 23 | cca_qr = distance.cca_by_qr(x, y) 24 | torch.testing.assert_close(cca_svd[0].abs(), cca_qr[0].abs()) 25 | torch.testing.assert_close(cca_svd[1].abs(), cca_qr[1].abs()) 26 | torch.testing.assert_close(cca_svd[2], cca_qr[2]) 27 | 28 | 29 | @pytest.mark.parametrize('method', ['svd', 'qr']) 30 | def test_cca_shape(matrices, method): 31 | i1, i2, i3, i4 = matrices 32 | distance.cca(i1, i1, method) 33 | a, b, diag = distance.cca(i1, i2, method) 34 | assert list(a.size()) == [i1.size(1), diag.size(0)] 35 | assert list(b.size()) == [i2.size(1), diag.size(0)] 36 | with pytest.raises(ValueError): 37 | # needs more batch size 38 | distance.cca(i1, i3, method) 39 | with pytest.raises(ValueError): 40 | distance.cca(i1, i4, method) 41 | with pytest.raises(ValueError): 42 | distance.cca(i1, i2, 'wrong') 43 | 44 | 45 | def test_cka_shape(matrices): 46 | i1, i2, i3, i4 = matrices 47 | distance.linear_cka_distance(i1, i2, True) 48 | distance.linear_cka_distance(i1, i3, True) 49 | distance.linear_cka_distance(i1, i2, False) 50 | with pytest.raises(ValueError): 51 | distance.linear_cka_distance(i1, i4, True) 52 | 53 | 54 | def test_opd(matrices): 55 | i1, i2, i3, i4 = matrices 56 | distance.orthogonal_procrustes_distance(i1, i1) 57 | distance.orthogonal_procrustes_distance(i1, i2) 58 | distance.orthogonal_procrustes_distance(i1, i3) 59 | with pytest.raises(ValueError): 60 | distance.orthogonal_procrustes_distance(i1, i4) 61 | 62 | 63 | @pytest.mark.parametrize('method', ['pwcca', 'svcca', 'lincka', 'opd']) 64 | def test_similarity_hook_linear(method): 65 | model1 = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 4)) 66 | model2 = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 4)) 67 | with pytest.raises(RuntimeError): 68 | distance.Distance(model1, model2, method=method, model1_names=['3']) 69 | 70 | dist = distance.Distance(model1, model2, method=method) 71 | 72 | assert dist.convert_names(model1, None, None, False) == ['0', '1'] 73 | with torch.no_grad(): 74 | dist.forward(torch.randn(13, 3)) 75 | 76 | dist.between("1", "1") 77 | 78 | 79 | @pytest.mark.parametrize('resize_by', ['avg_pool', 'dft']) 80 | def test_similarity_hook_conv2d(resize_by): 81 | model1 = nn.Sequential(nn.Conv2d(3, 3, kernel_size=3), nn.Conv2d(3, 4, kernel_size=3)) 82 | model2 = nn.Sequential(nn.Conv2d(3, 3, kernel_size=3), nn.Conv2d(3, 4, kernel_size=3)) 83 | 84 | dist = distance.Distance(model1, model2, model1_names=['0', '1'], model2_names=['0', '1'], method='lincka') 85 | 86 | with torch.no_grad(): 87 | dist.forward(torch.randn(13, 3, 11, 11)) 88 | 89 | dist.between('1', '1', size=5) 90 | dist.between('1', '1', size=7) 91 | with pytest.raises(RuntimeError): 92 | dist.between('1', '1', size=8) 93 | 94 | dist.between('0', '1') 95 | -------------------------------------------------------------------------------- /anatome/landscape.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from copy import deepcopy 4 | from typing import Callable 5 | 6 | import torch 7 | from torch import nn, Tensor 8 | 9 | try: 10 | from tqdm import tqdm 11 | except ImportError: 12 | tqdm = lambda x, ncols=None: x 13 | 14 | from .utils import _evaluate 15 | 16 | EPS = 1e-8 17 | 18 | 19 | def _filter_normed_random_direction(model: nn.Module 20 | ) -> list[Tensor]: 21 | # applies filter normalization proposed in Li+2018 22 | def _filter_norm(dirs: Tensor, 23 | params: Tensor 24 | ) -> Tensor: 25 | d_norm = dirs.view(dirs.size(0), -1).norm(dim=-1) 26 | p_norm = params.view(params.size(0), -1).norm(dim=-1) 27 | ones = [1 for _ in range(dirs.dim() - 1)] 28 | return dirs.mul_(p_norm.view(-1, *ones) / (d_norm.view(-1, *ones) + EPS)) 29 | 30 | directions = [(params, torch.randn_like(params)) for params in model.parameters()] 31 | directions = [_filter_norm(dirs, params) for params, dirs in directions] 32 | return directions 33 | 34 | 35 | def _get_perturbed_model(model: nn.Module, 36 | direction: list[Tensor] or tuple[list[Tensor], list[Tensor]], 37 | step_size: float or tuple[float, float] 38 | ) -> nn.Module: 39 | # perturb the weight of model along direction with step size 40 | new_model = deepcopy(model) 41 | if len(direction) == 2: 42 | # 2d 43 | perturbation = [d_0 * step_size[0] + d_1 * step_size[1] for d_0, d_1 in zip(*direction)] 44 | else: 45 | # 1d 46 | perturbation = [d_0 * step_size for d_0 in direction] 47 | 48 | for param, pert in zip(new_model.parameters(), perturbation): 49 | if param.data.dim() <= 1: 50 | # ignore biasbn in the original code 51 | continue 52 | param.data.add_(pert) 53 | 54 | return new_model 55 | 56 | 57 | @torch.no_grad() 58 | def landscape1d(model: nn.Module, 59 | data: tuple[Tensor, Tensor], 60 | criterion: Callable[[Tensor, Tensor], Tensor], 61 | x_range: tuple[float, float], 62 | step_size: float, 63 | auto_cast: bool = False 64 | ) -> tuple[Tensor, Tensor]: 65 | """ Compute loss landscape along a random direction X. The landscape is 66 | 67 | [{criterion(input, target) at Θ+iX} for i in range(x_min, x_max, α)] 68 | 69 | Args: 70 | model: Trained model, parameterized by Θ 71 | data: Pairs of [input, target] to compute criterion 72 | criterion: Criterion of (input, target) -> scalar value 73 | x_range: (x_min, x_max) 74 | step_size: α 75 | 76 | Returns: x-coordinates, landscape 77 | 78 | """ 79 | 80 | x_coord = torch.arange(x_range[0], x_range[1] + step_size, step_size, dtype=torch.float) 81 | x_direction = _filter_normed_random_direction(model) 82 | loss_values = torch.zeros_like(x_coord, device=torch.device('cpu')) 83 | for i, x in enumerate(tqdm(x_coord.tolist(), ncols=80)): 84 | new_model = _get_perturbed_model(model, x_direction, x) 85 | loss_values[i] = _evaluate(new_model, data, criterion, auto_cast) 86 | return x_coord, loss_values 87 | 88 | 89 | @torch.no_grad() 90 | def landscape2d(model: nn.Module, 91 | data: tuple[Tensor, Tensor], 92 | criterion: Callable[[Tensor, Tensor], Tensor], 93 | x_range: tuple[float, float], 94 | y_range: tuple[float, float], 95 | step_size: float or tuple[float, float], 96 | auto_cast: bool = False 97 | ) -> tuple[Tensor, Tensor, Tensor]: 98 | """ Compute loss landscape along two random directions X and Y. The landscape is 99 | 100 | [ 101 | [{criterion(input, target) at Θ+iX+jY} 102 | for i in range(x_min, x_max, α)] 103 | for j in range(y_min, y_max, β)] 104 | ] 105 | 106 | Args: 107 | model: Trained model, parameterized by Θ 108 | data: Pairs of [input, target] to compute criterion 109 | criterion: Criterion of (input, target) -> scalar value 110 | x_range: (x_min, x_max) 111 | y_range: (y_min, y_max) 112 | step_size: α, β 113 | 114 | Returns: x-coordinates, y-coordinates, landscape 115 | 116 | """ 117 | if isinstance(step_size, float): 118 | step_size = (step_size, step_size) 119 | x_coord = torch.arange(x_range[0], x_range[1] + step_size[0], step_size[0], dtype=torch.float) 120 | y_coord = torch.arange(y_range[0], y_range[1] + step_size[1], step_size[1], dtype=torch.float) 121 | x_coord, y_coord = torch.meshgrid(x_coord, y_coord, indexing='ij') 122 | shape = x_coord.shape 123 | x_coord, y_coord = x_coord.flatten(), y_coord.flatten() 124 | x_direction = _filter_normed_random_direction(model) 125 | y_direction = _filter_normed_random_direction(model) 126 | loss_values = torch.zeros_like(x_coord, device=torch.device('cpu')) 127 | # To enable tqdm 128 | for i, (x, y) in enumerate(zip( 129 | tqdm(x_coord.tolist(), ncols=80), 130 | y_coord.tolist()) 131 | ): 132 | new_model = _get_perturbed_model(model, (x_direction, y_direction), (x, y)) 133 | loss_values[i] = _evaluate(new_model, data, criterion, auto_cast) 134 | return x_coord.view(shape), y_coord.view(shape), loss_values.view(shape) 135 | -------------------------------------------------------------------------------- /assets/landscape2d.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 10 | 11 | 12 | 13 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | 32 | 33 | 34 | 35 | 36 | 39 | 40 | 41 | 44 | 45 | 46 | 49 | 50 | 51 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /anatome/distance.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import partial 4 | from typing import Callable, Literal 5 | 6 | import torch 7 | from torch import Tensor, nn 8 | from torch.nn import functional as F 9 | from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor 10 | 11 | from .utils import _irfft, _rfft, _svd 12 | 13 | 14 | def _zero_mean(input: Tensor, 15 | dim: int 16 | ) -> Tensor: 17 | return input - input.mean(dim=dim, keepdim=True) 18 | 19 | 20 | def _check_shape_equal(x: Tensor, 21 | y: Tensor, 22 | dim: int 23 | ): 24 | if x.size(dim) != y.size(dim): 25 | raise ValueError(f'x.size({dim}) == y.size({dim}) is expected, but got {x.size(dim)=}, {y.size(dim)=} instead.') 26 | 27 | 28 | def cca_by_svd(x: Tensor, 29 | y: Tensor 30 | ) -> tuple[Tensor, Tensor, Tensor]: 31 | """ CCA using only SVD. 32 | For more details, check Press 2011 "Canonical Correlation Clarified by Singular Value Decomposition" 33 | 34 | Args: 35 | x: input tensor of Shape DxH 36 | y: input tensor of shape DxW 37 | 38 | Returns: x-side coefficients, y-side coefficients, diagonal 39 | 40 | """ 41 | 42 | # torch.svd(x)[1] is vector 43 | u_1, s_1, v_1 = _svd(x) 44 | u_2, s_2, v_2 = _svd(y) 45 | uu = u_1.t() @ u_2 46 | u, diag, v = _svd(uu) 47 | # a @ (1 / s_1).diag() @ u, without creating s_1.diag() 48 | a = v_1 @ (1 / s_1[:, None] * u) 49 | b = v_2 @ (1 / s_2[:, None] * v) 50 | return a, b, diag 51 | 52 | 53 | def cca_by_qr(x: Tensor, 54 | y: Tensor 55 | ) -> tuple[Tensor, Tensor, Tensor]: 56 | """ CCA using QR and SVD. 57 | For more details, check Press 2011 "Canonical Correlation Clarified by Singular Value Decomposition" 58 | 59 | Args: 60 | x: input tensor of Shape DxH 61 | y: input tensor of shape DxW 62 | 63 | Returns: x-side coefficients, y-side coefficients, diagonal 64 | 65 | """ 66 | 67 | q_1, r_1 = torch.linalg.qr(x) 68 | q_2, r_2 = torch.linalg.qr(y) 69 | qq = q_1.t() @ q_2 70 | u, diag, v = _svd(qq) 71 | # a = r_1.inverse() @ u, but it is faster and more numerically stable 72 | a = torch.linalg.solve(r_1, u) 73 | b = torch.linalg.solve(r_2, v) 74 | return a, b, diag 75 | 76 | 77 | def cca(x: Tensor, 78 | y: Tensor, 79 | backend: str 80 | ) -> tuple[Tensor, Tensor, Tensor]: 81 | """ Compute CCA, Canonical Correlation Analysis 82 | 83 | Args: 84 | x: input tensor of Shape DxH 85 | y: input tensor of Shape DxW 86 | backend: svd or qr 87 | 88 | Returns: x-side coefficients, y-side coefficients, diagonal 89 | 90 | """ 91 | 92 | _check_shape_equal(x, y, 0) 93 | 94 | if x.size(0) < x.size(1): 95 | raise ValueError(f'x.size(0) >= x.size(1) is expected, but got {x.size()=}.') 96 | 97 | if y.size(0) < y.size(1): 98 | raise ValueError(f'y.size(0) >= y.size(1) is expected, but got {y.size()=}.') 99 | 100 | if backend not in ('svd', 'qr'): 101 | raise ValueError(f'backend is svd or qr, but got {backend}') 102 | 103 | x = _zero_mean(x, dim=0) 104 | y = _zero_mean(y, dim=0) 105 | return cca_by_svd(x, y) if backend == 'svd' else cca_by_qr(x, y) 106 | 107 | 108 | def _svd_reduction(input: Tensor, 109 | accept_rate: float 110 | ) -> Tensor: 111 | left, diag, right = _svd(input) 112 | full = diag.abs().sum() 113 | ratio = diag.abs().cumsum(dim=0) / full 114 | num = torch.where(ratio < accept_rate, 115 | input.new_ones(1, dtype=torch.long), 116 | input.new_zeros(1, dtype=torch.long) 117 | ).sum() 118 | return input @ right[:, : num] 119 | 120 | 121 | def svcca_distance(x: Tensor, 122 | y: Tensor, 123 | accept_rate: float, 124 | backend: str 125 | ) -> Tensor: 126 | """ Singular Vector CCA proposed in Raghu et al. 2017. 127 | 128 | Args: 129 | x: input tensor of Shape DxH, where D>H 130 | y: input tensor of Shape DxW, where D>H 131 | accept_rate: 0.99 132 | backend: svd or qr 133 | 134 | Returns: 135 | 136 | """ 137 | 138 | x = _svd_reduction(x, accept_rate) 139 | y = _svd_reduction(y, accept_rate) 140 | div = min(x.size(1), y.size(1)) 141 | a, b, diag = cca(x, y, backend) 142 | return 1 - diag.sum() / div 143 | 144 | 145 | def pwcca_distance(x: Tensor, 146 | y: Tensor, 147 | backend: str 148 | ) -> Tensor: 149 | """ Projection Weighted CCA proposed in Marcos et al. 2018. 150 | 151 | Args: 152 | x: input tensor of Shape DxH, where D>H 153 | y: input tensor of Shape DxW, where D>H 154 | backend: svd or qr 155 | 156 | Returns: 157 | 158 | """ 159 | 160 | a, b, diag = cca(x, y, backend) 161 | a, _ = torch.linalg.qr(a) # reorthonormalize 162 | alpha = (x @ a).abs_().sum(dim=0) 163 | alpha /= alpha.sum() 164 | return 1 - alpha @ diag 165 | 166 | 167 | def _debiased_dot_product_similarity(z: Tensor, 168 | sum_row_x: Tensor, 169 | sum_row_y: Tensor, 170 | sq_norm_x: Tensor, 171 | sq_norm_y: Tensor, 172 | size: int 173 | ) -> Tensor: 174 | return (z 175 | - size / (size - 2) * (sum_row_x @ sum_row_y) 176 | + sq_norm_x * sq_norm_y / ((size - 1) * (size - 2))) 177 | 178 | 179 | def linear_cka_distance(x: Tensor, 180 | y: Tensor, 181 | reduce_bias: bool 182 | ) -> Tensor: 183 | """ Linear CKA used in Kornblith et al. 19 184 | 185 | Args: 186 | x: input tensor of Shape DxH 187 | y: input tensor of Shape DxW 188 | reduce_bias: debias CKA estimator, which might be helpful when D is limited 189 | 190 | Returns: 191 | 192 | """ 193 | 194 | _check_shape_equal(x, y, 0) 195 | 196 | x = _zero_mean(x, dim=0) 197 | y = _zero_mean(y, dim=0) 198 | dot_prod = (y.t() @ x).norm('fro').pow(2) 199 | norm_x = (x.t() @ x).norm('fro') 200 | norm_y = (y.t() @ y).norm('fro') 201 | 202 | if reduce_bias: 203 | size = x.size(0) 204 | # (x @ x.t()).diag() 205 | sum_row_x = torch.einsum('ij,ij->i', x, x) 206 | sum_row_y = torch.einsum('ij,ij->i', y, y) 207 | sq_norm_x = sum_row_x.sum() 208 | sq_norm_y = sum_row_y.sum() 209 | dot_prod = _debiased_dot_product_similarity(dot_prod, sum_row_x, sum_row_y, sq_norm_x, sq_norm_y, size) 210 | norm_x = _debiased_dot_product_similarity(norm_x.pow(2), sum_row_x, sum_row_x, sq_norm_x, sq_norm_x, size 211 | ).sqrt() 212 | norm_y = _debiased_dot_product_similarity(norm_y.pow(2), sum_row_y, sum_row_y, sq_norm_y, sq_norm_y, size 213 | ).sqrt() 214 | return 1 - dot_prod / (norm_x * norm_y) 215 | 216 | 217 | def orthogonal_procrustes_distance(x: Tensor, 218 | y: Tensor, 219 | ) -> Tensor: 220 | """ Orthogonal Procrustes distance used in Ding+21 221 | 222 | Args: 223 | x: input tensor of Shape DxH 224 | y: input tensor of Shape DxW 225 | 226 | Returns: 227 | 228 | """ 229 | _check_shape_equal(x, y, 0) 230 | 231 | frobenius_norm = partial(torch.linalg.norm, ord="fro") 232 | nuclear_norm = partial(torch.linalg.norm, ord="nuc") 233 | 234 | x = _zero_mean(x, dim=0) 235 | x /= frobenius_norm(x) 236 | y = _zero_mean(y, dim=0) 237 | y /= frobenius_norm(y) 238 | # frobenius_norm(x) = 1, frobenius_norm(y) = 1 239 | # 0.5*d_proc(x, y) 240 | return 1 - nuclear_norm(x.t() @ y) 241 | 242 | 243 | class Distance(object): 244 | """ Module to measure distance between `model1` and `model2` 245 | 246 | Args: 247 | method: Method to compute distance. 'pwcca' by default. 248 | model1_names: Names of modules of `model1` to be used. If None (default), all names are used. 249 | model2_names: Names of modules of `model2` to be used. If None (default), all names are used. 250 | model1_leaf_modules: Modules of model1 to be considered as single nodes (see https://pytorch.org/blog/FX-feature-extraction-torchvision/). 251 | model2_leaf_modules: Modules of model2 to be considered as single nodes (see https://pytorch.org/blog/FX-feature-extraction-torchvision/). 252 | train_mode: If True, models' `train_model` is used, otherwise `eval_mode`. False by default. 253 | """ 254 | 255 | _supported_dims = (2, 4) 256 | _default_backends = {'pwcca': partial(pwcca_distance, backend='svd'), 257 | 'svcca': partial(svcca_distance, accept_rate=0.99, backend='svd'), 258 | 'lincka': partial(linear_cka_distance, reduce_bias=False), 259 | 'opd': orthogonal_procrustes_distance} 260 | 261 | def __init__(self, 262 | model1: nn.Module, 263 | model2: nn.Module, 264 | method: str | Callable = 'pwcca', 265 | model1_names: str | list[str] = None, 266 | model2_names: str | list[str] = None, 267 | model1_leaf_modules: list[nn.Module] = None, 268 | model2_leaf_modules: list[nn.Module] = None, 269 | train_mode: bool = False 270 | ): 271 | 272 | dp_ddp = (nn.DataParallel, nn.parallel.DistributedDataParallel) 273 | if isinstance(model1, dp_ddp) or isinstance(model2, dp_ddp): 274 | raise RuntimeWarning('model is nn.DataParallel or nn.DistributedDataParallel. ' 275 | 'SimilarityHook may causes unexpected behavior.') 276 | if isinstance(method, str): 277 | method = self._default_backends[method] 278 | self.distance_func = method 279 | self.model1 = model1 280 | self.model2 = model2 281 | self.extractor1 = create_feature_extractor(model1, self.convert_names(model1, model1_names, 282 | model1_leaf_modules, train_mode)) 283 | self.extractor2 = create_feature_extractor(model2, self.convert_names(model2, model2_names, 284 | model2_leaf_modules, train_mode)) 285 | self._model1_tensors: dict[str, torch.Tensor] = None 286 | self._model2_tensors: dict[str, torch.Tensor] = None 287 | 288 | def available_names(self, 289 | model1_leaf_modules: list[nn.Module] = None, 290 | model2_leaf_modules: list[nn.Module] = None, 291 | train_mode: bool = False 292 | ): 293 | return {'model1': self.convert_names(self.model1, None, model1_leaf_modules, train_mode), 294 | 'model2': self.convert_names(self.model2, None, model2_leaf_modules, train_mode)} 295 | 296 | @staticmethod 297 | def convert_names(model: nn.Module, 298 | names: str | list[str], 299 | leaf_modules: list[nn.Module], 300 | train_mode: bool 301 | ) -> list[str]: 302 | # a helper function 303 | if isinstance(names, str): 304 | names = [names] 305 | tracer_kwargs = {} 306 | if leaf_modules is not None: 307 | tracer_kwargs['leaf_modules'] = leaf_modules 308 | 309 | _names = get_graph_node_names(model, tracer_kwargs=tracer_kwargs) 310 | _names = _names[0] if train_mode else _names[1] 311 | _names = _names[1:] # because the first element is input 312 | 313 | if names is None: 314 | names = _names 315 | else: 316 | if not (set(names) <= set(_names)): 317 | diff = set(names) - set(_names) 318 | raise RuntimeError(f'Unknown names: {list(diff)}') 319 | 320 | return names 321 | 322 | def forward(self, 323 | data 324 | ) -> None: 325 | """ Forward pass of models. Used to store intermediate features. 326 | 327 | Args: 328 | data: input data to models 329 | 330 | """ 331 | self._model1_tensors = self.extractor1(data) 332 | self._model2_tensors = self.extractor2(data) 333 | 334 | def between(self, 335 | name1: str, 336 | name2: str, 337 | size: int | tuple[int, int] = None, 338 | downsample_method: Literal['avg_pool', 'dft'] = 'avg_pool' 339 | ) -> torch.Tensor: 340 | """ Compute distance between modules corresponding to name1 and name2. 341 | 342 | Args: 343 | name1: Name of a module of `model1` 344 | name2: Name of a module of `model2` 345 | size: Size for downsampling if necessary. If size's type is int, both features of name1 and name2 are 346 | reshaped to (size, size). If size's type is tuple[int, int], features are reshaped to (size[0], size[0]) and 347 | (size[1], size[1]). If size is None (default), no downsampling is applied. 348 | downsample_method: Downsampling method: 'avg_pool' for average pooling and 'dft' for discrete 349 | Fourier transform 350 | 351 | Returns: Distance in tensor. 352 | 353 | """ 354 | tensor1 = self._model1_tensors[name1] 355 | tensor2 = self._model2_tensors[name2] 356 | if tensor1.dim() not in self._supported_dims: 357 | raise RuntimeError(f'Supported dimensions are ={self._supported_dims}, but got {tensor1.dim()}') 358 | if tensor2.dim() not in self._supported_dims: 359 | raise RuntimeError(f'Supported dimensions are ={self._supported_dims}, but got {tensor2.dim()}') 360 | 361 | if size is not None: 362 | if isinstance(size, int): 363 | size = (size, size) 364 | 365 | def downsample_if_necessary(input, s): 366 | if input.dim() == 4: 367 | input = self._downsample_4d(input, s, downsample_method) 368 | return input 369 | 370 | tensor1 = downsample_if_necessary(tensor1, size[0]) 371 | tensor2 = downsample_if_necessary(tensor2, size[1]) 372 | 373 | def reshape_if_4d(input): 374 | if input.dim() == 4: 375 | # see https://arxiv.org/abs/1706.05806's P5. 376 | if name1 == name2: # same layer comparisons -> Cx(BHW) 377 | input = input.permute(1, 0, 2, 3).flatten(1) 378 | else: # different layer comparisons -> Bx(CHW) 379 | input = input.flatten(1) 380 | return input 381 | 382 | tensor1 = reshape_if_4d(tensor1) 383 | tensor2 = reshape_if_4d(tensor2) 384 | 385 | return self.distance_func(tensor1, tensor2) 386 | 387 | @staticmethod 388 | def _downsample_4d(input: Tensor, 389 | size: int, 390 | backend: Literal['avg_pool', 'dft'] 391 | ) -> Tensor: 392 | if input.dim() != 4: 393 | raise RuntimeError(f'input is expected to be 4D tensor, but got {input.dim()=}.') 394 | 395 | # todo: what if channel-last? 396 | b, c, h, w = input.size() 397 | 398 | if (size, size) == (h, w): 399 | return input 400 | 401 | if (size, size) > (h, w): 402 | raise RuntimeError(f'size ({size}) is expected to be smaller than h or w, but got {h=}, {w=}.') 403 | 404 | if backend not in ('avg_pool', 'dft'): 405 | raise RuntimeError(f'backend is expected to be avg_pool or dft, but got {backend=}.') 406 | 407 | if backend == 'avg_pool': 408 | return F.adaptive_avg_pool2d(input, (size, size)) 409 | 410 | # almost PyTorch implant of 411 | # https://github.com/google/svcca/blob/master/dft_ccas.py 412 | if input.size(2) != input.size(3): 413 | raise RuntimeError('width and height of input needs to be equal') 414 | h = input.size(2) 415 | input_fft = _rfft(input, 2, normalized=True, onesided=False) 416 | freqs = torch.fft.fftfreq(h, 1 / h, device=input.device) 417 | idx = (freqs >= -size / 2) & (freqs < size / 2) 418 | # BxCxHxWx2 -> BxCxhxwx2 419 | input_fft = input_fft[..., idx, :][..., idx, :, :] 420 | input = _irfft(input_fft, 2, normalized=True, onesided=False) 421 | return input 422 | -------------------------------------------------------------------------------- /assets/fourier.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 10 | 11 | 12 | 13 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | 32 | 33 | 34 | 35 | 36 | 39 | 40 | 41 | 44 | 45 | 46 | 49 | 50 | 51 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "tags": [] 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch\n", 12 | "from torch.nn import functional as F\n", 13 | "from torchvision.models import resnet18\n", 14 | "from torchvision import transforms\n", 15 | "from torchvision.datasets import ImageFolder\n", 16 | "from torch.utils.data import DataLoader\n", 17 | "\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "\n", 20 | "batch_size = 128\n", 21 | "\n", 22 | "model = resnet18(pretrained=True)\n", 23 | "imagenet = ImageFolder('~/.torch/data/imagenet/val', \n", 24 | " transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),\n", 25 | " transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]))\n", 26 | "data = next(iter(DataLoader(imagenet, batch_size=batch_size, num_workers=8)))" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "model.eval()\n", 36 | "model.cuda()\n", 37 | "data = data[0].cuda(), data[1].cuda()" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "## CCA\n", 45 | "\n", 46 | "Measure CCA distance between trained and not trained models." 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 3, 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "output_type": "execute_result", 56 | "data": { 57 | "text/plain": "tensor(0.3290, device='cuda:0')" 58 | }, 59 | "metadata": {}, 60 | "execution_count": 3 61 | } 62 | ], 63 | "source": [ 64 | "from anatome import CCAHook\n", 65 | "random_model = resnet18().cuda()\n", 66 | "\n", 67 | "hook1 = CCAHook(model, \"layer1.0.conv1\")\n", 68 | "hook2 = CCAHook(random_model, \"layer1.0.conv1\")\n", 69 | "\n", 70 | "with torch.no_grad():\n", 71 | " model(data[0])\n", 72 | " random_model(data[0])\n", 73 | "hook1.distance(hook2, size=8)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 4, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "# clear hooked tensors and caches\n", 83 | "hook1.clear()\n", 84 | "hook2.clear()\n", 85 | "torch.cuda.empty_cache()" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "## Loss Landscape\n", 93 | "\n", 94 | "Show loss landscape in 1D space." 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 5, 100 | "metadata": { 101 | "tags": [] 102 | }, 103 | "outputs": [ 104 | { 105 | "output_type": "stream", 106 | "name": "stderr", 107 | "text": "100%|███████████████████████████████████████████| 21/21 [00:01<00:00, 13.02it/s]\n" 108 | } 109 | ], 110 | "source": [ 111 | "from anatome import landscape1d\n", 112 | "x, y = landscape1d(model, data, F.cross_entropy, x_range=(-1, 1), step_size=0.1)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 6, 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "output_type": "execute_result", 122 | "data": { 123 | "text/plain": "[]" 124 | }, 125 | "metadata": {}, 126 | "execution_count": 6 127 | }, 128 | { 129 | "output_type": "display_data", 130 | "data": { 131 | "text/plain": "
", 132 | "image/svg+xml": "\n\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", 133 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEDCAYAAAAlRP8qAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAbn0lEQVR4nO3de3Bc53nf8e+zixtBgBcQBERRlCiHgOxIU8UqR5HsOJHTpiMpTpV0nIxc+RLHMSvX6sQzTTpumnE84z96mTZTK5KlKLRGUeJQU9e2qslQde3YiVTLVESyupCiTVKyZcGkSGCXBLgLYheXp3/sWXCF64I4t138PjMQdvecPfvoYPHDy3ff9z3m7oiISOPLJF2AiIiEQ4EuItIkFOgiIk1CgS4i0iQU6CIiTUKBLiLSJBINdDN71MzOmtmROvf/LTN71cyOmtlfR12fiEgjsSTHoZvZLwIF4HF3v2GZfQeA/wH8srufM7M+dz8bR50iIo0g0Ra6uz8D5GsfM7OfMbP/bWaHzOxZM3tnsOmTwIPufi54rsJcRKRGGvvQHwH+jbv/Y+D3gS8Fjw8Cg2b2PTM7YGa3J1ahiEgKtSRdQC0z6wLeA3zVzKoPtwffW4AB4DbgKuBZM7vB3c/HXaeISBqlKtCp/IvhvLv/3ALbhoAD7j4J/MjMfkgl4F+Is0ARkbRKVZeLu49RCevfBLCKG4PNTwLvDx7vpdIF83oihYqIpFDSwxb3Ad8HrjOzITP7BHAP8Akzewk4CtwV7P5NIGdmrwLfBf7A3XNJ1C0ikkaJDlsUEZHwpKrLRURELl9iH4r29vb6zp07k3p5EZGGdOjQoRF337rQtsQCfefOnRw8eDCplxcRaUhm9sZi29TlIiLSJBToIiJNQoEuItIkFOgiIk1CgS4i0iQU6CIiTUKBLiLSJBToIiIx+uK3T/DsieFIjq1AFxGJycyM88W/Pc4//Ci//M6XQYEuIhKT8xcnmXHoWd8WyfEV6CIiMckVSgBs6WpfZs/Lo0AXEYlJrlgGYIta6CIijS1XCAK9K6FAN7MdZvZdMztmZkfN7PcW2MfM7H4zO2lmL5vZTZFUKyLSwHLFSpdLVH3o9SyfOwX8W3c/bGbdwCEz+5a7v1qzzx1ULtg8APw88FDwXUREAtUWek9nQi10dz/t7oeD2xeAY8D2ObvdBTzuFQeATWa2LfRqRUQaWK5YYnNnKy3ZaHq7V3RUM9sJvBt4fs6m7cCbNfeHmB/6mNkeMztoZgeHh6MZWC8iklb5Yjmy7hZYQaCbWRfwNeAz7j42d/MCT5l39Wl3f8Tdd7v77q1bF7yCkohI0xoplCMbsgh1BrqZtVIJ86+4+9cX2GUI2FFz/yrg1OrLExFpHvliObIhi1DfKBcDvgwcc/c/WWS3p4CPBqNdbgFG3f10iHWKiDS8XKEU2ZBFqG+Uy3uBjwCvmNmLwWN/CFwN4O4PA/uBO4GTwDjw8fBLFRFpXFPTM5wbn6RnfXRdLssGurv/XxbuI6/dx4FPh1WUiEizOTc+CUBvhC10zRQVEYlBdVLRlghb6Ap0EZEY5KuTitIwbFFERC7fSLAwl7pcREQaXL4Q7TouoEAXEYlFrlgmY7AponVcQIEuIhKLkUKZzZ1tZDNLDhpcFQW6iEgM8sVoJxWBAl1EJBa5QjnSIYugQBcRiUW+WKZHLXQRkcY3UijRG+EIF1Cgi4hErjw1w9jEVKTruIACXUQkcufGo704dJUCXUQkYiOF6jouCnQRkYaWL1Zb6OpyERFpaLmCulxERJpCrtpCV5eLiEhjyxVKtGSMDR2tkb6OAl1EJGL5YpnN69vIRLiOCyjQRUQiN1IoR97dAgp0EZHI5WJYmAsU6CIikcsXo1+YCxToIiKRyxXKaqGLiDS6iclpCqUp9aGLiDS6uGaJggJdRCRS1UCP8uLQVQp0EZEIVRfm6lUfuohIY7vUQleXi4hIQ4trYS5QoIuIRGqkWKItm6G7vSXy11Kgi4hEKF8o07O+DbNo13EBBbqISKRyxXgmFYECXUQkUrliOZYhi6BAFxGJVK5QojeGSUWgQBcRiVReLXQRkcY3Xp5ivDytPnQRkUZXHYPeG8OkIlCgi4hEJs51XECBLiISmVyxso5LarpczOxRMztrZkcW2X6bmY2a2YvB1+fCL1NEpPHMTvuPqculnrmojwEPAI8vsc+z7v6BUCoSEWkSuWJ867hAHS10d38GyMdQi4hIU8kXy7S3ZOhsy8byemH1od9qZi+Z2dNmdv1iO5nZHjM7aGYHh4eHQ3ppEZF0GgkmFcWxjguEE+iHgWvc/UbgT4EnF9vR3R9x993uvnvr1q0hvLSISHrFdXHoqlUHuruPuXshuL0faDWz3lVXJiLS4OKcJQohBLqZXWHBvyfM7ObgmLnVHldEpNHlCqXYRrhAHaNczGwfcBvQa2ZDwB8DrQDu/jDwQeBTZjYFXATudnePrGIRkQbg7rEunQt1BLq7f2iZ7Q9QGdYoIiKBYnma0tQMWxqpy0VERObLF+Kd9g8KdBGRSIwE0/7jWgsdFOgiIpHIqYUuItIc8jEvzAUKdBGRSIzEvDAXKNBFRCKRL5bpbMuyLqZ1XECBLiISiVyhFGt3CyjQRUQikSuW6YmxuwUU6CIikcgVyvTGOMIFFOgiIpHIFUuxDlkEBbqISOjcnXyxzJYYJxWBAl1EJHRjE1NMTju9+lBURKSx5YvxzxIFBbqISOhyheosUXW5iIg0tFyxOktULXQRkYZWXZhLE4tERBpctctFfegiIg0uVyzT3d5Ce0t867iAAl1EJHRxX0u0SoEuIhKyfAKzREGBLiISulwh/lmioEAXEQldrliOfcgiKNBFREI1M1Ndx0WBLiLS0EYvTjI947GvhQ4KdBGRUFVnica9MBco0EVEQjW7jota6CIijS2plRZBgS4iEqoRdbmIiDSHfLAw12a10Jf33R+c5f3/9e94a3Qi6VJERObJFUtsXNdKazb+eG24QO9ozfKjkSLHz1xIuhQRkXlyhWQmFUEDBvpAfxcAJ84WEq5ERGS+XLGUyKQiaMBA7+1qp2d9GyfPqoUuIulTaaHHP2QRGjDQAXb1dXH8jFroIpI++WKZHrXQ6zfY38WJMxdw96RLERGZNT3j5MfL9KoPvX4Dfd2MTUxx9kIp6VJERGadHy/jnsykImjYQA8+GFW3i4ikSHUdlyTWQodGDfT+bgBO6INREUmRkdl1XNRCr1tvVxubOlv1waiIpEo+7S10M3vUzM6a2ZFFtpuZ3W9mJ83sZTO7Kfwy570mA31dGrooIqmSK1QDPb0t9MeA25fYfgcwEHztAR5afVnLG+jv5viZgka6iEhq5IplzGBzZ0oD3d2fAfJL7HIX8LhXHAA2mdm2sApczEBfF6MXJxkuaKSLiKRDrlBic2cb2Ywl8vph9KFvB96suT8UPDaPme0xs4NmdnB4eHhVLzoYfDB6Uv3oIpIS+WI5sSGLEE6gL/SnaMF+EHd/xN13u/vurVu3rupFq0MXtUiXiKRFkgtzQTiBPgTsqLl/FXAqhOMuaWt3Oxs6WrRIl4ikxkiCC3NBOIH+FPDRYLTLLcCou58O4bhLMjMG+7s1uUhEUiNfTG5hLoCW5XYws33AbUCvmQ0Bfwy0Arj7w8B+4E7gJDAOfDyqYuca6O/i6SNv4e6YJfMhhIgIwOT0DOfHJxNtoS8b6O7+oWW2O/Dp0CpagYG+bvaNv0muWKY3oYH8IiIA58aDMegN3oeemOrFLvTBqIgk7dKkouQal40d6H3B0EV9MCoiCatO+2/0YYuJ6d/QTndHi1roIpK46sJcvQ0+yiUx1TVdNNJFRJJ2qYWuLpfLNtjfrS4XEUlcrlAmY7BpXWtiNTR8oO/q6yJXLJPTmi4ikqBcsUTP+nYyCa3jAk0Q6JcudqFWuogkJ+lp/9AEgT7YX70cnT4YFZHk5IrlRCcVQRME+hUbOuhu15ouIpKspFdahCYIdDNjV79GuohIskYKpcRnrDd8oENlKV1dMFpEklKemuHCxJRa6GEY6OtmpFCeHQcqIhKnSxeHVqCv2oA+GBWRBFVniSa5dC40TaBr6KKIJEct9BBdubGD9W1ZzRgVkUTkitUWugJ91SojXbq1SJeIJGJ26Vx1uYSjMtJFLXQRiV+uWKYlY2xYt+w1gyLVNIE+2N/F8IUS58c10kVE4pUvVCYVJX0pzKYJ9OrFLtRKF5G45YqlRK9UVNU0gb6rrzp0UYEuIvEaKZQTvbBFVdME+vZN6+hsy+qDURGJXRrWcYEmCvRMxtjV16WhiyISu1yhlPgIF2iiQIdKP7rWdBGROE1MTlMsTyc+qQiaLdD7uzgzVmL04mTSpYjIGpGrzhJVl0u4BoIPRk+qlS4iMckXqheHVqCHajBY0+W4RrqISExGqtP+NWwxXNs3rWNda1ZDF0UkNtVp/xq2GLLqSBd9MCoicckHLXR1uURgoE+XoxOR+OQKZdpaMnS1J7uOCzRhoO/q7+KtsQnGJjTSRUSilyuW2ZKCdVygCQN9sLqmi1rpIhKDXKGUijHo0ISBXr0cnYYuikgcKtP+kx/hAk0Y6Fdt7qSjNaMWuojEYqRQpjcFH4hCEwZ6NmP8zNYujmtNFxGJQWXpXAV6ZAb6ujipVRdFJGLj5SkmJmfU5RKlgf5uTo1OcEEjXUQkQrPXElULPTqX1nRRt4uIRCdNC3NBkwZ6dU0XXY5ORKKUK6RnHReoM9DN7HYz+6GZnTSzzy6w/TYzGzWzF4Ovz4Vfav129HTS1pLhhPrRRSRCaWuhLztX1cyywIPArwBDwAtm9pS7vzpn12fd/QMR1Lhi1ZEuaqGLSJQasQ/9ZuCku7/u7mXgCeCuaMtavcF+rekiItHKFUqsa83S2Zb8Oi5QX6BvB96suT8UPDbXrWb2kpk9bWbXL3QgM9tjZgfN7ODw8PBllFu/gb4ufnr+IoXSVKSvIyJrV1ouDl1VT6AvtOKMz7l/GLjG3W8E/hR4cqEDufsj7r7b3Xdv3bp1ZZWu0K5gTZfX1O0iIhEZKZZTsQ56VT2BPgTsqLl/FXCqdgd3H3P3QnB7P9BqZr2hVXkZBoM1XY7rg1ERiUi+WGq4FvoLwICZXWtmbcDdwFO1O5jZFRasHWlmNwfHzYVd7Epc3dNJWzajsegiEplcoZyaIYtQxygXd58ys/uAbwJZ4FF3P2pm9wbbHwY+CHzKzKaAi8Dd7j63WyZWLdkM79i6Xi10EYmEu8+uhZ4WdX00G3Sj7J/z2MM1tx8AHgi3tNUb6O/m//3kXNJliEgTKpSmKE/NpGbIIjTpTNGqgb4uhs5dZLyskS4iEq7ZMegpWZgLmjzQB/u1pouIRKM6S7RHLfR47NLl6EQkItV1XHrVQo/Hzi2dtGaN47ocnYiELK8Werxashne0dvFSbXQRSRkaVuYC5o80AF29WuRLhEJX65QZn1blo7WbNKlzGr6QB/s6+bNc+NcLE8nXYqINJHKtUTT038OayDQB/q7cIfXhtVKF5Hw5ArpWpgL1kCga00XEYlCLmULc8EaCPRrtqynJWPqRxeRUOUKpVRNKoI1EOit2QzX9q7XWHQRCY27V9ZCVws9foP93ZzQWHQRCcnYxSmmZjxVQxZhjQT6rr4ufpIfZ2JSI11EZPVyxcos0TQtzAVrJNAH+7tx15ouIhKOS5OK1IceuwEt0iUiIaqu46JhiwnYOTvSRf3oIrJ61RZ6ryYWxa+tJcPO3vUc10gXEQlBdS10tdATMtDXpS4XEQlFvlimu6OFtpZ0RWi6qonQQF8Xb+SKGukiIqs2UiilrrsF1lKg93cz4/D6cDHpUkSkweWL6VvHBdZUoFdGuuiDURFZrVyhnLpJRbCGAv3a3vVkM6YlAERk1SpL5yrQE9PekuWaLZ1qoYvIqszMVNZxSdukIlhDgQ6VD0a16qKIrMb5i5PMePqm/cMaC/TB/m7eyI1TmtJIFxG5PK+eGgPSNwYd1lig7+rrYnrG+dGIRrqIyMq9PHSeT33lEDt61vG+ga1JlzPPmgr0wf5uAPa/fDrhSkSk0bw8dJ4P732eTZ2tPLHnVrXQk3Zdfzd33HAF93/nJJ9/6ijTM550SSLSAF4ZGuXDe59nw7pW9n3yFrZvWpd0SQtaU4GeyRgP/Mub+N1fuJbHnvsx/+ovDzFenkq6LBFJsVeGRrln7wE2rGvliT23cNXmzqRLWtSaCnSAbMb4ow/8LF+463q+84Mz/NaffZ+zYxNJlyUiKXTkp6N8+MvPN0SYwxoM9KqP3LqTvR/bzevDRX79we/xg7fGki5JRFLkyE9HuWfv83S1t7Dvk+kPc1jDgQ7wy+/s56v33sq0Ox986Pv8/fHhpEsSkRSoDfMn9tzCjp70hzms8UAHuP7KjTz56feyo6eT33nsBf76+Z8kXZKIJKhRwxwU6ABs27iOr957K+8b6OUPv/EK//HpY8xoBIzImnP0VKXPvBHDHBTos7raW9j70d18+Jar+bO/f5379h3W2ukia8jRU5WW+fq2xgxzUKC/TUs2wxfuuoE/+tV38fSRt/jQnx9gJLgYrIg0r2qYd7Zm2ffJxgxzUKDPY2b87vvewUP33MSx02P8xpe+x0mt0CjStF49NTYb5k/suZWrtzRmmIMCfVG337CNJ/bcysXyNP/iS8/x3GsjSZckIiGrhPmBSst8zy0NHeagQF/Sz+3YxDf+9Xvp29DBxx79B/7noaGkSxKRkFTDvCMI82u2rE+6pFUz9+VHc5jZ7cAXgSyw193/05ztFmy/ExgHftvdDy91zN27d/vBgwcvt+5YjV6c5FN/dYjnXsux+5rNbN+8jm0b13Hlpg62bVzHto0dbNvYQc/6NiqnQkTSolia4szYBGfGSpy9MMHZsRJnxib42uEhOlqzPNFgYW5mh9x990LbWup4chZ4EPgVYAh4wcyecvdXa3a7AxgIvn4eeCj43hQ2rmvlsY/fzH//9nEO/vgch944x5mx00xOv/2PYXtLJgj3IOSDwL9yUwdXbFhHz/o2shmjJWNks8H3jNGSyZAx9MdAZAHuztSMMz0TfJ92pt2ZmpmhWJoOwnqC4QulBYO7WJ4/Wq2jNcNgfzf33/3uhgrz5Swb6MDNwEl3fx3AzJ4A7gJqA/0u4HGvNPcPmNkmM9vm7k2zTm1bS4Z/d/s7Z+/PzDgjxRKnz09wevQip85P8NbYBKfOX+T06AQHXs9x5kJpRSs6Xgp4I1MN/kxm9nGAaubPfufS49U/B9U/DDb7n0vbLkej/qFJoupGnb2w3L/Ul9zq8/epHs9n71f38bff9/mBPTNTe3+GlUwJaW/J0L+hg/4N7bzryg380nVbZ+/3d3fQt6Gdvg0ddLe3NOz7ein1BPp24M2a+0PMb30vtM924G2BbmZ7gD0AV1999UprTZVMxujr7qCvu4Mbd2xacJ/pGWf4QolToxc5fX6C0YuTTLszPT3z9hbH7PeZ2RZI9fHK/pX7Puc359Iviy/wizP/l+qyNGhCeYKFWyJ/SkKwTNlLbX5bI2L2sbc/b94+wY2sGS1ZI2M1jZjspcbNUo2czrYs/Rs66OuuBPWGjuYM6nrVE+gLnZ25vy317IO7PwI8ApU+9Dpeu6FlM8YVGzu4YmMHNPbfLxFpAPWMchkCdtTcvwo4dRn7iIhIhOoJ9BeAATO71szagLuBp+bs8xTwUau4BRhtpv5zEZFGsGyXi7tPmdl9wDepDFt81N2Pmtm9wfaHgf1UhiyepDJs8ePRlSwiIguppw8dd99PJbRrH3u45rYDnw63NBERWQnNFBURaRIKdBGRJqFAFxFpEgp0EZEmUdfiXJG8sNkw8MZlPr0XSON6tmmtC9Jbm+paGdW1Ms1Y1zXuvnWhDYkF+mqY2cHFVhtLUlrrgvTWprpWRnWtzFqrS10uIiJNQoEuItIkGjXQH0m6gEWktS5Ib22qa2VU18qsqboasg9dRETma9QWuoiIzKFAFxFpEqkNdDP7TTM7amYzZrbo8B4zu93MfmhmJ83sszWP95jZt8zsRPB9c0h1LXtcM7vOzF6s+Rozs88E2z5vZj+t2XZnXHUF+/3YzF4JXvvgSp8fRV1mtsPMvmtmx4Kf+e/VbAv1fC32fqnZbmZ2f7D9ZTO7qd7nRlzXPUE9L5vZc2Z2Y822BX+mMdV1m5mN1vx8PlfvcyOu6w9qajpiZtNm1hNsi/J8PWpmZ83syCLbo31/uXsqv4B3AdcBfwfsXmSfLPAa8A6gDXgJ+Nlg238BPhvc/izwn0Oqa0XHDWp8i8pkAIDPA78fwfmqqy7gx0Dvav+/wqwL2AbcFNzuBo7X/BxDO19LvV9q9rkTeJrKVbhuAZ6v97kR1/UeYHNw+45qXUv9TGOq6zbgby7nuVHWNWf/XwO+E/X5Co79i8BNwJFFtkf6/kptC93dj7n7D5fZbfYC1u5eBqoXsCb4/hfB7b8Afj2k0lZ63H8CvObulzsrtl6r/f9N7Hy5+2l3PxzcvgAco3JN2rAt9X6prfdxrzgAbDKzbXU+N7K63P05dz8X3D1A5apgUVvN/3Oi52uODwH7QnrtJbn7M0B+iV0ifX+lNtDrtNjFqQH6PbhqUvC9L6TXXOlx72b+m+m+4J9bj4bVtbGCuhz4P2Z2yCoX7V7p86OqCwAz2wm8G3i+5uGwztdS75fl9qnnuVHWVesTVFp5VYv9TOOq61Yze8nMnjaz61f43Cjrwsw6gduBr9U8HNX5qkek76+6LnARFTP7NnDFApv+g7v/r3oOscBjqx6HuVRdKzxOG/DPgX9f8/BDwBeo1PkF4L8BvxNjXe9191Nm1gd8y8x+ELQqLluI56uLyi/eZ9x9LHj4ss/XQi+xwGP1XvA8kvfaMq85f0ez91MJ9F+oeTj0n+kK6jpMpTuxEHy+8SQwUOdzo6yr6teA77l7bas5qvNVj0jfX4kGurv/01UeYqmLU58xs23ufjr4J83ZMOoys5Uc9w7gsLufqTn27G0z+3Pgb+Ksy91PBd/Pmtk3qPxT7xkSPl9m1kolzL/i7l+vOfZln68FrOaC5211PDfKujCzfwTsBe5w91z18SV+ppHXVfOHF3ffb2ZfMrPeep4bZV015v0LOcLzVY9I31+N3uWy1AWsnwI+Ftz+GFBPi78eKznuvL67INSqfgNY8NPwKOoys/Vm1l29DfyzmtdP7HyZmQFfBo65+5/M2Rbm+VrNBc/reW5kdZnZ1cDXgY+4+/Gax5f6mcZR1xXBzw8zu5lKpuTqeW6UdQX1bAR+iZr3XMTnqx7Rvr+i+KQ3jC8qv7xDQAk4A3wzePxKYH/NfndSGRXxGpWumurjW4C/BU4E33tCqmvB4y5QVyeVN/bGOc//S+AV4OXgB7YtrrqofIL+UvB1NC3ni0r3gQfn5MXg684oztdC7xfgXuDe4LYBDwbbX6FmhNVi77WQztNyde0FztWcn4PL/Uxjquu+4HVfovJh7XvScL6C+78NPDHneVGfr33AaWCSSn59Is73l6b+i4g0iUbvchERkYACXUSkSSjQRUSahAJdRKRJKNBFRJqEAl1EpEko0EVEmsT/B+A9SeMvJTTJAAAAAElFTkSuQmCC\n" 134 | }, 135 | "metadata": { 136 | "needs_background": "light" 137 | } 138 | } 139 | ], 140 | "source": [ 141 | "plt.plot(x, y)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 7, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "torch.cuda.empty_cache()" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "source": [ 156 | "## Fourier Analysis\n", 157 | "\n", 158 | "Show a model's robustness." 159 | ], 160 | "metadata": { 161 | "collapsed": false 162 | } 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "outputs": [], 168 | "source": [ 169 | "from anatome import fourier_map\n", 170 | "map = fourier_map(model, data, F.cross_entropy, 20, (6, 6), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))" 171 | ], 172 | "metadata": { 173 | "collapsed": false, 174 | "pycharm": { 175 | "name": "#%%\n" 176 | } 177 | } 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "outputs": [], 183 | "source": [ 184 | "plt.imshow(map, interpolation='nearest')" 185 | ], 186 | "metadata": { 187 | "collapsed": false, 188 | "pycharm": { 189 | "name": "#%%\n" 190 | } 191 | } 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "outputs": [], 197 | "source": [ 198 | "torch.cuda.empty_cache()" 199 | ], 200 | "metadata": { 201 | "collapsed": false, 202 | "pycharm": { 203 | "name": "#%%\n" 204 | } 205 | } 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": {}, 210 | "source": [] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 8, 215 | "metadata": { 216 | "tags": [] 217 | }, 218 | "outputs": [ 219 | { 220 | "output_type": "stream", 221 | "name": "stderr", 222 | "text": "100%|███████████████████████████████████████████| 21/21 [00:03<00:00, 6.08it/s]\n" 223 | } 224 | ], 225 | "source": [ 226 | "from anatome import fourier_map\n", 227 | "map = fourier_map(model, data, F.cross_entropy, 20, (6, 6), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 9, 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "output_type": "execute_result", 237 | "data": { 238 | "text/plain": "" 239 | }, 240 | "metadata": {}, 241 | "execution_count": 9 242 | }, 243 | { 244 | "output_type": "display_data", 245 | "data": { 246 | "text/plain": "
", 247 | "image/svg+xml": "\n\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", 248 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAKsUlEQVR4nO3dbWhdhR3H8d/PpNquPlRdV2tTlvqAILLpCH1TmKyzoz5Mh2+moC+G0iETqhtI3Ttf+sb5ZmCDyja0FkEFcW6uoMUVfErb6KzREZzT0I50dlVrpV3T/17ktiQmac69OSfn8N/3A8Gk93L9Kf325N70nuOIEIA8Tqt7AIByETWQDFEDyRA1kAxRA8l0V/Ggi849I865cHEVD92Rnu7DdU9ovOEjZ9c9YZKvDp9R94Qpepfsr3vCSf8a+a8+OzDm6W6rJOpzLlys27f8sIqH7siDywbrntB4Nw+vq3vCJIO7L657whSbf/xI3RNO+vmNn8x4G99+A8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kUitr2etsf2B62vanqUQA6N2vUtrsk/VbStZIul3Sr7curHgagM0WO1KslDUfEhxFxVNJWSTdVOwtAp4pEvULSxHdkj7R+bRLbG2wP2B44/J8jZe0D0KYiUU93ypQpVwCIiP6I6IuIvm+c27xT0QD/L4pEPSJp5YSveyTtrWYOgLkqEvVbki61vcr26ZJukfR8tbMAdGrWEw9GxDHbd0t6SVKXpMcjYk/lywB0pNDZRCPiRUkvVrwFQAn4G2VAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kU+gNHe3q6T6sB5cNVvHQHVn305/VPaHxDl6ysO4Jk5xXye/Mubnzy7vqnnDSyIHfzHgbR2ogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIJlZo7b9uO1R2+/OxyAAc1PkSP07Sesr3gGgJLNGHRGvSjowD1sAlKC059S2N9gesD2w/9Oxsh4WQJtKizoi+iOiLyL6lp7fVdbDAmgTr34DyRA1kEyRH2k9Jek1SZfZHrF9R/WzAHRq1rMrR8St8zEEQDn49htIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkZn1DB6px2l931z1hkqX7VtU9ofGOd19Q94STTjt6itvmbwaA+UDUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJFLpC30vYrtods77G9cT6GAehMkfdTH5P0q4jYZfssSTttb4uI9yreBqADsx6pI2JfROxqff6FpCFJK6oeBqAzbT2ntt0r6SpJb0xz2wbbA7YH9n86Vs46AG0rHLXtMyU9I+meiPj867dHRH9E9EVE39Lzu8rcCKANhaK2vUDjQT8ZEc9WOwnAXBR59duSHpM0FBEPVT8JwFwUOVKvkXS7pLW2B1sf11W8C0CHZv2RVkTskOR52AKgBPyNMiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIpco6ytg0fOVs3D6+r4qE7cvCShXVPmGLpvlV1T5hkbPgfdU+YpPui3ronTLHp3i11Tzjp168emPE2jtRAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJFPkqpcLbb9p+23be2w/MB/DAHSmyPupj0haGxGHWtep3mH7TxHxesXbAHSgyFUvQ9Kh1pcLWh9R5SgAnSv0nNp2l+1BSaOStkXEG9PcZ4PtAdsDRw9+VfZOAAUVijoixiLiSkk9klbbvmKa+/RHRF9E9J2+ZFHZOwEU1Nar3xFxUNJ2SesrWQNgzoq8+r3U9pLW54skXSPp/aqHAehMkVe/l0v6ve0ujf8h8HREvFDtLACdKvLq9zuSrpqHLQBKwN8oA5IhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIJki79Jq21eHz9Dg7oureOiOnFfJf2Uu3Rf11j1hkqMrzq17whSbP7667gkn7T86MuNtHKmBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSKZw1K0Lz++2zcXxgAZr50i9UdJQVUMAlKNQ1LZ7JF0v6dFq5wCYq6JH6ocl3Sfp+Ex3sL3B9oDtgbFDX5YyDkD7Zo3a9g2SRiNi56nuFxH9EdEXEX1dZy4ubSCA9hQ5Uq+RdKPtjyRtlbTW9hOVrgLQsVmjjoj7I6InInol3SLp5Yi4rfJlADrCz6mBZNo6eW5EbJe0vZIlAErBkRpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIpq13aRXVu2S/Nv/4kSoeuiN3fnlX3ROmON59Qd0TJtl075a6J0yy+eOr654wRd/5H9c94aSh7qMz3saRGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkCr31snVt6i8kjUk6FhF9VY4C0Ll23k/9g4j4d2VLAJSCb7+BZIpGHZL+Ynun7Q3T3cH2BtsDtgc++3SsvIUA2lL02+81EbHX9rckbbP9fkS8OvEOEdEvqV+SLvvOwih5J4CCCh2pI2Jv65+jkp6TtLrKUQA6N2vUthfbPuvE55J+JOndqocB6EyRb7+XSXrO9on7b4mIP1e6CkDHZo06Ij6U9N152AKgBPxIC0iGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWQcUf75DGzvl/TPEh7qm5KadF409pxa0/ZIzdtU1p5vR8TS6W6oJOqy2B5o0plL2XNqTdsjNW/TfOzh228gGaIGkml61P11D/ga9pxa0/ZIzdtU+Z5GP6cG0L6mH6kBtImogWQaGbXt9bY/sD1se1MD9jxue9R2I06NbHul7VdsD9neY3tjzXsW2n7T9tutPQ/UuecE2122d9t+oe4t0viFJm3/zfag7YHK/j1Ne05tu0vS3yWtkzQi6S1Jt0bEezVu+r6kQ5L+EBFX1LVjwp7lkpZHxK7WOdl3SvpJXf+PPH7+6MURccj2Akk7JG2MiNfr2DNh1y8l9Uk6OyJuqHNLa89HkvqqvtBkE4/UqyUNR8SHEXFU0lZJN9U5qHWJoQN1bpgoIvZFxK7W519IGpK0osY9ERGHWl8uaH3UerSw3SPpekmP1rmjDk2MeoWkTyZ8PaIaf8M2ne1eSVdJeqPmHV22ByWNStoWEbXukfSwpPskHa95x0SzXmiyDE2M2tP8WrOeIzSE7TMlPSPpnoj4vM4tETEWEVdK6pG02nZtT1Ns3yBpNCJ21rVhBmsi4nuSrpX0i9bTutI1MeoRSSsnfN0jaW9NWxqr9dz1GUlPRsSzde85ISIOStouaX2NM9ZIurH1HHarpLW2n6hxj6T5u9BkE6N+S9KltlfZPl3SLZKer3lTo7RemHpM0lBEPNSAPUttL2l9vkjSNZLer2tPRNwfET0R0avx3z8vR8Rtde2R5vdCk42LOiKOSbpb0ksafwHo6YjYU+cm209Jek3SZbZHbN9R5x6NH4lu1/gRaLD1cV2Ne5ZLesX2Oxr/Q3lbRDTix0gNskzSDttvS3pT0h+rutBk436kBWBuGnekBjA3RA0kQ9RAMkQNJEPUQDJEDSRD1EAy/wO97YMCPlrXMQAAAABJRU5ErkJggg==\n" 249 | }, 250 | "metadata": { 251 | "needs_background": "light" 252 | } 253 | } 254 | ], 255 | "source": [ 256 | "plt.imshow(map, interpolation='nearest')" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 10, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "torch.cuda.empty_cache()" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [] 274 | } 275 | ], 276 | "metadata": { 277 | "kernelspec": { 278 | "display_name": "Python 3.8.3 64-bit ('jupyter': conda)", 279 | "language": "python", 280 | "name": "python38364bitjupyterconda644e0e65b1f74b2f90dbbacb7386a62c" 281 | }, 282 | "language_info": { 283 | "codemirror_mode": { 284 | "name": "ipython", 285 | "version": 2 286 | }, 287 | "file_extension": ".py", 288 | "mimetype": "text/x-python", 289 | "name": "python", 290 | "nbconvert_exporter": "python", 291 | "pygments_lexer": "ipython2", 292 | "version": "3.8.3-final" 293 | } 294 | }, 295 | "nbformat": 4, 296 | "nbformat_minor": 0 297 | } --------------------------------------------------------------------------------