├── .gitignore ├── LICENSE ├── README.md ├── figure ├── ssim_cifar_remove_class.png ├── swd_cifar_inbalance.png ├── swd_cifar_remove_class.png ├── swd_cifar_test.png ├── swd_cifar_test2.png ├── swd_random_test.png └── swd_random_test2.png ├── ssim_compare.py ├── swd.py └── swd_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | */ 2 | !figure/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 こしあん 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sliced Wasserstein Distance (SWD) in PyTorch 2 | An implementation of Sliced Wasserstein Distance (SWD) in PyTorch. **GPU acceleration is available**. 3 | 4 | SWD is not only for GANs. **SWD can measure image distribution mismatches or imbalances without additional labels.** 5 | 6 | ## About 7 | Original idea is written in [PGGAN paper](https://arxiv.org/pdf/1710.10196.pdf). This repo is an unofficial implementation. 8 | 9 | [Original code](https://github.com/tkarras/progressive_growing_of_gans) is for Numpy. But this repo's code is for PyTorch, so you can calculate SWD on CUDA devices. 10 | 11 | ## How to use 12 | A simple example of calculating SWD on GPU. 13 | 14 | ```python 15 | import torch 16 | from swd import swd 17 | 18 | torch.manual_seed(123) # fix seed 19 | x1 = torch.rand(1024, 3, 128, 128) # 1024 images, 3 chs, 128x128 resolution 20 | x2 = torch.rand(1024, 3, 128, 128) 21 | out = swd(x1, x2, device="cuda") # Fast estimation if device="cuda" 22 | print(out) # tensor(53.6950) 23 | ``` 24 | 25 | ## Japanese article 26 | PyTorchでSliced Wasserstein Distance (SWD)を実装した 27 | [https://blog.shikoan.com/swd-pytorch/](https://blog.shikoan.com/swd-pytorch/) 28 | 29 | ## Parameter details 30 | Detail information of ```swd``` paramters. 31 | 32 | * ```image1, image2``` : **Required** 4rank PyTorch tensor. Each tensor shapes are [N, ch, H, W]. Square size(H=W) is recommended. 33 | * ```n_pyramid``` : (Optional) Number of laplacian pyramids. If ```None```(default : same as paper), downsample pyramids toward 16x16 resolution. Output number of pyramids is ```n_pyramid + 1```, because lowest resolution gaussian pyramid is added to laplacian pyramids sequence. 34 | * ```slice_size``` : (Optional) Patch size when slicing each layer of pyramids. Default is 7 (same as paper). 35 | * ```n_descriptors```: (Optional) Number of descriptors per image. Default is 128 (same as paper). 36 | * ```n_repeat_projection``` : (Optional) Number of times to calculate a random projection. **Please specify this value according your GPU memory.** Default is 128. ```n_repeat_projection * proj_per_repeat = 512``` is recommended. This product value 512 is same as paper, but official implementation uses 4 for n_repeat_projection and 128 for proj_per_repeat. (This method needs huge amount of memory...) 37 | * ```proj_per_repeat``` : (Optional) Number of dimension to calculate a random projection on each repeat. Default is 4. Higher value needs much more GPU memory. ```n_repeat_projection * proj_per_repeat = 512``` is recommended. 38 | * ```device``` : (Optional) ```"cpu"``` or ```"cuda"```. **Please specify ```cuda``` when uses gpu acceleration.** Default is ```"cpu"```. 39 | * ```return_by_resolution``` : (Optional) If True, returns SWD by each resolutions (laplacian pyramids). If False, returns the average of SWD values ​​by resolution. Default is False. 40 | * ```pyramid_batchsize``` : (Optional) Mini batch size of calculating laplacian pyramids. Higher value may cause CUDA out of memory error. This value does not affect on SWD estimation. Default is 128. 41 | 42 | 43 | ## Experiments 44 | ### Changing n_repeat_projection and proj_per_repeat 45 | **Changing ```n_repeat_projection``` and ```proj_per_repeat``` has little effect swd** (if n_repeat_projection * proj_per_repeat is constant). 46 | 47 | Each plot shows SWD value by resolution of laplacian pyramid. Horizontal axis is proj_per_repeat and vertical axis is SWD. Each condition is run 10 times. 48 | 49 | In all conditions, n_repeat_projection * proj_per_repeat is fixed at 512. 50 | 51 | #### Random noise 52 | Compares 16384 different two random tensors. 53 | 54 | ![](figure/swd_random_test.png) 55 | 56 | #### CIFAR-10 57 | CIFAR-10 compares 10k training data with 10k test data. 58 | 59 | ![](figure/swd_cifar_test.png) 60 | 61 | So, you can change ```n_repeat_projection``` and ```proj_per_repeat``` values according GPU memory. 62 | 63 | ### Changing the number of data 64 | **Changing the number of data has a huge impact on SWD (important)**. 65 | 66 | Each plot shows SWD value by resolution of laplacian pyramid. Horizontal axis is number of data and vertical axis is SWD. Each condition is run 10 times. 67 | 68 | #### Random noise 69 | ![](figure/swd_random_test2.png) 70 | 71 | #### CIFAR-10 72 | ![](figure/swd_cifar_test2.png) 73 | 74 | It is important to fix the number of samples initially. If the number of samples changes, SWD returns incorrect result. 75 | 76 | ## As a metric for distribution-mismatch 77 | **SWD can be used as a metric of distribution mismatch.** 78 | 79 | 2 experiments on CIFAR-10. Measure SWD between training and test data in following conditions: 80 | 81 | 1. **Remove classes** : Test data is without changing, while training data is deleting 0-8 classes. 82 | 2. **Inbalanced classes** : Test data is without changing, while training data create imbalances artificially : 83 | Training A is data removed 1 class from whole training set (inbalanced set). Training B is changing nothing (balanced set). 84 | A and B are concatenated with a size of 0-10000, and only 1 class create unbalanced data. 85 | 86 | Experiment 1 and 2 are also imbalanced data, but 1 produces a stronger imbalance or distribution mismatch. 87 | 88 | ### 1. Remove classes 89 | 90 | Each plot shows SWD value by resolution of laplacian pyramid. Horizontal axis is number of removed classes and vertical axis is SWD. Each condition is run 10 times. 91 | 92 | ![](figure/swd_cifar_remove_class.png) 93 | 94 | As more classes are deleted, higher SWD are observed. 95 | 96 | ### 2. Inbalanced classes 97 | 98 | Each plot shows SWD value by index of unbalanced classes. Horizontal axis is number of inbalanced set(training A) and vertical axis is SWD. Each condition is run once. 99 | 100 | ![](figure/swd_cifar_inbalance.png) 101 | 102 | This is a weaker imbalance than experiment 1, but SWD can capture this imbalance or mismatch. 103 | 104 | ### (Optional) Compare to SSIM 105 | One thing that concerned is whether this kind of imbalance can be detected by other indicators (e.g. SSIM). Run experiment 1 with SSIM. 106 | 107 | ![](figure/ssim_cifar_remove_class.png) 108 | 109 | SSIM don't detect imbalances well. 110 | 111 | Therefore, It can be confirmed that SWD is effective for mismatch detection. -------------------------------------------------------------------------------- /figure/ssim_cifar_remove_class.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koshian2/swd-pytorch/2b0c224fa4e43ab081a40380689d6a334959eb65/figure/ssim_cifar_remove_class.png -------------------------------------------------------------------------------- /figure/swd_cifar_inbalance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koshian2/swd-pytorch/2b0c224fa4e43ab081a40380689d6a334959eb65/figure/swd_cifar_inbalance.png -------------------------------------------------------------------------------- /figure/swd_cifar_remove_class.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koshian2/swd-pytorch/2b0c224fa4e43ab081a40380689d6a334959eb65/figure/swd_cifar_remove_class.png -------------------------------------------------------------------------------- /figure/swd_cifar_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koshian2/swd-pytorch/2b0c224fa4e43ab081a40380689d6a334959eb65/figure/swd_cifar_test.png -------------------------------------------------------------------------------- /figure/swd_cifar_test2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koshian2/swd-pytorch/2b0c224fa4e43ab081a40380689d6a334959eb65/figure/swd_cifar_test2.png -------------------------------------------------------------------------------- /figure/swd_random_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koshian2/swd-pytorch/2b0c224fa4e43ab081a40380689d6a334959eb65/figure/swd_random_test.png -------------------------------------------------------------------------------- /figure/swd_random_test2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koshian2/swd-pytorch/2b0c224fa4e43ab081a40380689d6a334959eb65/figure/swd_random_test2.png -------------------------------------------------------------------------------- /ssim_compare.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import pickle 4 | import matplotlib.pyplot as plt 5 | 6 | def load_cifar(ignore_train_classes): 7 | (X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data() 8 | if len(ignore_train_classes) > 0: 9 | filter_labels = y_train == np.array(ignore_train_classes).reshape(1, -1) 10 | filter_labels = np.any(filter_labels, axis=1) 11 | else: 12 | filter_labels = np.zeros(y_train.shape[0], np.bool) 13 | train_img = X_train[~filter_labels][:10000].astype(np.float32) / 255.0 14 | test_img = X_test.astype(np.float32) / 255.0 15 | return tf.constant(train_img), tf.constant(test_img) 16 | 17 | # SSIM : for distribution mismatch 18 | def cifar_remove_class_test(): 19 | np.set_printoptions(precision=2) 20 | result = {} 21 | for i in range(9): 22 | train_img, test_img = load_cifar([j for j in range(i)]) # remove classes 23 | dist = tf.reduce_mean(tf.image.ssim(train_img, test_img, max_val=1.0)).numpy() 24 | result[i] = dist 25 | print("remove classes", i, dist) 26 | with open("ssim_cifar_remove_class.pkl", "wb") as fp: 27 | pickle.dump(result, fp) 28 | 29 | def plot_results(filename): 30 | with open(filename, "rb") as fp: 31 | data = pickle.load(fp) 32 | plt.plot(data.keys(), data.values()) 33 | plt.show() 34 | -------------------------------------------------------------------------------- /swd.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import torchvision 6 | 7 | # Gaussian blur kernel 8 | def get_gaussian_kernel(device="cpu"): 9 | kernel = np.array([ 10 | [1, 4, 6, 4, 1], 11 | [4, 16, 24, 16, 4], 12 | [6, 24, 36, 24, 6], 13 | [4, 16, 24, 16, 4], 14 | [1, 4, 6, 4, 1]], np.float32) / 256.0 15 | gaussian_k = torch.as_tensor(kernel.reshape(1, 1, 5, 5)).to(device) 16 | return gaussian_k 17 | 18 | def pyramid_down(image, device="cpu"): 19 | gaussian_k = get_gaussian_kernel(device=device) 20 | # channel-wise conv(important) 21 | multiband = [F.conv2d(image[:, i:i + 1,:,:], gaussian_k, padding=2, stride=2) for i in range(3)] 22 | down_image = torch.cat(multiband, dim=1) 23 | return down_image 24 | 25 | def pyramid_up(image, device="cpu"): 26 | gaussian_k = get_gaussian_kernel(device=device) 27 | upsample = F.interpolate(image, scale_factor=2) 28 | multiband = [F.conv2d(upsample[:, i:i + 1,:,:], gaussian_k, padding=2) for i in range(3)] 29 | up_image = torch.cat(multiband, dim=1) 30 | return up_image 31 | 32 | def gaussian_pyramid(original, n_pyramids, device="cpu"): 33 | x = original 34 | # pyramid down 35 | pyramids = [original] 36 | for i in range(n_pyramids): 37 | x = pyramid_down(x, device=device) 38 | pyramids.append(x) 39 | return pyramids 40 | 41 | def laplacian_pyramid(original, n_pyramids, device="cpu"): 42 | # create gaussian pyramid 43 | pyramids = gaussian_pyramid(original, n_pyramids, device=device) 44 | 45 | # pyramid up - diff 46 | laplacian = [] 47 | for i in range(len(pyramids) - 1): 48 | diff = pyramids[i] - pyramid_up(pyramids[i + 1], device=device) 49 | laplacian.append(diff) 50 | # Add last gaussian pyramid 51 | laplacian.append(pyramids[len(pyramids) - 1]) 52 | return laplacian 53 | 54 | def minibatch_laplacian_pyramid(image, n_pyramids, batch_size, device="cpu"): 55 | n = image.size(0) // batch_size + np.sign(image.size(0) % batch_size) 56 | pyramids = [] 57 | for i in range(n): 58 | x = image[i * batch_size:(i + 1) * batch_size] 59 | p = laplacian_pyramid(x.to(device), n_pyramids, device=device) 60 | p = [x.cpu() for x in p] 61 | pyramids.append(p) 62 | del x 63 | result = [] 64 | for i in range(n_pyramids + 1): 65 | x = [] 66 | for j in range(n): 67 | x.append(pyramids[j][i]) 68 | result.append(torch.cat(x, dim=0)) 69 | return result 70 | 71 | def extract_patches(pyramid_layer, slice_indices, 72 | slice_size=7, unfold_batch_size=128, device="cpu"): 73 | assert pyramid_layer.ndim == 4 74 | n = pyramid_layer.size(0) // unfold_batch_size + np.sign(pyramid_layer.size(0) % unfold_batch_size) 75 | # random slice 7x7 76 | p_slice = [] 77 | for i in range(n): 78 | # [unfold_batch_size, ch, n_slices, slice_size, slice_size] 79 | ind_start = i * unfold_batch_size 80 | ind_end = min((i + 1) * unfold_batch_size, pyramid_layer.size(0)) 81 | x = pyramid_layer[ind_start:ind_end].unfold( 82 | 2, slice_size, 1).unfold(3, slice_size, 1).reshape( 83 | ind_end - ind_start, pyramid_layer.size(1), -1, slice_size, slice_size) 84 | # [unfold_batch_size, ch, n_descriptors, slice_size, slice_size] 85 | x = x[:,:, slice_indices,:,:] 86 | # [unfold_batch_size, n_descriptors, ch, slice_size, slice_size] 87 | p_slice.append(x.permute([0, 2, 1, 3, 4])) 88 | # sliced tensor per layer [batch, n_descriptors, ch, slice_size, slice_size] 89 | x = torch.cat(p_slice, dim=0) 90 | # normalize along ch 91 | std, mean = torch.std_mean(x, dim=(0, 1, 3, 4), keepdim=True) 92 | x = (x - mean) / (std + 1e-8) 93 | # reshape to 2rank 94 | x = x.reshape(-1, 3 * slice_size * slice_size) 95 | return x 96 | 97 | def swd(image1, image2, 98 | n_pyramids=None, slice_size=7, n_descriptors=128, 99 | n_repeat_projection=128, proj_per_repeat=4, device="cpu", return_by_resolution=False, 100 | pyramid_batchsize=128): 101 | # n_repeat_projectton * proj_per_repeat = 512 102 | # Please change these values according to memory usage. 103 | # original = n_repeat_projection=4, proj_per_repeat=128 104 | assert image1.size() == image2.size() 105 | assert image1.ndim == 4 and image2.ndim == 4 106 | 107 | if n_pyramids is None: 108 | n_pyramids = int(np.rint(np.log2(image1.size(2) // 16))) 109 | with torch.no_grad(): 110 | # minibatch laplacian pyramid for cuda memory reasons 111 | pyramid1 = minibatch_laplacian_pyramid(image1, n_pyramids, pyramid_batchsize, device=device) 112 | pyramid2 = minibatch_laplacian_pyramid(image2, n_pyramids, pyramid_batchsize, device=device) 113 | result = [] 114 | 115 | for i_pyramid in range(n_pyramids + 1): 116 | # indices 117 | n = (pyramid1[i_pyramid].size(2) - 6) * (pyramid1[i_pyramid].size(3) - 6) 118 | indices = torch.randperm(n)[:n_descriptors] 119 | 120 | # extract patches on CPU 121 | # patch : 2rank (n_image*n_descriptors, slice_size**2*ch) 122 | p1 = extract_patches(pyramid1[i_pyramid], indices, 123 | slice_size=slice_size, device="cpu") 124 | p2 = extract_patches(pyramid2[i_pyramid], indices, 125 | slice_size=slice_size, device="cpu") 126 | 127 | p1, p2 = p1.to(device), p2.to(device) 128 | 129 | distances = [] 130 | for j in range(n_repeat_projection): 131 | # random 132 | rand = torch.randn(p1.size(1), proj_per_repeat).to(device) # (slice_size**2*ch) 133 | rand = rand / torch.std(rand, dim=0, keepdim=True) # noramlize 134 | # projection 135 | proj1 = torch.matmul(p1, rand) 136 | proj2 = torch.matmul(p2, rand) 137 | proj1, _ = torch.sort(proj1, dim=0) 138 | proj2, _ = torch.sort(proj2, dim=0) 139 | d = torch.abs(proj1 - proj2) 140 | distances.append(torch.mean(d)) 141 | 142 | # swd 143 | result.append(torch.mean(torch.stack(distances))) 144 | 145 | # average over resolution 146 | result = torch.stack(result) * 1e3 147 | if return_by_resolution: 148 | return result.cpu() 149 | else: 150 | return torch.mean(result).cpu() 151 | -------------------------------------------------------------------------------- /swd_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from swd import swd 3 | import pickle 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import torchvision 7 | 8 | # random image : change proj_repeat 9 | def random_value_test(): 10 | torch.manual_seed(123) 11 | image1 = torch.randn(16384, 3, 128, 128) 12 | image2 = torch.randn(16384, 3, 128, 128) 13 | np.set_printoptions(precision=2) 14 | result = {} 15 | for n_proj in [128, 64, 32, 16, 8, 4, 2, 1]: 16 | dists = [] 17 | for i in range(10): 18 | dists.append(swd(image1, image2, proj_per_repeat=n_proj, 19 | n_repeat_projection=512 // n_proj, 20 | device="cuda" if n_proj <= 32 else "cpu", 21 | return_by_resolution=True).numpy()) 22 | dists = np.array(dists) 23 | result[n_proj] = dists 24 | print("proj=", n_proj, dists) 25 | with open("swd_random_test.pkl", "wb") as fp: 26 | pickle.dump(result, fp) 27 | 28 | # random image : change image size 29 | def random_value_test2(): 30 | np.set_printoptions(precision=2) 31 | result = {} 32 | for image_size in [16384, 8192, 4096, 2048, 1024, 512, 256, 128, 64, 32]: 33 | torch.manual_seed(123) 34 | image1 = torch.randn(image_size, 3, 128, 128) 35 | image2 = torch.randn(image_size, 3, 128, 128) 36 | dists = [] 37 | for i in range(10): 38 | dists.append(swd(image1, image2, proj_per_repeat=32, 39 | n_repeat_projection=16, 40 | device="cuda", return_by_resolution=True).numpy()) 41 | dists = np.array(dists) 42 | result[image_size] = dists 43 | print("image_size", image_size, dists) 44 | with open("swd_random_test2.pkl", "wb") as fp: 45 | pickle.dump(result, fp) 46 | 47 | ## Test on cifar 48 | def load_cifar(ignore_train_classes): 49 | dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True) 50 | labels = np.array(dataset.targets) 51 | if len(ignore_train_classes) > 0: 52 | filter_labels = labels.reshape(-1, 1) == np.array(ignore_train_classes).reshape(1, -1) 53 | filter_labels = np.any(filter_labels, axis=1) 54 | else: 55 | filter_labels = np.zeros(labels.shape[0], np.bool) 56 | train_img = torch.as_tensor(dataset.data[~filter_labels][:10000].transpose([0, 3, 1, 2]).astype(np.float32) / 255.0) 57 | dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True) 58 | test_img = torch.as_tensor(dataset.data.transpose([0, 3, 1, 2]).astype(np.float32) / 255.0) 59 | return train_img, test_img 60 | 61 | # cifar10 : change proj_repeat 62 | def cifar_test(): 63 | train_img, test_img = load_cifar([]) 64 | 65 | np.set_printoptions(precision=2) 66 | result = {} 67 | for n_proj in [128, 64, 32, 16, 8, 4, 2, 1]: 68 | dists = [] 69 | for i in range(10): 70 | dists.append(swd(train_img, test_img, proj_per_repeat=n_proj, 71 | n_repeat_projection=512 // n_proj, 72 | device="cuda" if n_proj <= 64 else "cpu", 73 | return_by_resolution=True).numpy()) 74 | dists = np.array(dists) 75 | result[n_proj] = dists 76 | print("proj=", n_proj, dists) 77 | with open("swd_cifar_test.pkl", "wb") as fp: 78 | pickle.dump(result, fp) 79 | 80 | # cifar : change image size 81 | def cifar_test2(): 82 | train_img, test_img = load_cifar([]) 83 | 84 | np.set_printoptions(precision=2) 85 | result = {} 86 | for image_size in [10000, 8192, 4096, 2048, 1024, 512, 256, 128, 64, 32]: 87 | dists = [] 88 | for i in range(10): 89 | dists.append(swd(train_img[:image_size], test_img[:image_size], proj_per_repeat=64, 90 | n_repeat_projection=8, 91 | device="cuda", return_by_resolution=True).numpy()) 92 | dists = np.array(dists) 93 | result[image_size] = dists 94 | print("image_size", image_size, dists) 95 | with open("swd_cifar_test2.pkl", "wb") as fp: 96 | pickle.dump(result, fp) 97 | 98 | # cifar : for distribution mismatch 99 | def cifar_remove_class_test(): 100 | np.set_printoptions(precision=2) 101 | result = {} 102 | for i in range(9): 103 | train_img, test_img = load_cifar([j for j in range(i)]) # remove classes 104 | dists = [] 105 | for j in range(10): 106 | dists.append(swd(train_img, test_img, proj_per_repeat=64, n_repeat_projection=8, 107 | device="cuda", return_by_resolution=True).numpy()) 108 | dists = np.array(dists) 109 | result[i] = dists 110 | print("remove classes", i, dists) 111 | with open("swd_cifar_remove_class.pkl", "wb") as fp: 112 | pickle.dump(result, fp) 113 | 114 | # cifar : for inbalance data mismatch 115 | def cifar_inbalance_class_test(): 116 | np.set_printoptions(precision=2) 117 | result = {} 118 | for i in range(11): 119 | dists = [] 120 | for remove_class in range(10): 121 | # inbalance for one class, others are normal 122 | balanced_img, test_img = load_cifar([]) 123 | inbalanced_img, _ = load_cifar([remove_class]) 124 | train_img = torch.cat([inbalanced_img[:i * 1000], balanced_img[i * 1000:]], dim=0) 125 | dists.append(swd(train_img, test_img, n_repeat_projection=8, proj_per_repeat=64, 126 | device="cuda", return_by_resolution=True).numpy()) 127 | dists = np.array(dists) 128 | result[i * 1000] = dists 129 | print("n_inbalance", i * 1000, dists) 130 | with open("swd_cifar_inbalance.pkl", "wb") as fp: 131 | pickle.dump(result, fp) 132 | 133 | 134 | def plot_results(filename): 135 | with open(filename, "rb") as fp: 136 | data = pickle.load(fp) 137 | n = len(next(iter(data.values()))[0]) 138 | for i in range(n): 139 | points = [] 140 | for key in data.keys(): 141 | points.append(data[key][:, i]) 142 | rn = int(np.ceil(np.sqrt(n))) 143 | ax = plt.subplot(rn, rn, i + 1) 144 | ax.boxplot(points) 145 | ax.set_xticklabels(data.keys()) 146 | plt.show() 147 | 148 | def plot_inbalance(filename): 149 | with open(filename, "rb") as fp: 150 | data = pickle.load(fp) 151 | plt.subplots_adjust(top=0.95, bottom=0.05, hspace=0.2, wspace=0.1, left=0.05, right=0.95) 152 | for i in range(10): 153 | points = [] 154 | for key in data.keys(): 155 | points.append(np.mean(data[key][i, :])) 156 | ax = plt.subplot(5, 2, i + 1) 157 | ax.plot(data.keys(), points, label="class = " + str(i)) 158 | ax.legend() 159 | plt.show() 160 | --------------------------------------------------------------------------------