├── LICENSE_apple ├── README.md ├── assets ├── adapt_mcd_pytorch.gif ├── adapt_swd_pytorch.gif ├── outputs.gif └── source_only_pytorch.gif ├── moon_data.npz └── swd_pytorch.py /LICENSE_apple: -------------------------------------------------------------------------------- 1 | Copyright (C) 2019 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sliced Wasserstein Discrepancy for Unsupervised Domain Adaptation in PyTorch 2 | This is a PyTorch re-implementation of [CVPR 2019](http://cvpr2019.thecvf.com) paper "Sliced Wasserstein Discrepancy for Unsupervised Domain Adaptation" from Apple. 3 | 4 | If you find this repository helpful, please consider to cite the [original paper](https://arxiv.org/abs/1903.04064). 5 | 6 | ## Introduction 7 | This repository aims to reproduce the results presented in the [official repository](https://github.com/apple/ml-cvpr2019-swd). Thus, only a basic implementation on [intertwining 8 | moons 2D dataset](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html) is provided here. 9 | 10 | ## Requirements 11 | * Python 3.x 12 | * Pytorch 13 | * matplotlib 14 | 15 | This code is tested under Ubuntu 16.04 with Python 3.6 and PyTorch 1.1.0. A GPU is **NOT** required to run this code. 16 | 17 | ## Running the code 18 | To run the demo with adaptation: 19 | ``` 20 | python swd_pytorch.py -mode adapt_swd 21 | ``` 22 | 23 | To run the demo without adaptation: 24 | ``` 25 | python swd_pytorch.py -mode source_only 26 | ``` 27 | 28 | ## Interpreting Outputs 29 | Outputs will be saved as png and gif files in the current folder for each mode. 30 | The outputs show the source and target samples with the current decision boundary. Blue and red points are source samples of class 0 31 | and 1. Target samples are represented by green points. 32 |  
33 | 34 | 35 | ## Acknowledgement 36 | [ml-cvpr2019-swd](https://github.com/apple/ml-cvpr2019-swd) (Official implementation in Tensorflow) 37 | -------------------------------------------------------------------------------- /assets/adapt_mcd_pytorch.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krumo/swd_pytorch/9d5b49c8eed758da4677410b1036d3c28b4d17e0/assets/adapt_mcd_pytorch.gif -------------------------------------------------------------------------------- /assets/adapt_swd_pytorch.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krumo/swd_pytorch/9d5b49c8eed758da4677410b1036d3c28b4d17e0/assets/adapt_swd_pytorch.gif -------------------------------------------------------------------------------- /assets/outputs.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krumo/swd_pytorch/9d5b49c8eed758da4677410b1036d3c28b4d17e0/assets/outputs.gif -------------------------------------------------------------------------------- /assets/source_only_pytorch.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krumo/swd_pytorch/9d5b49c8eed758da4677410b1036d3c28b4d17e0/assets/source_only_pytorch.gif -------------------------------------------------------------------------------- /moon_data.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krumo/swd_pytorch/9d5b49c8eed758da4677410b1036d3c28b4d17e0/moon_data.npz -------------------------------------------------------------------------------- /swd_pytorch.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import matplotlib.pyplot as plt 7 | import imageio 8 | import platform 9 | if platform.system() == 'Darwin': 10 | import matplotlib 11 | matplotlib.use('TkAgg') 12 | 13 | def toyNet(): 14 | # Define network architecture 15 | class Generator(nn.Module): 16 | def __init__(self): 17 | super(Generator, self).__init__() 18 | self.l1 = nn.Linear(2, 15) 19 | self.l2 = nn.Linear(15, 15) 20 | self.l3 = nn.Linear(15, 15) 21 | self.relu = nn.ReLU(inplace=True) 22 | 23 | for m in self.modules(): 24 | if isinstance(m, nn.Linear): 25 | nn.init.xavier_uniform_(m.weight) 26 | nn.init.constant_(m.bias, 0) 27 | 28 | def forward(self, x): 29 | x = self.relu(self.l1(x)) 30 | x = self.relu(self.l2(x)) 31 | x = self.relu(self.l3(x)) 32 | return x 33 | class Classifier1(nn.Module): 34 | def __init__(self): 35 | super(Classifier1, self).__init__() 36 | self.l1 = nn.Linear(15, 15) 37 | self.l2 = nn.Linear(15, 15) 38 | self.l3 = nn.Linear(15, 1) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.sigmoid = nn.Sigmoid() 41 | 42 | for m in self.modules(): 43 | if isinstance(m, nn.Linear): 44 | nn.init.xavier_uniform_(m.weight) 45 | nn.init.constant_(m.bias, 0) 46 | 47 | def forward(self, x): 48 | x = self.relu(self.l1(x)) 49 | x = self.relu(self.l2(x)) 50 | x = self.sigmoid(self.l3(x)) 51 | return x 52 | class Classifier2(nn.Module): 53 | def __init__(self): 54 | super(Classifier2, self).__init__() 55 | self.l1 = nn.Linear(15, 15) 56 | self.l2 = nn.Linear(15, 15) 57 | self.l3 = nn.Linear(15, 1) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.sigmoid = nn.Sigmoid() 60 | 61 | for m in self.modules(): 62 | if isinstance(m, nn.Linear): 63 | nn.init.xavier_uniform_(m.weight) 64 | nn.init.constant_(m.bias, 0) 65 | 66 | def forward(self, x): 67 | x = self.relu(self.l1(x)) 68 | x = self.relu(self.l2(x)) 69 | x = self.sigmoid(self.l3(x)) 70 | return x 71 | return Generator(), Classifier1(), Classifier2() 72 | 73 | def discrepancy_slice_wasserstein(p1, p2): 74 | s = p1.shape 75 | if s[1]>1: 76 | proj = torch.randn(s[1], 128) 77 | proj *= torch.rsqrt(torch.sum(torch.mul(proj, proj), 0, keepdim=True)) 78 | p1 = torch.matmul(p1, proj) 79 | p2 = torch.matmul(p2, proj) 80 | p1 = torch.topk(p1, s[0], dim=0)[0] 81 | p2 = torch.topk(p2, s[0], dim=0)[0] 82 | dist = p1-p2 83 | wdist = torch.mean(torch.mul(dist, dist)) 84 | 85 | return wdist 86 | 87 | def discrepancy_mcd(out1, out2): 88 | return torch.mean(torch.abs(out1 - out2)) 89 | 90 | 91 | def load_data(): 92 | # Load inter twinning moons 2D dataset by F. Pedregosa et al. in JMLR 2011 93 | moon_data = np.load('moon_data.npz') 94 | x_s = moon_data['x_s'] 95 | y_s = moon_data['y_s'] 96 | x_t = moon_data['x_t'] 97 | return torch.from_numpy(x_s).float(), torch.from_numpy(y_s).float(), torch.from_numpy(x_t).float() 98 | 99 | 100 | def generate_grid_point(): 101 | x_min, x_max = x_s[:, 0].min() - .5, x_s[:, 0].max() + 0.5 102 | y_min, y_max = x_s[:, 1].min() - .5, x_s[:, 1].max() + 0.5 103 | xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01), np.arange(y_min, y_max, 0.01)) 104 | return xx, yy 105 | 106 | 107 | if __name__ == "__main__": 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument('-mode', type=str, default="adapt_swd", 110 | choices=["source_only", "adapt_mcd", "adapt_swd"]) 111 | parser.add_argument('-seed', type=int, default=1234) 112 | opts = parser.parse_args() 113 | 114 | # Load data 115 | x_s, y_s, x_t = load_data() 116 | 117 | # set random seed 118 | torch.manual_seed(opts.seed) 119 | 120 | torch.backends.cudnn.enabled = True 121 | torch.backends.cudnn.deterministic = True 122 | 123 | # Network definition 124 | generator, cls1, cls2 = toyNet() 125 | generator.train() 126 | cls1.train() 127 | cls2.train() 128 | 129 | # Cost functions 130 | bce_loss = nn.BCELoss() 131 | 132 | # Setup optimizers 133 | optim_g = torch.optim.SGD(generator.parameters(), lr=0.005) 134 | optim_f = torch.optim.SGD(list(cls1.parameters())+list(cls2.parameters()), lr=0.005) 135 | optim_g.zero_grad() 136 | optim_f.zero_grad() 137 | 138 | # # Generate grid points for visualization 139 | xx, yy = generate_grid_point() 140 | 141 | # For creating GIF purpose 142 | gif_images = [] 143 | 144 | for step in range(10001): 145 | if step%1000==0: 146 | print("Iteration: %d / %d" % (step, 10000)) 147 | z = torch.from_numpy(np.c_[xx.ravel(), yy.ravel()]).float() 148 | with torch.no_grad(): 149 | fea = generator(z) 150 | Z = (cls2(fea).cpu().numpy()>0.5).astype(np.float32) 151 | Z = Z.reshape(xx.shape) 152 | f = plt.figure() 153 | plt.contourf(xx, yy, Z, cmap=plt.cm.copper_r, alpha=0.9) 154 | plt.scatter(x_s[:, 0], x_s[:, 1], c=y_s.reshape((len(x_s))), 155 | cmap=plt.cm.coolwarm, alpha=0.8) 156 | plt.scatter(x_t[:, 0], x_t[:, 1], color='green', alpha=0.7) 157 | plt.text(1.6, -0.9, 'Iter: ' + str(step), fontsize=14, color='#FFD700', 158 | bbox=dict(facecolor='dimgray', alpha=0.7)) 159 | plt.axis('off') 160 | f.savefig(opts.mode + '_pytorch_iter' + str(step) + ".png", bbox_inches='tight', 161 | pad_inches=0, dpi=100, transparent=True) 162 | gif_images.append(imageio.imread( 163 | opts.mode + '_pytorch_iter' + str(step) + ".png")) 164 | plt.close() 165 | 166 | optim_g.zero_grad() 167 | optim_f.zero_grad() 168 | fea = generator(x_s) 169 | pred1 = cls1(fea) 170 | pred2 = cls2(fea) 171 | loss_s = bce_loss(pred1, y_s) + bce_loss(pred2, y_s) 172 | loss_s.backward() 173 | optim_g.step() 174 | optim_f.step() 175 | 176 | if opts.mode == 'source_only': 177 | continue 178 | 179 | optim_g.zero_grad() 180 | optim_f.zero_grad() 181 | loss = 0 182 | src_fea = generator(x_s) 183 | src_fea = src_fea.detach() 184 | src_pred1 = cls1(src_fea) 185 | src_pred2 = cls2(src_fea) 186 | loss += bce_loss(src_pred1, y_s) + bce_loss(src_pred2, y_s) 187 | # loss_s.backward() 188 | 189 | tgt_fea = generator(x_t) 190 | tgt_fea = tgt_fea.detach() 191 | tgt_pred1 = cls1(tgt_fea) 192 | tgt_pred2 = cls2(tgt_fea) 193 | if opts.mode == 'adapt_swd': 194 | loss_dis = 2*discrepancy_slice_wasserstein(tgt_pred1, tgt_pred2) 195 | else: 196 | loss_dis = discrepancy_mcd(tgt_pred1, tgt_pred2) 197 | loss -= loss_dis 198 | loss.backward() 199 | optim_f.step() 200 | 201 | optim_g.zero_grad() 202 | tgt_fea = generator(x_t) 203 | tgt_pred1 = cls1(tgt_fea) 204 | tgt_pred2 = cls2(tgt_fea) 205 | if opts.mode == 'adapt_swd': 206 | loss_dis = discrepancy_slice_wasserstein(tgt_pred1, tgt_pred2) 207 | else: 208 | loss_dis = discrepancy_mcd(tgt_pred1, tgt_pred2) 209 | loss_dis.backward() 210 | optim_g.step() 211 | 212 | # Save GIF 213 | imageio.mimsave(opts.mode + '_pytorch.gif', gif_images, duration=0.8) 214 | print("[Finished]\n-> Please see the current folder for outputs.") 215 | --------------------------------------------------------------------------------