├── LICENSE_apple
├── README.md
├── assets
├── adapt_mcd_pytorch.gif
├── adapt_swd_pytorch.gif
├── outputs.gif
└── source_only_pytorch.gif
├── moon_data.npz
└── swd_pytorch.py
/LICENSE_apple:
--------------------------------------------------------------------------------
1 | Copyright (C) 2019 Apple Inc. All Rights Reserved.
2 |
3 | IMPORTANT: This Apple software is supplied to you by Apple
4 | Inc. ("Apple") in consideration of your agreement to the following
5 | terms, and your use, installation, modification or redistribution of
6 | this Apple software constitutes acceptance of these terms. If you do
7 | not agree with these terms, please do not use, install, modify or
8 | redistribute this Apple software.
9 |
10 | In consideration of your agreement to abide by the following terms, and
11 | subject to these terms, Apple grants you a personal, non-exclusive
12 | license, under Apple's copyrights in this original Apple software (the
13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple
14 | Software, with or without modifications, in source and/or binary forms;
15 | provided that if you redistribute the Apple Software in its entirety and
16 | without modifications, you must retain this notice and the following
17 | text and disclaimers in all such redistributions of the Apple Software.
18 | Neither the name, trademarks, service marks or logos of Apple Inc. may
19 | be used to endorse or promote products derived from the Apple Software
20 | without specific prior written permission from Apple. Except as
21 | expressly stated in this notice, no other rights or licenses, express or
22 | implied, are granted by Apple herein, including but not limited to any
23 | patent rights that may be infringed by your derivative works or by other
24 | works in which the Apple Software may be incorporated.
25 |
26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31 |
32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39 | POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Sliced Wasserstein Discrepancy for Unsupervised Domain Adaptation in PyTorch
2 | This is a PyTorch re-implementation of [CVPR 2019](http://cvpr2019.thecvf.com) paper "Sliced Wasserstein Discrepancy for Unsupervised Domain Adaptation" from Apple.
3 |
4 | If you find this repository helpful, please consider to cite the [original paper](https://arxiv.org/abs/1903.04064).
5 |
6 | ## Introduction
7 | This repository aims to reproduce the results presented in the [official repository](https://github.com/apple/ml-cvpr2019-swd). Thus, only a basic implementation on [intertwining
8 | moons 2D dataset](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html) is provided here.
9 |
10 | ## Requirements
11 | * Python 3.x
12 | * Pytorch
13 | * matplotlib
14 |
15 | This code is tested under Ubuntu 16.04 with Python 3.6 and PyTorch 1.1.0. A GPU is **NOT** required to run this code.
16 |
17 | ## Running the code
18 | To run the demo with adaptation:
19 | ```
20 | python swd_pytorch.py -mode adapt_swd
21 | ```
22 |
23 | To run the demo without adaptation:
24 | ```
25 | python swd_pytorch.py -mode source_only
26 | ```
27 |
28 | ## Interpreting Outputs
29 | Outputs will be saved as png and gif files in the current folder for each mode.
30 | The outputs show the source and target samples with the current decision boundary. Blue and red points are source samples of class 0
31 | and 1. Target samples are represented by green points.
32 |
33 |
34 |
35 | ## Acknowledgement
36 | [ml-cvpr2019-swd](https://github.com/apple/ml-cvpr2019-swd) (Official implementation in Tensorflow)
37 |
--------------------------------------------------------------------------------
/assets/adapt_mcd_pytorch.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/krumo/swd_pytorch/9d5b49c8eed758da4677410b1036d3c28b4d17e0/assets/adapt_mcd_pytorch.gif
--------------------------------------------------------------------------------
/assets/adapt_swd_pytorch.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/krumo/swd_pytorch/9d5b49c8eed758da4677410b1036d3c28b4d17e0/assets/adapt_swd_pytorch.gif
--------------------------------------------------------------------------------
/assets/outputs.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/krumo/swd_pytorch/9d5b49c8eed758da4677410b1036d3c28b4d17e0/assets/outputs.gif
--------------------------------------------------------------------------------
/assets/source_only_pytorch.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/krumo/swd_pytorch/9d5b49c8eed758da4677410b1036d3c28b4d17e0/assets/source_only_pytorch.gif
--------------------------------------------------------------------------------
/moon_data.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/krumo/swd_pytorch/9d5b49c8eed758da4677410b1036d3c28b4d17e0/moon_data.npz
--------------------------------------------------------------------------------
/swd_pytorch.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import matplotlib.pyplot as plt
7 | import imageio
8 | import platform
9 | if platform.system() == 'Darwin':
10 | import matplotlib
11 | matplotlib.use('TkAgg')
12 |
13 | def toyNet():
14 | # Define network architecture
15 | class Generator(nn.Module):
16 | def __init__(self):
17 | super(Generator, self).__init__()
18 | self.l1 = nn.Linear(2, 15)
19 | self.l2 = nn.Linear(15, 15)
20 | self.l3 = nn.Linear(15, 15)
21 | self.relu = nn.ReLU(inplace=True)
22 |
23 | for m in self.modules():
24 | if isinstance(m, nn.Linear):
25 | nn.init.xavier_uniform_(m.weight)
26 | nn.init.constant_(m.bias, 0)
27 |
28 | def forward(self, x):
29 | x = self.relu(self.l1(x))
30 | x = self.relu(self.l2(x))
31 | x = self.relu(self.l3(x))
32 | return x
33 | class Classifier1(nn.Module):
34 | def __init__(self):
35 | super(Classifier1, self).__init__()
36 | self.l1 = nn.Linear(15, 15)
37 | self.l2 = nn.Linear(15, 15)
38 | self.l3 = nn.Linear(15, 1)
39 | self.relu = nn.ReLU(inplace=True)
40 | self.sigmoid = nn.Sigmoid()
41 |
42 | for m in self.modules():
43 | if isinstance(m, nn.Linear):
44 | nn.init.xavier_uniform_(m.weight)
45 | nn.init.constant_(m.bias, 0)
46 |
47 | def forward(self, x):
48 | x = self.relu(self.l1(x))
49 | x = self.relu(self.l2(x))
50 | x = self.sigmoid(self.l3(x))
51 | return x
52 | class Classifier2(nn.Module):
53 | def __init__(self):
54 | super(Classifier2, self).__init__()
55 | self.l1 = nn.Linear(15, 15)
56 | self.l2 = nn.Linear(15, 15)
57 | self.l3 = nn.Linear(15, 1)
58 | self.relu = nn.ReLU(inplace=True)
59 | self.sigmoid = nn.Sigmoid()
60 |
61 | for m in self.modules():
62 | if isinstance(m, nn.Linear):
63 | nn.init.xavier_uniform_(m.weight)
64 | nn.init.constant_(m.bias, 0)
65 |
66 | def forward(self, x):
67 | x = self.relu(self.l1(x))
68 | x = self.relu(self.l2(x))
69 | x = self.sigmoid(self.l3(x))
70 | return x
71 | return Generator(), Classifier1(), Classifier2()
72 |
73 | def discrepancy_slice_wasserstein(p1, p2):
74 | s = p1.shape
75 | if s[1]>1:
76 | proj = torch.randn(s[1], 128)
77 | proj *= torch.rsqrt(torch.sum(torch.mul(proj, proj), 0, keepdim=True))
78 | p1 = torch.matmul(p1, proj)
79 | p2 = torch.matmul(p2, proj)
80 | p1 = torch.topk(p1, s[0], dim=0)[0]
81 | p2 = torch.topk(p2, s[0], dim=0)[0]
82 | dist = p1-p2
83 | wdist = torch.mean(torch.mul(dist, dist))
84 |
85 | return wdist
86 |
87 | def discrepancy_mcd(out1, out2):
88 | return torch.mean(torch.abs(out1 - out2))
89 |
90 |
91 | def load_data():
92 | # Load inter twinning moons 2D dataset by F. Pedregosa et al. in JMLR 2011
93 | moon_data = np.load('moon_data.npz')
94 | x_s = moon_data['x_s']
95 | y_s = moon_data['y_s']
96 | x_t = moon_data['x_t']
97 | return torch.from_numpy(x_s).float(), torch.from_numpy(y_s).float(), torch.from_numpy(x_t).float()
98 |
99 |
100 | def generate_grid_point():
101 | x_min, x_max = x_s[:, 0].min() - .5, x_s[:, 0].max() + 0.5
102 | y_min, y_max = x_s[:, 1].min() - .5, x_s[:, 1].max() + 0.5
103 | xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01), np.arange(y_min, y_max, 0.01))
104 | return xx, yy
105 |
106 |
107 | if __name__ == "__main__":
108 | parser = argparse.ArgumentParser()
109 | parser.add_argument('-mode', type=str, default="adapt_swd",
110 | choices=["source_only", "adapt_mcd", "adapt_swd"])
111 | parser.add_argument('-seed', type=int, default=1234)
112 | opts = parser.parse_args()
113 |
114 | # Load data
115 | x_s, y_s, x_t = load_data()
116 |
117 | # set random seed
118 | torch.manual_seed(opts.seed)
119 |
120 | torch.backends.cudnn.enabled = True
121 | torch.backends.cudnn.deterministic = True
122 |
123 | # Network definition
124 | generator, cls1, cls2 = toyNet()
125 | generator.train()
126 | cls1.train()
127 | cls2.train()
128 |
129 | # Cost functions
130 | bce_loss = nn.BCELoss()
131 |
132 | # Setup optimizers
133 | optim_g = torch.optim.SGD(generator.parameters(), lr=0.005)
134 | optim_f = torch.optim.SGD(list(cls1.parameters())+list(cls2.parameters()), lr=0.005)
135 | optim_g.zero_grad()
136 | optim_f.zero_grad()
137 |
138 | # # Generate grid points for visualization
139 | xx, yy = generate_grid_point()
140 |
141 | # For creating GIF purpose
142 | gif_images = []
143 |
144 | for step in range(10001):
145 | if step%1000==0:
146 | print("Iteration: %d / %d" % (step, 10000))
147 | z = torch.from_numpy(np.c_[xx.ravel(), yy.ravel()]).float()
148 | with torch.no_grad():
149 | fea = generator(z)
150 | Z = (cls2(fea).cpu().numpy()>0.5).astype(np.float32)
151 | Z = Z.reshape(xx.shape)
152 | f = plt.figure()
153 | plt.contourf(xx, yy, Z, cmap=plt.cm.copper_r, alpha=0.9)
154 | plt.scatter(x_s[:, 0], x_s[:, 1], c=y_s.reshape((len(x_s))),
155 | cmap=plt.cm.coolwarm, alpha=0.8)
156 | plt.scatter(x_t[:, 0], x_t[:, 1], color='green', alpha=0.7)
157 | plt.text(1.6, -0.9, 'Iter: ' + str(step), fontsize=14, color='#FFD700',
158 | bbox=dict(facecolor='dimgray', alpha=0.7))
159 | plt.axis('off')
160 | f.savefig(opts.mode + '_pytorch_iter' + str(step) + ".png", bbox_inches='tight',
161 | pad_inches=0, dpi=100, transparent=True)
162 | gif_images.append(imageio.imread(
163 | opts.mode + '_pytorch_iter' + str(step) + ".png"))
164 | plt.close()
165 |
166 | optim_g.zero_grad()
167 | optim_f.zero_grad()
168 | fea = generator(x_s)
169 | pred1 = cls1(fea)
170 | pred2 = cls2(fea)
171 | loss_s = bce_loss(pred1, y_s) + bce_loss(pred2, y_s)
172 | loss_s.backward()
173 | optim_g.step()
174 | optim_f.step()
175 |
176 | if opts.mode == 'source_only':
177 | continue
178 |
179 | optim_g.zero_grad()
180 | optim_f.zero_grad()
181 | loss = 0
182 | src_fea = generator(x_s)
183 | src_fea = src_fea.detach()
184 | src_pred1 = cls1(src_fea)
185 | src_pred2 = cls2(src_fea)
186 | loss += bce_loss(src_pred1, y_s) + bce_loss(src_pred2, y_s)
187 | # loss_s.backward()
188 |
189 | tgt_fea = generator(x_t)
190 | tgt_fea = tgt_fea.detach()
191 | tgt_pred1 = cls1(tgt_fea)
192 | tgt_pred2 = cls2(tgt_fea)
193 | if opts.mode == 'adapt_swd':
194 | loss_dis = 2*discrepancy_slice_wasserstein(tgt_pred1, tgt_pred2)
195 | else:
196 | loss_dis = discrepancy_mcd(tgt_pred1, tgt_pred2)
197 | loss -= loss_dis
198 | loss.backward()
199 | optim_f.step()
200 |
201 | optim_g.zero_grad()
202 | tgt_fea = generator(x_t)
203 | tgt_pred1 = cls1(tgt_fea)
204 | tgt_pred2 = cls2(tgt_fea)
205 | if opts.mode == 'adapt_swd':
206 | loss_dis = discrepancy_slice_wasserstein(tgt_pred1, tgt_pred2)
207 | else:
208 | loss_dis = discrepancy_mcd(tgt_pred1, tgt_pred2)
209 | loss_dis.backward()
210 | optim_g.step()
211 |
212 | # Save GIF
213 | imageio.mimsave(opts.mode + '_pytorch.gif', gif_images, duration=0.8)
214 | print("[Finished]\n-> Please see the current folder for outputs.")
215 |
--------------------------------------------------------------------------------