├── LICENSE
├── digit
├── data_load
│ ├── __init__.py
│ ├── mnist.py
│ ├── svhn.py
│ ├── usps.py
│ ├── utils.py
│ └── vision.py
├── digit.sh
├── loss.py
├── network.py
└── uda_digit.py
├── figs
└── shot.jpg
├── object
├── data_list.py
├── image_multisource.py
├── image_multitarget.py
├── image_pretrained.py
├── image_source.py
├── image_target.py
├── image_target_oda.py
├── loss.py
├── network.py
└── run.sh
├── pretrained-models.md
├── readme.md
└── results.md
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/digit/data_load/__init__.py:
--------------------------------------------------------------------------------
1 | from .svhn import *
2 | from .mnist import *
3 | from .usps import *
--------------------------------------------------------------------------------
/digit/data_load/mnist.py:
--------------------------------------------------------------------------------
1 | from .vision import VisionDataset
2 | import warnings
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import torch
8 | import codecs
9 | import string
10 | from .utils import download_url, download_and_extract_archive, extract_archive, \
11 | verify_str_arg
12 |
13 |
14 | class MNIST(VisionDataset):
15 | """`MNIST `_ Dataset.
16 |
17 | Args:
18 | root (string): Root directory of dataset where ``MNIST/processed/training.pt``
19 | and ``MNIST/processed/test.pt`` exist.
20 | train (bool, optional): If True, creates dataset from ``training.pt``,
21 | otherwise from ``test.pt``.
22 | download (bool, optional): If true, downloads the dataset from the internet and
23 | puts it in root directory. If dataset is already downloaded, it is not
24 | downloaded again.
25 | transform (callable, optional): A function/transform that takes in an PIL image
26 | and returns a transformed version. E.g, ``transforms.RandomCrop``
27 | target_transform (callable, optional): A function/transform that takes in the
28 | target and transforms it.
29 | """
30 |
31 | resources = [
32 | ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
33 | ("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
34 | ("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
35 | ("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
36 | ]
37 |
38 | training_file = 'training.pt'
39 | test_file = 'test.pt'
40 | classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
41 | '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
42 |
43 | @property
44 | def train_labels(self):
45 | warnings.warn("train_labels has been renamed targets")
46 | return self.targets
47 |
48 | @property
49 | def test_labels(self):
50 | warnings.warn("test_labels has been renamed targets")
51 | return self.targets
52 |
53 | @property
54 | def train_data(self):
55 | warnings.warn("train_data has been renamed data")
56 | return self.data
57 |
58 | @property
59 | def test_data(self):
60 | warnings.warn("test_data has been renamed data")
61 | return self.data
62 |
63 | def __init__(self, root, train=True, transform=None, target_transform=None,
64 | download=False):
65 | super(MNIST, self).__init__(root, transform=transform,
66 | target_transform=target_transform)
67 | self.train = train # training set or test set
68 |
69 | if download:
70 | self.download()
71 |
72 | if not self._check_exists():
73 | raise RuntimeError('Dataset not found.' +
74 | ' You can use download=True to download it')
75 |
76 | if self.train:
77 | data_file = self.training_file
78 | else:
79 | data_file = self.test_file
80 | self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
81 |
82 | def __getitem__(self, index):
83 | """
84 | Args:
85 | index (int): Index
86 |
87 | Returns:
88 | tuple: (image, target) where target is index of the target class.
89 | """
90 | img, target = self.data[index], int(self.targets[index])
91 |
92 | # doing this so that it is consistent with all other datasets
93 | # to return a PIL Image
94 | img = Image.fromarray(img.numpy(), mode='L')
95 |
96 | if self.transform is not None:
97 | img = self.transform(img)
98 |
99 | if self.target_transform is not None:
100 | target = self.target_transform(target)
101 |
102 | return img, target
103 |
104 | def __len__(self):
105 | return len(self.data)
106 |
107 | # @property
108 | # def raw_folder(self):
109 | # return os.path.join(self.root, self.__class__.__name__, 'raw')
110 |
111 | # @property
112 | # def processed_folder(self):
113 | # return os.path.join(self.root, self.__class__.__name__, 'processed')
114 |
115 | @property
116 | def raw_folder(self):
117 | return os.path.join(self.root, 'raw')
118 |
119 | @property
120 | def processed_folder(self):
121 | return os.path.join(self.root, 'processed')
122 |
123 | @property
124 | def class_to_idx(self):
125 | return {_class: i for i, _class in enumerate(self.classes)}
126 |
127 | def _check_exists(self):
128 | return (os.path.exists(os.path.join(self.processed_folder,
129 | self.training_file)) and
130 | os.path.exists(os.path.join(self.processed_folder,
131 | self.test_file)))
132 |
133 | def download(self):
134 | """Download the MNIST data if it doesn't exist in processed_folder already."""
135 |
136 | if self._check_exists():
137 | return
138 |
139 | os.makedirs(self.raw_folder, exist_ok=True)
140 | os.makedirs(self.processed_folder, exist_ok=True)
141 |
142 | # download files
143 | for url, md5 in self.resources:
144 | filename = url.rpartition('/')[2]
145 | download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
146 |
147 | # process and save as torch files
148 | print('Processing...')
149 |
150 | training_set = (
151 | read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
152 | read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
153 | )
154 | test_set = (
155 | read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
156 | read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
157 | )
158 | with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
159 | torch.save(training_set, f)
160 | with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
161 | torch.save(test_set, f)
162 |
163 | print('Done!')
164 |
165 | def extra_repr(self):
166 | return "Split: {}".format("Train" if self.train is True else "Test")
167 |
168 |
169 |
170 | class MNIST_idx(VisionDataset):
171 | """`MNIST `_ Dataset.
172 |
173 | Args:
174 | root (string): Root directory of dataset where ``MNIST/processed/training.pt``
175 | and ``MNIST/processed/test.pt`` exist.
176 | train (bool, optional): If True, creates dataset from ``training.pt``,
177 | otherwise from ``test.pt``.
178 | download (bool, optional): If true, downloads the dataset from the internet and
179 | puts it in root directory. If dataset is already downloaded, it is not
180 | downloaded again.
181 | transform (callable, optional): A function/transform that takes in an PIL image
182 | and returns a transformed version. E.g, ``transforms.RandomCrop``
183 | target_transform (callable, optional): A function/transform that takes in the
184 | target and transforms it.
185 | """
186 |
187 | resources = [
188 | ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
189 | ("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
190 | ("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
191 | ("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
192 | ]
193 |
194 | training_file = 'training.pt'
195 | test_file = 'test.pt'
196 | classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
197 | '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
198 |
199 | @property
200 | def train_labels(self):
201 | warnings.warn("train_labels has been renamed targets")
202 | return self.targets
203 |
204 | @property
205 | def test_labels(self):
206 | warnings.warn("test_labels has been renamed targets")
207 | return self.targets
208 |
209 | @property
210 | def train_data(self):
211 | warnings.warn("train_data has been renamed data")
212 | return self.data
213 |
214 | @property
215 | def test_data(self):
216 | warnings.warn("test_data has been renamed data")
217 | return self.data
218 |
219 | def __init__(self, root, train=True, transform=None, target_transform=None,
220 | download=False):
221 | super(MNIST_idx, self).__init__(root, transform=transform,
222 | target_transform=target_transform)
223 | self.train = train # training set or test set
224 |
225 | if download:
226 | self.download()
227 |
228 | if not self._check_exists():
229 | raise RuntimeError('Dataset not found.' +
230 | ' You can use download=True to download it')
231 |
232 | if self.train:
233 | data_file = self.training_file
234 | else:
235 | data_file = self.test_file
236 | self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
237 |
238 | def __getitem__(self, index):
239 | """
240 | Args:
241 | index (int): Index
242 |
243 | Returns:
244 | tuple: (image, target) where target is index of the target class.
245 | """
246 | img, target = self.data[index], int(self.targets[index])
247 |
248 | # doing this so that it is consistent with all other datasets
249 | # to return a PIL Image
250 | img = Image.fromarray(img.numpy(), mode='L')
251 |
252 | if self.transform is not None:
253 | img = self.transform(img)
254 |
255 | if self.target_transform is not None:
256 | target = self.target_transform(target)
257 |
258 | return img, target, index
259 |
260 | def __len__(self):
261 | return len(self.data)
262 |
263 | # @property
264 | # def raw_folder(self):
265 | # return os.path.join(self.root, self.__class__.__name__, 'raw')
266 |
267 | # @property
268 | # def processed_folder(self):
269 | # return os.path.join(self.root, self.__class__.__name__, 'processed')
270 |
271 | @property
272 | def raw_folder(self):
273 | return os.path.join(self.root, 'raw')
274 |
275 | @property
276 | def processed_folder(self):
277 | return os.path.join(self.root, 'processed')
278 |
279 | @property
280 | def class_to_idx(self):
281 | return {_class: i for i, _class in enumerate(self.classes)}
282 |
283 | def _check_exists(self):
284 | return (os.path.exists(os.path.join(self.processed_folder,
285 | self.training_file)) and
286 | os.path.exists(os.path.join(self.processed_folder,
287 | self.test_file)))
288 |
289 | def download(self):
290 | """Download the MNIST data if it doesn't exist in processed_folder already."""
291 |
292 | if self._check_exists():
293 | return
294 |
295 | os.makedirs(self.raw_folder, exist_ok=True)
296 | os.makedirs(self.processed_folder, exist_ok=True)
297 |
298 | # download files
299 | for url, md5 in self.resources:
300 | filename = url.rpartition('/')[2]
301 | download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
302 |
303 | # process and save as torch files
304 | print('Processing...')
305 |
306 | training_set = (
307 | read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
308 | read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
309 | )
310 | test_set = (
311 | read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
312 | read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
313 | )
314 | with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
315 | torch.save(training_set, f)
316 | with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
317 | torch.save(test_set, f)
318 |
319 | print('Done!')
320 |
321 | def extra_repr(self):
322 | return "Split: {}".format("Train" if self.train is True else "Test")
323 |
324 | def get_int(b):
325 | return int(codecs.encode(b, 'hex'), 16)
326 |
327 |
328 | def open_maybe_compressed_file(path):
329 | """Return a file object that possibly decompresses 'path' on the fly.
330 | Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
331 | """
332 | if not isinstance(path, torch._six.string_classes):
333 | return path
334 | if path.endswith('.gz'):
335 | import gzip
336 | return gzip.open(path, 'rb')
337 | if path.endswith('.xz'):
338 | import lzma
339 | return lzma.open(path, 'rb')
340 | return open(path, 'rb')
341 |
342 | def read_sn3_pascalvincent_tensor(path, strict=True):
343 | """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
344 | Argument may be a filename, compressed filename, or file object.
345 | """
346 | # typemap
347 | if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):
348 | read_sn3_pascalvincent_tensor.typemap = {
349 | 8: (torch.uint8, np.uint8, np.uint8),
350 | 9: (torch.int8, np.int8, np.int8),
351 | 11: (torch.int16, np.dtype('>i2'), 'i2'),
352 | 12: (torch.int32, np.dtype('>i4'), 'i4'),
353 | 13: (torch.float32, np.dtype('>f4'), 'f4'),
354 | 14: (torch.float64, np.dtype('>f8'), 'f8')}
355 | # read
356 | with open_maybe_compressed_file(path) as f:
357 | data = f.read()
358 | # parse
359 | magic = get_int(data[0:4])
360 | nd = magic % 256
361 | ty = magic // 256
362 | assert nd >= 1 and nd <= 3
363 | assert ty >= 8 and ty <= 14
364 | m = read_sn3_pascalvincent_tensor.typemap[ty]
365 | s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)]
366 | parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
367 | assert parsed.shape[0] == np.prod(s) or not strict
368 | return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
369 |
370 |
371 | def read_label_file(path):
372 | with open(path, 'rb') as f:
373 | x = read_sn3_pascalvincent_tensor(f, strict=False)
374 | assert(x.dtype == torch.uint8)
375 | assert(x.ndimension() == 1)
376 | return x.long()
377 |
378 | def read_image_file(path):
379 | with open(path, 'rb') as f:
380 | x = read_sn3_pascalvincent_tensor(f, strict=False)
381 | assert(x.dtype == torch.uint8)
382 | assert(x.ndimension() == 3)
383 | return x
--------------------------------------------------------------------------------
/digit/data_load/svhn.py:
--------------------------------------------------------------------------------
1 | from .vision import VisionDataset
2 | from PIL import Image
3 | import os
4 | import os.path
5 | import numpy as np
6 | from .utils import download_url, check_integrity, verify_str_arg
7 |
8 |
9 | class SVHN(VisionDataset):
10 | """`SVHN `_ Dataset.
11 | Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
12 | we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
13 | expect the class labels to be in the range `[0, C-1]`
14 |
15 | .. warning::
16 |
17 | This class needs `scipy `_ to load data from `.mat` format.
18 |
19 | Args:
20 | root (string): Root directory of dataset where directory
21 | ``SVHN`` exists.
22 | split (string): One of {'train', 'test', 'extra'}.
23 | Accordingly dataset is selected. 'extra' is Extra training set.
24 | transform (callable, optional): A function/transform that takes in an PIL image
25 | and returns a transformed version. E.g, ``transforms.RandomCrop``
26 | target_transform (callable, optional): A function/transform that takes in the
27 | target and transforms it.
28 | download (bool, optional): If true, downloads the dataset from the internet and
29 | puts it in root directory. If dataset is already downloaded, it is not
30 | downloaded again.
31 |
32 | """
33 |
34 | split_list = {
35 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
36 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"],
37 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
38 | "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"],
39 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
40 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
41 |
42 | def __init__(self, root, split='train', transform=None, target_transform=None,
43 | download=False):
44 | super(SVHN, self).__init__(root, transform=transform,
45 | target_transform=target_transform)
46 | self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
47 | self.url = self.split_list[split][0]
48 | self.filename = self.split_list[split][1]
49 | self.file_md5 = self.split_list[split][2]
50 |
51 | if download:
52 | self.download()
53 |
54 | if not self._check_integrity():
55 | raise RuntimeError('Dataset not found or corrupted.' +
56 | ' You can use download=True to download it')
57 |
58 | # import here rather than at top of file because this is
59 | # an optional dependency for torchvision
60 | import scipy.io as sio
61 |
62 | # reading(loading) mat file as array
63 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
64 |
65 | self.data = loaded_mat['X']
66 | # loading from the .mat file gives an np array of type np.uint8
67 | # converting to np.int64, so that we have a LongTensor after
68 | # the conversion from the numpy array
69 | # the squeeze is needed to obtain a 1D tensor
70 | self.labels = loaded_mat['y'].astype(np.int64).squeeze()
71 |
72 | # the svhn dataset assigns the class label "10" to the digit 0
73 | # this makes it inconsistent with several loss functions
74 | # which expect the class labels to be in the range [0, C-1]
75 | np.place(self.labels, self.labels == 10, 0)
76 | self.data = np.transpose(self.data, (3, 2, 0, 1))
77 |
78 | def __getitem__(self, index):
79 | """
80 | Args:
81 | index (int): Index
82 |
83 | Returns:
84 | tuple: (image, target) where target is index of the target class.
85 | """
86 | img, target = self.data[index], int(self.labels[index])
87 |
88 | # doing this so that it is consistent with all other datasets
89 | # to return a PIL Image
90 | img = Image.fromarray(np.transpose(img, (1, 2, 0)))
91 |
92 | if self.transform is not None:
93 | img = self.transform(img)
94 |
95 | if self.target_transform is not None:
96 | target = self.target_transform(target)
97 |
98 | return img, target
99 |
100 | def __len__(self):
101 | return len(self.data)
102 |
103 | def _check_integrity(self):
104 | root = self.root
105 | md5 = self.split_list[self.split][2]
106 | fpath = os.path.join(root, self.filename)
107 | return check_integrity(fpath, md5)
108 |
109 | def download(self):
110 | md5 = self.split_list[self.split][2]
111 | download_url(self.url, self.root, self.filename, md5)
112 |
113 | def extra_repr(self):
114 | return "Split: {split}".format(**self.__dict__)
115 |
116 | class SVHN_idx(VisionDataset):
117 | """`SVHN `_ Dataset.
118 | Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
119 | we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
120 | expect the class labels to be in the range `[0, C-1]`
121 |
122 | .. warning::
123 |
124 | This class needs `scipy `_ to load data from `.mat` format.
125 |
126 | Args:
127 | root (string): Root directory of dataset where directory
128 | ``SVHN`` exists.
129 | split (string): One of {'train', 'test', 'extra'}.
130 | Accordingly dataset is selected. 'extra' is Extra training set.
131 | transform (callable, optional): A function/transform that takes in an PIL image
132 | and returns a transformed version. E.g, ``transforms.RandomCrop``
133 | target_transform (callable, optional): A function/transform that takes in the
134 | target and transforms it.
135 | download (bool, optional): If true, downloads the dataset from the internet and
136 | puts it in root directory. If dataset is already downloaded, it is not
137 | downloaded again.
138 |
139 | """
140 |
141 | split_list = {
142 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
143 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"],
144 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
145 | "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"],
146 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
147 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
148 |
149 | def __init__(self, root, split='train', transform=None, target_transform=None,
150 | download=False):
151 | super(SVHN_idx, self).__init__(root, transform=transform,
152 | target_transform=target_transform)
153 | self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
154 | self.url = self.split_list[split][0]
155 | self.filename = self.split_list[split][1]
156 | self.file_md5 = self.split_list[split][2]
157 |
158 | if download:
159 | self.download()
160 |
161 | if not self._check_integrity():
162 | raise RuntimeError('Dataset not found or corrupted.' +
163 | ' You can use download=True to download it')
164 |
165 | # import here rather than at top of file because this is
166 | # an optional dependency for torchvision
167 | import scipy.io as sio
168 |
169 | # reading(loading) mat file as array
170 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
171 |
172 | self.data = loaded_mat['X']
173 | # loading from the .mat file gives an np array of type np.uint8
174 | # converting to np.int64, so that we have a LongTensor after
175 | # the conversion from the numpy array
176 | # the squeeze is needed to obtain a 1D tensor
177 | self.labels = loaded_mat['y'].astype(np.int64).squeeze()
178 |
179 | # the svhn dataset assigns the class label "10" to the digit 0
180 | # this makes it inconsistent with several loss functions
181 | # which expect the class labels to be in the range [0, C-1]
182 | np.place(self.labels, self.labels == 10, 0)
183 | self.data = np.transpose(self.data, (3, 2, 0, 1))
184 |
185 | def __getitem__(self, index):
186 | """
187 | Args:
188 | index (int): Index
189 |
190 | Returns:
191 | tuple: (image, target) where target is index of the target class.
192 | """
193 | img, target = self.data[index], int(self.labels[index])
194 |
195 | # doing this so that it is consistent with all other datasets
196 | # to return a PIL Image
197 | img = Image.fromarray(np.transpose(img, (1, 2, 0)))
198 |
199 | if self.transform is not None:
200 | img = self.transform(img)
201 |
202 | if self.target_transform is not None:
203 | target = self.target_transform(target)
204 |
205 | return img, target, index
206 |
207 | def __len__(self):
208 | return len(self.data)
209 |
210 | def _check_integrity(self):
211 | root = self.root
212 | md5 = self.split_list[self.split][2]
213 | fpath = os.path.join(root, self.filename)
214 | return check_integrity(fpath, md5)
215 |
216 | def download(self):
217 | md5 = self.split_list[self.split][2]
218 | download_url(self.url, self.root, self.filename, md5)
219 |
220 | def extra_repr(self):
221 | return "Split: {split}".format(**self.__dict__)
--------------------------------------------------------------------------------
/digit/data_load/usps.py:
--------------------------------------------------------------------------------
1 | """Dataset setting and data loader for USPS.
2 | Modified from
3 | https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py
4 | """
5 |
6 | import gzip
7 | import os
8 | import pickle
9 | import urllib
10 | from PIL import Image
11 |
12 | import numpy as np
13 | import torch
14 | import torch.utils.data as data
15 | from torch.utils.data.sampler import WeightedRandomSampler
16 | from torchvision import datasets, transforms
17 |
18 |
19 | class USPS(data.Dataset):
20 | """USPS Dataset.
21 | Args:
22 | root (string): Root directory of dataset where dataset file exist.
23 | train (bool, optional): If True, resample from dataset randomly.
24 | download (bool, optional): If true, downloads the dataset
25 | from the internet and puts it in root directory.
26 | If dataset is already downloaded, it is not downloaded again.
27 | transform (callable, optional): A function/transform that takes in
28 | an PIL image and returns a transformed version.
29 | E.g, ``transforms.RandomCrop``
30 | """
31 |
32 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl"
33 |
34 | def __init__(self, root, train=True, transform=None, download=False):
35 | """Init USPS dataset."""
36 | # init params
37 | self.root = os.path.expanduser(root)
38 | self.filename = "usps_28x28.pkl"
39 | self.train = train
40 | # Num of Train = 7438, Num ot Test 1860
41 | self.transform = transform
42 | self.dataset_size = None
43 |
44 | # download dataset.
45 | if download:
46 | self.download()
47 | if not self._check_exists():
48 | raise RuntimeError("Dataset not found." +
49 | " You can use download=True to download it")
50 |
51 | self.train_data, self.train_labels = self.load_samples()
52 | if self.train:
53 | total_num_samples = self.train_labels.shape[0]
54 | indices = np.arange(total_num_samples)
55 | self.train_data = self.train_data[indices[0:self.dataset_size], ::]
56 | self.train_labels = self.train_labels[indices[0:self.dataset_size]]
57 | self.train_data *= 255.0
58 | self.train_data = np.squeeze(self.train_data).astype(np.uint8)
59 |
60 | def __getitem__(self, index):
61 | """Get images and target for data loader.
62 | Args:
63 | index (int): Index
64 | Returns:
65 | tuple: (image, target) where target is index of the target class.
66 | """
67 | img, label = self.train_data[index], self.train_labels[index]
68 | img = Image.fromarray(img, mode='L')
69 | img = img.copy()
70 | if self.transform is not None:
71 | img = self.transform(img)
72 | return img, label.astype("int64")
73 |
74 | def __len__(self):
75 | """Return size of dataset."""
76 | return len(self.train_data)
77 |
78 | def _check_exists(self):
79 | """Check if dataset is download and in right place."""
80 | return os.path.exists(os.path.join(self.root, self.filename))
81 |
82 | def download(self):
83 | """Download dataset."""
84 | filename = os.path.join(self.root, self.filename)
85 | dirname = os.path.dirname(filename)
86 | if not os.path.isdir(dirname):
87 | os.makedirs(dirname)
88 | if os.path.isfile(filename):
89 | return
90 | print("Download %s to %s" % (self.url, os.path.abspath(filename)))
91 | urllib.request.urlretrieve(self.url, filename)
92 | print("[DONE]")
93 | return
94 |
95 | def load_samples(self):
96 | """Load sample images from dataset."""
97 | filename = os.path.join(self.root, self.filename)
98 | f = gzip.open(filename, "rb")
99 | data_set = pickle.load(f, encoding="bytes")
100 | f.close()
101 | if self.train:
102 | images = data_set[0][0]
103 | labels = data_set[0][1]
104 | self.dataset_size = labels.shape[0]
105 | else:
106 | images = data_set[1][0]
107 | labels = data_set[1][1]
108 | self.dataset_size = labels.shape[0]
109 | return images, labels
110 |
111 |
112 | class USPS_idx(data.Dataset):
113 | """USPS Dataset.
114 | Args:
115 | root (string): Root directory of dataset where dataset file exist.
116 | train (bool, optional): If True, resample from dataset randomly.
117 | download (bool, optional): If true, downloads the dataset
118 | from the internet and puts it in root directory.
119 | If dataset is already downloaded, it is not downloaded again.
120 | transform (callable, optional): A function/transform that takes in
121 | an PIL image and returns a transformed version.
122 | E.g, ``transforms.RandomCrop``
123 | """
124 |
125 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl"
126 |
127 | def __init__(self, root, train=True, transform=None, download=False):
128 | """Init USPS dataset."""
129 | # init params
130 | self.root = os.path.expanduser(root)
131 | self.filename = "usps_28x28.pkl"
132 | self.train = train
133 | # Num of Train = 7438, Num ot Test 1860
134 | self.transform = transform
135 | self.dataset_size = None
136 |
137 | # download dataset.
138 | if download:
139 | self.download()
140 | if not self._check_exists():
141 | raise RuntimeError("Dataset not found." +
142 | " You can use download=True to download it")
143 |
144 | self.train_data, self.train_labels = self.load_samples()
145 | if self.train:
146 | total_num_samples = self.train_labels.shape[0]
147 | indices = np.arange(total_num_samples)
148 | self.train_data = self.train_data[indices[0:self.dataset_size], ::]
149 | self.train_labels = self.train_labels[indices[0:self.dataset_size]]
150 | self.train_data *= 255.0
151 | self.train_data = np.squeeze(self.train_data).astype(np.uint8)
152 |
153 | def __getitem__(self, index):
154 | """Get images and target for data loader.
155 | Args:
156 | index (int): Index
157 | Returns:
158 | tuple: (image, target) where target is index of the target class.
159 | """
160 | img, label = self.train_data[index], self.train_labels[index]
161 | img = Image.fromarray(img, mode='L')
162 | img = img.copy()
163 | if self.transform is not None:
164 | img = self.transform(img)
165 | return img, label.astype("int64"), index
166 |
167 | def __len__(self):
168 | """Return size of dataset."""
169 | return len(self.train_data)
170 |
171 | def _check_exists(self):
172 | """Check if dataset is download and in right place."""
173 | return os.path.exists(os.path.join(self.root, self.filename))
174 |
175 | def download(self):
176 | """Download dataset."""
177 | filename = os.path.join(self.root, self.filename)
178 | dirname = os.path.dirname(filename)
179 | if not os.path.isdir(dirname):
180 | os.makedirs(dirname)
181 | if os.path.isfile(filename):
182 | return
183 | print("Download %s to %s" % (self.url, os.path.abspath(filename)))
184 | urllib.request.urlretrieve(self.url, filename)
185 | print("[DONE]")
186 | return
187 |
188 | def load_samples(self):
189 | """Load sample images from dataset."""
190 | filename = os.path.join(self.root, self.filename)
191 | f = gzip.open(filename, "rb")
192 | data_set = pickle.load(f, encoding="bytes")
193 | f.close()
194 | if self.train:
195 | images = data_set[0][0]
196 | labels = data_set[0][1]
197 | self.dataset_size = labels.shape[0]
198 | else:
199 | images = data_set[1][0]
200 | labels = data_set[1][1]
201 | self.dataset_size = labels.shape[0]
202 | return images, labels
--------------------------------------------------------------------------------
/digit/data_load/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import hashlib
4 | import gzip
5 | import errno
6 | import tarfile
7 | import zipfile
8 |
9 | import torch
10 | from torch.utils.model_zoo import tqdm
11 |
12 |
13 | def gen_bar_updater():
14 | pbar = tqdm(total=None)
15 |
16 | def bar_update(count, block_size, total_size):
17 | if pbar.total is None and total_size:
18 | pbar.total = total_size
19 | progress_bytes = count * block_size
20 | pbar.update(progress_bytes - pbar.n)
21 |
22 | return bar_update
23 |
24 |
25 | def calculate_md5(fpath, chunk_size=1024 * 1024):
26 | md5 = hashlib.md5()
27 | with open(fpath, 'rb') as f:
28 | for chunk in iter(lambda: f.read(chunk_size), b''):
29 | md5.update(chunk)
30 | return md5.hexdigest()
31 |
32 |
33 | def check_md5(fpath, md5, **kwargs):
34 | return md5 == calculate_md5(fpath, **kwargs)
35 |
36 |
37 | def check_integrity(fpath, md5=None):
38 | if not os.path.isfile(fpath):
39 | return False
40 | if md5 is None:
41 | return True
42 | return check_md5(fpath, md5)
43 |
44 |
45 | def download_url(url, root, filename=None, md5=None):
46 | """Download a file from a url and place it in root.
47 |
48 | Args:
49 | url (str): URL to download file from
50 | root (str): Directory to place downloaded file in
51 | filename (str, optional): Name to save the file under. If None, use the basename of the URL
52 | md5 (str, optional): MD5 checksum of the download. If None, do not check
53 | """
54 | import urllib
55 |
56 | root = os.path.expanduser(root)
57 | if not filename:
58 | filename = os.path.basename(url)
59 | fpath = os.path.join(root, filename)
60 |
61 | os.makedirs(root, exist_ok=True)
62 |
63 | # check if file is already present locally
64 | if check_integrity(fpath, md5):
65 | print('Using downloaded and verified file: ' + fpath)
66 | else: # download the file
67 | try:
68 | print('Downloading ' + url + ' to ' + fpath)
69 | urllib.request.urlretrieve(
70 | url, fpath,
71 | reporthook=gen_bar_updater()
72 | )
73 | except (urllib.error.URLError, IOError) as e:
74 | if url[:5] == 'https':
75 | url = url.replace('https:', 'http:')
76 | print('Failed download. Trying https -> http instead.'
77 | ' Downloading ' + url + ' to ' + fpath)
78 | urllib.request.urlretrieve(
79 | url, fpath,
80 | reporthook=gen_bar_updater()
81 | )
82 | else:
83 | raise e
84 | # check integrity of downloaded file
85 | if not check_integrity(fpath, md5):
86 | raise RuntimeError("File not found or corrupted.")
87 |
88 |
89 | def list_dir(root, prefix=False):
90 | """List all directories at a given root
91 |
92 | Args:
93 | root (str): Path to directory whose folders need to be listed
94 | prefix (bool, optional): If true, prepends the path to each result, otherwise
95 | only returns the name of the directories found
96 | """
97 | root = os.path.expanduser(root)
98 | directories = list(
99 | filter(
100 | lambda p: os.path.isdir(os.path.join(root, p)),
101 | os.listdir(root)
102 | )
103 | )
104 |
105 | if prefix is True:
106 | directories = [os.path.join(root, d) for d in directories]
107 |
108 | return directories
109 |
110 |
111 | def list_files(root, suffix, prefix=False):
112 | """List all files ending with a suffix at a given root
113 |
114 | Args:
115 | root (str): Path to directory whose folders need to be listed
116 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
117 | It uses the Python "str.endswith" method and is passed directly
118 | prefix (bool, optional): If true, prepends the path to each result, otherwise
119 | only returns the name of the files found
120 | """
121 | root = os.path.expanduser(root)
122 | files = list(
123 | filter(
124 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
125 | os.listdir(root)
126 | )
127 | )
128 |
129 | if prefix is True:
130 | files = [os.path.join(root, d) for d in files]
131 |
132 | return files
133 |
134 |
135 | def download_file_from_google_drive(file_id, root, filename=None, md5=None):
136 | """Download a Google Drive file from and place it in root.
137 |
138 | Args:
139 | file_id (str): id of file to be downloaded
140 | root (str): Directory to place downloaded file in
141 | filename (str, optional): Name to save the file under. If None, use the id of the file.
142 | md5 (str, optional): MD5 checksum of the download. If None, do not check
143 | """
144 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
145 | import requests
146 | url = "https://docs.google.com/uc?export=download"
147 |
148 | root = os.path.expanduser(root)
149 | if not filename:
150 | filename = file_id
151 | fpath = os.path.join(root, filename)
152 |
153 | os.makedirs(root, exist_ok=True)
154 |
155 | if os.path.isfile(fpath) and check_integrity(fpath, md5):
156 | print('Using downloaded and verified file: ' + fpath)
157 | else:
158 | session = requests.Session()
159 |
160 | response = session.get(url, params={'id': file_id}, stream=True)
161 | token = _get_confirm_token(response)
162 |
163 | if token:
164 | params = {'id': file_id, 'confirm': token}
165 | response = session.get(url, params=params, stream=True)
166 |
167 | _save_response_content(response, fpath)
168 |
169 |
170 | def _get_confirm_token(response):
171 | for key, value in response.cookies.items():
172 | if key.startswith('download_warning'):
173 | return value
174 |
175 | return None
176 |
177 |
178 | def _save_response_content(response, destination, chunk_size=32768):
179 | with open(destination, "wb") as f:
180 | pbar = tqdm(total=None)
181 | progress = 0
182 | for chunk in response.iter_content(chunk_size):
183 | if chunk: # filter out keep-alive new chunks
184 | f.write(chunk)
185 | progress += len(chunk)
186 | pbar.update(progress - pbar.n)
187 | pbar.close()
188 |
189 |
190 | def _is_tarxz(filename):
191 | return filename.endswith(".tar.xz")
192 |
193 |
194 | def _is_tar(filename):
195 | return filename.endswith(".tar")
196 |
197 |
198 | def _is_targz(filename):
199 | return filename.endswith(".tar.gz")
200 |
201 |
202 | def _is_tgz(filename):
203 | return filename.endswith(".tgz")
204 |
205 |
206 | def _is_gzip(filename):
207 | return filename.endswith(".gz") and not filename.endswith(".tar.gz")
208 |
209 |
210 | def _is_zip(filename):
211 | return filename.endswith(".zip")
212 |
213 |
214 | def extract_archive(from_path, to_path=None, remove_finished=False):
215 | if to_path is None:
216 | to_path = os.path.dirname(from_path)
217 |
218 | if _is_tar(from_path):
219 | with tarfile.open(from_path, 'r') as tar:
220 | tar.extractall(path=to_path)
221 | elif _is_targz(from_path) or _is_tgz(from_path):
222 | with tarfile.open(from_path, 'r:gz') as tar:
223 | tar.extractall(path=to_path)
224 | elif _is_tarxz(from_path):
225 | with tarfile.open(from_path, 'r:xz') as tar:
226 | tar.extractall(path=to_path)
227 | elif _is_gzip(from_path):
228 | to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
229 | with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
230 | out_f.write(zip_f.read())
231 | elif _is_zip(from_path):
232 | with zipfile.ZipFile(from_path, 'r') as z:
233 | z.extractall(to_path)
234 | else:
235 | raise ValueError("Extraction of {} not supported".format(from_path))
236 |
237 | if remove_finished:
238 | os.remove(from_path)
239 |
240 |
241 | def download_and_extract_archive(url, download_root, extract_root=None, filename=None,
242 | md5=None, remove_finished=False):
243 | download_root = os.path.expanduser(download_root)
244 | if extract_root is None:
245 | extract_root = download_root
246 | if not filename:
247 | filename = os.path.basename(url)
248 |
249 | download_url(url, download_root, filename, md5)
250 |
251 | archive = os.path.join(download_root, filename)
252 | print("Extracting {} to {}".format(archive, extract_root))
253 | extract_archive(archive, extract_root, remove_finished)
254 |
255 |
256 | def iterable_to_str(iterable):
257 | return "'" + "', '".join([str(item) for item in iterable]) + "'"
258 |
259 |
260 | def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None):
261 | if not isinstance(value, torch._six.string_classes):
262 | if arg is None:
263 | msg = "Expected type str, but got type {type}."
264 | else:
265 | msg = "Expected type str for argument {arg}, but got type {type}."
266 | msg = msg.format(type=type(value), arg=arg)
267 | raise ValueError(msg)
268 |
269 | if valid_values is None:
270 | return value
271 |
272 | if value not in valid_values:
273 | if custom_msg is not None:
274 | msg = custom_msg
275 | else:
276 | msg = ("Unknown value '{value}' for argument {arg}. "
277 | "Valid values are {{{valid_values}}}.")
278 | msg = msg.format(value=value, arg=arg,
279 | valid_values=iterable_to_str(valid_values))
280 | raise ValueError(msg)
281 |
282 | return value
--------------------------------------------------------------------------------
/digit/data_load/vision.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data as data
4 |
5 |
6 | class VisionDataset(data.Dataset):
7 | _repr_indent = 4
8 |
9 | def __init__(self, root, transforms=None, transform=None, target_transform=None):
10 | if isinstance(root, torch._six.string_classes):
11 | root = os.path.expanduser(root)
12 | self.root = root
13 |
14 | has_transforms = transforms is not None
15 | has_separate_transform = transform is not None or target_transform is not None
16 | if has_transforms and has_separate_transform:
17 | raise ValueError("Only transforms or transform/target_transform can "
18 | "be passed as argument")
19 |
20 | # for backwards-compatibility
21 | self.transform = transform
22 | self.target_transform = target_transform
23 |
24 | if has_separate_transform:
25 | transforms = StandardTransform(transform, target_transform)
26 | self.transforms = transforms
27 |
28 | def __getitem__(self, index):
29 | raise NotImplementedError
30 |
31 | def __len__(self):
32 | raise NotImplementedError
33 |
34 | def __repr__(self):
35 | head = "Dataset " + self.__class__.__name__
36 | body = ["Number of datapoints: {}".format(self.__len__())]
37 | if self.root is not None:
38 | body.append("Root location: {}".format(self.root))
39 | body += self.extra_repr().splitlines()
40 | if hasattr(self, "transforms") and self.transforms is not None:
41 | body += [repr(self.transforms)]
42 | lines = [head] + [" " * self._repr_indent + line for line in body]
43 | return '\n'.join(lines)
44 |
45 | def _format_transform_repr(self, transform, head):
46 | lines = transform.__repr__().splitlines()
47 | return (["{}{}".format(head, lines[0])] +
48 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
49 |
50 | def extra_repr(self):
51 | return ""
52 |
53 |
54 | class StandardTransform(object):
55 | def __init__(self, transform=None, target_transform=None):
56 | self.transform = transform
57 | self.target_transform = target_transform
58 |
59 | def __call__(self, input, target):
60 | if self.transform is not None:
61 | input = self.transform(input)
62 | if self.target_transform is not None:
63 | target = self.target_transform(target)
64 | return input, target
65 |
66 | def _format_transform_repr(self, transform, head):
67 | lines = transform.__repr__().splitlines()
68 | return (["{}{}".format(head, lines[0])] +
69 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
70 |
71 | def __repr__(self):
72 | body = [self.__class__.__name__]
73 | if self.transform is not None:
74 | body += self._format_transform_repr(self.transform,
75 | "Transform: ")
76 | if self.target_transform is not None:
77 | body += self._format_transform_repr(self.target_transform,
78 | "Target transform: ")
79 |
80 | return '\n'.join(body)
--------------------------------------------------------------------------------
/digit/digit.sh:
--------------------------------------------------------------------------------
1 | ~/anaconda3/envs/pytorch/bin/python uda_digit.py --dset m2u --gpu_id 0 --cls_par 0.1 --output ckps_digits
2 | ~/anaconda3/envs/pytorch/bin/python uda_digit.py --dset u2m --gpu_id 0 --cls_par 0.1 --output ckps_digits
3 | ~/anaconda3/envs/pytorch/bin/python uda_digit.py --dset s2m --gpu_id 0 --cls_par 0.1 --output ckps_digits
--------------------------------------------------------------------------------
/digit/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import math
6 | import torch.nn.functional as F
7 | import pdb
8 |
9 | def Entropy(input_):
10 | bs = input_.size(0)
11 | entropy = -input_ * torch.log(input_ + 1e-5)
12 | entropy = torch.sum(entropy, dim=1)
13 | return entropy
14 |
15 | class CrossEntropyLabelSmooth(nn.Module):
16 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, size_average=True):
17 | super(CrossEntropyLabelSmooth, self).__init__()
18 | self.num_classes = num_classes
19 | self.epsilon = epsilon
20 | self.use_gpu = use_gpu
21 | self.size_average = size_average
22 | self.logsoftmax = nn.LogSoftmax(dim=1)
23 |
24 | def forward(self, inputs, targets):
25 | log_probs = self.logsoftmax(inputs)
26 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
27 | if self.use_gpu: targets = targets.cuda()
28 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
29 | if self.size_average:
30 | loss = (- targets * log_probs).mean(0).sum()
31 | else:
32 | loss = (- targets * log_probs).sum(1)
33 | return loss
--------------------------------------------------------------------------------
/digit/network.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torchvision
5 | from torchvision import models
6 | from torch.autograd import Variable
7 | import math
8 | import torch.nn.utils.weight_norm as weightNorm
9 | from collections import OrderedDict
10 |
11 | def init_weights(m):
12 | classname = m.__class__.__name__
13 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
14 | nn.init.kaiming_uniform_(m.weight)
15 | nn.init.zeros_(m.bias)
16 | elif classname.find('BatchNorm') != -1:
17 | nn.init.normal_(m.weight, 1.0, 0.02)
18 | nn.init.zeros_(m.bias)
19 | elif classname.find('Linear') != -1:
20 | nn.init.xavier_normal_(m.weight)
21 | nn.init.zeros_(m.bias)
22 |
23 | class feat_bottleneck(nn.Module):
24 | def __init__(self, feature_dim, bottleneck_dim=256, type="ori"):
25 | super(feat_bottleneck, self).__init__()
26 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True)
27 | self.dropout = nn.Dropout(p=0.5)
28 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim)
29 | self.bottleneck.apply(init_weights)
30 | self.type = type
31 |
32 | def forward(self, x):
33 | x = self.bottleneck(x)
34 | if self.type == "bn":
35 | x = self.bn(x)
36 | x = self.dropout(x)
37 | return x
38 |
39 | class feat_classifier(nn.Module):
40 | def __init__(self, class_num, bottleneck_dim=256, type="linear"):
41 | super(feat_classifier, self).__init__()
42 | if type == "linear":
43 | self.fc = nn.Linear(bottleneck_dim, class_num)
44 | else:
45 | self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight")
46 | self.fc.apply(init_weights)
47 |
48 | def forward(self, x):
49 | x = self.fc(x)
50 | return x
51 |
52 | class DTNBase(nn.Module):
53 | def __init__(self):
54 | super(DTNBase, self).__init__()
55 | self.conv_params = nn.Sequential(
56 | nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
57 | nn.BatchNorm2d(64),
58 | nn.Dropout2d(0.1),
59 | nn.ReLU(),
60 | nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
61 | nn.BatchNorm2d(128),
62 | nn.Dropout2d(0.3),
63 | nn.ReLU(),
64 | nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
65 | nn.BatchNorm2d(256),
66 | nn.Dropout2d(0.5),
67 | nn.ReLU()
68 | )
69 | self.in_features = 256*4*4
70 |
71 | def forward(self, x):
72 | x = self.conv_params(x)
73 | x = x.view(x.size(0), -1)
74 | return x
75 |
76 | class LeNetBase(nn.Module):
77 | def __init__(self):
78 | super(LeNetBase, self).__init__()
79 | self.conv_params = nn.Sequential(
80 | nn.Conv2d(1, 20, kernel_size=5),
81 | nn.MaxPool2d(2),
82 | nn.ReLU(),
83 | nn.Conv2d(20, 50, kernel_size=5),
84 | nn.Dropout2d(p=0.5),
85 | nn.MaxPool2d(2),
86 | nn.ReLU(),
87 | )
88 | self.in_features = 50*4*4
89 |
90 | def forward(self, x):
91 | x = self.conv_params(x)
92 | x = x.view(x.size(0), -1)
93 | return x
--------------------------------------------------------------------------------
/digit/uda_digit.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import torchvision
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torchvision import transforms
10 | import network, loss
11 | from torch.utils.data import DataLoader
12 | import random, pdb, math, copy
13 | from tqdm import tqdm
14 | from scipy.spatial.distance import cdist
15 | import pickle
16 | from data_load import mnist, svhn, usps
17 |
18 | def op_copy(optimizer):
19 | for param_group in optimizer.param_groups:
20 | param_group['lr0'] = param_group['lr']
21 | return optimizer
22 |
23 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
24 | decay = (1 + gamma * iter_num / max_iter) ** (-power)
25 | for param_group in optimizer.param_groups:
26 | param_group['lr'] = param_group['lr0'] * decay
27 | param_group['weight_decay'] = 1e-3
28 | param_group['momentum'] = 0.9
29 | param_group['nesterov'] = True
30 | return optimizer
31 |
32 | def digit_load(args):
33 | train_bs = args.batch_size
34 | if args.dset == 's2m':
35 | train_source = svhn.SVHN('./data/svhn/', split='train', download=True,
36 | transform=transforms.Compose([
37 | transforms.Resize(32),
38 | transforms.ToTensor(),
39 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
40 | ]))
41 | test_source = svhn.SVHN('./data/svhn/', split='test', download=True,
42 | transform=transforms.Compose([
43 | transforms.Resize(32),
44 | transforms.ToTensor(),
45 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
46 | ]))
47 | train_target = mnist.MNIST_idx('./data/mnist/', train=True, download=True,
48 | transform=transforms.Compose([
49 | transforms.Resize(32),
50 | transforms.Lambda(lambda x: x.convert("RGB")),
51 | transforms.ToTensor(),
52 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
53 | ]))
54 | test_target = mnist.MNIST('./data/mnist/', train=False, download=True,
55 | transform=transforms.Compose([
56 | transforms.Resize(32),
57 | transforms.Lambda(lambda x: x.convert("RGB")),
58 | transforms.ToTensor(),
59 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
60 | ]))
61 | elif args.dset == 'u2m':
62 | train_source = usps.USPS('./data/usps/', train=True, download=True,
63 | transform=transforms.Compose([
64 | transforms.RandomCrop(28, padding=4),
65 | transforms.RandomRotation(10),
66 | transforms.ToTensor(),
67 | transforms.Normalize((0.5,), (0.5,))
68 | ]))
69 | test_source = usps.USPS('./data/usps/', train=False, download=True,
70 | transform=transforms.Compose([
71 | transforms.RandomCrop(28, padding=4),
72 | transforms.RandomRotation(10),
73 | transforms.ToTensor(),
74 | transforms.Normalize((0.5,), (0.5,))
75 | ]))
76 | train_target = mnist.MNIST_idx('./data/mnist/', train=True, download=True,
77 | transform=transforms.Compose([
78 | transforms.ToTensor(),
79 | transforms.Normalize((0.5,), (0.5,))
80 | ]))
81 | test_target = mnist.MNIST('./data/mnist/', train=False, download=True,
82 | transform=transforms.Compose([
83 | transforms.ToTensor(),
84 | transforms.Normalize((0.5,), (0.5,))
85 | ]))
86 | elif args.dset == 'm2u':
87 | train_source = mnist.MNIST('./data/mnist/', train=True, download=True,
88 | transform=transforms.Compose([
89 | transforms.ToTensor(),
90 | transforms.Normalize((0.5,), (0.5,))
91 | ]))
92 | test_source = mnist.MNIST('./data/mnist/', train=False, download=True,
93 | transform=transforms.Compose([
94 | transforms.ToTensor(),
95 | transforms.Normalize((0.5,), (0.5,))
96 | ]))
97 |
98 | train_target = usps.USPS_idx('./data/usps/', train=True, download=True,
99 | transform=transforms.Compose([
100 | transforms.ToTensor(),
101 | transforms.Normalize((0.5,), (0.5,))
102 | ]))
103 | test_target = usps.USPS('./data/usps/', train=False, download=True,
104 | transform=transforms.Compose([
105 | transforms.ToTensor(),
106 | transforms.Normalize((0.5,), (0.5,))
107 | ]))
108 |
109 | dset_loaders = {}
110 | dset_loaders["source_tr"] = DataLoader(train_source, batch_size=train_bs, shuffle=True,
111 | num_workers=args.worker, drop_last=False)
112 | dset_loaders["source_te"] = DataLoader(test_source, batch_size=train_bs*2, shuffle=True,
113 | num_workers=args.worker, drop_last=False)
114 | dset_loaders["target"] = DataLoader(train_target, batch_size=train_bs, shuffle=True,
115 | num_workers=args.worker, drop_last=False)
116 | dset_loaders["target_te"] = DataLoader(train_target, batch_size=train_bs, shuffle=False,
117 | num_workers=args.worker, drop_last=False)
118 | dset_loaders["test"] = DataLoader(test_target, batch_size=train_bs*2, shuffle=False,
119 | num_workers=args.worker, drop_last=False)
120 | return dset_loaders
121 |
122 | def cal_acc(loader, netF, netB, netC):
123 | start_test = True
124 | with torch.no_grad():
125 | iter_test = iter(loader)
126 | for i in range(len(loader)):
127 | data = iter_test.next()
128 | inputs = data[0]
129 | labels = data[1]
130 | inputs = inputs.cuda()
131 | outputs = netC(netB(netF(inputs)))
132 | if start_test:
133 | all_output = outputs.float().cpu()
134 | all_label = labels.float()
135 | start_test = False
136 | else:
137 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
138 | all_label = torch.cat((all_label, labels.float()), 0)
139 | _, predict = torch.max(all_output, 1)
140 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
141 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item()
142 | return accuracy*100, mean_ent
143 |
144 | def train_source(args):
145 | dset_loaders = digit_load(args)
146 | ## set base network
147 | if args.dset == 'u2m':
148 | netF = network.LeNetBase().cuda()
149 | elif args.dset == 'm2u':
150 | netF = network.LeNetBase().cuda()
151 | elif args.dset == 's2m':
152 | netF = network.DTNBase().cuda()
153 |
154 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
155 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
156 |
157 | param_group = []
158 | learning_rate = args.lr
159 | for k, v in netF.named_parameters():
160 | param_group += [{'params': v, 'lr': learning_rate}]
161 | for k, v in netB.named_parameters():
162 | param_group += [{'params': v, 'lr': learning_rate}]
163 | for k, v in netC.named_parameters():
164 | param_group += [{'params': v, 'lr': learning_rate}]
165 |
166 | optimizer = optim.SGD(param_group)
167 | optimizer = op_copy(optimizer)
168 |
169 | acc_init = 0
170 | max_iter = args.max_epoch * len(dset_loaders["source_tr"])
171 | interval_iter = max_iter // 10
172 | iter_num = 0
173 |
174 | netF.train()
175 | netB.train()
176 | netC.train()
177 |
178 | while iter_num < max_iter:
179 | try:
180 | inputs_source, labels_source = iter_source.next()
181 | except:
182 | iter_source = iter(dset_loaders["source_tr"])
183 | inputs_source, labels_source = iter_source.next()
184 |
185 | if inputs_source.size(0) == 1:
186 | continue
187 |
188 | iter_num += 1
189 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
190 |
191 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda()
192 | outputs_source = netC(netB(netF(inputs_source)))
193 | classifier_loss = loss.CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source)
194 | optimizer.zero_grad()
195 | classifier_loss.backward()
196 | optimizer.step()
197 |
198 | if iter_num % interval_iter == 0 or iter_num == max_iter:
199 | netF.eval()
200 | netB.eval()
201 | netC.eval()
202 | acc_s_tr, _ = cal_acc(dset_loaders['source_tr'], netF, netB, netC)
203 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC)
204 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%/ {:.2f}%'.format(args.dset, iter_num, max_iter, acc_s_tr, acc_s_te)
205 | args.out_file.write(log_str + '\n')
206 | args.out_file.flush()
207 | print(log_str+'\n')
208 |
209 | if acc_s_te >= acc_init:
210 | acc_init = acc_s_te
211 | best_netF = copy.deepcopy(netF.state_dict())
212 | best_netB = copy.deepcopy(netB.state_dict())
213 | best_netC = copy.deepcopy(netC.state_dict())
214 |
215 | netF.train()
216 | netB.train()
217 | netC.train()
218 |
219 | torch.save(best_netF, osp.join(args.output_dir, "source_F.pt"))
220 | torch.save(best_netB, osp.join(args.output_dir, "source_B.pt"))
221 | torch.save(best_netC, osp.join(args.output_dir, "source_C.pt"))
222 |
223 | return netF, netB, netC
224 |
225 | def test_target(args):
226 | dset_loaders = digit_load(args)
227 | ## set base network
228 | if args.dset == 'u2m':
229 | netF = network.LeNetBase().cuda()
230 | elif args.dset == 'm2u':
231 | netF = network.LeNetBase().cuda()
232 | elif args.dset == 's2m':
233 | netF = network.DTNBase().cuda()
234 |
235 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
236 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
237 |
238 | args.modelpath = args.output_dir + '/source_F.pt'
239 | netF.load_state_dict(torch.load(args.modelpath))
240 | args.modelpath = args.output_dir + '/source_B.pt'
241 | netB.load_state_dict(torch.load(args.modelpath))
242 | args.modelpath = args.output_dir + '/source_C.pt'
243 | netC.load_state_dict(torch.load(args.modelpath))
244 | netF.eval()
245 | netB.eval()
246 | netC.eval()
247 |
248 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC)
249 | log_str = 'Task: {}, Accuracy = {:.2f}%'.format(args.dset, acc)
250 | args.out_file.write(log_str + '\n')
251 | args.out_file.flush()
252 | print(log_str+'\n')
253 |
254 | def print_args(args):
255 | s = "==========================================\n"
256 | for arg, content in args.__dict__.items():
257 | s += "{}:{}\n".format(arg, content)
258 | return s
259 |
260 | def train_target(args):
261 | dset_loaders = digit_load(args)
262 | ## set base network
263 | if args.dset == 'u2m':
264 | netF = network.LeNetBase().cuda()
265 | elif args.dset == 'm2u':
266 | netF = network.LeNetBase().cuda()
267 | elif args.dset == 's2m':
268 | netF = network.DTNBase().cuda()
269 |
270 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
271 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
272 |
273 | args.modelpath = args.output_dir + '/source_F.pt'
274 | netF.load_state_dict(torch.load(args.modelpath))
275 | args.modelpath = args.output_dir + '/source_B.pt'
276 | netB.load_state_dict(torch.load(args.modelpath))
277 | args.modelpath = args.output_dir + '/source_C.pt'
278 | netC.load_state_dict(torch.load(args.modelpath))
279 | netC.eval()
280 | for k, v in netC.named_parameters():
281 | v.requires_grad = False
282 |
283 | param_group = []
284 | for k, v in netF.named_parameters():
285 | param_group += [{'params': v, 'lr': args.lr}]
286 | for k, v in netB.named_parameters():
287 | param_group += [{'params': v, 'lr': args.lr}]
288 |
289 | optimizer = optim.SGD(param_group)
290 | optimizer = op_copy(optimizer)
291 |
292 | max_iter = args.max_epoch * len(dset_loaders["target"])
293 | interval_iter = len(dset_loaders["target"])
294 | # interval_iter = max_iter // args.interval
295 | iter_num = 0
296 |
297 | while iter_num < max_iter:
298 | optimizer.zero_grad()
299 | try:
300 | inputs_test, _, tar_idx = iter_test.next()
301 | except:
302 | iter_test = iter(dset_loaders["target"])
303 | inputs_test, _, tar_idx = iter_test.next()
304 |
305 | if inputs_test.size(0) == 1:
306 | continue
307 |
308 | if iter_num % interval_iter == 0 and args.cls_par > 0:
309 | netF.eval()
310 | netB.eval()
311 | mem_label = obtain_label(dset_loaders['target_te'], netF, netB, netC, args)
312 | mem_label = torch.from_numpy(mem_label).cuda()
313 | netF.train()
314 | netB.train()
315 |
316 | iter_num += 1
317 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
318 |
319 | inputs_test = inputs_test.cuda()
320 | features_test = netB(netF(inputs_test))
321 | outputs_test = netC(features_test)
322 |
323 | if args.cls_par > 0:
324 | pred = mem_label[tar_idx]
325 | classifier_loss = args.cls_par * nn.CrossEntropyLoss()(outputs_test, pred)
326 | else:
327 | classifier_loss = torch.tensor(0.0).cuda()
328 |
329 | if args.ent:
330 | softmax_out = nn.Softmax(dim=1)(outputs_test)
331 | entropy_loss = torch.mean(loss.Entropy(softmax_out))
332 | if args.gent:
333 | msoftmax = softmax_out.mean(dim=0)
334 | entropy_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5))
335 |
336 | im_loss = entropy_loss * args.ent_par
337 | classifier_loss += im_loss
338 |
339 | optimizer.zero_grad()
340 | classifier_loss.backward()
341 | optimizer.step()
342 |
343 | if iter_num % interval_iter == 0 or iter_num == max_iter:
344 | netF.eval()
345 | netB.eval()
346 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC)
347 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.dset, iter_num, max_iter, acc)
348 | args.out_file.write(log_str + '\n')
349 | args.out_file.flush()
350 | print(log_str+'\n')
351 | netF.train()
352 | netB.train()
353 |
354 | if args.issave:
355 | torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt"))
356 | torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt"))
357 | torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt"))
358 |
359 | return netF, netB, netC
360 |
361 | def obtain_label(loader, netF, netB, netC, args, c=None):
362 | start_test = True
363 | with torch.no_grad():
364 | iter_test = iter(loader)
365 | for _ in range(len(loader)):
366 | data = iter_test.next()
367 | inputs = data[0]
368 | labels = data[1]
369 | inputs = inputs.cuda()
370 | feas = netB(netF(inputs))
371 | outputs = netC(feas)
372 | if start_test:
373 | all_fea = feas.float().cpu()
374 | all_output = outputs.float().cpu()
375 | all_label = labels.float()
376 | start_test = False
377 | else:
378 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0)
379 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
380 | all_label = torch.cat((all_label, labels.float()), 0)
381 | all_output = nn.Softmax(dim=1)(all_output)
382 | _, predict = torch.max(all_output, 1)
383 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
384 |
385 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
386 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
387 | all_fea = all_fea.float().cpu().numpy()
388 |
389 | K = all_output.size(1)
390 | aff = all_output.float().cpu().numpy()
391 | initc = aff.transpose().dot(all_fea)
392 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
393 | dd = cdist(all_fea, initc, 'cosine')
394 | pred_label = dd.argmin(axis=1)
395 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
396 |
397 | for round in range(1):
398 | aff = np.eye(K)[pred_label]
399 | initc = aff.transpose().dot(all_fea)
400 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
401 | dd = cdist(all_fea, initc, 'cosine')
402 | pred_label = dd.argmin(axis=1)
403 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
404 |
405 | log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy*100, acc*100)
406 | args.out_file.write(log_str + '\n')
407 | args.out_file.flush()
408 | print(log_str+'\n')
409 | return pred_label.astype('int')
410 |
411 | if __name__ == "__main__":
412 | parser = argparse.ArgumentParser(description='SHOT')
413 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
414 | parser.add_argument('--s', type=int, default=0, help="source")
415 | parser.add_argument('--t', type=int, default=1, help="target")
416 | parser.add_argument('--max_epoch', type=int, default=30, help="maximum epoch")
417 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
418 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
419 | parser.add_argument('--dset', type=str, default='s2m', choices=['u2m', 'm2u','s2m'])
420 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
421 | parser.add_argument('--seed', type=int, default=2020, help="random seed")
422 | parser.add_argument('--cls_par', type=float, default=0.3)
423 | parser.add_argument('--ent_par', type=float, default=1.0)
424 | parser.add_argument('--gent', type=bool, default=True)
425 | parser.add_argument('--ent', type=bool, default=True)
426 | parser.add_argument('--bottleneck', type=int, default=256)
427 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
428 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
429 | parser.add_argument('--smooth', type=float, default=0.1)
430 | parser.add_argument('--output', type=str, default='')
431 | parser.add_argument('--issave', type=bool, default=True)
432 | args = parser.parse_args()
433 | args.class_num = 10
434 |
435 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
436 | SEED = args.seed
437 | torch.manual_seed(SEED)
438 | torch.cuda.manual_seed(SEED)
439 | np.random.seed(SEED)
440 | random.seed(SEED)
441 | # torch.backends.cudnn.deterministic = True
442 |
443 | args.output_dir = osp.join(args.output, 'seed' + str(args.seed), args.dset)
444 | if not osp.exists(args.output_dir):
445 | os.system('mkdir -p ' + args.output_dir)
446 | if not osp.exists(args.output_dir):
447 | os.mkdir(args.output_dir)
448 |
449 | if not osp.exists(osp.join(args.output_dir + '/source_F.pt')):
450 | args.out_file = open(osp.join(args.output_dir, 'log_src.txt'), 'w')
451 | args.out_file.write(print_args(args)+'\n')
452 | args.out_file.flush()
453 | train_source(args)
454 | test_target(args)
455 |
456 | args.savename = 'par_' + str(args.cls_par)
457 | args.out_file = open(osp.join(args.output_dir, 'log_tar_' + args.savename + '.txt'), 'w')
458 | args.out_file.write(print_args(args)+'\n')
459 | args.out_file.flush()
460 | train_target(args)
461 |
--------------------------------------------------------------------------------
/figs/shot.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT/f7d555a0d53b525b885e5ef2a887a267a5be3c36/figs/shot.jpg
--------------------------------------------------------------------------------
/object/data_list.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import random
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | import os
7 | import os.path
8 | import cv2
9 | import torchvision
10 |
11 | def make_dataset(image_list, labels):
12 | if labels:
13 | len_ = len(image_list)
14 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
15 | else:
16 | if len(image_list[0].split()) > 2:
17 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
18 | else:
19 | images = [(val.split()[0], int(val.split()[1])) for val in image_list]
20 | return images
21 |
22 |
23 | def rgb_loader(path):
24 | with open(path, 'rb') as f:
25 | with Image.open(f) as img:
26 | return img.convert('RGB')
27 |
28 | def l_loader(path):
29 | with open(path, 'rb') as f:
30 | with Image.open(f) as img:
31 | return img.convert('L')
32 |
33 | class ImageList(Dataset):
34 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
35 | imgs = make_dataset(image_list, labels)
36 | if len(imgs) == 0:
37 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
38 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
39 |
40 | self.imgs = imgs
41 | self.transform = transform
42 | self.target_transform = target_transform
43 | if mode == 'RGB':
44 | self.loader = rgb_loader
45 | elif mode == 'L':
46 | self.loader = l_loader
47 |
48 | def __getitem__(self, index):
49 | path, target = self.imgs[index]
50 | img = self.loader(path)
51 | if self.transform is not None:
52 | img = self.transform(img)
53 | if self.target_transform is not None:
54 | target = self.target_transform(target)
55 |
56 | return img, target
57 |
58 | def __len__(self):
59 | return len(self.imgs)
60 |
61 | class ImageList_idx(Dataset):
62 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
63 | imgs = make_dataset(image_list, labels)
64 | if len(imgs) == 0:
65 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
66 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
67 |
68 | self.imgs = imgs
69 | self.transform = transform
70 | self.target_transform = target_transform
71 | if mode == 'RGB':
72 | self.loader = rgb_loader
73 | elif mode == 'L':
74 | self.loader = l_loader
75 |
76 | def __getitem__(self, index):
77 | path, target = self.imgs[index]
78 | img = self.loader(path)
79 | if self.transform is not None:
80 | img = self.transform(img)
81 | if self.target_transform is not None:
82 | target = self.target_transform(target)
83 |
84 | return img, target, index
85 |
86 | def __len__(self):
87 | return len(self.imgs)
--------------------------------------------------------------------------------
/object/image_multisource.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import torchvision
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torchvision import transforms
10 | import network, loss
11 | from torch.utils.data import DataLoader
12 | from data_list import ImageList
13 | import random, pdb, math, copy
14 | from tqdm import tqdm
15 | from sklearn.metrics import confusion_matrix
16 |
17 | def image_train(resize_size=256, crop_size=224, alexnet=False):
18 | if not alexnet:
19 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
20 | std=[0.229, 0.224, 0.225])
21 | else:
22 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
23 | return transforms.Compose([
24 | transforms.Resize((resize_size, resize_size)),
25 | transforms.RandomCrop(crop_size),
26 | transforms.RandomHorizontalFlip(),
27 | transforms.ToTensor(),
28 | normalize
29 | ])
30 |
31 | def image_test(resize_size=256, crop_size=224, alexnet=False):
32 | if not alexnet:
33 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
34 | std=[0.229, 0.224, 0.225])
35 | else:
36 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
37 | return transforms.Compose([
38 | transforms.Resize((resize_size, resize_size)),
39 | transforms.CenterCrop(crop_size),
40 | transforms.ToTensor(),
41 | normalize
42 | ])
43 |
44 | def data_load(args):
45 | ## prepare data
46 | dsets = {}
47 | dset_loaders = {}
48 | train_bs = args.batch_size
49 | txt_tar = open(args.t_dset_path).readlines()
50 | txt_test = open(args.test_dset_path).readlines()
51 |
52 | if not args.da == 'uda':
53 | label_map_s = {}
54 | for i in range(len(args.src_classes)):
55 | label_map_s[args.src_classes[i]] = i
56 |
57 | new_tar = []
58 | for i in range(len(txt_tar)):
59 | rec = txt_tar[i]
60 | reci = rec.strip().split(' ')
61 | if int(reci[1]) in args.tar_classes:
62 | if int(reci[1]) in args.src_classes:
63 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'
64 | new_tar.append(line)
65 | else:
66 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n'
67 | new_tar.append(line)
68 | txt_tar = new_tar.copy()
69 | txt_test = txt_tar.copy()
70 |
71 | dsets["target"] = ImageList(txt_tar, transform=image_test())
72 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
73 | dsets["test"] = ImageList(txt_test, transform=image_test())
74 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False)
75 |
76 | return dset_loaders
77 |
78 | def cal_acc(loader, netF, netB, netC, flag=False):
79 | start_test = True
80 | with torch.no_grad():
81 | iter_test = iter(loader)
82 | for i in range(len(loader)):
83 | data = iter_test.next()
84 | inputs = data[0]
85 | labels = data[1]
86 | inputs = inputs.cuda()
87 | outputs = netC(netB(netF(inputs)))
88 | if start_test:
89 | all_output = outputs.float().cpu()
90 | all_label = labels.float()
91 | start_test = False
92 | else:
93 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
94 | all_label = torch.cat((all_label, labels.float()), 0)
95 | _, predict = torch.max(all_output, 1)
96 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
97 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item()
98 |
99 | return accuracy, all_label, nn.Softmax(dim=1)(all_output)
100 |
101 | def print_args(args):
102 | s = "==========================================\n"
103 | for arg, content in args.__dict__.items():
104 | s += "{}:{}\n".format(arg, content)
105 | return s
106 |
107 | def test_target_srconly(args):
108 | dset_loaders = data_load(args)
109 | ## set base network
110 | if args.net[0:3] == 'res':
111 | netF = network.ResBase(res_name=args.net).cuda()
112 | elif args.net[0:3] == 'vgg':
113 | netF = network.VGGBase(vgg_name=args.net).cuda()
114 |
115 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
116 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
117 |
118 | args.modelpath = args.output_dir_src + '/source_F.pt'
119 | netF.load_state_dict(torch.load(args.modelpath))
120 | args.modelpath = args.output_dir_src + '/source_B.pt'
121 | netB.load_state_dict(torch.load(args.modelpath))
122 | args.modelpath = args.output_dir_src + '/source_C.pt'
123 | netC.load_state_dict(torch.load(args.modelpath))
124 | netF.eval()
125 | netB.eval()
126 | netC.eval()
127 |
128 | acc, y, py = cal_acc(dset_loaders['test'], netF, netB, netC)
129 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc*100)
130 | args.out_file.write(log_str)
131 | args.out_file.flush()
132 | print(log_str)
133 |
134 | return y, py
135 |
136 | def test_target(args):
137 | dset_loaders = data_load(args)
138 | ## set base network
139 | if args.net[0:3] == 'res':
140 | netF = network.ResBase(res_name=args.net).cuda()
141 | elif args.net[0:3] == 'vgg':
142 | netF = network.VGGBase(vgg_name=args.net).cuda()
143 |
144 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
145 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
146 |
147 | args.modelpath = args.output_dir_ori + "/target_F_" + args.savename + ".pt"
148 | netF.load_state_dict(torch.load(args.modelpath))
149 | args.modelpath = args.output_dir_ori + "/target_B_" + args.savename + ".pt"
150 | netB.load_state_dict(torch.load(args.modelpath))
151 | args.modelpath = args.output_dir_ori + "/target_C_" + args.savename + ".pt"
152 | netC.load_state_dict(torch.load(args.modelpath))
153 | netF.eval()
154 | netB.eval()
155 | netC.eval()
156 |
157 | acc, y, py = cal_acc(dset_loaders['test'], netF, netB, netC)
158 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc*100)
159 | args.out_file.write(log_str)
160 | args.out_file.flush()
161 | print(log_str)
162 |
163 | return y, py
164 |
165 | if __name__ == "__main__":
166 | parser = argparse.ArgumentParser(description='SHOT')
167 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
168 | parser.add_argument('--s', type=int, default=0, help="source")
169 | parser.add_argument('--t', type=int, default=1, help="target")
170 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
171 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
172 | parser.add_argument('--dset', type=str, default='office-caltech', choices=['office-caltech'])
173 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
174 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101")
175 | parser.add_argument('--seed', type=int, default=2020, help="random seed")
176 |
177 | parser.add_argument('--threshold', type=int, default=0)
178 | parser.add_argument('--cls_par', type=float, default=0.3)
179 | parser.add_argument('--bottleneck', type=int, default=256)
180 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
181 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
182 | parser.add_argument('--output', type=str, default='san')
183 | parser.add_argument('--output_src', type=str, default='ckps')
184 | parser.add_argument('--da', type=str, default='uda', choices=['uda'])
185 | args = parser.parse_args()
186 |
187 | if args.dset == 'office-caltech':
188 | names = ['amazon', 'caltech', 'dslr', 'webcam']
189 | args.class_num = 10
190 |
191 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
192 | SEED = args.seed
193 | torch.manual_seed(SEED)
194 | torch.cuda.manual_seed(SEED)
195 | np.random.seed(SEED)
196 | random.seed(SEED)
197 | # torch.backends.cudnn.deterministic = True
198 |
199 | score_srconly = 0
200 | score = 0
201 |
202 | args.output_dir = osp.join(args.output, args.da, args.dset, str(0)+names[args.t][0].upper())
203 | if not osp.exists(args.output_dir):
204 | os.system('mkdir -p ' + args.output_dir)
205 | if not osp.exists(args.output_dir):
206 | os.mkdir(args.output_dir)
207 |
208 | args.savename = 'par_' + str(args.cls_par)
209 | args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w')
210 | args.out_file.write(print_args(args)+'\n')
211 | args.out_file.flush()
212 |
213 | for i in range(len(names)):
214 | if i == args.t:
215 | continue
216 | args.s = i
217 |
218 | folder = './data/'
219 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt'
220 | args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
221 | args.test_dset_path = args.t_dset_path
222 |
223 | args.output_dir_src = osp.join(args.output_src, args.da, args.dset, names[args.s][0].upper())
224 | args.output_dir_ori = osp.join(args.output, args.da, args.dset, names[args.s][0].upper()+names[args.t][0].upper())
225 | args.name = names[args.s][0].upper() + names[args.t][0].upper()
226 |
227 | label, output_srconly = test_target_srconly(args)
228 | score_srconly += output_srconly
229 |
230 | _, output = test_target(args)
231 | score += output
232 |
233 | _, predict = torch.max(score_srconly, 1)
234 | acc = torch.sum(torch.squeeze(predict).float() == label).item() / float(label.size()[0])
235 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format('->' + names[args.t][0].upper(), acc*100)
236 | args.out_file.write(log_str)
237 | args.out_file.flush()
238 | print(log_str)
239 |
240 | _, predict = torch.max(score, 1)
241 | acc = torch.sum(torch.squeeze(predict).float() == label).item() / float(label.size()[0])
242 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format('->' + names[args.t][0].upper(), acc*100)
243 | args.out_file.write(log_str)
244 | args.out_file.flush()
245 | print(log_str)
--------------------------------------------------------------------------------
/object/image_multitarget.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import torchvision
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torchvision import transforms
10 | import network, loss
11 | from torch.utils.data import DataLoader
12 | from data_list import ImageList, ImageList_idx
13 | import random, pdb, math, copy
14 | from tqdm import tqdm
15 | from scipy.spatial.distance import cdist
16 | from sklearn.metrics import confusion_matrix
17 | import rotation
18 |
19 | def op_copy(optimizer):
20 | for param_group in optimizer.param_groups:
21 | param_group['lr0'] = param_group['lr']
22 | return optimizer
23 |
24 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
25 | decay = (1 + gamma * iter_num / max_iter) ** (-power)
26 | for param_group in optimizer.param_groups:
27 | param_group['lr'] = param_group['lr0'] * decay
28 | param_group['weight_decay'] = 1e-3
29 | param_group['momentum'] = 0.9
30 | param_group['nesterov'] = True
31 | return optimizer
32 |
33 | def image_train(resize_size=256, crop_size=224, alexnet=False):
34 | if not alexnet:
35 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
36 | std=[0.229, 0.224, 0.225])
37 | else:
38 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
39 | return transforms.Compose([
40 | transforms.Resize((resize_size, resize_size)),
41 | transforms.RandomCrop(crop_size),
42 | transforms.RandomHorizontalFlip(),
43 | transforms.ToTensor(),
44 | normalize
45 | ])
46 |
47 | def image_test(resize_size=256, crop_size=224, alexnet=False):
48 | if not alexnet:
49 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
50 | std=[0.229, 0.224, 0.225])
51 | else:
52 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
53 | return transforms.Compose([
54 | transforms.Resize((resize_size, resize_size)),
55 | transforms.CenterCrop(crop_size),
56 | transforms.ToTensor(),
57 | normalize
58 | ])
59 |
60 | def data_load(args):
61 | ## prepare data
62 | dsets = {}
63 | dset_loaders = {}
64 | train_bs = args.batch_size
65 | txt_src = open(args.s_dset_path).readlines()
66 |
67 | txt_tar = []
68 | for i in range(len(args.t_dset_path)):
69 | tmp = open(args.t_dset_path[i]).readlines()
70 | txt_tar.extend(tmp)
71 | txt_test = txt_tar.copy()
72 |
73 | dsets["target"] = ImageList_idx(txt_tar, transform=image_train())
74 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
75 | dsets["test"] = ImageList(txt_test, transform=image_test())
76 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=False, num_workers=args.worker, drop_last=False)
77 |
78 | return dset_loaders
79 |
80 | def cal_acc(loader, netF, netB, netC, flag=False):
81 | start_test = True
82 | with torch.no_grad():
83 | iter_test = iter(loader)
84 | for i in range(len(loader)):
85 | data = iter_test.next()
86 | inputs = data[0]
87 | labels = data[1]
88 | inputs = inputs.cuda()
89 | outputs = netC(netB(netF(inputs)))
90 | if start_test:
91 | all_output = outputs.float().cpu()
92 | all_label = labels.float()
93 | start_test = False
94 | else:
95 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
96 | all_label = torch.cat((all_label, labels.float()), 0)
97 | _, predict = torch.max(all_output, 1)
98 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
99 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item()
100 |
101 | if flag:
102 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
103 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100
104 | aacc = acc.mean()
105 | aa = [str(np.round(i, 2)) for i in acc]
106 | acc = ' '.join(aa)
107 | return aacc, acc
108 | else:
109 | return accuracy*100, mean_ent
110 |
111 | def print_args(args):
112 | s = "==========================================\n"
113 | for arg, content in args.__dict__.items():
114 | s += "{}:{}\n".format(arg, content)
115 | return s
116 |
117 | def train_target(args):
118 | dset_loaders = data_load(args)
119 | ## set base network
120 | if args.net[0:3] == 'res':
121 | netF = network.ResBase(res_name=args.net).cuda()
122 | elif args.net[0:3] == 'vgg':
123 | netF = network.VGGBase(vgg_name=args.net).cuda()
124 |
125 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
126 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
127 |
128 | args.modelpath = args.output_dir_src + '/source_F.pt'
129 | netF.load_state_dict(torch.load(args.modelpath))
130 | args.modelpath = args.output_dir_src + '/source_B.pt'
131 | netB.load_state_dict(torch.load(args.modelpath))
132 | args.modelpath = args.output_dir_src + '/source_C.pt'
133 | netC.load_state_dict(torch.load(args.modelpath))
134 | netC.eval()
135 | for k, v in netC.named_parameters():
136 | v.requires_grad = False
137 |
138 | param_group = []
139 | for k, v in netF.named_parameters():
140 | if args.lr_decay1 > 0:
141 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
142 | else:
143 | v.requires_grad = False
144 | for k, v in netB.named_parameters():
145 | if args.lr_decay2 > 0:
146 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
147 | else:
148 | v.requires_grad = False
149 | optimizer = optim.SGD(param_group)
150 | optimizer = op_copy(optimizer)
151 |
152 | max_iter = args.max_epoch * len(dset_loaders["target"])
153 | interval_iter = max_iter // args.interval
154 | iter_num = 0
155 |
156 | while iter_num < max_iter:
157 | try:
158 | inputs_test, _, tar_idx = iter_test.next()
159 | except:
160 | iter_test = iter(dset_loaders["target"])
161 | inputs_test, _, tar_idx = iter_test.next()
162 |
163 | if inputs_test.size(0) == 1:
164 | continue
165 |
166 | if iter_num % interval_iter == 0 and args.cls_par > 0:
167 | netF.eval()
168 | netB.eval()
169 | mem_label = obtain_label(dset_loaders['test'], netF, netB, netC, args)
170 | mem_label = torch.from_numpy(mem_label).cuda()
171 | netF.train()
172 | netB.train()
173 |
174 | inputs_test = inputs_test.cuda()
175 |
176 | iter_num += 1
177 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
178 |
179 | features_test = netB(netF(inputs_test))
180 | outputs_test = netC(features_test)
181 |
182 | if args.cls_par > 0:
183 | pred = mem_label[tar_idx]
184 | classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred)
185 | classifier_loss *= args.cls_par
186 | else:
187 | classifier_loss = torch.tensor(0.0).cuda()
188 |
189 | if args.ent:
190 | softmax_out = nn.Softmax(dim=1)(outputs_test)
191 | entropy_loss = torch.mean(loss.Entropy(softmax_out))
192 | if args.gent:
193 | msoftmax = softmax_out.mean(dim=0)
194 | gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon))
195 | entropy_loss -= gentropy_loss
196 | classifier_loss += entropy_loss * args.ent_par
197 |
198 | optimizer.zero_grad()
199 | classifier_loss.backward()
200 | optimizer.step()
201 |
202 | if iter_num % interval_iter == 0 or iter_num == max_iter:
203 | netF.eval()
204 | netB.eval()
205 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False)
206 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc)
207 | args.out_file.write(log_str + '\n')
208 | args.out_file.flush()
209 | print(log_str+'\n')
210 | netF.train()
211 | netB.train()
212 |
213 | if args.issave:
214 | torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt"))
215 | torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt"))
216 | torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt"))
217 |
218 | return netF, netB, netC
219 |
220 | def obtain_label(loader, netF, netB, netC, args):
221 | start_test = True
222 | with torch.no_grad():
223 | iter_test = iter(loader)
224 | for _ in range(len(loader)):
225 | data = iter_test.next()
226 | inputs = data[0]
227 | labels = data[1]
228 | inputs = inputs.cuda()
229 | feas = netB(netF(inputs))
230 | outputs = netC(feas)
231 | if start_test:
232 | all_fea = feas.float().cpu()
233 | all_output = outputs.float().cpu()
234 | all_label = labels.float()
235 | start_test = False
236 | else:
237 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0)
238 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
239 | all_label = torch.cat((all_label, labels.float()), 0)
240 |
241 | all_output = nn.Softmax(dim=1)(all_output)
242 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1)
243 | unknown_weight = 1 - ent / np.log(args.class_num)
244 | _, predict = torch.max(all_output, 1)
245 |
246 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
247 | if args.distance == 'cosine':
248 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
249 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
250 |
251 | all_fea = all_fea.float().cpu().numpy()
252 | K = all_output.size(1)
253 | aff = all_output.float().cpu().numpy()
254 | initc = aff.transpose().dot(all_fea)
255 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
256 | cls_count = np.eye(K)[predict].sum(axis=0)
257 | labelset = np.where(cls_count>args.threshold)
258 | labelset = labelset[0]
259 | # print(labelset)
260 |
261 | dd = cdist(all_fea, initc[labelset], args.distance)
262 | pred_label = dd.argmin(axis=1)
263 | pred_label = labelset[pred_label]
264 |
265 | for round in range(1):
266 | aff = np.eye(K)[pred_label]
267 | initc = aff.transpose().dot(all_fea)
268 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
269 | dd = cdist(all_fea, initc[labelset], args.distance)
270 | pred_label = dd.argmin(axis=1)
271 | pred_label = labelset[pred_label]
272 |
273 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
274 | log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy*100, acc*100)
275 |
276 | args.out_file.write(log_str + '\n')
277 | args.out_file.flush()
278 | print(log_str+'\n')
279 |
280 | return pred_label.astype('int') #, labelset
281 |
282 | if __name__ == "__main__":
283 | parser = argparse.ArgumentParser(description='Conditional Domain Adversarial Network')
284 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
285 | parser.add_argument('--s', type=int, default=0, help="source")
286 | parser.add_argument('--t', type=int, default=1, help="target")
287 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations")
288 | parser.add_argument('--interval', type=int, default=15)
289 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
290 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
291 | parser.add_argument('--dset', type=str, default='office-caltech', choices=['office-caltech'])
292 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
293 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101")
294 | parser.add_argument('--seed', type=int, default=2020, help="random seed")
295 |
296 | parser.add_argument('--gent', type=bool, default=True)
297 | parser.add_argument('--ent', type=bool, default=True)
298 | parser.add_argument('--threshold', type=int, default=-1)
299 | parser.add_argument('--cls_par', type=float, default=0.3)
300 | parser.add_argument('--ent_par', type=float, default=1.0)
301 | parser.add_argument('--lr_decay1', type=float, default=0.1)
302 | parser.add_argument('--lr_decay2', type=float, default=1.0)
303 |
304 | parser.add_argument('--bottleneck', type=int, default=256)
305 | parser.add_argument('--epsilon', type=float, default=1e-5)
306 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
307 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
308 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"])
309 | parser.add_argument('--output', type=str, default='san')
310 | parser.add_argument('--output_src', type=str, default='ckps')
311 | parser.add_argument('--da', type=str, default='uda', choices=['uda'])
312 | parser.add_argument('--issave', type=bool, default=True)
313 | args = parser.parse_args()
314 |
315 | if args.dset == 'office-caltech':
316 | names = ['amazon', 'caltech', 'dslr', 'webcam']
317 | args.class_num = 10
318 |
319 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
320 | SEED = args.seed
321 | torch.manual_seed(SEED)
322 | torch.cuda.manual_seed(SEED)
323 | np.random.seed(SEED)
324 | random.seed(SEED)
325 | # torch.backends.cudnn.deterministic = True
326 |
327 | t_dset = []
328 | for i in range(len(names)):
329 | if i == args.s:
330 | continue
331 | args.t = i
332 |
333 | folder = './data/'
334 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt'
335 | args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
336 | t_dset.append(args.t_dset_path)
337 |
338 | args.t_dset_path = t_dset
339 | args.test_dset_path = args.t_dset_path
340 |
341 | args.output_dir_src = osp.join(args.output_src, args.da, args.dset, names[args.s][0].upper())
342 | args.output_dir = osp.join(args.output, args.da, args.dset, names[args.s][0].upper() + str(0))
343 | args.name = names[args.s][0].upper() + str(0)
344 |
345 | if not osp.exists(args.output_dir):
346 | os.system('mkdir -p ' + args.output_dir)
347 | if not osp.exists(args.output_dir):
348 | os.mkdir(args.output_dir)
349 |
350 | args.savename = 'par_' + str(args.cls_par)
351 | args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w')
352 | args.out_file.write(print_args(args)+'\n')
353 | args.out_file.flush()
354 | train_target(args)
--------------------------------------------------------------------------------
/object/image_pretrained.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import torchvision
5 | from torchvision import transforms
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import network, loss
11 | from torch.utils.data import DataLoader
12 | from data_list import ImageList, ImageList_idx
13 | import random, pdb, math, copy
14 | from tqdm import tqdm
15 | from scipy.spatial.distance import cdist
16 | from sklearn.metrics import confusion_matrix
17 |
18 | def op_copy(optimizer):
19 | for param_group in optimizer.param_groups:
20 | param_group['lr0'] = param_group['lr']
21 | return optimizer
22 |
23 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
24 | decay = (1 + gamma * iter_num / max_iter) ** (-power)
25 | for param_group in optimizer.param_groups:
26 | param_group['lr'] = param_group['lr0'] * decay
27 | param_group['weight_decay'] = 1e-3
28 | param_group['momentum'] = 0.9
29 | param_group['nesterov'] = True
30 | return optimizer
31 |
32 | def image_train(resize_size=256, crop_size=224, alexnet=False):
33 | if not alexnet:
34 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
35 | std=[0.229, 0.224, 0.225])
36 | else:
37 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
38 | return transforms.Compose([
39 | transforms.Resize((resize_size, resize_size)),
40 | transforms.RandomCrop(crop_size),
41 | transforms.RandomHorizontalFlip(),
42 | transforms.ToTensor(),
43 | normalize
44 | ])
45 |
46 | def image_test(resize_size=256, crop_size=224, alexnet=False):
47 | if not alexnet:
48 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
49 | std=[0.229, 0.224, 0.225])
50 | else:
51 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
52 | return transforms.Compose([
53 | transforms.Resize((resize_size, resize_size)),
54 | transforms.CenterCrop(crop_size),
55 | transforms.ToTensor(),
56 | normalize
57 | ])
58 |
59 | def data_load(args):
60 | dsets = {}
61 | dset_loaders = {}
62 | train_bs = args.batch_size
63 | txt_tar = open(args.t_dset_path).readlines()
64 | txt_test = open(args.test_dset_path).readlines()
65 |
66 | dsets["target"] = ImageList_idx(txt_tar, transform=image_train())
67 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
68 | dsets["test"] = ImageList_idx(txt_test, transform=image_test())
69 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False)
70 |
71 | return dset_loaders
72 |
73 | def cal_acc(loader, net, flag=False):
74 | start_test = True
75 | with torch.no_grad():
76 | iter_test = iter(loader)
77 | for i in range(len(loader)):
78 | data = iter_test.next()
79 | inputs = data[0]
80 | labels = data[1]
81 | inputs = inputs.cuda()
82 | _, outputs = net(inputs)
83 | if start_test:
84 | all_output = outputs.float().cpu()
85 | all_label = labels.float()
86 | start_test = False
87 | else:
88 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
89 | all_label = torch.cat((all_label, labels.float()), 0)
90 | _, predict = torch.max(all_output, 1)
91 | all_output = nn.Softmax(dim=1)(all_output)
92 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) / np.log(all_output.size(1))
93 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
94 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item()
95 |
96 | return accuracy, mean_ent
97 |
98 | def train_target(args):
99 | dset_loaders = data_load(args)
100 | netF = network.Res50().cuda()
101 |
102 | param_group = []
103 | for k, v in netF.named_parameters():
104 | if k.__contains__("fc"):
105 | v.requires_grad = False
106 | else:
107 | param_group += [{'params': v, 'lr': args.lr*args.lr_decay1}]
108 |
109 | optimizer = optim.SGD(param_group)
110 | optimizer = op_copy(optimizer)
111 |
112 | max_iter = args.max_epoch * len(dset_loaders["target"])
113 | interval_iter = max_iter // args.interval
114 | iter_num = 0
115 |
116 | netF.train()
117 | while iter_num < max_iter:
118 | try:
119 | inputs_test, _, tar_idx = iter_test.next()
120 | except:
121 | iter_test = iter(dset_loaders["target"])
122 | inputs_test, _, tar_idx = iter_test.next()
123 |
124 | if inputs_test.size(0) == 1:
125 | continue
126 |
127 | if iter_num % interval_iter == 0 and args.cls_par > 0:
128 | netF.eval()
129 | mem_label = obtain_label(dset_loaders['test'], netF, args)
130 | mem_label = torch.from_numpy(mem_label).cuda()
131 | netF.train()
132 |
133 | inputs_test = inputs_test.cuda()
134 | iter_num += 1
135 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
136 |
137 | features_test, outputs_test = netF(inputs_test)
138 |
139 | if args.cls_par > 0:
140 | pred = mem_label[tar_idx]
141 | classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred)
142 | classifier_loss *= args.cls_par
143 | else:
144 | classifier_loss = torch.tensor(0.0).cuda()
145 |
146 | if args.ent:
147 | softmax_out = nn.Softmax(dim=1)(outputs_test)
148 | entropy_loss = torch.mean(loss.Entropy(softmax_out))
149 | if args.gent:
150 | msoftmax = softmax_out.mean(dim=0)
151 | gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon))
152 | entropy_loss -= gentropy_loss
153 | classifier_loss += entropy_loss * args.ent_par
154 |
155 | optimizer.zero_grad()
156 | classifier_loss.backward()
157 | optimizer.step()
158 |
159 | if iter_num % interval_iter == 0 or iter_num == max_iter:
160 | netF.eval()
161 | acc, ment = cal_acc(dset_loaders['test'], netF)
162 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.dset, iter_num, max_iter, acc*100)
163 | args.out_file.write(log_str + '\n')
164 | args.out_file.flush()
165 | print(log_str+'\n')
166 | netF.train()
167 |
168 | if args.issave:
169 | torch.save(netF.state_dict(), osp.join(args.output_dir, "target" + args.savename + ".pt"))
170 |
171 | return netF
172 |
173 | def print_args(args):
174 | s = "==========================================\n"
175 | for arg, content in args.__dict__.items():
176 | s += "{}:{}\n".format(arg, content)
177 | return s
178 |
179 | def obtain_label(loader, net, args):
180 | start_test = True
181 | with torch.no_grad():
182 | iter_test = iter(loader)
183 | for _ in range(len(loader)):
184 | data = iter_test.next()
185 | inputs = data[0]
186 | labels = data[1]
187 | inputs = inputs.cuda()
188 | feas, outputs = net(inputs)
189 | if start_test:
190 | all_fea = feas.float().cpu()
191 | all_output = outputs.float().cpu()
192 | all_label = labels.float()
193 | start_test = False
194 | else:
195 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0)
196 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
197 | all_label = torch.cat((all_label, labels.float()), 0)
198 |
199 | all_output = nn.Softmax(dim=1)(all_output)
200 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1)
201 | unknown_weight = 1 - ent / np.log(args.class_num)
202 | _, predict = torch.max(all_output, 1)
203 |
204 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
205 | if args.distance == 'cosine':
206 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
207 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
208 |
209 | all_fea = all_fea.float().cpu().numpy()
210 | K = all_output.size(1)
211 | aff = all_output.float().cpu().numpy()
212 | initc = aff.transpose().dot(all_fea)
213 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
214 | cls_count = np.eye(K)[predict].sum(axis=0)
215 | labelset = np.where(cls_count>args.threshold)
216 | labelset = labelset[0]
217 | # print(labelset)
218 |
219 | dd = cdist(all_fea, initc[labelset], args.distance)
220 | pred_label = dd.argmin(axis=1)
221 | pred_label = labelset[pred_label]
222 |
223 | for round in range(1):
224 | aff = np.eye(K)[pred_label]
225 | initc = aff.transpose().dot(all_fea)
226 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
227 | dd = cdist(all_fea, initc[labelset], args.distance)
228 | pred_label = dd.argmin(axis=1)
229 | pred_label = labelset[pred_label]
230 |
231 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
232 | log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy*100, acc*100)
233 |
234 | args.out_file.write(log_str + '\n')
235 | args.out_file.flush()
236 | print(log_str+'\n')
237 |
238 | return pred_label.astype('int') #, labelset
239 |
240 |
241 | if __name__ == "__main__":
242 | parser = argparse.ArgumentParser(description='SHOT')
243 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
244 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations")
245 | parser.add_argument('--interval', type=int, default=15, help="max iterations")
246 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
247 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
248 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"])
249 | parser.add_argument('--dset', type=str, default='imagenet_caltech', choices=['imagenet_caltech'])
250 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
251 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101")
252 | parser.add_argument('--seed', type=int, default=2019, help="random seed")
253 | parser.add_argument('--epsilon', type=float, default=1e-5)
254 | parser.add_argument('--gent', type=bool, default=False)
255 | parser.add_argument('--ent', type=bool, default=True)
256 | parser.add_argument('--threshold', type=int, default=30)
257 |
258 | parser.add_argument('--cls_par', type=float, default=0.3)
259 | parser.add_argument('--ent_par', type=float, default=1.0)
260 | parser.add_argument('--output', type=str, default='seed')
261 | parser.add_argument('--da', type=str, default='pda', choices=['pda'])
262 | parser.add_argument('--issave', type=bool, default=True)
263 | parser.add_argument('--lr_decay1', type=float, default=0.1)
264 |
265 | args = parser.parse_args()
266 |
267 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
268 | SEED = args.seed
269 | torch.manual_seed(SEED)
270 | torch.cuda.manual_seed(SEED)
271 | np.random.seed(SEED)
272 | random.seed(SEED)
273 | # torch.backends.cudnn.deterministic = True
274 |
275 | args.class_num = 1000
276 | folder = './data/'
277 | if args.da == 'pda':
278 | args.t_dset_path = folder + args.dset + '/' + 'caltech_84' + '_list.txt'
279 | args.test_dset_path = args.t_dset_path
280 |
281 | args.output_dir = osp.join(args.output, args.da, args.dset)
282 | args.name = args.dset
283 |
284 | if not osp.exists(args.output_dir):
285 | os.system('mkdir -p ' + args.output_dir)
286 | if not osp.exists(args.output_dir):
287 | os.mkdir(args.output_dir)
288 |
289 | args.savename = 'par_' + str(args.cls_par)
290 | if args.da == 'pda':
291 | args.savename = 'par_' + str(args.cls_par) + '_thr' + str(args.threshold)
292 | args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w')
293 | args.out_file.write(print_args(args)+'\n')
294 | args.out_file.flush()
295 | train_target(args)
--------------------------------------------------------------------------------
/object/image_source.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import torchvision
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torchvision import transforms
10 | import network, loss
11 | from torch.utils.data import DataLoader
12 | from data_list import ImageList
13 | import random, pdb, math, copy
14 | from tqdm import tqdm
15 | from loss import CrossEntropyLabelSmooth
16 | from scipy.spatial.distance import cdist
17 | from sklearn.metrics import confusion_matrix
18 | from sklearn.cluster import KMeans
19 |
20 | def op_copy(optimizer):
21 | for param_group in optimizer.param_groups:
22 | param_group['lr0'] = param_group['lr']
23 | return optimizer
24 |
25 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
26 | decay = (1 + gamma * iter_num / max_iter) ** (-power)
27 | for param_group in optimizer.param_groups:
28 | param_group['lr'] = param_group['lr0'] * decay
29 | param_group['weight_decay'] = 1e-3
30 | param_group['momentum'] = 0.9
31 | param_group['nesterov'] = True
32 | return optimizer
33 |
34 | def image_train(resize_size=256, crop_size=224, alexnet=False):
35 | if not alexnet:
36 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
37 | std=[0.229, 0.224, 0.225])
38 | else:
39 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
40 | return transforms.Compose([
41 | transforms.Resize((resize_size, resize_size)),
42 | transforms.RandomCrop(crop_size),
43 | transforms.RandomHorizontalFlip(),
44 | transforms.ToTensor(),
45 | normalize
46 | ])
47 |
48 | def image_test(resize_size=256, crop_size=224, alexnet=False):
49 | if not alexnet:
50 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
51 | std=[0.229, 0.224, 0.225])
52 | else:
53 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
54 | return transforms.Compose([
55 | transforms.Resize((resize_size, resize_size)),
56 | transforms.CenterCrop(crop_size),
57 | transforms.ToTensor(),
58 | normalize
59 | ])
60 |
61 | def data_load(args):
62 | ## prepare data
63 | dsets = {}
64 | dset_loaders = {}
65 | train_bs = args.batch_size
66 | txt_src = open(args.s_dset_path).readlines()
67 | txt_test = open(args.test_dset_path).readlines()
68 |
69 | if not args.da == 'uda':
70 | label_map_s = {}
71 | for i in range(len(args.src_classes)):
72 | label_map_s[args.src_classes[i]] = i
73 |
74 | new_src = []
75 | for i in range(len(txt_src)):
76 | rec = txt_src[i]
77 | reci = rec.strip().split(' ')
78 | if int(reci[1]) in args.src_classes:
79 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'
80 | new_src.append(line)
81 | txt_src = new_src.copy()
82 |
83 | new_tar = []
84 | for i in range(len(txt_test)):
85 | rec = txt_test[i]
86 | reci = rec.strip().split(' ')
87 | if int(reci[1]) in args.tar_classes:
88 | if int(reci[1]) in args.src_classes:
89 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'
90 | new_tar.append(line)
91 | else:
92 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n'
93 | new_tar.append(line)
94 | txt_test = new_tar.copy()
95 |
96 | if args.trte == "val":
97 | dsize = len(txt_src)
98 | tr_size = int(0.9*dsize)
99 | # print(dsize, tr_size, dsize - tr_size)
100 | tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])
101 | else:
102 | dsize = len(txt_src)
103 | tr_size = int(0.9*dsize)
104 | _, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])
105 | tr_txt = txt_src
106 |
107 | dsets["source_tr"] = ImageList(tr_txt, transform=image_train())
108 | dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
109 | dsets["source_te"] = ImageList(te_txt, transform=image_test())
110 | dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
111 | dsets["test"] = ImageList(txt_test, transform=image_test())
112 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=True, num_workers=args.worker, drop_last=False)
113 |
114 | return dset_loaders
115 |
116 | def cal_acc(loader, netF, netB, netC, flag=False):
117 | start_test = True
118 | with torch.no_grad():
119 | iter_test = iter(loader)
120 | for i in range(len(loader)):
121 | data = iter_test.next()
122 | inputs = data[0]
123 | labels = data[1]
124 | inputs = inputs.cuda()
125 | outputs = netC(netB(netF(inputs)))
126 | if start_test:
127 | all_output = outputs.float().cpu()
128 | all_label = labels.float()
129 | start_test = False
130 | else:
131 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
132 | all_label = torch.cat((all_label, labels.float()), 0)
133 |
134 | all_output = nn.Softmax(dim=1)(all_output)
135 | _, predict = torch.max(all_output, 1)
136 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
137 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item()
138 |
139 | if flag:
140 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
141 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100
142 | aacc = acc.mean()
143 | aa = [str(np.round(i, 2)) for i in acc]
144 | acc = ' '.join(aa)
145 | return aacc, acc
146 | else:
147 | return accuracy*100, mean_ent
148 |
149 | def cal_acc_oda(loader, netF, netB, netC):
150 | start_test = True
151 | with torch.no_grad():
152 | iter_test = iter(loader)
153 | for i in range(len(loader)):
154 | data = iter_test.next()
155 | inputs = data[0]
156 | labels = data[1]
157 | inputs = inputs.cuda()
158 | outputs = netC(netB(netF(inputs)))
159 | if start_test:
160 | all_output = outputs.float().cpu()
161 | all_label = labels.float()
162 | start_test = False
163 | else:
164 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
165 | all_label = torch.cat((all_label, labels.float()), 0)
166 |
167 | all_output = nn.Softmax(dim=1)(all_output)
168 | _, predict = torch.max(all_output, 1)
169 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) / np.log(args.class_num)
170 | ent = ent.float().cpu()
171 | initc = np.array([[0], [1]])
172 | kmeans = KMeans(n_clusters=2, random_state=0, init=initc, n_init=1).fit(ent.reshape(-1,1))
173 | threshold = (kmeans.cluster_centers_).mean()
174 |
175 | predict[ent>threshold] = args.class_num
176 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
177 | matrix = matrix[np.unique(all_label).astype(int),:]
178 |
179 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100
180 | unknown_acc = acc[-1:].item()
181 |
182 | return np.mean(acc[:-1]), np.mean(acc), unknown_acc
183 | # return np.mean(acc), np.mean(acc[:-1])
184 |
185 | def train_source(args):
186 | dset_loaders = data_load(args)
187 | ## set base network
188 | if args.net[0:3] == 'res':
189 | netF = network.ResBase(res_name=args.net).cuda()
190 | elif args.net[0:3] == 'vgg':
191 | netF = network.VGGBase(vgg_name=args.net).cuda()
192 |
193 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
194 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
195 |
196 | param_group = []
197 | learning_rate = args.lr
198 | for k, v in netF.named_parameters():
199 | param_group += [{'params': v, 'lr': learning_rate*0.1}]
200 | for k, v in netB.named_parameters():
201 | param_group += [{'params': v, 'lr': learning_rate}]
202 | for k, v in netC.named_parameters():
203 | param_group += [{'params': v, 'lr': learning_rate}]
204 | optimizer = optim.SGD(param_group)
205 | optimizer = op_copy(optimizer)
206 |
207 | acc_init = 0
208 | max_iter = args.max_epoch * len(dset_loaders["source_tr"])
209 | interval_iter = max_iter // 10
210 | iter_num = 0
211 |
212 | netF.train()
213 | netB.train()
214 | netC.train()
215 |
216 | while iter_num < max_iter:
217 | try:
218 | inputs_source, labels_source = iter_source.next()
219 | except:
220 | iter_source = iter(dset_loaders["source_tr"])
221 | inputs_source, labels_source = iter_source.next()
222 |
223 | if inputs_source.size(0) == 1:
224 | continue
225 |
226 | iter_num += 1
227 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
228 |
229 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda()
230 | outputs_source = netC(netB(netF(inputs_source)))
231 | classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source)
232 |
233 | optimizer.zero_grad()
234 | classifier_loss.backward()
235 | optimizer.step()
236 |
237 | if iter_num % interval_iter == 0 or iter_num == max_iter:
238 | netF.eval()
239 | netB.eval()
240 | netC.eval()
241 | if args.dset=='VISDA-C':
242 | acc_s_te, acc_list = cal_acc(dset_loaders['source_te'], netF, netB, netC, True)
243 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te) + '\n' + acc_list
244 | else:
245 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC, False)
246 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te)
247 | args.out_file.write(log_str + '\n')
248 | args.out_file.flush()
249 | print(log_str+'\n')
250 |
251 | if acc_s_te >= acc_init:
252 | acc_init = acc_s_te
253 | best_netF = copy.deepcopy(netF.state_dict())
254 | best_netB = copy.deepcopy(netB.state_dict())
255 | best_netC = copy.deepcopy(netC.state_dict())
256 |
257 | netF.train()
258 | netB.train()
259 | netC.train()
260 |
261 | torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt"))
262 | torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt"))
263 | torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt"))
264 |
265 | return netF, netB, netC
266 |
267 | def test_target(args):
268 | dset_loaders = data_load(args)
269 | ## set base network
270 | if args.net[0:3] == 'res':
271 | netF = network.ResBase(res_name=args.net).cuda()
272 | elif args.net[0:3] == 'vgg':
273 | netF = network.VGGBase(vgg_name=args.net).cuda()
274 |
275 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
276 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
277 |
278 | args.modelpath = args.output_dir_src + '/source_F.pt'
279 | netF.load_state_dict(torch.load(args.modelpath))
280 | args.modelpath = args.output_dir_src + '/source_B.pt'
281 | netB.load_state_dict(torch.load(args.modelpath))
282 | args.modelpath = args.output_dir_src + '/source_C.pt'
283 | netC.load_state_dict(torch.load(args.modelpath))
284 | netF.eval()
285 | netB.eval()
286 | netC.eval()
287 |
288 | if args.da == 'oda':
289 | acc_os1, acc_os2, acc_unknown = cal_acc_oda(dset_loaders['test'], netF, netB, netC)
290 | log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format(args.trte, args.name, acc_os2, acc_os1, acc_unknown)
291 | else:
292 | if args.dset=='VISDA-C':
293 | acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True)
294 | log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format(args.trte, args.name, acc) + '\n' + acc_list
295 | else:
296 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False)
297 | log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format(args.trte, args.name, acc)
298 |
299 | args.out_file.write(log_str)
300 | args.out_file.flush()
301 | print(log_str)
302 |
303 | def print_args(args):
304 | s = "==========================================\n"
305 | for arg, content in args.__dict__.items():
306 | s += "{}:{}\n".format(arg, content)
307 | return s
308 |
309 | if __name__ == "__main__":
310 | parser = argparse.ArgumentParser(description='SHOT')
311 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
312 | parser.add_argument('--s', type=int, default=0, help="source")
313 | parser.add_argument('--t', type=int, default=1, help="target")
314 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations")
315 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
316 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
317 | parser.add_argument('--dset', type=str, default='office-home', choices=['VISDA-C', 'office', 'office-home', 'office-caltech'])
318 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
319 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101")
320 | parser.add_argument('--seed', type=int, default=2020, help="random seed")
321 | parser.add_argument('--bottleneck', type=int, default=256)
322 | parser.add_argument('--epsilon', type=float, default=1e-5)
323 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
324 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
325 | parser.add_argument('--smooth', type=float, default=0.1)
326 | parser.add_argument('--output', type=str, default='san')
327 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda', 'oda'])
328 | parser.add_argument('--trte', type=str, default='val', choices=['full', 'val'])
329 | args = parser.parse_args()
330 |
331 | if args.dset == 'office-home':
332 | names = ['Art', 'Clipart', 'Product', 'RealWorld']
333 | args.class_num = 65
334 | if args.dset == 'office':
335 | names = ['amazon', 'dslr', 'webcam']
336 | args.class_num = 31
337 | if args.dset == 'VISDA-C':
338 | names = ['train', 'validation']
339 | args.class_num = 12
340 | if args.dset == 'office-caltech':
341 | names = ['amazon', 'caltech', 'dslr', 'webcam']
342 | args.class_num = 10
343 |
344 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
345 | SEED = args.seed
346 | torch.manual_seed(SEED)
347 | torch.cuda.manual_seed(SEED)
348 | np.random.seed(SEED)
349 | random.seed(SEED)
350 | # torch.backends.cudnn.deterministic = True
351 |
352 | folder = './data/'
353 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt'
354 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
355 |
356 | if args.dset == 'office-home':
357 | if args.da == 'pda':
358 | args.class_num = 65
359 | args.src_classes = [i for i in range(65)]
360 | args.tar_classes = [i for i in range(25)]
361 | if args.da == 'oda':
362 | args.class_num = 25
363 | args.src_classes = [i for i in range(25)]
364 | args.tar_classes = [i for i in range(65)]
365 |
366 | args.output_dir_src = osp.join(args.output, args.da, args.dset, names[args.s][0].upper())
367 | args.name_src = names[args.s][0].upper()
368 | if not osp.exists(args.output_dir_src):
369 | os.system('mkdir -p ' + args.output_dir_src)
370 | if not osp.exists(args.output_dir_src):
371 | os.mkdir(args.output_dir_src)
372 |
373 | args.out_file = open(osp.join(args.output_dir_src, 'log.txt'), 'w')
374 | args.out_file.write(print_args(args)+'\n')
375 | args.out_file.flush()
376 | train_source(args)
377 |
378 | args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w')
379 | for i in range(len(names)):
380 | if i == args.s:
381 | continue
382 | args.t = i
383 | args.name = names[args.s][0].upper() + names[args.t][0].upper()
384 |
385 | folder = '/Checkpoint/liangjian/tran/data/'
386 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt'
387 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
388 |
389 | if args.dset == 'office-home':
390 | if args.da == 'pda':
391 | args.class_num = 65
392 | args.src_classes = [i for i in range(65)]
393 | args.tar_classes = [i for i in range(25)]
394 | if args.da == 'oda':
395 | args.class_num = 25
396 | args.src_classes = [i for i in range(25)]
397 | args.tar_classes = [i for i in range(65)]
398 |
399 | test_target(args)
400 |
--------------------------------------------------------------------------------
/object/image_target.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import torchvision
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torchvision import transforms
10 | import network, loss
11 | from torch.utils.data import DataLoader
12 | from data_list import ImageList, ImageList_idx
13 | import random, pdb, math, copy
14 | from tqdm import tqdm
15 | from scipy.spatial.distance import cdist
16 | from sklearn.metrics import confusion_matrix
17 |
18 | def op_copy(optimizer):
19 | for param_group in optimizer.param_groups:
20 | param_group['lr0'] = param_group['lr']
21 | return optimizer
22 |
23 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
24 | decay = (1 + gamma * iter_num / max_iter) ** (-power)
25 | for param_group in optimizer.param_groups:
26 | param_group['lr'] = param_group['lr0'] * decay
27 | param_group['weight_decay'] = 1e-3
28 | param_group['momentum'] = 0.9
29 | param_group['nesterov'] = True
30 | return optimizer
31 |
32 | def image_train(resize_size=256, crop_size=224, alexnet=False):
33 | if not alexnet:
34 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
35 | std=[0.229, 0.224, 0.225])
36 | else:
37 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
38 | return transforms.Compose([
39 | transforms.Resize((resize_size, resize_size)),
40 | transforms.RandomCrop(crop_size),
41 | transforms.RandomHorizontalFlip(),
42 | transforms.ToTensor(),
43 | normalize
44 | ])
45 |
46 | def image_test(resize_size=256, crop_size=224, alexnet=False):
47 | if not alexnet:
48 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
49 | std=[0.229, 0.224, 0.225])
50 | else:
51 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
52 | return transforms.Compose([
53 | transforms.Resize((resize_size, resize_size)),
54 | transforms.CenterCrop(crop_size),
55 | transforms.ToTensor(),
56 | normalize
57 | ])
58 |
59 | def data_load(args):
60 | ## prepare data
61 | dsets = {}
62 | dset_loaders = {}
63 | train_bs = args.batch_size
64 | txt_tar = open(args.t_dset_path).readlines()
65 | txt_test = open(args.test_dset_path).readlines()
66 |
67 | if not args.da == 'uda':
68 | label_map_s = {}
69 | for i in range(len(args.src_classes)):
70 | label_map_s[args.src_classes[i]] = i
71 |
72 | new_tar = []
73 | for i in range(len(txt_tar)):
74 | rec = txt_tar[i]
75 | reci = rec.strip().split(' ')
76 | if int(reci[1]) in args.tar_classes:
77 | if int(reci[1]) in args.src_classes:
78 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'
79 | new_tar.append(line)
80 | else:
81 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n'
82 | new_tar.append(line)
83 | txt_tar = new_tar.copy()
84 | txt_test = txt_tar.copy()
85 |
86 | dsets["target"] = ImageList_idx(txt_tar, transform=image_train())
87 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
88 | dsets["test"] = ImageList_idx(txt_test, transform=image_test())
89 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False)
90 |
91 | return dset_loaders
92 |
93 | def cal_acc(loader, netF, netB, netC, flag=False):
94 | start_test = True
95 | with torch.no_grad():
96 | iter_test = iter(loader)
97 | for i in range(len(loader)):
98 | data = iter_test.next()
99 | inputs = data[0]
100 | labels = data[1]
101 | inputs = inputs.cuda()
102 | outputs = netC(netB(netF(inputs)))
103 | if start_test:
104 | all_output = outputs.float().cpu()
105 | all_label = labels.float()
106 | start_test = False
107 | else:
108 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
109 | all_label = torch.cat((all_label, labels.float()), 0)
110 | _, predict = torch.max(all_output, 1)
111 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
112 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item()
113 |
114 | if flag:
115 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
116 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100
117 | aacc = acc.mean()
118 | aa = [str(np.round(i, 2)) for i in acc]
119 | acc = ' '.join(aa)
120 | return aacc, acc
121 | else:
122 | return accuracy*100, mean_ent
123 |
124 | def train_target(args):
125 | dset_loaders = data_load(args)
126 | ## set base network
127 | if args.net[0:3] == 'res':
128 | netF = network.ResBase(res_name=args.net).cuda()
129 | elif args.net[0:3] == 'vgg':
130 | netF = network.VGGBase(vgg_name=args.net).cuda()
131 |
132 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
133 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
134 |
135 | modelpath = args.output_dir_src + '/source_F.pt'
136 | netF.load_state_dict(torch.load(modelpath))
137 | modelpath = args.output_dir_src + '/source_B.pt'
138 | netB.load_state_dict(torch.load(modelpath))
139 | modelpath = args.output_dir_src + '/source_C.pt'
140 | netC.load_state_dict(torch.load(modelpath))
141 | netC.eval()
142 | for k, v in netC.named_parameters():
143 | v.requires_grad = False
144 |
145 | param_group = []
146 | for k, v in netF.named_parameters():
147 | if args.lr_decay1 > 0:
148 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
149 | else:
150 | v.requires_grad = False
151 | for k, v in netB.named_parameters():
152 | if args.lr_decay2 > 0:
153 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
154 | else:
155 | v.requires_grad = False
156 |
157 | optimizer = optim.SGD(param_group)
158 | optimizer = op_copy(optimizer)
159 |
160 | max_iter = args.max_epoch * len(dset_loaders["target"])
161 | interval_iter = max_iter // args.interval
162 | iter_num = 0
163 |
164 | while iter_num < max_iter:
165 | try:
166 | inputs_test, _, tar_idx = iter_test.next()
167 | except:
168 | iter_test = iter(dset_loaders["target"])
169 | inputs_test, _, tar_idx = iter_test.next()
170 |
171 | if inputs_test.size(0) == 1:
172 | continue
173 |
174 | if iter_num % interval_iter == 0 and args.cls_par > 0:
175 | netF.eval()
176 | netB.eval()
177 | mem_label = obtain_label(dset_loaders['test'], netF, netB, netC, args)
178 | mem_label = torch.from_numpy(mem_label).cuda()
179 | netF.train()
180 | netB.train()
181 |
182 | inputs_test = inputs_test.cuda()
183 |
184 | iter_num += 1
185 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
186 |
187 | features_test = netB(netF(inputs_test))
188 | outputs_test = netC(features_test)
189 |
190 | if args.cls_par > 0:
191 | pred = mem_label[tar_idx]
192 | classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred)
193 | classifier_loss *= args.cls_par
194 | if iter_num < interval_iter and args.dset == "VISDA-C":
195 | classifier_loss *= 0
196 | else:
197 | classifier_loss = torch.tensor(0.0).cuda()
198 |
199 | if args.ent:
200 | softmax_out = nn.Softmax(dim=1)(outputs_test)
201 | entropy_loss = torch.mean(loss.Entropy(softmax_out))
202 | if args.gent:
203 | msoftmax = softmax_out.mean(dim=0)
204 | gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon))
205 | entropy_loss -= gentropy_loss
206 | im_loss = entropy_loss * args.ent_par
207 | classifier_loss += im_loss
208 |
209 | optimizer.zero_grad()
210 | classifier_loss.backward()
211 | optimizer.step()
212 |
213 | if iter_num % interval_iter == 0 or iter_num == max_iter:
214 | netF.eval()
215 | netB.eval()
216 | if args.dset=='VISDA-C':
217 | acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True)
218 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list
219 | else:
220 | acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False)
221 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te)
222 |
223 | args.out_file.write(log_str + '\n')
224 | args.out_file.flush()
225 | print(log_str+'\n')
226 | netF.train()
227 | netB.train()
228 |
229 | if args.issave:
230 | torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt"))
231 | torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt"))
232 | torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt"))
233 |
234 | return netF, netB, netC
235 |
236 | def print_args(args):
237 | s = "==========================================\n"
238 | for arg, content in args.__dict__.items():
239 | s += "{}:{}\n".format(arg, content)
240 | return s
241 |
242 | def obtain_label(loader, netF, netB, netC, args):
243 | start_test = True
244 | with torch.no_grad():
245 | iter_test = iter(loader)
246 | for _ in range(len(loader)):
247 | data = iter_test.next()
248 | inputs = data[0]
249 | labels = data[1]
250 | inputs = inputs.cuda()
251 | feas = netB(netF(inputs))
252 | outputs = netC(feas)
253 | if start_test:
254 | all_fea = feas.float().cpu()
255 | all_output = outputs.float().cpu()
256 | all_label = labels.float()
257 | start_test = False
258 | else:
259 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0)
260 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
261 | all_label = torch.cat((all_label, labels.float()), 0)
262 |
263 | all_output = nn.Softmax(dim=1)(all_output)
264 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1)
265 | unknown_weight = 1 - ent / np.log(args.class_num)
266 | _, predict = torch.max(all_output, 1)
267 |
268 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
269 | if args.distance == 'cosine':
270 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
271 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
272 |
273 | all_fea = all_fea.float().cpu().numpy()
274 | K = all_output.size(1)
275 | aff = all_output.float().cpu().numpy()
276 |
277 | for _ in range(2):
278 | initc = aff.transpose().dot(all_fea)
279 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
280 | cls_count = np.eye(K)[predict].sum(axis=0)
281 | labelset = np.where(cls_count>args.threshold)
282 | labelset = labelset[0]
283 |
284 | dd = cdist(all_fea, initc[labelset], args.distance)
285 | pred_label = dd.argmin(axis=1)
286 | predict = labelset[pred_label]
287 |
288 | aff = np.eye(K)[predict]
289 |
290 | acc = np.sum(predict == all_label.float().numpy()) / len(all_fea)
291 | log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100)
292 |
293 | args.out_file.write(log_str + '\n')
294 | args.out_file.flush()
295 | print(log_str+'\n')
296 |
297 | return predict.astype('int')
298 |
299 |
300 | if __name__ == "__main__":
301 | parser = argparse.ArgumentParser(description='SHOT')
302 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
303 | parser.add_argument('--s', type=int, default=0, help="source")
304 | parser.add_argument('--t', type=int, default=1, help="target")
305 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations")
306 | parser.add_argument('--interval', type=int, default=15)
307 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
308 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
309 | parser.add_argument('--dset', type=str, default='office-home', choices=['VISDA-C', 'office', 'office-home', 'office-caltech'])
310 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
311 | parser.add_argument('--net', type=str, default='resnet50', help="alexnet, vgg16, resnet50, res101")
312 | parser.add_argument('--seed', type=int, default=2020, help="random seed")
313 |
314 | parser.add_argument('--gent', type=bool, default=True)
315 | parser.add_argument('--ent', type=bool, default=True)
316 | parser.add_argument('--threshold', type=int, default=0)
317 | parser.add_argument('--cls_par', type=float, default=0.3)
318 | parser.add_argument('--ent_par', type=float, default=1.0)
319 | parser.add_argument('--lr_decay1', type=float, default=0.1)
320 | parser.add_argument('--lr_decay2', type=float, default=1.0)
321 |
322 | parser.add_argument('--bottleneck', type=int, default=256)
323 | parser.add_argument('--epsilon', type=float, default=1e-5)
324 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
325 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
326 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"])
327 | parser.add_argument('--output', type=str, default='san')
328 | parser.add_argument('--output_src', type=str, default='san')
329 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda'])
330 | parser.add_argument('--issave', type=bool, default=True)
331 | args = parser.parse_args()
332 |
333 | if args.dset == 'office-home':
334 | names = ['Art', 'Clipart', 'Product', 'RealWorld']
335 | args.class_num = 65
336 | if args.dset == 'office':
337 | names = ['amazon', 'dslr', 'webcam']
338 | args.class_num = 31
339 | if args.dset == 'VISDA-C':
340 | names = ['train', 'validation']
341 | args.class_num = 12
342 | if args.dset == 'office-caltech':
343 | names = ['amazon', 'caltech', 'dslr', 'webcam']
344 | args.class_num = 10
345 |
346 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
347 | SEED = args.seed
348 | torch.manual_seed(SEED)
349 | torch.cuda.manual_seed(SEED)
350 | np.random.seed(SEED)
351 | random.seed(SEED)
352 | # torch.backends.cudnn.deterministic = True
353 |
354 | for i in range(len(names)):
355 | if i == args.s:
356 | continue
357 | args.t = i
358 |
359 | folder = './data/'
360 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt'
361 | args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
362 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
363 |
364 | if args.dset == 'office-home':
365 | if args.da == 'pda':
366 | args.class_num = 65
367 | args.src_classes = [i for i in range(65)]
368 | args.tar_classes = [i for i in range(25)]
369 |
370 | args.output_dir_src = osp.join(args.output_src, args.da, args.dset, names[args.s][0].upper())
371 | args.output_dir = osp.join(args.output, args.da, args.dset, names[args.s][0].upper()+names[args.t][0].upper())
372 | args.name = names[args.s][0].upper()+names[args.t][0].upper()
373 |
374 | if not osp.exists(args.output_dir):
375 | os.system('mkdir -p ' + args.output_dir)
376 | if not osp.exists(args.output_dir):
377 | os.mkdir(args.output_dir)
378 |
379 | args.savename = 'par_' + str(args.cls_par)
380 | if args.da == 'pda':
381 | args.gent = ''
382 | args.savename = 'par_' + str(args.cls_par) + '_thr' + str(args.threshold)
383 | args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w')
384 | args.out_file.write(print_args(args)+'\n')
385 | args.out_file.flush()
386 | train_target(args)
--------------------------------------------------------------------------------
/object/image_target_oda.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import torchvision
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torchvision import transforms
10 | import network, loss
11 | from torch.utils.data import DataLoader
12 | from data_list import ImageList, ImageList_idx
13 | import random, pdb, math, copy
14 | from tqdm import tqdm
15 | from scipy.spatial.distance import cdist
16 | from sklearn.metrics import confusion_matrix
17 | from sklearn.cluster import KMeans
18 |
19 | def op_copy(optimizer):
20 | for param_group in optimizer.param_groups:
21 | param_group['lr0'] = param_group['lr']
22 | return optimizer
23 |
24 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
25 | decay = (1 + gamma * iter_num / max_iter) ** (-power)
26 | for param_group in optimizer.param_groups:
27 | param_group['lr'] = param_group['lr0'] * decay
28 | param_group['weight_decay'] = 1e-3
29 | param_group['momentum'] = 0.9
30 | param_group['nesterov'] = True
31 | return optimizer
32 |
33 | def image_train(resize_size=256, crop_size=224, alexnet=False):
34 | if not alexnet:
35 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
36 | std=[0.229, 0.224, 0.225])
37 | else:
38 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
39 | return transforms.Compose([
40 | transforms.Resize((resize_size, resize_size)),
41 | transforms.RandomCrop(crop_size),
42 | transforms.RandomHorizontalFlip(),
43 | transforms.ToTensor(),
44 | normalize
45 | ])
46 |
47 | def image_test(resize_size=256, crop_size=224, alexnet=False):
48 | if not alexnet:
49 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
50 | std=[0.229, 0.224, 0.225])
51 | else:
52 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
53 | return transforms.Compose([
54 | transforms.Resize((resize_size, resize_size)),
55 | transforms.CenterCrop(crop_size),
56 | transforms.ToTensor(),
57 | normalize
58 | ])
59 |
60 | def data_load(args):
61 | ## prepare data
62 | dsets = {}
63 | dset_loaders = {}
64 | train_bs = args.batch_size
65 | txt_src = open(args.s_dset_path).readlines()
66 | txt_tar = open(args.t_dset_path).readlines()
67 | txt_test = open(args.test_dset_path).readlines()
68 |
69 | if not args.da == 'uda':
70 | label_map_s = {}
71 | for i in range(len(args.src_classes)):
72 | label_map_s[args.src_classes[i]] = i
73 |
74 | new_tar = []
75 | for i in range(len(txt_tar)):
76 | rec = txt_tar[i]
77 | reci = rec.strip().split(' ')
78 | if int(reci[1]) in args.tar_classes:
79 | if int(reci[1]) in args.src_classes:
80 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'
81 | new_tar.append(line)
82 | else:
83 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n'
84 | new_tar.append(line)
85 | txt_tar = new_tar.copy()
86 | txt_test = txt_tar.copy()
87 |
88 | dsets["target"] = ImageList_idx(txt_tar, transform=image_train())
89 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
90 | dsets["test"] = ImageList(txt_test, transform=image_test())
91 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False)
92 |
93 | return dset_loaders
94 |
95 | def cal_acc(loader, netF, netB, netC, flag=False, threshold=0.1):
96 | start_test = True
97 | with torch.no_grad():
98 | iter_test = iter(loader)
99 | for i in range(len(loader)):
100 | data = iter_test.next()
101 | inputs = data[0]
102 | labels = data[1]
103 | inputs = inputs.cuda()
104 | outputs = netC(netB(netF(inputs)))
105 | if start_test:
106 | all_output = outputs.float().cpu()
107 | all_label = labels.float()
108 | start_test = False
109 | else:
110 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
111 | all_label = torch.cat((all_label, labels.float()), 0)
112 | _, predict = torch.max(all_output, 1)
113 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
114 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item()
115 |
116 | if flag:
117 | all_output = nn.Softmax(dim=1)(all_output)
118 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) / np.log(args.class_num)
119 |
120 | from sklearn.cluster import KMeans
121 | kmeans = KMeans(2, random_state=0).fit(ent.reshape(-1,1))
122 | labels = kmeans.predict(ent.reshape(-1,1))
123 |
124 | idx = np.where(labels==1)[0]
125 | iidx = 0
126 | if ent[idx].mean() > ent.mean():
127 | iidx = 1
128 | predict[np.where(labels==iidx)[0]] = args.class_num
129 |
130 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
131 | matrix = matrix[np.unique(all_label).astype(int),:]
132 |
133 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100
134 | unknown_acc = acc[-1:].item()
135 | return np.mean(acc[:-1]), np.mean(acc), unknown_acc
136 | else:
137 | return accuracy*100, mean_ent
138 |
139 | def print_args(args):
140 | s = "==========================================\n"
141 | for arg, content in args.__dict__.items():
142 | s += "{}:{}\n".format(arg, content)
143 | return s
144 |
145 | def train_target(args):
146 | dset_loaders = data_load(args)
147 | ## set base network
148 | if args.net[0:3] == 'res':
149 | netF = network.ResBase(res_name=args.net).cuda()
150 | elif args.net[0:3] == 'vgg':
151 | netF = network.VGGBase(vgg_name=args.net).cuda()
152 |
153 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
154 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
155 |
156 | args.modelpath = args.output_dir_src + '/source_F.pt'
157 | netF.load_state_dict(torch.load(args.modelpath))
158 | args.modelpath = args.output_dir_src + '/source_B.pt'
159 | netB.load_state_dict(torch.load(args.modelpath))
160 | args.modelpath = args.output_dir_src + '/source_C.pt'
161 | netC.load_state_dict(torch.load(args.modelpath))
162 | netC.eval()
163 | for k, v in netC.named_parameters():
164 | v.requires_grad = False
165 |
166 | param_group = []
167 | for k, v in netF.named_parameters():
168 | if args.lr_decay1 > 0:
169 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
170 | else:
171 | v.requires_grad = False
172 | for k, v in netB.named_parameters():
173 | if args.lr_decay2 > 0:
174 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
175 | else:
176 | v.requires_grad = False
177 |
178 | optimizer = optim.SGD(param_group)
179 | optimizer = op_copy(optimizer)
180 |
181 | tt = 0
182 | iter_num = 0
183 | max_iter = args.max_epoch * len(dset_loaders["target"])
184 | interval_iter = max_iter // args.interval
185 |
186 | while iter_num < max_iter:
187 | try:
188 | inputs_test, _, tar_idx = iter_test.next()
189 | except:
190 | iter_test = iter(dset_loaders["target"])
191 | inputs_test, _, tar_idx = iter_test.next()
192 |
193 | if inputs_test.size(0) == 1:
194 | continue
195 |
196 | if iter_num % interval_iter == 0:
197 | netF.eval()
198 | netB.eval()
199 | mem_label, ENT_THRESHOLD = obtain_label(dset_loaders['test'], netF, netB, netC, args)
200 | mem_label = torch.from_numpy(mem_label).cuda()
201 | netF.train()
202 | netB.train()
203 |
204 | inputs_test = inputs_test.cuda()
205 |
206 | iter_num += 1
207 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
208 |
209 | pred = mem_label[tar_idx]
210 | features_test = netB(netF(inputs_test))
211 | outputs_test = netC(features_test)
212 |
213 | softmax_out = nn.Softmax(dim=1)(outputs_test)
214 | outputs_test_known = outputs_test[pred < args.class_num, :]
215 | pred = pred[pred < args.class_num]
216 |
217 | if len(pred) == 0:
218 | print(tt)
219 | del features_test
220 | del outputs_test
221 | tt += 1
222 | continue
223 |
224 | if args.cls_par > 0:
225 | classifier_loss = nn.CrossEntropyLoss()(outputs_test_known, pred)
226 | classifier_loss *= args.cls_par
227 | else:
228 | classifier_loss = torch.tensor(0.0).cuda()
229 |
230 | if args.ent:
231 | softmax_out_known = nn.Softmax(dim=1)(outputs_test_known)
232 | entropy_loss = torch.mean(loss.Entropy(softmax_out_known))
233 | if args.gent:
234 | msoftmax = softmax_out.mean(dim=0)
235 | gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon))
236 | entropy_loss -= gentropy_loss
237 | classifier_loss += entropy_loss * args.ent_par
238 |
239 | optimizer.zero_grad()
240 | classifier_loss.backward()
241 | optimizer.step()
242 |
243 | if iter_num % interval_iter == 0 or iter_num == max_iter:
244 | netF.eval()
245 | netB.eval()
246 | acc_os1, acc_os2, acc_unknown = cal_acc(dset_loaders['test'], netF, netB, netC, True, ENT_THRESHOLD)
247 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format(args.name, iter_num, max_iter, acc_os2, acc_os1, acc_unknown)
248 | args.out_file.write(log_str + '\n')
249 | args.out_file.flush()
250 | print(log_str+'\n')
251 | netF.train()
252 | netB.train()
253 |
254 | if args.issave:
255 | torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt"))
256 | torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt"))
257 | torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt"))
258 |
259 | return netF, netB, netC
260 |
261 | def obtain_label(loader, netF, netB, netC, args):
262 | start_test = True
263 | with torch.no_grad():
264 | iter_test = iter(loader)
265 | for _ in range(len(loader)):
266 | data = iter_test.next()
267 | inputs = data[0]
268 | labels = data[1]
269 | inputs = inputs.cuda()
270 | feas = netB(netF(inputs))
271 | outputs = netC(feas)
272 | if start_test:
273 | all_fea = feas.float().cpu()
274 | all_output = outputs.float().cpu()
275 | all_label = labels.float()
276 | start_test = False
277 | else:
278 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0)
279 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
280 | all_label = torch.cat((all_label, labels.float()), 0)
281 |
282 | all_output = nn.Softmax(dim=1)(all_output)
283 | _, predict = torch.max(all_output, 1)
284 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
285 | if args.distance == 'cosine':
286 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
287 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
288 |
289 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) / np.log(args.class_num)
290 | ent = ent.float().cpu()
291 |
292 | from sklearn.cluster import KMeans
293 | kmeans = KMeans(2, random_state=0).fit(ent.reshape(-1,1))
294 | labels = kmeans.predict(ent.reshape(-1,1))
295 |
296 | idx = np.where(labels==1)[0]
297 | iidx = 0
298 | if ent[idx].mean() > ent.mean():
299 | iidx = 1
300 | known_idx = np.where(kmeans.labels_ != iidx)[0]
301 |
302 | all_fea = all_fea[known_idx,:]
303 | all_output = all_output[known_idx,:]
304 | predict = predict[known_idx]
305 | all_label_idx = all_label[known_idx]
306 | ENT_THRESHOLD = (kmeans.cluster_centers_).mean()
307 |
308 | all_fea = all_fea.float().cpu().numpy()
309 | K = all_output.size(1)
310 | aff = all_output.float().cpu().numpy()
311 | initc = aff.transpose().dot(all_fea)
312 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
313 | cls_count = np.eye(K)[predict].sum(axis=0)
314 | labelset = np.where(cls_count>args.threshold)
315 | labelset = labelset[0]
316 |
317 | dd = cdist(all_fea, initc[labelset], args.distance)
318 | pred_label = dd.argmin(axis=1)
319 | pred_label = labelset[pred_label]
320 |
321 | for round in range(1):
322 | aff = np.eye(K)[pred_label]
323 | initc = aff.transpose().dot(all_fea)
324 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None])
325 | dd = cdist(all_fea, initc[labelset], args.distance)
326 | pred_label = dd.argmin(axis=1)
327 | pred_label = labelset[pred_label]
328 |
329 | guess_label = args.class_num * np.ones(len(all_label), )
330 | guess_label[known_idx] = pred_label
331 |
332 | acc = np.sum(guess_label == all_label.float().numpy()) / len(all_label_idx)
333 | log_str = 'Threshold = {:.2f}, Accuracy = {:.2f}% -> {:.2f}%'.format(ENT_THRESHOLD, accuracy*100, acc*100)
334 |
335 | return guess_label.astype('int'), ENT_THRESHOLD
336 |
337 |
338 | if __name__ == "__main__":
339 | parser = argparse.ArgumentParser(description='SHOT')
340 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
341 | parser.add_argument('--s', type=int, default=0, help="source")
342 | parser.add_argument('--t', type=int, default=1, help="target")
343 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations")
344 | parser.add_argument('--interval', type=int, default=15)
345 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
346 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
347 | parser.add_argument('--dset', type=str, default='office-home', choices=['office-home'])
348 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
349 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101")
350 | parser.add_argument('--seed', type=int, default=2020, help="random seed")
351 |
352 | parser.add_argument('--gent', type=bool, default=True)
353 | parser.add_argument('--ent', type=bool, default=True)
354 | parser.add_argument('--threshold', type=int, default=0)
355 | parser.add_argument('--cls_par', type=float, default=0.3)
356 | parser.add_argument('--ent_par', type=float, default=1.0)
357 | parser.add_argument('--lr_decay1', type=float, default=0.1)
358 | parser.add_argument('--lr_decay2', type=float, default=1.0)
359 |
360 | parser.add_argument('--bottleneck', type=int, default=256)
361 | parser.add_argument('--epsilon', type=float, default=1e-5)
362 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
363 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
364 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"])
365 | parser.add_argument('--output', type=str, default='san')
366 | parser.add_argument('--output_src', type=str, default='san')
367 | parser.add_argument('--da', type=str, default='oda', choices=['oda'])
368 | parser.add_argument('--issave', type=bool, default=True)
369 | args = parser.parse_args()
370 |
371 | if args.dset == 'office-home':
372 | names = ['Art', 'Clipart', 'Product', 'RealWorld']
373 | args.class_num = 65
374 |
375 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
376 | SEED = args.seed
377 | torch.manual_seed(SEED)
378 | torch.cuda.manual_seed(SEED)
379 | np.random.seed(SEED)
380 | random.seed(SEED)
381 | # torch.backends.cudnn.deterministic = True
382 |
383 | for i in range(len(names)):
384 | if i == args.s:
385 | continue
386 | args.t = i
387 |
388 | folder = './data/'
389 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt'
390 | args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
391 | args.test_dset_path = args.t_dset_path
392 |
393 | if args.dset == 'office-home':
394 | if args.da == 'oda':
395 | args.class_num = 25
396 | args.src_classes = [i for i in range(25)]
397 | args.tar_classes = [i for i in range(65)]
398 |
399 | args.output_dir_src = osp.join(args.output_src, args.da, args.dset, names[args.s][0].upper())
400 | args.output_dir = osp.join(args.output, args.da, args.dset, names[args.s][0].upper()+names[args.t][0].upper())
401 | args.name = names[args.s][0].upper()+names[args.t][0].upper()
402 |
403 | if not osp.exists(args.output_dir):
404 | os.system('mkdir -p ' + args.output_dir)
405 | if not osp.exists(args.output_dir):
406 | os.mkdir(args.output_dir)
407 |
408 | args.savename = 'par_' + str(args.cls_par)
409 | args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w')
410 | args.out_file.write(print_args(args)+'\n')
411 | args.out_file.flush()
412 | train_target(args)
--------------------------------------------------------------------------------
/object/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import math
6 | import torch.nn.functional as F
7 | import pdb
8 |
9 | def Entropy(input_):
10 | bs = input_.size(0)
11 | epsilon = 1e-5
12 | entropy = -input_ * torch.log(input_ + epsilon)
13 | entropy = torch.sum(entropy, dim=1)
14 | return entropy
15 |
16 | def grl_hook(coeff):
17 | def fun1(grad):
18 | return -coeff*grad.clone()
19 | return fun1
20 |
21 | def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None):
22 | softmax_output = input_list[1].detach()
23 | feature = input_list[0]
24 | if random_layer is None:
25 | op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1))
26 | ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1)))
27 | else:
28 | random_out = random_layer.forward([feature, softmax_output])
29 | ad_out = ad_net(random_out.view(-1, random_out.size(1)))
30 | batch_size = softmax_output.size(0) // 2
31 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda()
32 | if entropy is not None:
33 | entropy.register_hook(grl_hook(coeff))
34 | entropy = 1.0+torch.exp(-entropy)
35 | source_mask = torch.ones_like(entropy)
36 | source_mask[feature.size(0)//2:] = 0
37 | source_weight = entropy*source_mask
38 | target_mask = torch.ones_like(entropy)
39 | target_mask[0:feature.size(0)//2] = 0
40 | target_weight = entropy*target_mask
41 | weight = source_weight / torch.sum(source_weight).detach().item() + \
42 | target_weight / torch.sum(target_weight).detach().item()
43 | return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item()
44 | else:
45 | return nn.BCELoss()(ad_out, dc_target)
46 |
47 | def DANN(features, ad_net):
48 | ad_out = ad_net(features)
49 | batch_size = ad_out.size(0) // 2
50 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda()
51 | return nn.BCELoss()(ad_out, dc_target)
52 |
53 |
54 | class CrossEntropyLabelSmooth(nn.Module):
55 | """Cross entropy loss with label smoothing regularizer.
56 | Reference:
57 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
58 | Equation: y = (1 - epsilon) * y + epsilon / K.
59 | Args:
60 | num_classes (int): number of classes.
61 | epsilon (float): weight.
62 | """
63 |
64 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True):
65 | super(CrossEntropyLabelSmooth, self).__init__()
66 | self.num_classes = num_classes
67 | self.epsilon = epsilon
68 | self.use_gpu = use_gpu
69 | self.reduction = reduction
70 | self.logsoftmax = nn.LogSoftmax(dim=1)
71 |
72 | def forward(self, inputs, targets):
73 | """
74 | Args:
75 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
76 | targets: ground truth labels with shape (num_classes)
77 | """
78 | log_probs = self.logsoftmax(inputs)
79 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
80 | if self.use_gpu: targets = targets.cuda()
81 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
82 | loss = (- targets * log_probs).sum(dim=1)
83 | if self.reduction:
84 | return loss.mean()
85 | else:
86 | return loss
87 | return loss
--------------------------------------------------------------------------------
/object/network.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torchvision
5 | from torchvision import models
6 | from torch.autograd import Variable
7 | import math
8 | import torch.nn.utils.weight_norm as weightNorm
9 | from collections import OrderedDict
10 |
11 | def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0):
12 | return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha*iter_num / max_iter)) - (high - low) + low)
13 |
14 | def init_weights(m):
15 | classname = m.__class__.__name__
16 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
17 | nn.init.kaiming_uniform_(m.weight)
18 | nn.init.zeros_(m.bias)
19 | elif classname.find('BatchNorm') != -1:
20 | nn.init.normal_(m.weight, 1.0, 0.02)
21 | nn.init.zeros_(m.bias)
22 | elif classname.find('Linear') != -1:
23 | nn.init.xavier_normal_(m.weight)
24 | nn.init.zeros_(m.bias)
25 |
26 | vgg_dict = {"vgg11":models.vgg11, "vgg13":models.vgg13, "vgg16":models.vgg16, "vgg19":models.vgg19,
27 | "vgg11bn":models.vgg11_bn, "vgg13bn":models.vgg13_bn, "vgg16bn":models.vgg16_bn, "vgg19bn":models.vgg19_bn}
28 | class VGGBase(nn.Module):
29 | def __init__(self, vgg_name):
30 | super(VGGBase, self).__init__()
31 | model_vgg = vgg_dict[vgg_name](pretrained=True)
32 | self.features = model_vgg.features
33 | self.classifier = nn.Sequential()
34 | for i in range(6):
35 | self.classifier.add_module("classifier"+str(i), model_vgg.classifier[i])
36 | self.in_features = model_vgg.classifier[6].in_features
37 |
38 | def forward(self, x):
39 | x = self.features(x)
40 | x = x.view(x.size(0), -1)
41 | x = self.classifier(x)
42 | return x
43 |
44 | res_dict = {"resnet18":models.resnet18, "resnet34":models.resnet34, "resnet50":models.resnet50,
45 | "resnet101":models.resnet101, "resnet152":models.resnet152, "resnext50":models.resnext50_32x4d, "resnext101":models.resnext101_32x8d}
46 |
47 | class ResBase(nn.Module):
48 | def __init__(self, res_name):
49 | super(ResBase, self).__init__()
50 | model_resnet = res_dict[res_name](pretrained=True)
51 | self.conv1 = model_resnet.conv1
52 | self.bn1 = model_resnet.bn1
53 | self.relu = model_resnet.relu
54 | self.maxpool = model_resnet.maxpool
55 | self.layer1 = model_resnet.layer1
56 | self.layer2 = model_resnet.layer2
57 | self.layer3 = model_resnet.layer3
58 | self.layer4 = model_resnet.layer4
59 | self.avgpool = model_resnet.avgpool
60 | self.in_features = model_resnet.fc.in_features
61 |
62 | def forward(self, x):
63 | x = self.conv1(x)
64 | x = self.bn1(x)
65 | x = self.relu(x)
66 | x = self.maxpool(x)
67 | x = self.layer1(x)
68 | x = self.layer2(x)
69 | x = self.layer3(x)
70 | x = self.layer4(x)
71 | x = self.avgpool(x)
72 | x = x.view(x.size(0), -1)
73 | return x
74 |
75 | class feat_bottleneck(nn.Module):
76 | def __init__(self, feature_dim, bottleneck_dim=256, type="ori"):
77 | super(feat_bottleneck, self).__init__()
78 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True)
79 | self.relu = nn.ReLU(inplace=True)
80 | self.dropout = nn.Dropout(p=0.5)
81 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim)
82 | self.bottleneck.apply(init_weights)
83 | self.type = type
84 |
85 | def forward(self, x):
86 | x = self.bottleneck(x)
87 | if self.type == "bn":
88 | x = self.bn(x)
89 | return x
90 |
91 | class feat_classifier(nn.Module):
92 | def __init__(self, class_num, bottleneck_dim=256, type="linear"):
93 | super(feat_classifier, self).__init__()
94 | self.type = type
95 | if type == 'wn':
96 | self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight")
97 | self.fc.apply(init_weights)
98 | else:
99 | self.fc = nn.Linear(bottleneck_dim, class_num)
100 | self.fc.apply(init_weights)
101 |
102 | def forward(self, x):
103 | x = self.fc(x)
104 | return x
105 |
106 | class feat_classifier_two(nn.Module):
107 | def __init__(self, class_num, input_dim, bottleneck_dim=256):
108 | super(feat_classifier_two, self).__init__()
109 | self.type = type
110 | self.fc0 = nn.Linear(input_dim, bottleneck_dim)
111 | self.fc0.apply(init_weights)
112 | self.fc1 = nn.Linear(bottleneck_dim, class_num)
113 | self.fc1.apply(init_weights)
114 |
115 | def forward(self, x):
116 | x = self.fc0(x)
117 | x = self.fc1(x)
118 | return x
119 |
120 | class Res50(nn.Module):
121 | def __init__(self):
122 | super(Res50, self).__init__()
123 | model_resnet = models.resnet50(pretrained=True)
124 | self.conv1 = model_resnet.conv1
125 | self.bn1 = model_resnet.bn1
126 | self.relu = model_resnet.relu
127 | self.maxpool = model_resnet.maxpool
128 | self.layer1 = model_resnet.layer1
129 | self.layer2 = model_resnet.layer2
130 | self.layer3 = model_resnet.layer3
131 | self.layer4 = model_resnet.layer4
132 | self.avgpool = model_resnet.avgpool
133 | self.in_features = model_resnet.fc.in_features
134 | self.fc = model_resnet.fc
135 |
136 | def forward(self, x):
137 | x = self.conv1(x)
138 | x = self.bn1(x)
139 | x = self.relu(x)
140 | x = self.maxpool(x)
141 | x = self.layer1(x)
142 | x = self.layer2(x)
143 | x = self.layer3(x)
144 | x = self.layer4(x)
145 | x = self.avgpool(x)
146 | x = x.view(x.size(0), -1)
147 | y = self.fc(x)
148 | return x, y
--------------------------------------------------------------------------------
/object/run.sh:
--------------------------------------------------------------------------------
1 | # Table 3 A->D,W
2 | ~/anaconda3/envs/pytorch/bin/python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office --max_epoch 100 --s 0
3 | ~/anaconda3/envs/pytorch/bin/python image_target.py --cls_par 0.3 --da uda --dset office --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/
4 |
5 | # Table 4 A->C,P,R
6 | ~/anaconda3/envs/pytorch/bin/python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office-home --max_epoch 50 --s 0
7 | ~/anaconda3/envs/pytorch/bin/python image_target.py --cls_par 0.3 --da uda --dset office-home --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/
8 |
9 | # Table 5 VisDA-C
10 | ~/anaconda3/envs/pytorch/bin/python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset VISDA-C --net resnet101 --lr 1e-3 --max_epoch 10 --s 0
11 | ~/anaconda3/envs/pytorch/bin/python image_target.py --cls_par 0.3 --da uda --dset VISDA-C --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ --net resnet101 --lr 1e-3
12 |
13 | # Table 7 A->C,P,R (PDA)
14 | ~/anaconda3/envs/pytorch/bin/python image_source.py --trte val --output ckps/source/ --da pda --gpu_id 0 --dset office-home --max_epoch 50 --s 0
15 | ~/anaconda3/envs/pytorch/bin/python image_target.py --cls_par 0.3 --threshold 10 --da pda --dset office-home --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/
16 |
17 | # Table 7 A->C,P,R (ODA)
18 | ~/anaconda3/envs/pytorch/bin/python image_source.py --trte val --output ckps/source/ --da oda --gpu_id 0 --dset office-home --max_epoch 50 --s 0
19 | ~/anaconda3/envs/pytorch/bin/python image_target_oda.py --cls_par 0.3 --da oda --dset office-home --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/
20 |
21 |
22 | # Table 8 C,D,W->A (MSDA)
23 | ~/anaconda3/envs/pytorch/bin/python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office-caltech --net resnet101 --max_epoch 100 --s 1
24 | ~/anaconda3/envs/pytorch/bin/python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office-caltech --net resnet101 --max_epoch 100 --s 2
25 | ~/anaconda3/envs/pytorch/bin/python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office-caltech --net resnet101 --max_epoch 100 --s 3
26 |
27 | ~/anaconda3/envs/pytorch/bin/python image_target.py --cls_par 0.3 --da uda --dset office-caltech --net resnet101 --gpu_id 0 --s 1 --output_src ckps/source/ --output ckps/target/
28 | ~/anaconda3/envs/pytorch/bin/python image_target.py --cls_par 0.3 --da uda --dset office-caltech --net resnet101 --gpu_id 0 --s 2 --output_src ckps/source/ --output ckps/target/
29 | ~/anaconda3/envs/pytorch/bin/python image_target.py --cls_par 0.3 --da uda --dset office-caltech --net resnet101 --gpu_id 0 --s 3 --output_src ckps/source/ --output ckps/target/
30 |
31 | ~/anaconda3/envs/pytorch/bin/python image_multisource.py --cls_par 0.3 --da uda --dset office-caltech --gpu_id 0 --t 0 --output_src ckps/source/ --output ckps/target/
32 |
33 | # Table 8 A->(C,D,W)(MTDA)
34 | ~/anaconda3/envs/pytorch/bin/python image_multitarget.py --cls_par 0.3 --da uda --dset office-caltech --net resnet101 --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/
35 |
36 |
37 | # Table 9 ImageNet->Caltech(PDA)
38 | ~/anaconda3/envs/pytorch/bin/python image_pretrained.py --gpu_id 0 --output ckps/target --cls_par 0.3
39 |
40 |
--------------------------------------------------------------------------------
/pretrained-models.md:
--------------------------------------------------------------------------------
1 | https://drive.google.com/drive/folders/1Hn3MXbwQF-A6UTBZG3L3ZBiwSrxctB35?usp=sharing
2 |
3 | All the pre-trained source models are provided here.
4 |
5 | | source (resnet50)\ seed | 2019 | 2020 | 2021 |
6 | | :---------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
7 | | Amazon | [pretrained source model](https://drive.google.com/drive/folders/1tR5nzN6GSSzwJBQpeMNdwlFmgudf1LMt) | [pretrained source model](https://drive.google.com/drive/folders/12Tt9xjoCPoouNvxyaYefjA8utu_Ns15o) | [pretrained source model](https://drive.google.com/drive/folders/1Ky-dryAkIFanjZG8zvpwtFKqcKfXoJwX) |
8 | | Dslr | [pretrained source model](https://drive.google.com/drive/folders/1gyRALSpKlPPBtj8fpk722s3JTEjD2JR_) | [pretrained source model](https://drive.google.com/drive/folders/1EO2ZN4fuWEM5uH0yowZpxqIbAm3G0Qgf) | [pretrained source model](https://drive.google.com/drive/folders/1cPUKwimnK4dfT4K-FbjgRnLsD1iqjysA) |
9 | | Webcam | [pretrained source model](https://drive.google.com/drive/folders/1P4GH-BOoFoVWRqrV2h5sZhSvwmetVuWU) | [pretrained source model](https://drive.google.com/drive/folders/1EVxejkRJAvgdR_PleWo6WC42yVusAbae) | [pretrained source model](https://drive.google.com/drive/folders/1US_yJD0dDubyjKT2vXOIVyZXkQJ3WbU7) |
10 |
11 | | source (resnet50)\ seed | 2019 | 2020 | 2021 |
12 | | :---------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
13 | | Art | [pretrained source model](https://drive.google.com/drive/folders/1t44cr406AKNwhMg0TmHWyXKoUyOAZhKL) | [pretrained source model](https://drive.google.com/drive/folders/1stSi6Lx5T-PRp-Hxc29wdvUdlXjPMiDZ) | [pretrained source model](https://drive.google.com/drive/folders/1xEkqFyTDoj5rBmf-XnnE5uuiNMVdz28W) |
14 | | Clipart | [pretrained source model](https://drive.google.com/drive/folders/1K7NXVqKwCG0HZlYPGLqa0Klpa47HmXcT) | [pretrained source model](https://drive.google.com/drive/folders/1mZK0v1XlocKWezvd5bdrM28_6hI5A543) | [pretrained source model](https://drive.google.com/drive/folders/1HrsVlb5KnBmxGQ-ZyepNKGyZr4wRSzyq) |
15 | | Product | [pretrained source model](https://drive.google.com/drive/folders/18-JZDSyrahcSx4IV-H2lPiLptgWmOfFy) | [pretrained source model](https://drive.google.com/drive/folders/1Liyf9VGepW2ulBp7EHWjN6ho14S-Q_3i) | [pretrained source model](https://drive.google.com/drive/folders/1ej81eKV8gBfos4byUjg9TZSAez13RkbO) |
16 | | RealWorld | [pretrained source model](https://drive.google.com/drive/folders/1f_s4i3l1HZl2HrovSQBJpwamFqsZjBEE) | [pretrained source model](https://drive.google.com/drive/folders/1jn-pi2IIWIbVQ_cMXCKd3UwuCiDrQnz3) | [pretrained source model](https://drive.google.com/drive/folders/18RTAq7wlyloZn00QR03DQvmKG0VX63rg) |
17 |
18 | | source (resnet101)\ seed | 2019 | 2020 | 2021 |
19 | | :----------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
20 | | train | [pretrained source model](https://drive.google.com/drive/folders/1Dev9TFgdyw1hcc8F9ngjJ6omHuRpWliK) | [pretrained source model](https://drive.google.com/drive/folders/1AeTt5sPbo-7oNX5u7Jbm8LSf3Pp8buTd) | [pretrained source model](https://drive.google.com/drive/folders/1JGQGCQwLLI5A2FNAHeEJgtjqAXkrco6d) |
21 |
22 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # Official implementation for **SHOT**
2 |
3 | ## [**[ICML-2020] Do We Really Need to Access the Source Data? Source Hypothesis Transfer for Unsupervised Domain Adaptation**](http://proceedings.mlr.press/v119/liang20a.html)
4 |
5 |
6 |
7 | - **2022/6/6 We correct a bug in the pseudo-labeling function (def obtain_label), many thanks to @TomSheng21.**
8 | - **2022/2/8 We upload the pretrained source models via Google drive in [pretrained-models.md](./pretrained-models.md).**
9 |
10 |
11 |
12 | ### Attention-v2: ***we release the code of our recent black-box UDA method (DINE, https://arxiv.org/pdf/2104.01539.pdf) in the following repository (https://github.com/tim-learn/DINE).***
13 |
14 | #### Attention: ***The code of our stronger TPAMI extension (SHOT++, https://arxiv.org/pdf/2012.07297.pdf) has been released in a new repository (https://github.com/tim-learn/SHOT-plus).***
15 |
16 |
17 |
18 | ### Results:
19 |
20 | #### **Note that we update the code and further consider the standard learning rate scheduler like DANN and report new results in the final camera ready version.** Please refer [results.md](./results.md) for the detailed results on various datasets.
21 |
22 | *We have updated the results for **Digits**. Now the results of SHOT-IM for **Digits** are stable and promising. (Thanks to @wengzejia1 for pointing the bugs in **uda_digit.py**).*
23 |
24 |
25 | ### Framework:
26 |
27 |
28 |
29 | ### Prerequisites:
30 | - python == 3.6.8
31 | - pytorch ==1.1.0
32 | - torchvision == 0.3.0
33 | - numpy, scipy, sklearn, PIL, argparse, tqdm
34 |
35 | ### Dataset:
36 |
37 | - Please manually download the datasets [Office](https://drive.google.com/file/d/0B4IapRTv9pJ1WGZVd1VDMmhwdlE/view), [Office-Home](https://drive.google.com/file/d/0B81rNlvomiwed0V1YUxQdC1uOTg/view), [VisDA-C](https://github.com/VisionLearningGroup/taskcv-2017-public/tree/master/classification), [Office-Caltech](http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar) from the official websites, and modify the path of images in each '.txt' under the folder './object/data/'. [**How to generate such txt files could be found in https://github.com/tim-learn/Generate_list **]
38 |
39 | - Concerning the **Digits** dsatasets, the code will automatically download three digit datasets (i.e., MNIST, USPS, and SVHN) in './digit/data/'.
40 |
41 |
42 | ### Training:
43 | 1. ##### Unsupervised Closed-set Domain Adaptation (UDA) on the Digits dataset
44 | - MNIST -> USPS (**m2u**) SHOT (**cls_par = 0.1**) and SHOT-IM (**cls_par = 0.0**)
45 | ```python
46 | cd digit/
47 | python uda_digit.py --dset m2u --gpu_id 0 --output ckps_digits --cls_par 0.0
48 | python uda_digit.py --dset m2u --gpu_id 0 --output ckps_digits --cls_par 0.1
49 | ```
50 |
51 | 2. ##### Unsupervised Closed-set Domain Adaptation (UDA) on the Office/ Office-Home dataset
52 | - Train model on the source domain **A** (**s = 0**)
53 | ```python
54 | cd object/
55 | python image_source.py --trte val --da uda --output ckps/source/ --gpu_id 0 --dset office --max_epoch 100 --s 0
56 | ```
57 |
58 | - Adaptation to other target domains **D and W**, respectively
59 | ```python
60 | python image_target.py --cls_par 0.3 --da uda --output_src ckps/source/ --output ckps/target/ --gpu_id 0 --dset office --s 0
61 | ```
62 |
63 | 3. ##### Unsupervised Closed-set Domain Adaptation (UDA) on the VisDA-C dataset
64 | - Synthetic-to-real
65 | ```python
66 | cd object/
67 | python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset VISDA-C --net resnet101 --lr 1e-3 --max_epoch 10 --s 0
68 | python image_target.py --cls_par 0.3 --da uda --dset VISDA-C --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ --net resnet101 --lr 1e-3
69 | ```
70 |
71 | 4. ##### Unsupervised Partial-set Domain Adaptation (PDA) on the Office-Home dataset
72 | - Train model on the source domain **A** (**s = 0**)
73 | ```python
74 | cd object/
75 | python image_source.py --trte val --da pda --output ckps/source/ --gpu_id 0 --dset office-home --max_epoch 50 --s 0
76 | ```
77 |
78 | - Adaptation to other target domains **C and P and R**, respectively
79 | ```python
80 | python image_target.py --cls_par 0.3 --threshold 10 --da pda --dset office-home --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/
81 | ```
82 |
83 | 5. ##### Unsupervised Open-set Domain Adaptation (ODA) on the Office-Home dataset
84 | - Train model on the source domain **A** (**s = 0**)
85 | ```python
86 | cd object/
87 | python image_source.py --trte val --da oda --output ckps/source/ --gpu_id 0 --dset office-home --max_epoch 50 --s 0
88 | ```
89 |
90 | - Adaptation to other target domains **C and P and R**, respectively
91 | ```python
92 | python image_target_oda.py --cls_par 0.3 --da oda --dset office-home --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/
93 | ```
94 |
95 | 6. ##### Unsupervised Multi-source Domain Adaptation (MSDA) on the Office-Caltech dataset
96 | - Train model on the source domains **A** (**s = 0**), **C** (**s = 1**), **D** (**s = 2**), respectively
97 | ```python
98 | cd object/
99 | python image_source.py --trte val --da uda --output ckps/source/ --gpu_id 0 --dset office-caltech --max_epoch 100 --s 0
100 | python image_source.py --trte val --da uda --output ckps/source/ --gpu_id 0 --dset office-caltech --max_epoch 100 --s 1
101 | python image_source.py --trte val --da uda --output ckps/source/ --gpu_id 0 --dset office-caltech --max_epoch 100 --s 2
102 | ```
103 |
104 | - Adaptation to the target domain **W** (**t = 3**)
105 | ```python
106 | python image_target.py --cls_par 0.3 --da uda --output_src ckps/source/ --output ckps/target/ --gpu_id 0 --dset office --s 0
107 | python image_target.py --cls_par 0.3 --da uda --output_src ckps/source/ --output ckps/target/ --gpu_id 0 --dset office --s 1
108 | python image_target.py --cls_par 0.3 --da uda --output_src ckps/source/ --output ckps/target/ --gpu_id 0 --dset office --s 2
109 | python image_multisource.py --cls_par 0.0 --da uda --dset office-caltech --gpu_id 0 --t 3 --output_src ckps/source/ --output ckps/target/
110 | ```
111 |
112 | 7. ##### Unsupervised Multi-target Domain Adaptation (MTDA) on the Office-Caltech dataset
113 | - Train model on the source domain **A** (**s = 0**)
114 | ```python
115 | cd object/
116 | python image_source.py --trte val --da uda --output ckps/source/ --gpu_id 0 --dset office-caltech --max_epoch 100 --s 0
117 | ```
118 |
119 | - Adaptation to multiple target domains **C and P and R** at the same time
120 | ```python
121 | python image_multitarget.py --cls_par 0.3 --da uda --dset office-caltech --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/
122 | ```
123 |
124 | 8. ##### Unsupervised Partial Domain Adaptation (PDA) on the ImageNet-Caltech dataset without source training by ourselves (using the downloaded Pytorch ResNet50 model directly)
125 | - ImageNet -> Caltech (84 classes) [following the protocol in [PADA](https://github.com/thuml/PADA/tree/master/pytorch/data/imagenet-caltech)]
126 | ```python
127 | cd object/
128 | python image_pretrained.py --gpu_id 0 --output ckps/target/ --cls_par 0.3
129 | ```
130 |
131 | **Please refer *./object/run.sh*** for all the settings for different methods and scenarios.
132 |
133 | ### Citation
134 |
135 | If you find this code useful for your research, please cite our papers
136 |
137 | ```
138 | @inproceedings{liang2020we,
139 | title={Do We Really Need to Access the Source Data? Source Hypothesis Transfer for Unsupervised Domain Adaptation},
140 | author={Liang, Jian and Hu, Dapeng and Feng, Jiashi},
141 | booktitle={International Conference on Machine Learning (ICML)},
142 | pages={6028--6039},
143 | year={2020}
144 | }
145 |
146 | @article{liang2021source,
147 | title={Source Data-absent Unsupervised Domain Adaptation through Hypothesis Transfer and Labeling Transfer},
148 | author={Liang, Jian and Hu, Dapeng and Wang, Yunbo and He, Ran and Feng, Jiashi},
149 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)},
150 | year={2021},
151 | note={In Press}
152 | }
153 | ```
154 |
155 |
156 | ### Contact
157 |
158 | - [liangjian92@gmail.com](mailto:liangjian92@gmail.com)
159 | - [dapeng.hu@u.nus.edu](mailto:dapeng.hu@u.nus.edu)
160 | - [elefjia@nus.edu.sg](mailto:elefjia@nus.edu.sg)
161 |
--------------------------------------------------------------------------------
/results.md:
--------------------------------------------------------------------------------
1 | Code for our ICML-2020 paper [**Do We Really Need to Access the Source Data? Source Hypothesis Transfer for Unsupervised Domain Adaptation**](https://arxiv.org/abs/2002.08546).
2 |
3 | ### Framework:
4 |
5 |
6 |
7 | ### Results:
8 |
9 | #### Table 2 [UDA results on Digits]
10 |
11 | | Methods | S->M | U->M | M->U | Avg. |
12 | | -------------- | ---- | ---- | ---- | ---- |
13 | | srconly (2019) | 71.5 | 85.5 | 82.5 | |
14 | | srconly (2020) | 69.2 | 89.8 | 77.6 | |
15 | | srconly (2021) | 69.7 | 88.7 | 79.0 | |
16 | | srconly (Avg.) | 70.2 | 88.0 | 79.7 | 79.3 |
17 | | SHOT-IM (2019) | 98.9 | 98.6 | 97.8 | |
18 | | SHOT-IM (2020) | 99.0 | 97.8 | 97.7 | |
19 | | SHOT-IM (2021) | 98.9 | 97.6 | 97.7 | |
20 | | SHOT-IM (Avg.) | 99.0 | 97.6 | 97.7 | 98.2 |
21 | | SHOT (2019) | 98.8 | 98.6 | 98.0 | |
22 | | SHOT (2020) | 99.0 | 97.6 | 97.8 | |
23 | | SHOT (2021) | 99.0 | 97.7 | 97.7 | |
24 | | SHOT (Avg.) | 98.9 | 98.0 | 97.9 | 98.3 |
25 | | Oracle (2019) | 99.2 | 99.2 | 97.1 | |
26 | | Oracle (2020) | 99.2 | 99.2 | 97.0 | |
27 | | Oracle (2021) | 99.3 | 99.3 | 97.0 | |
28 | | Oracle (Avg.) | 99.2 | 99.2 | 97.0 | 98.8 |
29 |
30 | #### Table 3 [UDA results on Office]
31 |
32 | | Methods | A->D | A->W | D->A | D->W | W->A | W->D | Avg. |
33 | | -------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
34 | | srconly (2019) | 79.9 | 77.5 | 58.9 | 95.0 | 64.6 | 98.4 | |
35 | | srconly (2020) | 81.5 | 75.8 | 61.6 | 96.0 | 63.3 | 99.0 | |
36 | | srconly (2021) | 80.9 | 77.5 | 60.2 | 94.8 | 62.9 | 98.8 | |
37 | | srconly (Avg.) | 80.8 | 76.9 | 60.3 | 95.3 | 63.6 | 98.7 | 79.3 |
38 | | SHOT-IM (2019) | 88.8 | 90.7 | 71.7 | 98.5 | 71.7 | 99.8 | |
39 | | SHOT-IM (2020) | 92.6 | 92.2 | 72.4 | 98.4 | 71.1 | 100. | |
40 | | SHOT-IM (2021) | 90.6 | 90.7 | 73.3 | 98.0 | 71.2 | 99.8 | |
41 | | SHOT-IM (Avg.) | 90.6 | 91.2 | 72.5 | 98.3 | 71.4 | 99.9 | 87.3 |
42 | | SHOT (2019) | 93.4 | 88.8 | 74.9 | 98.5 | 74.6 | 99.8 | |
43 | | SHOT (2020) | 95.0 | 92.0 | 75.7 | 98.6 | 73.7 | 100. | |
44 | | SHOT (2021) | 93.8 | 89.7 | 73.6 | 98.2 | 74.6 | 99.8 | |
45 | | SHOT (Avg.) | 94.0 | 90.1 | 74.7 | 98.4 | 74.3 | 99.9 | 88.6 |
46 |
47 | #### Table 4 [UDA results on Office-Home]
48 |
49 | | Methods |Ar->Cl|Ar->Pr|Ar->Re|Cl->Ar|Cl->Pr|Cl->Re|Pr->Ar|Pr->Cl|Pr->Re|Re->Ar|Re->Cl|Re->Pr| Avg. |
50 | | -------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
51 | | srconly (2019) | 45.2 | 67.2 | 75.0 | 52.4 | 62.6 | 64.9 | 52.4 | 40.6 | 73.0 | 65.0 | 43.8 | 78.1 | |
52 | | srconly (2020) | 44.2 | 67.2 | 74.4 | 52.3 | 63.1 | 64.5 | 53.1 | 41.0 | 73.7 | 65.3 | 46.8 | 77.9 | |
53 | | srconly (2021) | 44.5 | 67.7 | 74.8 | 53.4 | 62.4 | 64.9 | 53.4 | 40.4 | 72.9 | 65.7 | 45.8 | 78.1 | |
54 | | srconly (Avg.) | 44.6 | 67.3 | 74.8 | 52.7 | 62.7 | 64.8 | 53.0 | 40.6 | 73.2 | 65.3 | 45.4 | 78.0 | 60.2 |
55 | | SHOT-IM (2019) | 56.5 | 77.1 | 80.8 | 67.7 | 73.3 | 75.1 | 65.5 | 54.5 | 80.6 | 73.4 | 57.2 | 84.0 | |
56 | | SHOT-IM (2020) | 54.7 | 76.3 | 80.2 | 66.8 | 75.8 | 76.2 | 65.6 | 53.9 | 80.7 | 73.6 | 58.3 | 83.5 | |
57 | | SHOT-IM (2021) | 54.9 | 76.4 | 80.1 | 66.2 | 73.8 | 75.0 | 65.7 | 56.1 | 80.7 | 74.2 | 59.6 | 82.9 | |
58 | | SHOT-IM (Avg.) | 55.4 | 76.6 | 80.4 | 66.9 | 74.3 | 75.4 | 65.6 | 54.8 | 80.7 | 73.7 | 58.4 | 83.4 | 70.5 |
59 | | SHOT (2019) | 57.3 | 79.3 | 81.8 | 68.1 | 77.1 | 78.0 | 67.8 | 55.0 | 82.5 | 73.2 | 58.5 | 84.1 | |
60 | | SHOT (2020) | 57.1 | 77.5 | 81.6 | 68.4 | 78.2 | 77.9 | 67.0 | 55.6 | 82.4 | 73.6 | 60.2 | 84.6 | |
61 | | SHOT (2021) | 57.0 | 77.6 | 81.0 | 67.5 | 79.2 | 78.3 | 67.3 | 54.1 | 81.6 | 73.0 | 57.8 | 84.2 | |
62 | | SHOT (Avg.) | 57.1 | 78.1 | 81.5 | 68.0 | 78.2 | 78.1 | 67.4 | 54.9 | 82.2 | 73.3 | 58.8 | 84.3 | 71.8 |
63 |
64 | #### Table 5 [UDA results on VisDA-C]
65 |
66 | | Methods | plane | bcycl | bus | car | horse | knife | mcycl | person | plant | sktbrd | train | truck | Per-class |
67 | | -------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
68 | | srconly (2019) | 57.1 | 20.5 | 48.6 | 60.8 | 66.2 | 3.6 | 80.7 | 23.9 | 38.5 | 31.0 | 87.0 | 10.7 | |
69 | | srconly (2020) | 65.1 | 18.9 | 57.2 | 66.9 | 69.9 | 11.0| 84.7 | 23.9 | 69.4 | 34.0 | 83.8 | 9.3 | |
70 | | srconly (2021) | 60.5 | 25.5 | 47.0 | 75.2 | 61.3 | 4.2 | 81.1 | 21.9 | 63.9 | 26.7 | 83.1 | 4.0 | |
71 | | srconly (Avg.) | 60.9 | 21.6 | 50.9 | 67.6 | 65.8 | 6.3 | 82.2 | 23.2 | 57.3 | 30.6 | 84.6 | 8.0 | 46.6 |
72 | | SHOT-IM (2019) | 94.3 | 86.6 | 78.1 | 54.0 | 91.0 | 92.3 | 79.1 | 78.9 | 88.4 | 86.0 | 88.0 | 50.7 | |
73 | | SHOT-IM (2020) | 93.4 | 87.1 | 80.4 | 51.7 | 91.5 | 92.9 | 80.0 | 78.0 | 89.6 | 85.1 | 87.2 | 51.3 | |
74 | | SHOT-IM (2021) | 93.5 | 85.7 | 77.6 | 46.3 | 90.5 | 95.1 | 77.9 | 78.1 | 89.7 | 85.0 | 88.5 | 51.2 | |
75 | | SHOT-IM (Avg.) | 93.7 | 86.4 | 78.7 | 50.7 | 91.0 | 93.5 | 79.0 | 78.3 | 89.2 | 85.4 | 87.9 | 51.1 | 80.4 |
76 | | SHOT (2019) | 93.8 | 89.0 | 81.4 | 57.0 | 93.4 | 94.7 | 81.3 | 80.3 | 90.5 | 89.1 | 85.3 | 58.4 | |
77 | | SHOT (2020) | 94.5 | 87.3 | 80.0 | 57.1 | 93.1 | 94.5 | 82.0 | 80.7 | 91.7 | 89.4 | 87.0 | 58.3 | |
78 | | SHOT (2021) | 94.7 | 89.1 | 78.7 | 57.8 | 92.8 | 95.5 | 78.8 | 79.9 | 92.4 | 89.0 | 86.6 | 57.9 | |
79 | | SHOT (Avg.) | 94.3 | 88.5 | 80.1 | 57.3 | 93.1 | 94.9 | 80.7 | 80.3 | 91.5 | 89.1 | 86.3 | 58.2 | 82.9 |
80 |
81 | #### Table 7 [PDA/ ODA results on Office-Home]
82 |
83 | | Methods@PDA |Ar->Cl|Ar->Pr|Ar->Re|Cl->Ar|Cl->Pr|Cl->Re|Pr->Ar|Pr->Cl|Pr->Re|Re->Ar|Re->Cl|Re->Pr| Avg. |
84 | | -------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
85 | | srconly (2019) | 46.0 | 69.7 | 80.7 | 56.3 | 60.4 | 66.9 | 60.2 | 40.6 | 76.0 | 70.8 | 48.6 | 78.5 | |
86 | | srconly (2020) | 45.1 | 71.0 | 80.8 | 55.7 | 61.8 | 66.4 | 61.4 | 39.7 | 76.1 | 70.6 | 49.7 | 76.3 | |
87 | | srconly (2021) | 44.5 | 70.5 | 81.3 | 56.8 | 60.2 | 65.2 | 61.2 | 40.0 | 76.5 | 70.9 | 47.2 | 77.2 | |
88 | | srconly (Avg.) | 45.2 | 70.4 | 81.0 | 56.2 | 60.8 | 66.2 | 60.9 | 40.1 | 76.2 | 70.8 | 48.5 | 77.3 | 62.8 |
89 | | SHOT-IM (2019) | 57.5 | 86.2 | 88.2 | 69.3 | 73.6 | 79.9 | 79.7 | 62.2 | 89.0 | 80.8 | 66.6 | 91.0 | |
90 | | SHOT-IM (2020) | 61.2 | 82.0 | 87.8 | 73.3 | 74.4 | 80.6 | 74.1 | 58.8 | 90.0 | 81.7 | 70.8 | 87.1 | |
91 | | SHOT-IM (2021) | 55.0 | 82.6 | 90.3 | 74.5 | 74.0 | 76.5 | 74.4 | 60.8 | 91.4 | 83.0 | 67.5 | 87.3 | |
92 | | SHOT-IM (Avg.) | 57.9 | 83.6 | 88.8 | 72.4 | 74.0 | 79.0 | 76.1 | 60.6 | 90.1 | 81.9 | 68.3 | 88.5 | 76.8 |
93 | | SHOT (2019) | 65.0 | 85.0 | 93.3 | 75.7 | 79.3 | 88.9 | 80.5 | 65.3 | 90.1 | 80.9 | 67.0 | 86.3 | |
94 | | SHOT (2020) | 64.1 | 82.0 | 92.7 | 77.6 | 74.8 | 90.7 | 80.0 | 63.5 | 88.4 | 79.9 | 66.8 | 85.0 | |
95 | | SHOT (2021) | 65.2 | 88.7 | 92.2 | 75.7 | 78.8 | 86.8 | 78.5 | 64.1 | 90.1 | 80.9 | 65.3 | 86.0 | |
96 | | SHOT (Avg.) | 64.8 | 85.2 | 92.7 | 76.3 | 77.6 | 88.8 | 79.7 | 64.3 | 89.5 | 80.6 | 66.4 | 85.8 | 79.3 |
97 |
98 |
99 | | Methods@ODA |Ar->Cl|Ar->Pr|Ar->Re|Cl->Ar|Cl->Pr|Cl->Re|Pr->Ar|Pr->Cl|Pr->Re|Re->Ar|Re->Cl|Re->Pr| Avg. |
100 | | -------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
101 | | srconly (2019) | 37.4 | 54.7 | 69.9 | 34.2 | 44.3 | 49.7 | 37.7 | 30.1 | 56.2 | 50.6 | 35.2 | 61.6 | |
102 | | srconly (2020) | 36.4 | 55.0 | 69.0 | 33.3 | 44.7 | 47.8 | 34.6 | 29.2 | 55.7 | 53.2 | 36.0 | 62.4 | |
103 | | srconly (2021) | 35.1 | 54.8 | 68.4 | 33.8 | 44.1 | 50.1 | 38.2 | 28.1 | 58.3 | 50.3 | 34.1 | 62.9 | |
104 | | srconly (Avg.) | 36.3 | 54.8 | 69.1 | 33.8 | 44.4 | 49.2 | 36.8 | 29.2 | 56.8 | 51.4 | 35.1 | 62.3 | 46.6 |
105 | | SHOT-IM (2019) | 61.6 | 80.1 | 84.4 | 61.8 | 74.0 | 81.9 | 63.6 | 58.5 | 83.1 | 68.4 | 63.7 | 82.2 | |
106 | | SHOT-IM (2020) | 63.4 | 76.0 | 83.2 | 61.4 | 74.3 | 78.7 | 63.8 | 59.6 | 83.1 | 70.0 | 61.8 | 82.7 | |
107 | | SHOT-IM (2021) | 62.4 | 77.3 | 84.1 | 59.6 | 71.9 | 77.7 | 66.7 | 58.0 | 83.0 | 68.9 | 60.6 | 81.6 | |
108 | | SHOT-IM (Avg.) | 62.5 | 77.8 | 83.9 | 60.9 | 73.4 | 79.4 | 64.7 | 58.7 | 83.1 | 69.1 | 62.0 | 82.1 | 71.5 |
109 | | SHOT (2019) | 63.9 | 80.6 | 85.6 | 63.6 | 77.1 | 83.2 | 64.9 | 58.3 | 83.2 | 69.7 | 65.2 | 82.8 | |
110 | | SHOT (2020) | 64.0 | 80.4 | 84.7 | 63.4 | 75.3 | 81.6 | 65.1 | 60.9 | 82.8 | 69.9 | 64.4 | 82.4 | |
111 | | SHOT (2021) | 65.6 | 80.2 | 83.8 | 62.2 | 73.7 | 78.8 | 65.9 | 58.8 | 83.9 | 69.2 | 64.1 | 81.7 | |
112 | | SHOT (Avg.) | 64.5 | 80.4 | 84.7 | 63.1 | 75.4 | 81.2 | 65.3 | 59.3 | 83.3 | 69.6 | 64.6 | 82.3 | 72.8 |
113 |
114 | #### Table 8 [MSDA/ MTDA results on Office-Caltech]
115 |
116 | | Methods@MSDA | ->A | ->C | ->D | ->W | Avg. |
117 | | -------------- | ---- | ---- | ---- | ---- | ---- |
118 | | srconly (2019) | 95.2 | 93.9 | 99.4 | 98.0 | |
119 | | srconly (2020) | 95.4 | 93.5 | 98.7 | 98.6 | |
120 | | srconly (2021) | 95.6 | 93.7 | 98.7 | 98.3 | |
121 | | srconly (Avg.) | 95.4 | 93.7 | 98.9 | 98.3 | 96.6 |
122 | | SHOT-IM (2019) | 95.8 | 96.0 | 99.4 | 99.7 | |
123 | | SHOT-IM (2020) | 96.5 | 95.9 | 97.5 | 99.7 | |
124 | | SHOT-IM (2021) | 96.4 | 96.3 | 98.7 | 99.7 | |
125 | | SHOT-IM (Avg.) | 96.2 | 96.1 | 98.5 | 99.7 | 97.6 |
126 | | SHOT (2019) | 96.2 | 95.9 | 98.7 | 99.7 | |
127 | | SHOT (2020) | 96.5 | 96.1 | 98.7 | 99.7 ||
128 | | SHOT (2021) | 96.6 | 96.6 | 98.1 | 99.7 | |
129 | | SHOT (Avg.) | 96.4 | 96.2 | 98.5 | 99.7 | 97.7 |
130 |
131 | | Methods@MTDA | A-> | C-> | D-> | W-> | Avg. |
132 | | -------------- | ---- | ---- | ---- | ---- | ---- |
133 | | srconly (2019) | 90.4 | 95.9 | 90.3 | 90.6 | |
134 | | srconly (2020) | 91.2 | 95.9 | 90.2 | 91.1 | |
135 | | srconly (2021) | 90.5 | 96.5 | 90.2 | 91.1 | |
136 | | srconly (Avg.) | 90.7 | 96.1 | 90.2 | 90.9 | 92.0 |
137 | | SHOT-IM (2019) | 96.6 | 97.5 | 96.3 | 96.0 | |
138 | | SHOT-IM (2020) | 95.1 | 96.7 | 96.3 | 96.4 | |
139 | | SHOT-IM (2021) | 95.4 | 97.3 | 96.3 | 96.0 | |
140 | | SHOT-IM (Avg.) | 95.7 | 97.2 | 96.3 | 96.1 | 96.3 |
141 | | SHOT (2019) | 96.6 | 97.5 | 96.4 | 96.0 | |
142 | | SHOT (2020) | 95.4 | 97.0 | 96.5 | 96.7 | |
143 | | SHOT (2021) | 96.6 | 97.5 | 96.0 | 96.0 | |
144 | | SHOT (Avg.) | 96.2 | 97.3 | 96.3 | 96.2 | 96.5 |
145 |
146 |
147 | #### Table 9 [PDA results on ImageNet->Caltech]
148 |
149 | | Methods@PDA | 2019 | 2020 | 2021 | Avg. |
150 | | -------------- | ---- | ---- | ---- | ---- |
151 | | srconly | 69.7 | 69.7 | 69.7 | 69.7 |
152 | | SHOT-IM | 81.1 | 82.2 | 81.8 | 81.7 |
153 | | SHOT | 83.2 | 83.3 | 83.4 | 83.3 |
154 |
155 |
156 |
157 |
158 | ### Citation
159 |
160 | If you find this code useful for your research, please cite our paper
161 |
162 | > @inproceedings{liang2020shot,
163 | > title={Do We Really Need to Access the Source Data? Source Hypothesis Transfer for Unsupervised Domain Adaptation},
164 | > author={Liang, Jian and Hu, Dapeng and Feng, Jiashi},
165 | > booktitle={International Conference on Machine Learning (ICML)},
166 | > pages={xx-xx},
167 | > month = {July},
168 | > year={2020}
169 | > }
170 |
171 | ### Contact
172 |
173 | - [liangjian92@gmail.com](mailto:liangjian92@gmail.com)
174 | - [dapeng.hu@u.nus.edu](mailto:dapeng.hu@u.nus.edu)
175 | - [elefjia@nus.edu.sg](mailto:elefjia@nus.edu.sg)
176 |
--------------------------------------------------------------------------------