├── LICENSE ├── digit ├── data_load │ ├── __init__.py │ ├── mnist.py │ ├── svhn.py │ ├── usps.py │ ├── utils.py │ └── vision.py ├── digit.sh ├── loss.py ├── network.py └── uda_digit.py ├── figs └── shot.jpg ├── object ├── data_list.py ├── image_multisource.py ├── image_multitarget.py ├── image_pretrained.py ├── image_source.py ├── image_target.py ├── image_target_oda.py ├── loss.py ├── network.py └── run.sh ├── pretrained-models.md ├── readme.md └── results.md /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /digit/data_load/__init__.py: -------------------------------------------------------------------------------- 1 | from .svhn import * 2 | from .mnist import * 3 | from .usps import * -------------------------------------------------------------------------------- /digit/data_load/mnist.py: -------------------------------------------------------------------------------- 1 | from .vision import VisionDataset 2 | import warnings 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import torch 8 | import codecs 9 | import string 10 | from .utils import download_url, download_and_extract_archive, extract_archive, \ 11 | verify_str_arg 12 | 13 | 14 | class MNIST(VisionDataset): 15 | """`MNIST `_ Dataset. 16 | 17 | Args: 18 | root (string): Root directory of dataset where ``MNIST/processed/training.pt`` 19 | and ``MNIST/processed/test.pt`` exist. 20 | train (bool, optional): If True, creates dataset from ``training.pt``, 21 | otherwise from ``test.pt``. 22 | download (bool, optional): If true, downloads the dataset from the internet and 23 | puts it in root directory. If dataset is already downloaded, it is not 24 | downloaded again. 25 | transform (callable, optional): A function/transform that takes in an PIL image 26 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 27 | target_transform (callable, optional): A function/transform that takes in the 28 | target and transforms it. 29 | """ 30 | 31 | resources = [ 32 | ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), 33 | ("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), 34 | ("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), 35 | ("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c") 36 | ] 37 | 38 | training_file = 'training.pt' 39 | test_file = 'test.pt' 40 | classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', 41 | '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] 42 | 43 | @property 44 | def train_labels(self): 45 | warnings.warn("train_labels has been renamed targets") 46 | return self.targets 47 | 48 | @property 49 | def test_labels(self): 50 | warnings.warn("test_labels has been renamed targets") 51 | return self.targets 52 | 53 | @property 54 | def train_data(self): 55 | warnings.warn("train_data has been renamed data") 56 | return self.data 57 | 58 | @property 59 | def test_data(self): 60 | warnings.warn("test_data has been renamed data") 61 | return self.data 62 | 63 | def __init__(self, root, train=True, transform=None, target_transform=None, 64 | download=False): 65 | super(MNIST, self).__init__(root, transform=transform, 66 | target_transform=target_transform) 67 | self.train = train # training set or test set 68 | 69 | if download: 70 | self.download() 71 | 72 | if not self._check_exists(): 73 | raise RuntimeError('Dataset not found.' + 74 | ' You can use download=True to download it') 75 | 76 | if self.train: 77 | data_file = self.training_file 78 | else: 79 | data_file = self.test_file 80 | self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) 81 | 82 | def __getitem__(self, index): 83 | """ 84 | Args: 85 | index (int): Index 86 | 87 | Returns: 88 | tuple: (image, target) where target is index of the target class. 89 | """ 90 | img, target = self.data[index], int(self.targets[index]) 91 | 92 | # doing this so that it is consistent with all other datasets 93 | # to return a PIL Image 94 | img = Image.fromarray(img.numpy(), mode='L') 95 | 96 | if self.transform is not None: 97 | img = self.transform(img) 98 | 99 | if self.target_transform is not None: 100 | target = self.target_transform(target) 101 | 102 | return img, target 103 | 104 | def __len__(self): 105 | return len(self.data) 106 | 107 | # @property 108 | # def raw_folder(self): 109 | # return os.path.join(self.root, self.__class__.__name__, 'raw') 110 | 111 | # @property 112 | # def processed_folder(self): 113 | # return os.path.join(self.root, self.__class__.__name__, 'processed') 114 | 115 | @property 116 | def raw_folder(self): 117 | return os.path.join(self.root, 'raw') 118 | 119 | @property 120 | def processed_folder(self): 121 | return os.path.join(self.root, 'processed') 122 | 123 | @property 124 | def class_to_idx(self): 125 | return {_class: i for i, _class in enumerate(self.classes)} 126 | 127 | def _check_exists(self): 128 | return (os.path.exists(os.path.join(self.processed_folder, 129 | self.training_file)) and 130 | os.path.exists(os.path.join(self.processed_folder, 131 | self.test_file))) 132 | 133 | def download(self): 134 | """Download the MNIST data if it doesn't exist in processed_folder already.""" 135 | 136 | if self._check_exists(): 137 | return 138 | 139 | os.makedirs(self.raw_folder, exist_ok=True) 140 | os.makedirs(self.processed_folder, exist_ok=True) 141 | 142 | # download files 143 | for url, md5 in self.resources: 144 | filename = url.rpartition('/')[2] 145 | download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) 146 | 147 | # process and save as torch files 148 | print('Processing...') 149 | 150 | training_set = ( 151 | read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')), 152 | read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte')) 153 | ) 154 | test_set = ( 155 | read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')), 156 | read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte')) 157 | ) 158 | with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f: 159 | torch.save(training_set, f) 160 | with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f: 161 | torch.save(test_set, f) 162 | 163 | print('Done!') 164 | 165 | def extra_repr(self): 166 | return "Split: {}".format("Train" if self.train is True else "Test") 167 | 168 | 169 | 170 | class MNIST_idx(VisionDataset): 171 | """`MNIST `_ Dataset. 172 | 173 | Args: 174 | root (string): Root directory of dataset where ``MNIST/processed/training.pt`` 175 | and ``MNIST/processed/test.pt`` exist. 176 | train (bool, optional): If True, creates dataset from ``training.pt``, 177 | otherwise from ``test.pt``. 178 | download (bool, optional): If true, downloads the dataset from the internet and 179 | puts it in root directory. If dataset is already downloaded, it is not 180 | downloaded again. 181 | transform (callable, optional): A function/transform that takes in an PIL image 182 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 183 | target_transform (callable, optional): A function/transform that takes in the 184 | target and transforms it. 185 | """ 186 | 187 | resources = [ 188 | ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), 189 | ("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), 190 | ("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), 191 | ("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c") 192 | ] 193 | 194 | training_file = 'training.pt' 195 | test_file = 'test.pt' 196 | classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', 197 | '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] 198 | 199 | @property 200 | def train_labels(self): 201 | warnings.warn("train_labels has been renamed targets") 202 | return self.targets 203 | 204 | @property 205 | def test_labels(self): 206 | warnings.warn("test_labels has been renamed targets") 207 | return self.targets 208 | 209 | @property 210 | def train_data(self): 211 | warnings.warn("train_data has been renamed data") 212 | return self.data 213 | 214 | @property 215 | def test_data(self): 216 | warnings.warn("test_data has been renamed data") 217 | return self.data 218 | 219 | def __init__(self, root, train=True, transform=None, target_transform=None, 220 | download=False): 221 | super(MNIST_idx, self).__init__(root, transform=transform, 222 | target_transform=target_transform) 223 | self.train = train # training set or test set 224 | 225 | if download: 226 | self.download() 227 | 228 | if not self._check_exists(): 229 | raise RuntimeError('Dataset not found.' + 230 | ' You can use download=True to download it') 231 | 232 | if self.train: 233 | data_file = self.training_file 234 | else: 235 | data_file = self.test_file 236 | self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) 237 | 238 | def __getitem__(self, index): 239 | """ 240 | Args: 241 | index (int): Index 242 | 243 | Returns: 244 | tuple: (image, target) where target is index of the target class. 245 | """ 246 | img, target = self.data[index], int(self.targets[index]) 247 | 248 | # doing this so that it is consistent with all other datasets 249 | # to return a PIL Image 250 | img = Image.fromarray(img.numpy(), mode='L') 251 | 252 | if self.transform is not None: 253 | img = self.transform(img) 254 | 255 | if self.target_transform is not None: 256 | target = self.target_transform(target) 257 | 258 | return img, target, index 259 | 260 | def __len__(self): 261 | return len(self.data) 262 | 263 | # @property 264 | # def raw_folder(self): 265 | # return os.path.join(self.root, self.__class__.__name__, 'raw') 266 | 267 | # @property 268 | # def processed_folder(self): 269 | # return os.path.join(self.root, self.__class__.__name__, 'processed') 270 | 271 | @property 272 | def raw_folder(self): 273 | return os.path.join(self.root, 'raw') 274 | 275 | @property 276 | def processed_folder(self): 277 | return os.path.join(self.root, 'processed') 278 | 279 | @property 280 | def class_to_idx(self): 281 | return {_class: i for i, _class in enumerate(self.classes)} 282 | 283 | def _check_exists(self): 284 | return (os.path.exists(os.path.join(self.processed_folder, 285 | self.training_file)) and 286 | os.path.exists(os.path.join(self.processed_folder, 287 | self.test_file))) 288 | 289 | def download(self): 290 | """Download the MNIST data if it doesn't exist in processed_folder already.""" 291 | 292 | if self._check_exists(): 293 | return 294 | 295 | os.makedirs(self.raw_folder, exist_ok=True) 296 | os.makedirs(self.processed_folder, exist_ok=True) 297 | 298 | # download files 299 | for url, md5 in self.resources: 300 | filename = url.rpartition('/')[2] 301 | download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) 302 | 303 | # process and save as torch files 304 | print('Processing...') 305 | 306 | training_set = ( 307 | read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')), 308 | read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte')) 309 | ) 310 | test_set = ( 311 | read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')), 312 | read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte')) 313 | ) 314 | with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f: 315 | torch.save(training_set, f) 316 | with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f: 317 | torch.save(test_set, f) 318 | 319 | print('Done!') 320 | 321 | def extra_repr(self): 322 | return "Split: {}".format("Train" if self.train is True else "Test") 323 | 324 | def get_int(b): 325 | return int(codecs.encode(b, 'hex'), 16) 326 | 327 | 328 | def open_maybe_compressed_file(path): 329 | """Return a file object that possibly decompresses 'path' on the fly. 330 | Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'. 331 | """ 332 | if not isinstance(path, torch._six.string_classes): 333 | return path 334 | if path.endswith('.gz'): 335 | import gzip 336 | return gzip.open(path, 'rb') 337 | if path.endswith('.xz'): 338 | import lzma 339 | return lzma.open(path, 'rb') 340 | return open(path, 'rb') 341 | 342 | def read_sn3_pascalvincent_tensor(path, strict=True): 343 | """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). 344 | Argument may be a filename, compressed filename, or file object. 345 | """ 346 | # typemap 347 | if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'): 348 | read_sn3_pascalvincent_tensor.typemap = { 349 | 8: (torch.uint8, np.uint8, np.uint8), 350 | 9: (torch.int8, np.int8, np.int8), 351 | 11: (torch.int16, np.dtype('>i2'), 'i2'), 352 | 12: (torch.int32, np.dtype('>i4'), 'i4'), 353 | 13: (torch.float32, np.dtype('>f4'), 'f4'), 354 | 14: (torch.float64, np.dtype('>f8'), 'f8')} 355 | # read 356 | with open_maybe_compressed_file(path) as f: 357 | data = f.read() 358 | # parse 359 | magic = get_int(data[0:4]) 360 | nd = magic % 256 361 | ty = magic // 256 362 | assert nd >= 1 and nd <= 3 363 | assert ty >= 8 and ty <= 14 364 | m = read_sn3_pascalvincent_tensor.typemap[ty] 365 | s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)] 366 | parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) 367 | assert parsed.shape[0] == np.prod(s) or not strict 368 | return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) 369 | 370 | 371 | def read_label_file(path): 372 | with open(path, 'rb') as f: 373 | x = read_sn3_pascalvincent_tensor(f, strict=False) 374 | assert(x.dtype == torch.uint8) 375 | assert(x.ndimension() == 1) 376 | return x.long() 377 | 378 | def read_image_file(path): 379 | with open(path, 'rb') as f: 380 | x = read_sn3_pascalvincent_tensor(f, strict=False) 381 | assert(x.dtype == torch.uint8) 382 | assert(x.ndimension() == 3) 383 | return x -------------------------------------------------------------------------------- /digit/data_load/svhn.py: -------------------------------------------------------------------------------- 1 | from .vision import VisionDataset 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | from .utils import download_url, check_integrity, verify_str_arg 7 | 8 | 9 | class SVHN(VisionDataset): 10 | """`SVHN `_ Dataset. 11 | Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset, 12 | we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which 13 | expect the class labels to be in the range `[0, C-1]` 14 | 15 | .. warning:: 16 | 17 | This class needs `scipy `_ to load data from `.mat` format. 18 | 19 | Args: 20 | root (string): Root directory of dataset where directory 21 | ``SVHN`` exists. 22 | split (string): One of {'train', 'test', 'extra'}. 23 | Accordingly dataset is selected. 'extra' is Extra training set. 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | download (bool, optional): If true, downloads the dataset from the internet and 29 | puts it in root directory. If dataset is already downloaded, it is not 30 | downloaded again. 31 | 32 | """ 33 | 34 | split_list = { 35 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 36 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 37 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", 38 | "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], 39 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 40 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]} 41 | 42 | def __init__(self, root, split='train', transform=None, target_transform=None, 43 | download=False): 44 | super(SVHN, self).__init__(root, transform=transform, 45 | target_transform=target_transform) 46 | self.split = verify_str_arg(split, "split", tuple(self.split_list.keys())) 47 | self.url = self.split_list[split][0] 48 | self.filename = self.split_list[split][1] 49 | self.file_md5 = self.split_list[split][2] 50 | 51 | if download: 52 | self.download() 53 | 54 | if not self._check_integrity(): 55 | raise RuntimeError('Dataset not found or corrupted.' + 56 | ' You can use download=True to download it') 57 | 58 | # import here rather than at top of file because this is 59 | # an optional dependency for torchvision 60 | import scipy.io as sio 61 | 62 | # reading(loading) mat file as array 63 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename)) 64 | 65 | self.data = loaded_mat['X'] 66 | # loading from the .mat file gives an np array of type np.uint8 67 | # converting to np.int64, so that we have a LongTensor after 68 | # the conversion from the numpy array 69 | # the squeeze is needed to obtain a 1D tensor 70 | self.labels = loaded_mat['y'].astype(np.int64).squeeze() 71 | 72 | # the svhn dataset assigns the class label "10" to the digit 0 73 | # this makes it inconsistent with several loss functions 74 | # which expect the class labels to be in the range [0, C-1] 75 | np.place(self.labels, self.labels == 10, 0) 76 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 77 | 78 | def __getitem__(self, index): 79 | """ 80 | Args: 81 | index (int): Index 82 | 83 | Returns: 84 | tuple: (image, target) where target is index of the target class. 85 | """ 86 | img, target = self.data[index], int(self.labels[index]) 87 | 88 | # doing this so that it is consistent with all other datasets 89 | # to return a PIL Image 90 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 91 | 92 | if self.transform is not None: 93 | img = self.transform(img) 94 | 95 | if self.target_transform is not None: 96 | target = self.target_transform(target) 97 | 98 | return img, target 99 | 100 | def __len__(self): 101 | return len(self.data) 102 | 103 | def _check_integrity(self): 104 | root = self.root 105 | md5 = self.split_list[self.split][2] 106 | fpath = os.path.join(root, self.filename) 107 | return check_integrity(fpath, md5) 108 | 109 | def download(self): 110 | md5 = self.split_list[self.split][2] 111 | download_url(self.url, self.root, self.filename, md5) 112 | 113 | def extra_repr(self): 114 | return "Split: {split}".format(**self.__dict__) 115 | 116 | class SVHN_idx(VisionDataset): 117 | """`SVHN `_ Dataset. 118 | Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset, 119 | we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which 120 | expect the class labels to be in the range `[0, C-1]` 121 | 122 | .. warning:: 123 | 124 | This class needs `scipy `_ to load data from `.mat` format. 125 | 126 | Args: 127 | root (string): Root directory of dataset where directory 128 | ``SVHN`` exists. 129 | split (string): One of {'train', 'test', 'extra'}. 130 | Accordingly dataset is selected. 'extra' is Extra training set. 131 | transform (callable, optional): A function/transform that takes in an PIL image 132 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 133 | target_transform (callable, optional): A function/transform that takes in the 134 | target and transforms it. 135 | download (bool, optional): If true, downloads the dataset from the internet and 136 | puts it in root directory. If dataset is already downloaded, it is not 137 | downloaded again. 138 | 139 | """ 140 | 141 | split_list = { 142 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 143 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 144 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", 145 | "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], 146 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 147 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]} 148 | 149 | def __init__(self, root, split='train', transform=None, target_transform=None, 150 | download=False): 151 | super(SVHN_idx, self).__init__(root, transform=transform, 152 | target_transform=target_transform) 153 | self.split = verify_str_arg(split, "split", tuple(self.split_list.keys())) 154 | self.url = self.split_list[split][0] 155 | self.filename = self.split_list[split][1] 156 | self.file_md5 = self.split_list[split][2] 157 | 158 | if download: 159 | self.download() 160 | 161 | if not self._check_integrity(): 162 | raise RuntimeError('Dataset not found or corrupted.' + 163 | ' You can use download=True to download it') 164 | 165 | # import here rather than at top of file because this is 166 | # an optional dependency for torchvision 167 | import scipy.io as sio 168 | 169 | # reading(loading) mat file as array 170 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename)) 171 | 172 | self.data = loaded_mat['X'] 173 | # loading from the .mat file gives an np array of type np.uint8 174 | # converting to np.int64, so that we have a LongTensor after 175 | # the conversion from the numpy array 176 | # the squeeze is needed to obtain a 1D tensor 177 | self.labels = loaded_mat['y'].astype(np.int64).squeeze() 178 | 179 | # the svhn dataset assigns the class label "10" to the digit 0 180 | # this makes it inconsistent with several loss functions 181 | # which expect the class labels to be in the range [0, C-1] 182 | np.place(self.labels, self.labels == 10, 0) 183 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 184 | 185 | def __getitem__(self, index): 186 | """ 187 | Args: 188 | index (int): Index 189 | 190 | Returns: 191 | tuple: (image, target) where target is index of the target class. 192 | """ 193 | img, target = self.data[index], int(self.labels[index]) 194 | 195 | # doing this so that it is consistent with all other datasets 196 | # to return a PIL Image 197 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 198 | 199 | if self.transform is not None: 200 | img = self.transform(img) 201 | 202 | if self.target_transform is not None: 203 | target = self.target_transform(target) 204 | 205 | return img, target, index 206 | 207 | def __len__(self): 208 | return len(self.data) 209 | 210 | def _check_integrity(self): 211 | root = self.root 212 | md5 = self.split_list[self.split][2] 213 | fpath = os.path.join(root, self.filename) 214 | return check_integrity(fpath, md5) 215 | 216 | def download(self): 217 | md5 = self.split_list[self.split][2] 218 | download_url(self.url, self.root, self.filename, md5) 219 | 220 | def extra_repr(self): 221 | return "Split: {split}".format(**self.__dict__) -------------------------------------------------------------------------------- /digit/data_load/usps.py: -------------------------------------------------------------------------------- 1 | """Dataset setting and data loader for USPS. 2 | Modified from 3 | https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py 4 | """ 5 | 6 | import gzip 7 | import os 8 | import pickle 9 | import urllib 10 | from PIL import Image 11 | 12 | import numpy as np 13 | import torch 14 | import torch.utils.data as data 15 | from torch.utils.data.sampler import WeightedRandomSampler 16 | from torchvision import datasets, transforms 17 | 18 | 19 | class USPS(data.Dataset): 20 | """USPS Dataset. 21 | Args: 22 | root (string): Root directory of dataset where dataset file exist. 23 | train (bool, optional): If True, resample from dataset randomly. 24 | download (bool, optional): If true, downloads the dataset 25 | from the internet and puts it in root directory. 26 | If dataset is already downloaded, it is not downloaded again. 27 | transform (callable, optional): A function/transform that takes in 28 | an PIL image and returns a transformed version. 29 | E.g, ``transforms.RandomCrop`` 30 | """ 31 | 32 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl" 33 | 34 | def __init__(self, root, train=True, transform=None, download=False): 35 | """Init USPS dataset.""" 36 | # init params 37 | self.root = os.path.expanduser(root) 38 | self.filename = "usps_28x28.pkl" 39 | self.train = train 40 | # Num of Train = 7438, Num ot Test 1860 41 | self.transform = transform 42 | self.dataset_size = None 43 | 44 | # download dataset. 45 | if download: 46 | self.download() 47 | if not self._check_exists(): 48 | raise RuntimeError("Dataset not found." + 49 | " You can use download=True to download it") 50 | 51 | self.train_data, self.train_labels = self.load_samples() 52 | if self.train: 53 | total_num_samples = self.train_labels.shape[0] 54 | indices = np.arange(total_num_samples) 55 | self.train_data = self.train_data[indices[0:self.dataset_size], ::] 56 | self.train_labels = self.train_labels[indices[0:self.dataset_size]] 57 | self.train_data *= 255.0 58 | self.train_data = np.squeeze(self.train_data).astype(np.uint8) 59 | 60 | def __getitem__(self, index): 61 | """Get images and target for data loader. 62 | Args: 63 | index (int): Index 64 | Returns: 65 | tuple: (image, target) where target is index of the target class. 66 | """ 67 | img, label = self.train_data[index], self.train_labels[index] 68 | img = Image.fromarray(img, mode='L') 69 | img = img.copy() 70 | if self.transform is not None: 71 | img = self.transform(img) 72 | return img, label.astype("int64") 73 | 74 | def __len__(self): 75 | """Return size of dataset.""" 76 | return len(self.train_data) 77 | 78 | def _check_exists(self): 79 | """Check if dataset is download and in right place.""" 80 | return os.path.exists(os.path.join(self.root, self.filename)) 81 | 82 | def download(self): 83 | """Download dataset.""" 84 | filename = os.path.join(self.root, self.filename) 85 | dirname = os.path.dirname(filename) 86 | if not os.path.isdir(dirname): 87 | os.makedirs(dirname) 88 | if os.path.isfile(filename): 89 | return 90 | print("Download %s to %s" % (self.url, os.path.abspath(filename))) 91 | urllib.request.urlretrieve(self.url, filename) 92 | print("[DONE]") 93 | return 94 | 95 | def load_samples(self): 96 | """Load sample images from dataset.""" 97 | filename = os.path.join(self.root, self.filename) 98 | f = gzip.open(filename, "rb") 99 | data_set = pickle.load(f, encoding="bytes") 100 | f.close() 101 | if self.train: 102 | images = data_set[0][0] 103 | labels = data_set[0][1] 104 | self.dataset_size = labels.shape[0] 105 | else: 106 | images = data_set[1][0] 107 | labels = data_set[1][1] 108 | self.dataset_size = labels.shape[0] 109 | return images, labels 110 | 111 | 112 | class USPS_idx(data.Dataset): 113 | """USPS Dataset. 114 | Args: 115 | root (string): Root directory of dataset where dataset file exist. 116 | train (bool, optional): If True, resample from dataset randomly. 117 | download (bool, optional): If true, downloads the dataset 118 | from the internet and puts it in root directory. 119 | If dataset is already downloaded, it is not downloaded again. 120 | transform (callable, optional): A function/transform that takes in 121 | an PIL image and returns a transformed version. 122 | E.g, ``transforms.RandomCrop`` 123 | """ 124 | 125 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl" 126 | 127 | def __init__(self, root, train=True, transform=None, download=False): 128 | """Init USPS dataset.""" 129 | # init params 130 | self.root = os.path.expanduser(root) 131 | self.filename = "usps_28x28.pkl" 132 | self.train = train 133 | # Num of Train = 7438, Num ot Test 1860 134 | self.transform = transform 135 | self.dataset_size = None 136 | 137 | # download dataset. 138 | if download: 139 | self.download() 140 | if not self._check_exists(): 141 | raise RuntimeError("Dataset not found." + 142 | " You can use download=True to download it") 143 | 144 | self.train_data, self.train_labels = self.load_samples() 145 | if self.train: 146 | total_num_samples = self.train_labels.shape[0] 147 | indices = np.arange(total_num_samples) 148 | self.train_data = self.train_data[indices[0:self.dataset_size], ::] 149 | self.train_labels = self.train_labels[indices[0:self.dataset_size]] 150 | self.train_data *= 255.0 151 | self.train_data = np.squeeze(self.train_data).astype(np.uint8) 152 | 153 | def __getitem__(self, index): 154 | """Get images and target for data loader. 155 | Args: 156 | index (int): Index 157 | Returns: 158 | tuple: (image, target) where target is index of the target class. 159 | """ 160 | img, label = self.train_data[index], self.train_labels[index] 161 | img = Image.fromarray(img, mode='L') 162 | img = img.copy() 163 | if self.transform is not None: 164 | img = self.transform(img) 165 | return img, label.astype("int64"), index 166 | 167 | def __len__(self): 168 | """Return size of dataset.""" 169 | return len(self.train_data) 170 | 171 | def _check_exists(self): 172 | """Check if dataset is download and in right place.""" 173 | return os.path.exists(os.path.join(self.root, self.filename)) 174 | 175 | def download(self): 176 | """Download dataset.""" 177 | filename = os.path.join(self.root, self.filename) 178 | dirname = os.path.dirname(filename) 179 | if not os.path.isdir(dirname): 180 | os.makedirs(dirname) 181 | if os.path.isfile(filename): 182 | return 183 | print("Download %s to %s" % (self.url, os.path.abspath(filename))) 184 | urllib.request.urlretrieve(self.url, filename) 185 | print("[DONE]") 186 | return 187 | 188 | def load_samples(self): 189 | """Load sample images from dataset.""" 190 | filename = os.path.join(self.root, self.filename) 191 | f = gzip.open(filename, "rb") 192 | data_set = pickle.load(f, encoding="bytes") 193 | f.close() 194 | if self.train: 195 | images = data_set[0][0] 196 | labels = data_set[0][1] 197 | self.dataset_size = labels.shape[0] 198 | else: 199 | images = data_set[1][0] 200 | labels = data_set[1][1] 201 | self.dataset_size = labels.shape[0] 202 | return images, labels -------------------------------------------------------------------------------- /digit/data_load/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import hashlib 4 | import gzip 5 | import errno 6 | import tarfile 7 | import zipfile 8 | 9 | import torch 10 | from torch.utils.model_zoo import tqdm 11 | 12 | 13 | def gen_bar_updater(): 14 | pbar = tqdm(total=None) 15 | 16 | def bar_update(count, block_size, total_size): 17 | if pbar.total is None and total_size: 18 | pbar.total = total_size 19 | progress_bytes = count * block_size 20 | pbar.update(progress_bytes - pbar.n) 21 | 22 | return bar_update 23 | 24 | 25 | def calculate_md5(fpath, chunk_size=1024 * 1024): 26 | md5 = hashlib.md5() 27 | with open(fpath, 'rb') as f: 28 | for chunk in iter(lambda: f.read(chunk_size), b''): 29 | md5.update(chunk) 30 | return md5.hexdigest() 31 | 32 | 33 | def check_md5(fpath, md5, **kwargs): 34 | return md5 == calculate_md5(fpath, **kwargs) 35 | 36 | 37 | def check_integrity(fpath, md5=None): 38 | if not os.path.isfile(fpath): 39 | return False 40 | if md5 is None: 41 | return True 42 | return check_md5(fpath, md5) 43 | 44 | 45 | def download_url(url, root, filename=None, md5=None): 46 | """Download a file from a url and place it in root. 47 | 48 | Args: 49 | url (str): URL to download file from 50 | root (str): Directory to place downloaded file in 51 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 52 | md5 (str, optional): MD5 checksum of the download. If None, do not check 53 | """ 54 | import urllib 55 | 56 | root = os.path.expanduser(root) 57 | if not filename: 58 | filename = os.path.basename(url) 59 | fpath = os.path.join(root, filename) 60 | 61 | os.makedirs(root, exist_ok=True) 62 | 63 | # check if file is already present locally 64 | if check_integrity(fpath, md5): 65 | print('Using downloaded and verified file: ' + fpath) 66 | else: # download the file 67 | try: 68 | print('Downloading ' + url + ' to ' + fpath) 69 | urllib.request.urlretrieve( 70 | url, fpath, 71 | reporthook=gen_bar_updater() 72 | ) 73 | except (urllib.error.URLError, IOError) as e: 74 | if url[:5] == 'https': 75 | url = url.replace('https:', 'http:') 76 | print('Failed download. Trying https -> http instead.' 77 | ' Downloading ' + url + ' to ' + fpath) 78 | urllib.request.urlretrieve( 79 | url, fpath, 80 | reporthook=gen_bar_updater() 81 | ) 82 | else: 83 | raise e 84 | # check integrity of downloaded file 85 | if not check_integrity(fpath, md5): 86 | raise RuntimeError("File not found or corrupted.") 87 | 88 | 89 | def list_dir(root, prefix=False): 90 | """List all directories at a given root 91 | 92 | Args: 93 | root (str): Path to directory whose folders need to be listed 94 | prefix (bool, optional): If true, prepends the path to each result, otherwise 95 | only returns the name of the directories found 96 | """ 97 | root = os.path.expanduser(root) 98 | directories = list( 99 | filter( 100 | lambda p: os.path.isdir(os.path.join(root, p)), 101 | os.listdir(root) 102 | ) 103 | ) 104 | 105 | if prefix is True: 106 | directories = [os.path.join(root, d) for d in directories] 107 | 108 | return directories 109 | 110 | 111 | def list_files(root, suffix, prefix=False): 112 | """List all files ending with a suffix at a given root 113 | 114 | Args: 115 | root (str): Path to directory whose folders need to be listed 116 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 117 | It uses the Python "str.endswith" method and is passed directly 118 | prefix (bool, optional): If true, prepends the path to each result, otherwise 119 | only returns the name of the files found 120 | """ 121 | root = os.path.expanduser(root) 122 | files = list( 123 | filter( 124 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 125 | os.listdir(root) 126 | ) 127 | ) 128 | 129 | if prefix is True: 130 | files = [os.path.join(root, d) for d in files] 131 | 132 | return files 133 | 134 | 135 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 136 | """Download a Google Drive file from and place it in root. 137 | 138 | Args: 139 | file_id (str): id of file to be downloaded 140 | root (str): Directory to place downloaded file in 141 | filename (str, optional): Name to save the file under. If None, use the id of the file. 142 | md5 (str, optional): MD5 checksum of the download. If None, do not check 143 | """ 144 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 145 | import requests 146 | url = "https://docs.google.com/uc?export=download" 147 | 148 | root = os.path.expanduser(root) 149 | if not filename: 150 | filename = file_id 151 | fpath = os.path.join(root, filename) 152 | 153 | os.makedirs(root, exist_ok=True) 154 | 155 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 156 | print('Using downloaded and verified file: ' + fpath) 157 | else: 158 | session = requests.Session() 159 | 160 | response = session.get(url, params={'id': file_id}, stream=True) 161 | token = _get_confirm_token(response) 162 | 163 | if token: 164 | params = {'id': file_id, 'confirm': token} 165 | response = session.get(url, params=params, stream=True) 166 | 167 | _save_response_content(response, fpath) 168 | 169 | 170 | def _get_confirm_token(response): 171 | for key, value in response.cookies.items(): 172 | if key.startswith('download_warning'): 173 | return value 174 | 175 | return None 176 | 177 | 178 | def _save_response_content(response, destination, chunk_size=32768): 179 | with open(destination, "wb") as f: 180 | pbar = tqdm(total=None) 181 | progress = 0 182 | for chunk in response.iter_content(chunk_size): 183 | if chunk: # filter out keep-alive new chunks 184 | f.write(chunk) 185 | progress += len(chunk) 186 | pbar.update(progress - pbar.n) 187 | pbar.close() 188 | 189 | 190 | def _is_tarxz(filename): 191 | return filename.endswith(".tar.xz") 192 | 193 | 194 | def _is_tar(filename): 195 | return filename.endswith(".tar") 196 | 197 | 198 | def _is_targz(filename): 199 | return filename.endswith(".tar.gz") 200 | 201 | 202 | def _is_tgz(filename): 203 | return filename.endswith(".tgz") 204 | 205 | 206 | def _is_gzip(filename): 207 | return filename.endswith(".gz") and not filename.endswith(".tar.gz") 208 | 209 | 210 | def _is_zip(filename): 211 | return filename.endswith(".zip") 212 | 213 | 214 | def extract_archive(from_path, to_path=None, remove_finished=False): 215 | if to_path is None: 216 | to_path = os.path.dirname(from_path) 217 | 218 | if _is_tar(from_path): 219 | with tarfile.open(from_path, 'r') as tar: 220 | tar.extractall(path=to_path) 221 | elif _is_targz(from_path) or _is_tgz(from_path): 222 | with tarfile.open(from_path, 'r:gz') as tar: 223 | tar.extractall(path=to_path) 224 | elif _is_tarxz(from_path): 225 | with tarfile.open(from_path, 'r:xz') as tar: 226 | tar.extractall(path=to_path) 227 | elif _is_gzip(from_path): 228 | to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) 229 | with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: 230 | out_f.write(zip_f.read()) 231 | elif _is_zip(from_path): 232 | with zipfile.ZipFile(from_path, 'r') as z: 233 | z.extractall(to_path) 234 | else: 235 | raise ValueError("Extraction of {} not supported".format(from_path)) 236 | 237 | if remove_finished: 238 | os.remove(from_path) 239 | 240 | 241 | def download_and_extract_archive(url, download_root, extract_root=None, filename=None, 242 | md5=None, remove_finished=False): 243 | download_root = os.path.expanduser(download_root) 244 | if extract_root is None: 245 | extract_root = download_root 246 | if not filename: 247 | filename = os.path.basename(url) 248 | 249 | download_url(url, download_root, filename, md5) 250 | 251 | archive = os.path.join(download_root, filename) 252 | print("Extracting {} to {}".format(archive, extract_root)) 253 | extract_archive(archive, extract_root, remove_finished) 254 | 255 | 256 | def iterable_to_str(iterable): 257 | return "'" + "', '".join([str(item) for item in iterable]) + "'" 258 | 259 | 260 | def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None): 261 | if not isinstance(value, torch._six.string_classes): 262 | if arg is None: 263 | msg = "Expected type str, but got type {type}." 264 | else: 265 | msg = "Expected type str for argument {arg}, but got type {type}." 266 | msg = msg.format(type=type(value), arg=arg) 267 | raise ValueError(msg) 268 | 269 | if valid_values is None: 270 | return value 271 | 272 | if value not in valid_values: 273 | if custom_msg is not None: 274 | msg = custom_msg 275 | else: 276 | msg = ("Unknown value '{value}' for argument {arg}. " 277 | "Valid values are {{{valid_values}}}.") 278 | msg = msg.format(value=value, arg=arg, 279 | valid_values=iterable_to_str(valid_values)) 280 | raise ValueError(msg) 281 | 282 | return value -------------------------------------------------------------------------------- /digit/data_load/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class VisionDataset(data.Dataset): 7 | _repr_indent = 4 8 | 9 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 10 | if isinstance(root, torch._six.string_classes): 11 | root = os.path.expanduser(root) 12 | self.root = root 13 | 14 | has_transforms = transforms is not None 15 | has_separate_transform = transform is not None or target_transform is not None 16 | if has_transforms and has_separate_transform: 17 | raise ValueError("Only transforms or transform/target_transform can " 18 | "be passed as argument") 19 | 20 | # for backwards-compatibility 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | 24 | if has_separate_transform: 25 | transforms = StandardTransform(transform, target_transform) 26 | self.transforms = transforms 27 | 28 | def __getitem__(self, index): 29 | raise NotImplementedError 30 | 31 | def __len__(self): 32 | raise NotImplementedError 33 | 34 | def __repr__(self): 35 | head = "Dataset " + self.__class__.__name__ 36 | body = ["Number of datapoints: {}".format(self.__len__())] 37 | if self.root is not None: 38 | body.append("Root location: {}".format(self.root)) 39 | body += self.extra_repr().splitlines() 40 | if hasattr(self, "transforms") and self.transforms is not None: 41 | body += [repr(self.transforms)] 42 | lines = [head] + [" " * self._repr_indent + line for line in body] 43 | return '\n'.join(lines) 44 | 45 | def _format_transform_repr(self, transform, head): 46 | lines = transform.__repr__().splitlines() 47 | return (["{}{}".format(head, lines[0])] + 48 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 49 | 50 | def extra_repr(self): 51 | return "" 52 | 53 | 54 | class StandardTransform(object): 55 | def __init__(self, transform=None, target_transform=None): 56 | self.transform = transform 57 | self.target_transform = target_transform 58 | 59 | def __call__(self, input, target): 60 | if self.transform is not None: 61 | input = self.transform(input) 62 | if self.target_transform is not None: 63 | target = self.target_transform(target) 64 | return input, target 65 | 66 | def _format_transform_repr(self, transform, head): 67 | lines = transform.__repr__().splitlines() 68 | return (["{}{}".format(head, lines[0])] + 69 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 70 | 71 | def __repr__(self): 72 | body = [self.__class__.__name__] 73 | if self.transform is not None: 74 | body += self._format_transform_repr(self.transform, 75 | "Transform: ") 76 | if self.target_transform is not None: 77 | body += self._format_transform_repr(self.target_transform, 78 | "Target transform: ") 79 | 80 | return '\n'.join(body) -------------------------------------------------------------------------------- /digit/digit.sh: -------------------------------------------------------------------------------- 1 | ~/anaconda3/envs/pytorch/bin/python uda_digit.py --dset m2u --gpu_id 0 --cls_par 0.1 --output ckps_digits 2 | ~/anaconda3/envs/pytorch/bin/python uda_digit.py --dset u2m --gpu_id 0 --cls_par 0.1 --output ckps_digits 3 | ~/anaconda3/envs/pytorch/bin/python uda_digit.py --dset s2m --gpu_id 0 --cls_par 0.1 --output ckps_digits -------------------------------------------------------------------------------- /digit/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import math 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | def Entropy(input_): 10 | bs = input_.size(0) 11 | entropy = -input_ * torch.log(input_ + 1e-5) 12 | entropy = torch.sum(entropy, dim=1) 13 | return entropy 14 | 15 | class CrossEntropyLabelSmooth(nn.Module): 16 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, size_average=True): 17 | super(CrossEntropyLabelSmooth, self).__init__() 18 | self.num_classes = num_classes 19 | self.epsilon = epsilon 20 | self.use_gpu = use_gpu 21 | self.size_average = size_average 22 | self.logsoftmax = nn.LogSoftmax(dim=1) 23 | 24 | def forward(self, inputs, targets): 25 | log_probs = self.logsoftmax(inputs) 26 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 27 | if self.use_gpu: targets = targets.cuda() 28 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 29 | if self.size_average: 30 | loss = (- targets * log_probs).mean(0).sum() 31 | else: 32 | loss = (- targets * log_probs).sum(1) 33 | return loss -------------------------------------------------------------------------------- /digit/network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | from torchvision import models 6 | from torch.autograd import Variable 7 | import math 8 | import torch.nn.utils.weight_norm as weightNorm 9 | from collections import OrderedDict 10 | 11 | def init_weights(m): 12 | classname = m.__class__.__name__ 13 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1: 14 | nn.init.kaiming_uniform_(m.weight) 15 | nn.init.zeros_(m.bias) 16 | elif classname.find('BatchNorm') != -1: 17 | nn.init.normal_(m.weight, 1.0, 0.02) 18 | nn.init.zeros_(m.bias) 19 | elif classname.find('Linear') != -1: 20 | nn.init.xavier_normal_(m.weight) 21 | nn.init.zeros_(m.bias) 22 | 23 | class feat_bottleneck(nn.Module): 24 | def __init__(self, feature_dim, bottleneck_dim=256, type="ori"): 25 | super(feat_bottleneck, self).__init__() 26 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True) 27 | self.dropout = nn.Dropout(p=0.5) 28 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim) 29 | self.bottleneck.apply(init_weights) 30 | self.type = type 31 | 32 | def forward(self, x): 33 | x = self.bottleneck(x) 34 | if self.type == "bn": 35 | x = self.bn(x) 36 | x = self.dropout(x) 37 | return x 38 | 39 | class feat_classifier(nn.Module): 40 | def __init__(self, class_num, bottleneck_dim=256, type="linear"): 41 | super(feat_classifier, self).__init__() 42 | if type == "linear": 43 | self.fc = nn.Linear(bottleneck_dim, class_num) 44 | else: 45 | self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight") 46 | self.fc.apply(init_weights) 47 | 48 | def forward(self, x): 49 | x = self.fc(x) 50 | return x 51 | 52 | class DTNBase(nn.Module): 53 | def __init__(self): 54 | super(DTNBase, self).__init__() 55 | self.conv_params = nn.Sequential( 56 | nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2), 57 | nn.BatchNorm2d(64), 58 | nn.Dropout2d(0.1), 59 | nn.ReLU(), 60 | nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2), 61 | nn.BatchNorm2d(128), 62 | nn.Dropout2d(0.3), 63 | nn.ReLU(), 64 | nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2), 65 | nn.BatchNorm2d(256), 66 | nn.Dropout2d(0.5), 67 | nn.ReLU() 68 | ) 69 | self.in_features = 256*4*4 70 | 71 | def forward(self, x): 72 | x = self.conv_params(x) 73 | x = x.view(x.size(0), -1) 74 | return x 75 | 76 | class LeNetBase(nn.Module): 77 | def __init__(self): 78 | super(LeNetBase, self).__init__() 79 | self.conv_params = nn.Sequential( 80 | nn.Conv2d(1, 20, kernel_size=5), 81 | nn.MaxPool2d(2), 82 | nn.ReLU(), 83 | nn.Conv2d(20, 50, kernel_size=5), 84 | nn.Dropout2d(p=0.5), 85 | nn.MaxPool2d(2), 86 | nn.ReLU(), 87 | ) 88 | self.in_features = 50*4*4 89 | 90 | def forward(self, x): 91 | x = self.conv_params(x) 92 | x = x.view(x.size(0), -1) 93 | return x -------------------------------------------------------------------------------- /digit/uda_digit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | import random, pdb, math, copy 13 | from tqdm import tqdm 14 | from scipy.spatial.distance import cdist 15 | import pickle 16 | from data_load import mnist, svhn, usps 17 | 18 | def op_copy(optimizer): 19 | for param_group in optimizer.param_groups: 20 | param_group['lr0'] = param_group['lr'] 21 | return optimizer 22 | 23 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 24 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 25 | for param_group in optimizer.param_groups: 26 | param_group['lr'] = param_group['lr0'] * decay 27 | param_group['weight_decay'] = 1e-3 28 | param_group['momentum'] = 0.9 29 | param_group['nesterov'] = True 30 | return optimizer 31 | 32 | def digit_load(args): 33 | train_bs = args.batch_size 34 | if args.dset == 's2m': 35 | train_source = svhn.SVHN('./data/svhn/', split='train', download=True, 36 | transform=transforms.Compose([ 37 | transforms.Resize(32), 38 | transforms.ToTensor(), 39 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 40 | ])) 41 | test_source = svhn.SVHN('./data/svhn/', split='test', download=True, 42 | transform=transforms.Compose([ 43 | transforms.Resize(32), 44 | transforms.ToTensor(), 45 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 46 | ])) 47 | train_target = mnist.MNIST_idx('./data/mnist/', train=True, download=True, 48 | transform=transforms.Compose([ 49 | transforms.Resize(32), 50 | transforms.Lambda(lambda x: x.convert("RGB")), 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 53 | ])) 54 | test_target = mnist.MNIST('./data/mnist/', train=False, download=True, 55 | transform=transforms.Compose([ 56 | transforms.Resize(32), 57 | transforms.Lambda(lambda x: x.convert("RGB")), 58 | transforms.ToTensor(), 59 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 60 | ])) 61 | elif args.dset == 'u2m': 62 | train_source = usps.USPS('./data/usps/', train=True, download=True, 63 | transform=transforms.Compose([ 64 | transforms.RandomCrop(28, padding=4), 65 | transforms.RandomRotation(10), 66 | transforms.ToTensor(), 67 | transforms.Normalize((0.5,), (0.5,)) 68 | ])) 69 | test_source = usps.USPS('./data/usps/', train=False, download=True, 70 | transform=transforms.Compose([ 71 | transforms.RandomCrop(28, padding=4), 72 | transforms.RandomRotation(10), 73 | transforms.ToTensor(), 74 | transforms.Normalize((0.5,), (0.5,)) 75 | ])) 76 | train_target = mnist.MNIST_idx('./data/mnist/', train=True, download=True, 77 | transform=transforms.Compose([ 78 | transforms.ToTensor(), 79 | transforms.Normalize((0.5,), (0.5,)) 80 | ])) 81 | test_target = mnist.MNIST('./data/mnist/', train=False, download=True, 82 | transform=transforms.Compose([ 83 | transforms.ToTensor(), 84 | transforms.Normalize((0.5,), (0.5,)) 85 | ])) 86 | elif args.dset == 'm2u': 87 | train_source = mnist.MNIST('./data/mnist/', train=True, download=True, 88 | transform=transforms.Compose([ 89 | transforms.ToTensor(), 90 | transforms.Normalize((0.5,), (0.5,)) 91 | ])) 92 | test_source = mnist.MNIST('./data/mnist/', train=False, download=True, 93 | transform=transforms.Compose([ 94 | transforms.ToTensor(), 95 | transforms.Normalize((0.5,), (0.5,)) 96 | ])) 97 | 98 | train_target = usps.USPS_idx('./data/usps/', train=True, download=True, 99 | transform=transforms.Compose([ 100 | transforms.ToTensor(), 101 | transforms.Normalize((0.5,), (0.5,)) 102 | ])) 103 | test_target = usps.USPS('./data/usps/', train=False, download=True, 104 | transform=transforms.Compose([ 105 | transforms.ToTensor(), 106 | transforms.Normalize((0.5,), (0.5,)) 107 | ])) 108 | 109 | dset_loaders = {} 110 | dset_loaders["source_tr"] = DataLoader(train_source, batch_size=train_bs, shuffle=True, 111 | num_workers=args.worker, drop_last=False) 112 | dset_loaders["source_te"] = DataLoader(test_source, batch_size=train_bs*2, shuffle=True, 113 | num_workers=args.worker, drop_last=False) 114 | dset_loaders["target"] = DataLoader(train_target, batch_size=train_bs, shuffle=True, 115 | num_workers=args.worker, drop_last=False) 116 | dset_loaders["target_te"] = DataLoader(train_target, batch_size=train_bs, shuffle=False, 117 | num_workers=args.worker, drop_last=False) 118 | dset_loaders["test"] = DataLoader(test_target, batch_size=train_bs*2, shuffle=False, 119 | num_workers=args.worker, drop_last=False) 120 | return dset_loaders 121 | 122 | def cal_acc(loader, netF, netB, netC): 123 | start_test = True 124 | with torch.no_grad(): 125 | iter_test = iter(loader) 126 | for i in range(len(loader)): 127 | data = iter_test.next() 128 | inputs = data[0] 129 | labels = data[1] 130 | inputs = inputs.cuda() 131 | outputs = netC(netB(netF(inputs))) 132 | if start_test: 133 | all_output = outputs.float().cpu() 134 | all_label = labels.float() 135 | start_test = False 136 | else: 137 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 138 | all_label = torch.cat((all_label, labels.float()), 0) 139 | _, predict = torch.max(all_output, 1) 140 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 141 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item() 142 | return accuracy*100, mean_ent 143 | 144 | def train_source(args): 145 | dset_loaders = digit_load(args) 146 | ## set base network 147 | if args.dset == 'u2m': 148 | netF = network.LeNetBase().cuda() 149 | elif args.dset == 'm2u': 150 | netF = network.LeNetBase().cuda() 151 | elif args.dset == 's2m': 152 | netF = network.DTNBase().cuda() 153 | 154 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 155 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 156 | 157 | param_group = [] 158 | learning_rate = args.lr 159 | for k, v in netF.named_parameters(): 160 | param_group += [{'params': v, 'lr': learning_rate}] 161 | for k, v in netB.named_parameters(): 162 | param_group += [{'params': v, 'lr': learning_rate}] 163 | for k, v in netC.named_parameters(): 164 | param_group += [{'params': v, 'lr': learning_rate}] 165 | 166 | optimizer = optim.SGD(param_group) 167 | optimizer = op_copy(optimizer) 168 | 169 | acc_init = 0 170 | max_iter = args.max_epoch * len(dset_loaders["source_tr"]) 171 | interval_iter = max_iter // 10 172 | iter_num = 0 173 | 174 | netF.train() 175 | netB.train() 176 | netC.train() 177 | 178 | while iter_num < max_iter: 179 | try: 180 | inputs_source, labels_source = iter_source.next() 181 | except: 182 | iter_source = iter(dset_loaders["source_tr"]) 183 | inputs_source, labels_source = iter_source.next() 184 | 185 | if inputs_source.size(0) == 1: 186 | continue 187 | 188 | iter_num += 1 189 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 190 | 191 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda() 192 | outputs_source = netC(netB(netF(inputs_source))) 193 | classifier_loss = loss.CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source) 194 | optimizer.zero_grad() 195 | classifier_loss.backward() 196 | optimizer.step() 197 | 198 | if iter_num % interval_iter == 0 or iter_num == max_iter: 199 | netF.eval() 200 | netB.eval() 201 | netC.eval() 202 | acc_s_tr, _ = cal_acc(dset_loaders['source_tr'], netF, netB, netC) 203 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC) 204 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%/ {:.2f}%'.format(args.dset, iter_num, max_iter, acc_s_tr, acc_s_te) 205 | args.out_file.write(log_str + '\n') 206 | args.out_file.flush() 207 | print(log_str+'\n') 208 | 209 | if acc_s_te >= acc_init: 210 | acc_init = acc_s_te 211 | best_netF = copy.deepcopy(netF.state_dict()) 212 | best_netB = copy.deepcopy(netB.state_dict()) 213 | best_netC = copy.deepcopy(netC.state_dict()) 214 | 215 | netF.train() 216 | netB.train() 217 | netC.train() 218 | 219 | torch.save(best_netF, osp.join(args.output_dir, "source_F.pt")) 220 | torch.save(best_netB, osp.join(args.output_dir, "source_B.pt")) 221 | torch.save(best_netC, osp.join(args.output_dir, "source_C.pt")) 222 | 223 | return netF, netB, netC 224 | 225 | def test_target(args): 226 | dset_loaders = digit_load(args) 227 | ## set base network 228 | if args.dset == 'u2m': 229 | netF = network.LeNetBase().cuda() 230 | elif args.dset == 'm2u': 231 | netF = network.LeNetBase().cuda() 232 | elif args.dset == 's2m': 233 | netF = network.DTNBase().cuda() 234 | 235 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 236 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 237 | 238 | args.modelpath = args.output_dir + '/source_F.pt' 239 | netF.load_state_dict(torch.load(args.modelpath)) 240 | args.modelpath = args.output_dir + '/source_B.pt' 241 | netB.load_state_dict(torch.load(args.modelpath)) 242 | args.modelpath = args.output_dir + '/source_C.pt' 243 | netC.load_state_dict(torch.load(args.modelpath)) 244 | netF.eval() 245 | netB.eval() 246 | netC.eval() 247 | 248 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC) 249 | log_str = 'Task: {}, Accuracy = {:.2f}%'.format(args.dset, acc) 250 | args.out_file.write(log_str + '\n') 251 | args.out_file.flush() 252 | print(log_str+'\n') 253 | 254 | def print_args(args): 255 | s = "==========================================\n" 256 | for arg, content in args.__dict__.items(): 257 | s += "{}:{}\n".format(arg, content) 258 | return s 259 | 260 | def train_target(args): 261 | dset_loaders = digit_load(args) 262 | ## set base network 263 | if args.dset == 'u2m': 264 | netF = network.LeNetBase().cuda() 265 | elif args.dset == 'm2u': 266 | netF = network.LeNetBase().cuda() 267 | elif args.dset == 's2m': 268 | netF = network.DTNBase().cuda() 269 | 270 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 271 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 272 | 273 | args.modelpath = args.output_dir + '/source_F.pt' 274 | netF.load_state_dict(torch.load(args.modelpath)) 275 | args.modelpath = args.output_dir + '/source_B.pt' 276 | netB.load_state_dict(torch.load(args.modelpath)) 277 | args.modelpath = args.output_dir + '/source_C.pt' 278 | netC.load_state_dict(torch.load(args.modelpath)) 279 | netC.eval() 280 | for k, v in netC.named_parameters(): 281 | v.requires_grad = False 282 | 283 | param_group = [] 284 | for k, v in netF.named_parameters(): 285 | param_group += [{'params': v, 'lr': args.lr}] 286 | for k, v in netB.named_parameters(): 287 | param_group += [{'params': v, 'lr': args.lr}] 288 | 289 | optimizer = optim.SGD(param_group) 290 | optimizer = op_copy(optimizer) 291 | 292 | max_iter = args.max_epoch * len(dset_loaders["target"]) 293 | interval_iter = len(dset_loaders["target"]) 294 | # interval_iter = max_iter // args.interval 295 | iter_num = 0 296 | 297 | while iter_num < max_iter: 298 | optimizer.zero_grad() 299 | try: 300 | inputs_test, _, tar_idx = iter_test.next() 301 | except: 302 | iter_test = iter(dset_loaders["target"]) 303 | inputs_test, _, tar_idx = iter_test.next() 304 | 305 | if inputs_test.size(0) == 1: 306 | continue 307 | 308 | if iter_num % interval_iter == 0 and args.cls_par > 0: 309 | netF.eval() 310 | netB.eval() 311 | mem_label = obtain_label(dset_loaders['target_te'], netF, netB, netC, args) 312 | mem_label = torch.from_numpy(mem_label).cuda() 313 | netF.train() 314 | netB.train() 315 | 316 | iter_num += 1 317 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 318 | 319 | inputs_test = inputs_test.cuda() 320 | features_test = netB(netF(inputs_test)) 321 | outputs_test = netC(features_test) 322 | 323 | if args.cls_par > 0: 324 | pred = mem_label[tar_idx] 325 | classifier_loss = args.cls_par * nn.CrossEntropyLoss()(outputs_test, pred) 326 | else: 327 | classifier_loss = torch.tensor(0.0).cuda() 328 | 329 | if args.ent: 330 | softmax_out = nn.Softmax(dim=1)(outputs_test) 331 | entropy_loss = torch.mean(loss.Entropy(softmax_out)) 332 | if args.gent: 333 | msoftmax = softmax_out.mean(dim=0) 334 | entropy_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) 335 | 336 | im_loss = entropy_loss * args.ent_par 337 | classifier_loss += im_loss 338 | 339 | optimizer.zero_grad() 340 | classifier_loss.backward() 341 | optimizer.step() 342 | 343 | if iter_num % interval_iter == 0 or iter_num == max_iter: 344 | netF.eval() 345 | netB.eval() 346 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC) 347 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.dset, iter_num, max_iter, acc) 348 | args.out_file.write(log_str + '\n') 349 | args.out_file.flush() 350 | print(log_str+'\n') 351 | netF.train() 352 | netB.train() 353 | 354 | if args.issave: 355 | torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) 356 | torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) 357 | torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) 358 | 359 | return netF, netB, netC 360 | 361 | def obtain_label(loader, netF, netB, netC, args, c=None): 362 | start_test = True 363 | with torch.no_grad(): 364 | iter_test = iter(loader) 365 | for _ in range(len(loader)): 366 | data = iter_test.next() 367 | inputs = data[0] 368 | labels = data[1] 369 | inputs = inputs.cuda() 370 | feas = netB(netF(inputs)) 371 | outputs = netC(feas) 372 | if start_test: 373 | all_fea = feas.float().cpu() 374 | all_output = outputs.float().cpu() 375 | all_label = labels.float() 376 | start_test = False 377 | else: 378 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0) 379 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 380 | all_label = torch.cat((all_label, labels.float()), 0) 381 | all_output = nn.Softmax(dim=1)(all_output) 382 | _, predict = torch.max(all_output, 1) 383 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 384 | 385 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) 386 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() 387 | all_fea = all_fea.float().cpu().numpy() 388 | 389 | K = all_output.size(1) 390 | aff = all_output.float().cpu().numpy() 391 | initc = aff.transpose().dot(all_fea) 392 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 393 | dd = cdist(all_fea, initc, 'cosine') 394 | pred_label = dd.argmin(axis=1) 395 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 396 | 397 | for round in range(1): 398 | aff = np.eye(K)[pred_label] 399 | initc = aff.transpose().dot(all_fea) 400 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 401 | dd = cdist(all_fea, initc, 'cosine') 402 | pred_label = dd.argmin(axis=1) 403 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 404 | 405 | log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy*100, acc*100) 406 | args.out_file.write(log_str + '\n') 407 | args.out_file.flush() 408 | print(log_str+'\n') 409 | return pred_label.astype('int') 410 | 411 | if __name__ == "__main__": 412 | parser = argparse.ArgumentParser(description='SHOT') 413 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 414 | parser.add_argument('--s', type=int, default=0, help="source") 415 | parser.add_argument('--t', type=int, default=1, help="target") 416 | parser.add_argument('--max_epoch', type=int, default=30, help="maximum epoch") 417 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 418 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 419 | parser.add_argument('--dset', type=str, default='s2m', choices=['u2m', 'm2u','s2m']) 420 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate") 421 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 422 | parser.add_argument('--cls_par', type=float, default=0.3) 423 | parser.add_argument('--ent_par', type=float, default=1.0) 424 | parser.add_argument('--gent', type=bool, default=True) 425 | parser.add_argument('--ent', type=bool, default=True) 426 | parser.add_argument('--bottleneck', type=int, default=256) 427 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 428 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 429 | parser.add_argument('--smooth', type=float, default=0.1) 430 | parser.add_argument('--output', type=str, default='') 431 | parser.add_argument('--issave', type=bool, default=True) 432 | args = parser.parse_args() 433 | args.class_num = 10 434 | 435 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 436 | SEED = args.seed 437 | torch.manual_seed(SEED) 438 | torch.cuda.manual_seed(SEED) 439 | np.random.seed(SEED) 440 | random.seed(SEED) 441 | # torch.backends.cudnn.deterministic = True 442 | 443 | args.output_dir = osp.join(args.output, 'seed' + str(args.seed), args.dset) 444 | if not osp.exists(args.output_dir): 445 | os.system('mkdir -p ' + args.output_dir) 446 | if not osp.exists(args.output_dir): 447 | os.mkdir(args.output_dir) 448 | 449 | if not osp.exists(osp.join(args.output_dir + '/source_F.pt')): 450 | args.out_file = open(osp.join(args.output_dir, 'log_src.txt'), 'w') 451 | args.out_file.write(print_args(args)+'\n') 452 | args.out_file.flush() 453 | train_source(args) 454 | test_target(args) 455 | 456 | args.savename = 'par_' + str(args.cls_par) 457 | args.out_file = open(osp.join(args.output_dir, 'log_tar_' + args.savename + '.txt'), 'w') 458 | args.out_file.write(print_args(args)+'\n') 459 | args.out_file.flush() 460 | train_target(args) 461 | -------------------------------------------------------------------------------- /figs/shot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT/f7d555a0d53b525b885e5ef2a887a267a5be3c36/figs/shot.jpg -------------------------------------------------------------------------------- /object/data_list.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | import os 7 | import os.path 8 | import cv2 9 | import torchvision 10 | 11 | def make_dataset(image_list, labels): 12 | if labels: 13 | len_ = len(image_list) 14 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)] 15 | else: 16 | if len(image_list[0].split()) > 2: 17 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list] 18 | else: 19 | images = [(val.split()[0], int(val.split()[1])) for val in image_list] 20 | return images 21 | 22 | 23 | def rgb_loader(path): 24 | with open(path, 'rb') as f: 25 | with Image.open(f) as img: 26 | return img.convert('RGB') 27 | 28 | def l_loader(path): 29 | with open(path, 'rb') as f: 30 | with Image.open(f) as img: 31 | return img.convert('L') 32 | 33 | class ImageList(Dataset): 34 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 35 | imgs = make_dataset(image_list, labels) 36 | if len(imgs) == 0: 37 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 38 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 39 | 40 | self.imgs = imgs 41 | self.transform = transform 42 | self.target_transform = target_transform 43 | if mode == 'RGB': 44 | self.loader = rgb_loader 45 | elif mode == 'L': 46 | self.loader = l_loader 47 | 48 | def __getitem__(self, index): 49 | path, target = self.imgs[index] 50 | img = self.loader(path) 51 | if self.transform is not None: 52 | img = self.transform(img) 53 | if self.target_transform is not None: 54 | target = self.target_transform(target) 55 | 56 | return img, target 57 | 58 | def __len__(self): 59 | return len(self.imgs) 60 | 61 | class ImageList_idx(Dataset): 62 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 63 | imgs = make_dataset(image_list, labels) 64 | if len(imgs) == 0: 65 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 66 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 67 | 68 | self.imgs = imgs 69 | self.transform = transform 70 | self.target_transform = target_transform 71 | if mode == 'RGB': 72 | self.loader = rgb_loader 73 | elif mode == 'L': 74 | self.loader = l_loader 75 | 76 | def __getitem__(self, index): 77 | path, target = self.imgs[index] 78 | img = self.loader(path) 79 | if self.transform is not None: 80 | img = self.transform(img) 81 | if self.target_transform is not None: 82 | target = self.target_transform(target) 83 | 84 | return img, target, index 85 | 86 | def __len__(self): 87 | return len(self.imgs) -------------------------------------------------------------------------------- /object/image_multisource.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from sklearn.metrics import confusion_matrix 16 | 17 | def image_train(resize_size=256, crop_size=224, alexnet=False): 18 | if not alexnet: 19 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 20 | std=[0.229, 0.224, 0.225]) 21 | else: 22 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 23 | return transforms.Compose([ 24 | transforms.Resize((resize_size, resize_size)), 25 | transforms.RandomCrop(crop_size), 26 | transforms.RandomHorizontalFlip(), 27 | transforms.ToTensor(), 28 | normalize 29 | ]) 30 | 31 | def image_test(resize_size=256, crop_size=224, alexnet=False): 32 | if not alexnet: 33 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 34 | std=[0.229, 0.224, 0.225]) 35 | else: 36 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 37 | return transforms.Compose([ 38 | transforms.Resize((resize_size, resize_size)), 39 | transforms.CenterCrop(crop_size), 40 | transforms.ToTensor(), 41 | normalize 42 | ]) 43 | 44 | def data_load(args): 45 | ## prepare data 46 | dsets = {} 47 | dset_loaders = {} 48 | train_bs = args.batch_size 49 | txt_tar = open(args.t_dset_path).readlines() 50 | txt_test = open(args.test_dset_path).readlines() 51 | 52 | if not args.da == 'uda': 53 | label_map_s = {} 54 | for i in range(len(args.src_classes)): 55 | label_map_s[args.src_classes[i]] = i 56 | 57 | new_tar = [] 58 | for i in range(len(txt_tar)): 59 | rec = txt_tar[i] 60 | reci = rec.strip().split(' ') 61 | if int(reci[1]) in args.tar_classes: 62 | if int(reci[1]) in args.src_classes: 63 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 64 | new_tar.append(line) 65 | else: 66 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n' 67 | new_tar.append(line) 68 | txt_tar = new_tar.copy() 69 | txt_test = txt_tar.copy() 70 | 71 | dsets["target"] = ImageList(txt_tar, transform=image_test()) 72 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 73 | dsets["test"] = ImageList(txt_test, transform=image_test()) 74 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False) 75 | 76 | return dset_loaders 77 | 78 | def cal_acc(loader, netF, netB, netC, flag=False): 79 | start_test = True 80 | with torch.no_grad(): 81 | iter_test = iter(loader) 82 | for i in range(len(loader)): 83 | data = iter_test.next() 84 | inputs = data[0] 85 | labels = data[1] 86 | inputs = inputs.cuda() 87 | outputs = netC(netB(netF(inputs))) 88 | if start_test: 89 | all_output = outputs.float().cpu() 90 | all_label = labels.float() 91 | start_test = False 92 | else: 93 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 94 | all_label = torch.cat((all_label, labels.float()), 0) 95 | _, predict = torch.max(all_output, 1) 96 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 97 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item() 98 | 99 | return accuracy, all_label, nn.Softmax(dim=1)(all_output) 100 | 101 | def print_args(args): 102 | s = "==========================================\n" 103 | for arg, content in args.__dict__.items(): 104 | s += "{}:{}\n".format(arg, content) 105 | return s 106 | 107 | def test_target_srconly(args): 108 | dset_loaders = data_load(args) 109 | ## set base network 110 | if args.net[0:3] == 'res': 111 | netF = network.ResBase(res_name=args.net).cuda() 112 | elif args.net[0:3] == 'vgg': 113 | netF = network.VGGBase(vgg_name=args.net).cuda() 114 | 115 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 116 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 117 | 118 | args.modelpath = args.output_dir_src + '/source_F.pt' 119 | netF.load_state_dict(torch.load(args.modelpath)) 120 | args.modelpath = args.output_dir_src + '/source_B.pt' 121 | netB.load_state_dict(torch.load(args.modelpath)) 122 | args.modelpath = args.output_dir_src + '/source_C.pt' 123 | netC.load_state_dict(torch.load(args.modelpath)) 124 | netF.eval() 125 | netB.eval() 126 | netC.eval() 127 | 128 | acc, y, py = cal_acc(dset_loaders['test'], netF, netB, netC) 129 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc*100) 130 | args.out_file.write(log_str) 131 | args.out_file.flush() 132 | print(log_str) 133 | 134 | return y, py 135 | 136 | def test_target(args): 137 | dset_loaders = data_load(args) 138 | ## set base network 139 | if args.net[0:3] == 'res': 140 | netF = network.ResBase(res_name=args.net).cuda() 141 | elif args.net[0:3] == 'vgg': 142 | netF = network.VGGBase(vgg_name=args.net).cuda() 143 | 144 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 145 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 146 | 147 | args.modelpath = args.output_dir_ori + "/target_F_" + args.savename + ".pt" 148 | netF.load_state_dict(torch.load(args.modelpath)) 149 | args.modelpath = args.output_dir_ori + "/target_B_" + args.savename + ".pt" 150 | netB.load_state_dict(torch.load(args.modelpath)) 151 | args.modelpath = args.output_dir_ori + "/target_C_" + args.savename + ".pt" 152 | netC.load_state_dict(torch.load(args.modelpath)) 153 | netF.eval() 154 | netB.eval() 155 | netC.eval() 156 | 157 | acc, y, py = cal_acc(dset_loaders['test'], netF, netB, netC) 158 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc*100) 159 | args.out_file.write(log_str) 160 | args.out_file.flush() 161 | print(log_str) 162 | 163 | return y, py 164 | 165 | if __name__ == "__main__": 166 | parser = argparse.ArgumentParser(description='SHOT') 167 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 168 | parser.add_argument('--s', type=int, default=0, help="source") 169 | parser.add_argument('--t', type=int, default=1, help="target") 170 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 171 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 172 | parser.add_argument('--dset', type=str, default='office-caltech', choices=['office-caltech']) 173 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 174 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101") 175 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 176 | 177 | parser.add_argument('--threshold', type=int, default=0) 178 | parser.add_argument('--cls_par', type=float, default=0.3) 179 | parser.add_argument('--bottleneck', type=int, default=256) 180 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 181 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 182 | parser.add_argument('--output', type=str, default='san') 183 | parser.add_argument('--output_src', type=str, default='ckps') 184 | parser.add_argument('--da', type=str, default='uda', choices=['uda']) 185 | args = parser.parse_args() 186 | 187 | if args.dset == 'office-caltech': 188 | names = ['amazon', 'caltech', 'dslr', 'webcam'] 189 | args.class_num = 10 190 | 191 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 192 | SEED = args.seed 193 | torch.manual_seed(SEED) 194 | torch.cuda.manual_seed(SEED) 195 | np.random.seed(SEED) 196 | random.seed(SEED) 197 | # torch.backends.cudnn.deterministic = True 198 | 199 | score_srconly = 0 200 | score = 0 201 | 202 | args.output_dir = osp.join(args.output, args.da, args.dset, str(0)+names[args.t][0].upper()) 203 | if not osp.exists(args.output_dir): 204 | os.system('mkdir -p ' + args.output_dir) 205 | if not osp.exists(args.output_dir): 206 | os.mkdir(args.output_dir) 207 | 208 | args.savename = 'par_' + str(args.cls_par) 209 | args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w') 210 | args.out_file.write(print_args(args)+'\n') 211 | args.out_file.flush() 212 | 213 | for i in range(len(names)): 214 | if i == args.t: 215 | continue 216 | args.s = i 217 | 218 | folder = './data/' 219 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 220 | args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 221 | args.test_dset_path = args.t_dset_path 222 | 223 | args.output_dir_src = osp.join(args.output_src, args.da, args.dset, names[args.s][0].upper()) 224 | args.output_dir_ori = osp.join(args.output, args.da, args.dset, names[args.s][0].upper()+names[args.t][0].upper()) 225 | args.name = names[args.s][0].upper() + names[args.t][0].upper() 226 | 227 | label, output_srconly = test_target_srconly(args) 228 | score_srconly += output_srconly 229 | 230 | _, output = test_target(args) 231 | score += output 232 | 233 | _, predict = torch.max(score_srconly, 1) 234 | acc = torch.sum(torch.squeeze(predict).float() == label).item() / float(label.size()[0]) 235 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format('->' + names[args.t][0].upper(), acc*100) 236 | args.out_file.write(log_str) 237 | args.out_file.flush() 238 | print(log_str) 239 | 240 | _, predict = torch.max(score, 1) 241 | acc = torch.sum(torch.squeeze(predict).float() == label).item() / float(label.size()[0]) 242 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format('->' + names[args.t][0].upper(), acc*100) 243 | args.out_file.write(log_str) 244 | args.out_file.flush() 245 | print(log_str) -------------------------------------------------------------------------------- /object/image_multitarget.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList, ImageList_idx 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from scipy.spatial.distance import cdist 16 | from sklearn.metrics import confusion_matrix 17 | import rotation 18 | 19 | def op_copy(optimizer): 20 | for param_group in optimizer.param_groups: 21 | param_group['lr0'] = param_group['lr'] 22 | return optimizer 23 | 24 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 25 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 26 | for param_group in optimizer.param_groups: 27 | param_group['lr'] = param_group['lr0'] * decay 28 | param_group['weight_decay'] = 1e-3 29 | param_group['momentum'] = 0.9 30 | param_group['nesterov'] = True 31 | return optimizer 32 | 33 | def image_train(resize_size=256, crop_size=224, alexnet=False): 34 | if not alexnet: 35 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 36 | std=[0.229, 0.224, 0.225]) 37 | else: 38 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 39 | return transforms.Compose([ 40 | transforms.Resize((resize_size, resize_size)), 41 | transforms.RandomCrop(crop_size), 42 | transforms.RandomHorizontalFlip(), 43 | transforms.ToTensor(), 44 | normalize 45 | ]) 46 | 47 | def image_test(resize_size=256, crop_size=224, alexnet=False): 48 | if not alexnet: 49 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 50 | std=[0.229, 0.224, 0.225]) 51 | else: 52 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 53 | return transforms.Compose([ 54 | transforms.Resize((resize_size, resize_size)), 55 | transforms.CenterCrop(crop_size), 56 | transforms.ToTensor(), 57 | normalize 58 | ]) 59 | 60 | def data_load(args): 61 | ## prepare data 62 | dsets = {} 63 | dset_loaders = {} 64 | train_bs = args.batch_size 65 | txt_src = open(args.s_dset_path).readlines() 66 | 67 | txt_tar = [] 68 | for i in range(len(args.t_dset_path)): 69 | tmp = open(args.t_dset_path[i]).readlines() 70 | txt_tar.extend(tmp) 71 | txt_test = txt_tar.copy() 72 | 73 | dsets["target"] = ImageList_idx(txt_tar, transform=image_train()) 74 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 75 | dsets["test"] = ImageList(txt_test, transform=image_test()) 76 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=False, num_workers=args.worker, drop_last=False) 77 | 78 | return dset_loaders 79 | 80 | def cal_acc(loader, netF, netB, netC, flag=False): 81 | start_test = True 82 | with torch.no_grad(): 83 | iter_test = iter(loader) 84 | for i in range(len(loader)): 85 | data = iter_test.next() 86 | inputs = data[0] 87 | labels = data[1] 88 | inputs = inputs.cuda() 89 | outputs = netC(netB(netF(inputs))) 90 | if start_test: 91 | all_output = outputs.float().cpu() 92 | all_label = labels.float() 93 | start_test = False 94 | else: 95 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 96 | all_label = torch.cat((all_label, labels.float()), 0) 97 | _, predict = torch.max(all_output, 1) 98 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 99 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item() 100 | 101 | if flag: 102 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 103 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100 104 | aacc = acc.mean() 105 | aa = [str(np.round(i, 2)) for i in acc] 106 | acc = ' '.join(aa) 107 | return aacc, acc 108 | else: 109 | return accuracy*100, mean_ent 110 | 111 | def print_args(args): 112 | s = "==========================================\n" 113 | for arg, content in args.__dict__.items(): 114 | s += "{}:{}\n".format(arg, content) 115 | return s 116 | 117 | def train_target(args): 118 | dset_loaders = data_load(args) 119 | ## set base network 120 | if args.net[0:3] == 'res': 121 | netF = network.ResBase(res_name=args.net).cuda() 122 | elif args.net[0:3] == 'vgg': 123 | netF = network.VGGBase(vgg_name=args.net).cuda() 124 | 125 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 126 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 127 | 128 | args.modelpath = args.output_dir_src + '/source_F.pt' 129 | netF.load_state_dict(torch.load(args.modelpath)) 130 | args.modelpath = args.output_dir_src + '/source_B.pt' 131 | netB.load_state_dict(torch.load(args.modelpath)) 132 | args.modelpath = args.output_dir_src + '/source_C.pt' 133 | netC.load_state_dict(torch.load(args.modelpath)) 134 | netC.eval() 135 | for k, v in netC.named_parameters(): 136 | v.requires_grad = False 137 | 138 | param_group = [] 139 | for k, v in netF.named_parameters(): 140 | if args.lr_decay1 > 0: 141 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] 142 | else: 143 | v.requires_grad = False 144 | for k, v in netB.named_parameters(): 145 | if args.lr_decay2 > 0: 146 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] 147 | else: 148 | v.requires_grad = False 149 | optimizer = optim.SGD(param_group) 150 | optimizer = op_copy(optimizer) 151 | 152 | max_iter = args.max_epoch * len(dset_loaders["target"]) 153 | interval_iter = max_iter // args.interval 154 | iter_num = 0 155 | 156 | while iter_num < max_iter: 157 | try: 158 | inputs_test, _, tar_idx = iter_test.next() 159 | except: 160 | iter_test = iter(dset_loaders["target"]) 161 | inputs_test, _, tar_idx = iter_test.next() 162 | 163 | if inputs_test.size(0) == 1: 164 | continue 165 | 166 | if iter_num % interval_iter == 0 and args.cls_par > 0: 167 | netF.eval() 168 | netB.eval() 169 | mem_label = obtain_label(dset_loaders['test'], netF, netB, netC, args) 170 | mem_label = torch.from_numpy(mem_label).cuda() 171 | netF.train() 172 | netB.train() 173 | 174 | inputs_test = inputs_test.cuda() 175 | 176 | iter_num += 1 177 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 178 | 179 | features_test = netB(netF(inputs_test)) 180 | outputs_test = netC(features_test) 181 | 182 | if args.cls_par > 0: 183 | pred = mem_label[tar_idx] 184 | classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred) 185 | classifier_loss *= args.cls_par 186 | else: 187 | classifier_loss = torch.tensor(0.0).cuda() 188 | 189 | if args.ent: 190 | softmax_out = nn.Softmax(dim=1)(outputs_test) 191 | entropy_loss = torch.mean(loss.Entropy(softmax_out)) 192 | if args.gent: 193 | msoftmax = softmax_out.mean(dim=0) 194 | gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) 195 | entropy_loss -= gentropy_loss 196 | classifier_loss += entropy_loss * args.ent_par 197 | 198 | optimizer.zero_grad() 199 | classifier_loss.backward() 200 | optimizer.step() 201 | 202 | if iter_num % interval_iter == 0 or iter_num == max_iter: 203 | netF.eval() 204 | netB.eval() 205 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) 206 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc) 207 | args.out_file.write(log_str + '\n') 208 | args.out_file.flush() 209 | print(log_str+'\n') 210 | netF.train() 211 | netB.train() 212 | 213 | if args.issave: 214 | torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) 215 | torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) 216 | torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) 217 | 218 | return netF, netB, netC 219 | 220 | def obtain_label(loader, netF, netB, netC, args): 221 | start_test = True 222 | with torch.no_grad(): 223 | iter_test = iter(loader) 224 | for _ in range(len(loader)): 225 | data = iter_test.next() 226 | inputs = data[0] 227 | labels = data[1] 228 | inputs = inputs.cuda() 229 | feas = netB(netF(inputs)) 230 | outputs = netC(feas) 231 | if start_test: 232 | all_fea = feas.float().cpu() 233 | all_output = outputs.float().cpu() 234 | all_label = labels.float() 235 | start_test = False 236 | else: 237 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0) 238 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 239 | all_label = torch.cat((all_label, labels.float()), 0) 240 | 241 | all_output = nn.Softmax(dim=1)(all_output) 242 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) 243 | unknown_weight = 1 - ent / np.log(args.class_num) 244 | _, predict = torch.max(all_output, 1) 245 | 246 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 247 | if args.distance == 'cosine': 248 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) 249 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() 250 | 251 | all_fea = all_fea.float().cpu().numpy() 252 | K = all_output.size(1) 253 | aff = all_output.float().cpu().numpy() 254 | initc = aff.transpose().dot(all_fea) 255 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 256 | cls_count = np.eye(K)[predict].sum(axis=0) 257 | labelset = np.where(cls_count>args.threshold) 258 | labelset = labelset[0] 259 | # print(labelset) 260 | 261 | dd = cdist(all_fea, initc[labelset], args.distance) 262 | pred_label = dd.argmin(axis=1) 263 | pred_label = labelset[pred_label] 264 | 265 | for round in range(1): 266 | aff = np.eye(K)[pred_label] 267 | initc = aff.transpose().dot(all_fea) 268 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 269 | dd = cdist(all_fea, initc[labelset], args.distance) 270 | pred_label = dd.argmin(axis=1) 271 | pred_label = labelset[pred_label] 272 | 273 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 274 | log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy*100, acc*100) 275 | 276 | args.out_file.write(log_str + '\n') 277 | args.out_file.flush() 278 | print(log_str+'\n') 279 | 280 | return pred_label.astype('int') #, labelset 281 | 282 | if __name__ == "__main__": 283 | parser = argparse.ArgumentParser(description='Conditional Domain Adversarial Network') 284 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 285 | parser.add_argument('--s', type=int, default=0, help="source") 286 | parser.add_argument('--t', type=int, default=1, help="target") 287 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations") 288 | parser.add_argument('--interval', type=int, default=15) 289 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 290 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 291 | parser.add_argument('--dset', type=str, default='office-caltech', choices=['office-caltech']) 292 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 293 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101") 294 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 295 | 296 | parser.add_argument('--gent', type=bool, default=True) 297 | parser.add_argument('--ent', type=bool, default=True) 298 | parser.add_argument('--threshold', type=int, default=-1) 299 | parser.add_argument('--cls_par', type=float, default=0.3) 300 | parser.add_argument('--ent_par', type=float, default=1.0) 301 | parser.add_argument('--lr_decay1', type=float, default=0.1) 302 | parser.add_argument('--lr_decay2', type=float, default=1.0) 303 | 304 | parser.add_argument('--bottleneck', type=int, default=256) 305 | parser.add_argument('--epsilon', type=float, default=1e-5) 306 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 307 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 308 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"]) 309 | parser.add_argument('--output', type=str, default='san') 310 | parser.add_argument('--output_src', type=str, default='ckps') 311 | parser.add_argument('--da', type=str, default='uda', choices=['uda']) 312 | parser.add_argument('--issave', type=bool, default=True) 313 | args = parser.parse_args() 314 | 315 | if args.dset == 'office-caltech': 316 | names = ['amazon', 'caltech', 'dslr', 'webcam'] 317 | args.class_num = 10 318 | 319 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 320 | SEED = args.seed 321 | torch.manual_seed(SEED) 322 | torch.cuda.manual_seed(SEED) 323 | np.random.seed(SEED) 324 | random.seed(SEED) 325 | # torch.backends.cudnn.deterministic = True 326 | 327 | t_dset = [] 328 | for i in range(len(names)): 329 | if i == args.s: 330 | continue 331 | args.t = i 332 | 333 | folder = './data/' 334 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 335 | args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 336 | t_dset.append(args.t_dset_path) 337 | 338 | args.t_dset_path = t_dset 339 | args.test_dset_path = args.t_dset_path 340 | 341 | args.output_dir_src = osp.join(args.output_src, args.da, args.dset, names[args.s][0].upper()) 342 | args.output_dir = osp.join(args.output, args.da, args.dset, names[args.s][0].upper() + str(0)) 343 | args.name = names[args.s][0].upper() + str(0) 344 | 345 | if not osp.exists(args.output_dir): 346 | os.system('mkdir -p ' + args.output_dir) 347 | if not osp.exists(args.output_dir): 348 | os.mkdir(args.output_dir) 349 | 350 | args.savename = 'par_' + str(args.cls_par) 351 | args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w') 352 | args.out_file.write(print_args(args)+'\n') 353 | args.out_file.flush() 354 | train_target(args) -------------------------------------------------------------------------------- /object/image_pretrained.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | from torchvision import transforms 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList, ImageList_idx 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from scipy.spatial.distance import cdist 16 | from sklearn.metrics import confusion_matrix 17 | 18 | def op_copy(optimizer): 19 | for param_group in optimizer.param_groups: 20 | param_group['lr0'] = param_group['lr'] 21 | return optimizer 22 | 23 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 24 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 25 | for param_group in optimizer.param_groups: 26 | param_group['lr'] = param_group['lr0'] * decay 27 | param_group['weight_decay'] = 1e-3 28 | param_group['momentum'] = 0.9 29 | param_group['nesterov'] = True 30 | return optimizer 31 | 32 | def image_train(resize_size=256, crop_size=224, alexnet=False): 33 | if not alexnet: 34 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 35 | std=[0.229, 0.224, 0.225]) 36 | else: 37 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 38 | return transforms.Compose([ 39 | transforms.Resize((resize_size, resize_size)), 40 | transforms.RandomCrop(crop_size), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | normalize 44 | ]) 45 | 46 | def image_test(resize_size=256, crop_size=224, alexnet=False): 47 | if not alexnet: 48 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 49 | std=[0.229, 0.224, 0.225]) 50 | else: 51 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 52 | return transforms.Compose([ 53 | transforms.Resize((resize_size, resize_size)), 54 | transforms.CenterCrop(crop_size), 55 | transforms.ToTensor(), 56 | normalize 57 | ]) 58 | 59 | def data_load(args): 60 | dsets = {} 61 | dset_loaders = {} 62 | train_bs = args.batch_size 63 | txt_tar = open(args.t_dset_path).readlines() 64 | txt_test = open(args.test_dset_path).readlines() 65 | 66 | dsets["target"] = ImageList_idx(txt_tar, transform=image_train()) 67 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 68 | dsets["test"] = ImageList_idx(txt_test, transform=image_test()) 69 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False) 70 | 71 | return dset_loaders 72 | 73 | def cal_acc(loader, net, flag=False): 74 | start_test = True 75 | with torch.no_grad(): 76 | iter_test = iter(loader) 77 | for i in range(len(loader)): 78 | data = iter_test.next() 79 | inputs = data[0] 80 | labels = data[1] 81 | inputs = inputs.cuda() 82 | _, outputs = net(inputs) 83 | if start_test: 84 | all_output = outputs.float().cpu() 85 | all_label = labels.float() 86 | start_test = False 87 | else: 88 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 89 | all_label = torch.cat((all_label, labels.float()), 0) 90 | _, predict = torch.max(all_output, 1) 91 | all_output = nn.Softmax(dim=1)(all_output) 92 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) / np.log(all_output.size(1)) 93 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 94 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item() 95 | 96 | return accuracy, mean_ent 97 | 98 | def train_target(args): 99 | dset_loaders = data_load(args) 100 | netF = network.Res50().cuda() 101 | 102 | param_group = [] 103 | for k, v in netF.named_parameters(): 104 | if k.__contains__("fc"): 105 | v.requires_grad = False 106 | else: 107 | param_group += [{'params': v, 'lr': args.lr*args.lr_decay1}] 108 | 109 | optimizer = optim.SGD(param_group) 110 | optimizer = op_copy(optimizer) 111 | 112 | max_iter = args.max_epoch * len(dset_loaders["target"]) 113 | interval_iter = max_iter // args.interval 114 | iter_num = 0 115 | 116 | netF.train() 117 | while iter_num < max_iter: 118 | try: 119 | inputs_test, _, tar_idx = iter_test.next() 120 | except: 121 | iter_test = iter(dset_loaders["target"]) 122 | inputs_test, _, tar_idx = iter_test.next() 123 | 124 | if inputs_test.size(0) == 1: 125 | continue 126 | 127 | if iter_num % interval_iter == 0 and args.cls_par > 0: 128 | netF.eval() 129 | mem_label = obtain_label(dset_loaders['test'], netF, args) 130 | mem_label = torch.from_numpy(mem_label).cuda() 131 | netF.train() 132 | 133 | inputs_test = inputs_test.cuda() 134 | iter_num += 1 135 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 136 | 137 | features_test, outputs_test = netF(inputs_test) 138 | 139 | if args.cls_par > 0: 140 | pred = mem_label[tar_idx] 141 | classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred) 142 | classifier_loss *= args.cls_par 143 | else: 144 | classifier_loss = torch.tensor(0.0).cuda() 145 | 146 | if args.ent: 147 | softmax_out = nn.Softmax(dim=1)(outputs_test) 148 | entropy_loss = torch.mean(loss.Entropy(softmax_out)) 149 | if args.gent: 150 | msoftmax = softmax_out.mean(dim=0) 151 | gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) 152 | entropy_loss -= gentropy_loss 153 | classifier_loss += entropy_loss * args.ent_par 154 | 155 | optimizer.zero_grad() 156 | classifier_loss.backward() 157 | optimizer.step() 158 | 159 | if iter_num % interval_iter == 0 or iter_num == max_iter: 160 | netF.eval() 161 | acc, ment = cal_acc(dset_loaders['test'], netF) 162 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.dset, iter_num, max_iter, acc*100) 163 | args.out_file.write(log_str + '\n') 164 | args.out_file.flush() 165 | print(log_str+'\n') 166 | netF.train() 167 | 168 | if args.issave: 169 | torch.save(netF.state_dict(), osp.join(args.output_dir, "target" + args.savename + ".pt")) 170 | 171 | return netF 172 | 173 | def print_args(args): 174 | s = "==========================================\n" 175 | for arg, content in args.__dict__.items(): 176 | s += "{}:{}\n".format(arg, content) 177 | return s 178 | 179 | def obtain_label(loader, net, args): 180 | start_test = True 181 | with torch.no_grad(): 182 | iter_test = iter(loader) 183 | for _ in range(len(loader)): 184 | data = iter_test.next() 185 | inputs = data[0] 186 | labels = data[1] 187 | inputs = inputs.cuda() 188 | feas, outputs = net(inputs) 189 | if start_test: 190 | all_fea = feas.float().cpu() 191 | all_output = outputs.float().cpu() 192 | all_label = labels.float() 193 | start_test = False 194 | else: 195 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0) 196 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 197 | all_label = torch.cat((all_label, labels.float()), 0) 198 | 199 | all_output = nn.Softmax(dim=1)(all_output) 200 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) 201 | unknown_weight = 1 - ent / np.log(args.class_num) 202 | _, predict = torch.max(all_output, 1) 203 | 204 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 205 | if args.distance == 'cosine': 206 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) 207 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() 208 | 209 | all_fea = all_fea.float().cpu().numpy() 210 | K = all_output.size(1) 211 | aff = all_output.float().cpu().numpy() 212 | initc = aff.transpose().dot(all_fea) 213 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 214 | cls_count = np.eye(K)[predict].sum(axis=0) 215 | labelset = np.where(cls_count>args.threshold) 216 | labelset = labelset[0] 217 | # print(labelset) 218 | 219 | dd = cdist(all_fea, initc[labelset], args.distance) 220 | pred_label = dd.argmin(axis=1) 221 | pred_label = labelset[pred_label] 222 | 223 | for round in range(1): 224 | aff = np.eye(K)[pred_label] 225 | initc = aff.transpose().dot(all_fea) 226 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 227 | dd = cdist(all_fea, initc[labelset], args.distance) 228 | pred_label = dd.argmin(axis=1) 229 | pred_label = labelset[pred_label] 230 | 231 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 232 | log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy*100, acc*100) 233 | 234 | args.out_file.write(log_str + '\n') 235 | args.out_file.flush() 236 | print(log_str+'\n') 237 | 238 | return pred_label.astype('int') #, labelset 239 | 240 | 241 | if __name__ == "__main__": 242 | parser = argparse.ArgumentParser(description='SHOT') 243 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 244 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations") 245 | parser.add_argument('--interval', type=int, default=15, help="max iterations") 246 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 247 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 248 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"]) 249 | parser.add_argument('--dset', type=str, default='imagenet_caltech', choices=['imagenet_caltech']) 250 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 251 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101") 252 | parser.add_argument('--seed', type=int, default=2019, help="random seed") 253 | parser.add_argument('--epsilon', type=float, default=1e-5) 254 | parser.add_argument('--gent', type=bool, default=False) 255 | parser.add_argument('--ent', type=bool, default=True) 256 | parser.add_argument('--threshold', type=int, default=30) 257 | 258 | parser.add_argument('--cls_par', type=float, default=0.3) 259 | parser.add_argument('--ent_par', type=float, default=1.0) 260 | parser.add_argument('--output', type=str, default='seed') 261 | parser.add_argument('--da', type=str, default='pda', choices=['pda']) 262 | parser.add_argument('--issave', type=bool, default=True) 263 | parser.add_argument('--lr_decay1', type=float, default=0.1) 264 | 265 | args = parser.parse_args() 266 | 267 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 268 | SEED = args.seed 269 | torch.manual_seed(SEED) 270 | torch.cuda.manual_seed(SEED) 271 | np.random.seed(SEED) 272 | random.seed(SEED) 273 | # torch.backends.cudnn.deterministic = True 274 | 275 | args.class_num = 1000 276 | folder = './data/' 277 | if args.da == 'pda': 278 | args.t_dset_path = folder + args.dset + '/' + 'caltech_84' + '_list.txt' 279 | args.test_dset_path = args.t_dset_path 280 | 281 | args.output_dir = osp.join(args.output, args.da, args.dset) 282 | args.name = args.dset 283 | 284 | if not osp.exists(args.output_dir): 285 | os.system('mkdir -p ' + args.output_dir) 286 | if not osp.exists(args.output_dir): 287 | os.mkdir(args.output_dir) 288 | 289 | args.savename = 'par_' + str(args.cls_par) 290 | if args.da == 'pda': 291 | args.savename = 'par_' + str(args.cls_par) + '_thr' + str(args.threshold) 292 | args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w') 293 | args.out_file.write(print_args(args)+'\n') 294 | args.out_file.flush() 295 | train_target(args) -------------------------------------------------------------------------------- /object/image_source.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from loss import CrossEntropyLabelSmooth 16 | from scipy.spatial.distance import cdist 17 | from sklearn.metrics import confusion_matrix 18 | from sklearn.cluster import KMeans 19 | 20 | def op_copy(optimizer): 21 | for param_group in optimizer.param_groups: 22 | param_group['lr0'] = param_group['lr'] 23 | return optimizer 24 | 25 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 26 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 27 | for param_group in optimizer.param_groups: 28 | param_group['lr'] = param_group['lr0'] * decay 29 | param_group['weight_decay'] = 1e-3 30 | param_group['momentum'] = 0.9 31 | param_group['nesterov'] = True 32 | return optimizer 33 | 34 | def image_train(resize_size=256, crop_size=224, alexnet=False): 35 | if not alexnet: 36 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 37 | std=[0.229, 0.224, 0.225]) 38 | else: 39 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 40 | return transforms.Compose([ 41 | transforms.Resize((resize_size, resize_size)), 42 | transforms.RandomCrop(crop_size), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.ToTensor(), 45 | normalize 46 | ]) 47 | 48 | def image_test(resize_size=256, crop_size=224, alexnet=False): 49 | if not alexnet: 50 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 51 | std=[0.229, 0.224, 0.225]) 52 | else: 53 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 54 | return transforms.Compose([ 55 | transforms.Resize((resize_size, resize_size)), 56 | transforms.CenterCrop(crop_size), 57 | transforms.ToTensor(), 58 | normalize 59 | ]) 60 | 61 | def data_load(args): 62 | ## prepare data 63 | dsets = {} 64 | dset_loaders = {} 65 | train_bs = args.batch_size 66 | txt_src = open(args.s_dset_path).readlines() 67 | txt_test = open(args.test_dset_path).readlines() 68 | 69 | if not args.da == 'uda': 70 | label_map_s = {} 71 | for i in range(len(args.src_classes)): 72 | label_map_s[args.src_classes[i]] = i 73 | 74 | new_src = [] 75 | for i in range(len(txt_src)): 76 | rec = txt_src[i] 77 | reci = rec.strip().split(' ') 78 | if int(reci[1]) in args.src_classes: 79 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 80 | new_src.append(line) 81 | txt_src = new_src.copy() 82 | 83 | new_tar = [] 84 | for i in range(len(txt_test)): 85 | rec = txt_test[i] 86 | reci = rec.strip().split(' ') 87 | if int(reci[1]) in args.tar_classes: 88 | if int(reci[1]) in args.src_classes: 89 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 90 | new_tar.append(line) 91 | else: 92 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n' 93 | new_tar.append(line) 94 | txt_test = new_tar.copy() 95 | 96 | if args.trte == "val": 97 | dsize = len(txt_src) 98 | tr_size = int(0.9*dsize) 99 | # print(dsize, tr_size, dsize - tr_size) 100 | tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size]) 101 | else: 102 | dsize = len(txt_src) 103 | tr_size = int(0.9*dsize) 104 | _, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size]) 105 | tr_txt = txt_src 106 | 107 | dsets["source_tr"] = ImageList(tr_txt, transform=image_train()) 108 | dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 109 | dsets["source_te"] = ImageList(te_txt, transform=image_test()) 110 | dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 111 | dsets["test"] = ImageList(txt_test, transform=image_test()) 112 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=True, num_workers=args.worker, drop_last=False) 113 | 114 | return dset_loaders 115 | 116 | def cal_acc(loader, netF, netB, netC, flag=False): 117 | start_test = True 118 | with torch.no_grad(): 119 | iter_test = iter(loader) 120 | for i in range(len(loader)): 121 | data = iter_test.next() 122 | inputs = data[0] 123 | labels = data[1] 124 | inputs = inputs.cuda() 125 | outputs = netC(netB(netF(inputs))) 126 | if start_test: 127 | all_output = outputs.float().cpu() 128 | all_label = labels.float() 129 | start_test = False 130 | else: 131 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 132 | all_label = torch.cat((all_label, labels.float()), 0) 133 | 134 | all_output = nn.Softmax(dim=1)(all_output) 135 | _, predict = torch.max(all_output, 1) 136 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 137 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item() 138 | 139 | if flag: 140 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 141 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100 142 | aacc = acc.mean() 143 | aa = [str(np.round(i, 2)) for i in acc] 144 | acc = ' '.join(aa) 145 | return aacc, acc 146 | else: 147 | return accuracy*100, mean_ent 148 | 149 | def cal_acc_oda(loader, netF, netB, netC): 150 | start_test = True 151 | with torch.no_grad(): 152 | iter_test = iter(loader) 153 | for i in range(len(loader)): 154 | data = iter_test.next() 155 | inputs = data[0] 156 | labels = data[1] 157 | inputs = inputs.cuda() 158 | outputs = netC(netB(netF(inputs))) 159 | if start_test: 160 | all_output = outputs.float().cpu() 161 | all_label = labels.float() 162 | start_test = False 163 | else: 164 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 165 | all_label = torch.cat((all_label, labels.float()), 0) 166 | 167 | all_output = nn.Softmax(dim=1)(all_output) 168 | _, predict = torch.max(all_output, 1) 169 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) / np.log(args.class_num) 170 | ent = ent.float().cpu() 171 | initc = np.array([[0], [1]]) 172 | kmeans = KMeans(n_clusters=2, random_state=0, init=initc, n_init=1).fit(ent.reshape(-1,1)) 173 | threshold = (kmeans.cluster_centers_).mean() 174 | 175 | predict[ent>threshold] = args.class_num 176 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 177 | matrix = matrix[np.unique(all_label).astype(int),:] 178 | 179 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100 180 | unknown_acc = acc[-1:].item() 181 | 182 | return np.mean(acc[:-1]), np.mean(acc), unknown_acc 183 | # return np.mean(acc), np.mean(acc[:-1]) 184 | 185 | def train_source(args): 186 | dset_loaders = data_load(args) 187 | ## set base network 188 | if args.net[0:3] == 'res': 189 | netF = network.ResBase(res_name=args.net).cuda() 190 | elif args.net[0:3] == 'vgg': 191 | netF = network.VGGBase(vgg_name=args.net).cuda() 192 | 193 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 194 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 195 | 196 | param_group = [] 197 | learning_rate = args.lr 198 | for k, v in netF.named_parameters(): 199 | param_group += [{'params': v, 'lr': learning_rate*0.1}] 200 | for k, v in netB.named_parameters(): 201 | param_group += [{'params': v, 'lr': learning_rate}] 202 | for k, v in netC.named_parameters(): 203 | param_group += [{'params': v, 'lr': learning_rate}] 204 | optimizer = optim.SGD(param_group) 205 | optimizer = op_copy(optimizer) 206 | 207 | acc_init = 0 208 | max_iter = args.max_epoch * len(dset_loaders["source_tr"]) 209 | interval_iter = max_iter // 10 210 | iter_num = 0 211 | 212 | netF.train() 213 | netB.train() 214 | netC.train() 215 | 216 | while iter_num < max_iter: 217 | try: 218 | inputs_source, labels_source = iter_source.next() 219 | except: 220 | iter_source = iter(dset_loaders["source_tr"]) 221 | inputs_source, labels_source = iter_source.next() 222 | 223 | if inputs_source.size(0) == 1: 224 | continue 225 | 226 | iter_num += 1 227 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 228 | 229 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda() 230 | outputs_source = netC(netB(netF(inputs_source))) 231 | classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source) 232 | 233 | optimizer.zero_grad() 234 | classifier_loss.backward() 235 | optimizer.step() 236 | 237 | if iter_num % interval_iter == 0 or iter_num == max_iter: 238 | netF.eval() 239 | netB.eval() 240 | netC.eval() 241 | if args.dset=='VISDA-C': 242 | acc_s_te, acc_list = cal_acc(dset_loaders['source_te'], netF, netB, netC, True) 243 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te) + '\n' + acc_list 244 | else: 245 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC, False) 246 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te) 247 | args.out_file.write(log_str + '\n') 248 | args.out_file.flush() 249 | print(log_str+'\n') 250 | 251 | if acc_s_te >= acc_init: 252 | acc_init = acc_s_te 253 | best_netF = copy.deepcopy(netF.state_dict()) 254 | best_netB = copy.deepcopy(netB.state_dict()) 255 | best_netC = copy.deepcopy(netC.state_dict()) 256 | 257 | netF.train() 258 | netB.train() 259 | netC.train() 260 | 261 | torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt")) 262 | torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt")) 263 | torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt")) 264 | 265 | return netF, netB, netC 266 | 267 | def test_target(args): 268 | dset_loaders = data_load(args) 269 | ## set base network 270 | if args.net[0:3] == 'res': 271 | netF = network.ResBase(res_name=args.net).cuda() 272 | elif args.net[0:3] == 'vgg': 273 | netF = network.VGGBase(vgg_name=args.net).cuda() 274 | 275 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 276 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 277 | 278 | args.modelpath = args.output_dir_src + '/source_F.pt' 279 | netF.load_state_dict(torch.load(args.modelpath)) 280 | args.modelpath = args.output_dir_src + '/source_B.pt' 281 | netB.load_state_dict(torch.load(args.modelpath)) 282 | args.modelpath = args.output_dir_src + '/source_C.pt' 283 | netC.load_state_dict(torch.load(args.modelpath)) 284 | netF.eval() 285 | netB.eval() 286 | netC.eval() 287 | 288 | if args.da == 'oda': 289 | acc_os1, acc_os2, acc_unknown = cal_acc_oda(dset_loaders['test'], netF, netB, netC) 290 | log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format(args.trte, args.name, acc_os2, acc_os1, acc_unknown) 291 | else: 292 | if args.dset=='VISDA-C': 293 | acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True) 294 | log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format(args.trte, args.name, acc) + '\n' + acc_list 295 | else: 296 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) 297 | log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format(args.trte, args.name, acc) 298 | 299 | args.out_file.write(log_str) 300 | args.out_file.flush() 301 | print(log_str) 302 | 303 | def print_args(args): 304 | s = "==========================================\n" 305 | for arg, content in args.__dict__.items(): 306 | s += "{}:{}\n".format(arg, content) 307 | return s 308 | 309 | if __name__ == "__main__": 310 | parser = argparse.ArgumentParser(description='SHOT') 311 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 312 | parser.add_argument('--s', type=int, default=0, help="source") 313 | parser.add_argument('--t', type=int, default=1, help="target") 314 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations") 315 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 316 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 317 | parser.add_argument('--dset', type=str, default='office-home', choices=['VISDA-C', 'office', 'office-home', 'office-caltech']) 318 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 319 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101") 320 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 321 | parser.add_argument('--bottleneck', type=int, default=256) 322 | parser.add_argument('--epsilon', type=float, default=1e-5) 323 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 324 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 325 | parser.add_argument('--smooth', type=float, default=0.1) 326 | parser.add_argument('--output', type=str, default='san') 327 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda', 'oda']) 328 | parser.add_argument('--trte', type=str, default='val', choices=['full', 'val']) 329 | args = parser.parse_args() 330 | 331 | if args.dset == 'office-home': 332 | names = ['Art', 'Clipart', 'Product', 'RealWorld'] 333 | args.class_num = 65 334 | if args.dset == 'office': 335 | names = ['amazon', 'dslr', 'webcam'] 336 | args.class_num = 31 337 | if args.dset == 'VISDA-C': 338 | names = ['train', 'validation'] 339 | args.class_num = 12 340 | if args.dset == 'office-caltech': 341 | names = ['amazon', 'caltech', 'dslr', 'webcam'] 342 | args.class_num = 10 343 | 344 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 345 | SEED = args.seed 346 | torch.manual_seed(SEED) 347 | torch.cuda.manual_seed(SEED) 348 | np.random.seed(SEED) 349 | random.seed(SEED) 350 | # torch.backends.cudnn.deterministic = True 351 | 352 | folder = './data/' 353 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 354 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 355 | 356 | if args.dset == 'office-home': 357 | if args.da == 'pda': 358 | args.class_num = 65 359 | args.src_classes = [i for i in range(65)] 360 | args.tar_classes = [i for i in range(25)] 361 | if args.da == 'oda': 362 | args.class_num = 25 363 | args.src_classes = [i for i in range(25)] 364 | args.tar_classes = [i for i in range(65)] 365 | 366 | args.output_dir_src = osp.join(args.output, args.da, args.dset, names[args.s][0].upper()) 367 | args.name_src = names[args.s][0].upper() 368 | if not osp.exists(args.output_dir_src): 369 | os.system('mkdir -p ' + args.output_dir_src) 370 | if not osp.exists(args.output_dir_src): 371 | os.mkdir(args.output_dir_src) 372 | 373 | args.out_file = open(osp.join(args.output_dir_src, 'log.txt'), 'w') 374 | args.out_file.write(print_args(args)+'\n') 375 | args.out_file.flush() 376 | train_source(args) 377 | 378 | args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w') 379 | for i in range(len(names)): 380 | if i == args.s: 381 | continue 382 | args.t = i 383 | args.name = names[args.s][0].upper() + names[args.t][0].upper() 384 | 385 | folder = '/Checkpoint/liangjian/tran/data/' 386 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 387 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 388 | 389 | if args.dset == 'office-home': 390 | if args.da == 'pda': 391 | args.class_num = 65 392 | args.src_classes = [i for i in range(65)] 393 | args.tar_classes = [i for i in range(25)] 394 | if args.da == 'oda': 395 | args.class_num = 25 396 | args.src_classes = [i for i in range(25)] 397 | args.tar_classes = [i for i in range(65)] 398 | 399 | test_target(args) 400 | -------------------------------------------------------------------------------- /object/image_target.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList, ImageList_idx 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from scipy.spatial.distance import cdist 16 | from sklearn.metrics import confusion_matrix 17 | 18 | def op_copy(optimizer): 19 | for param_group in optimizer.param_groups: 20 | param_group['lr0'] = param_group['lr'] 21 | return optimizer 22 | 23 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 24 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 25 | for param_group in optimizer.param_groups: 26 | param_group['lr'] = param_group['lr0'] * decay 27 | param_group['weight_decay'] = 1e-3 28 | param_group['momentum'] = 0.9 29 | param_group['nesterov'] = True 30 | return optimizer 31 | 32 | def image_train(resize_size=256, crop_size=224, alexnet=False): 33 | if not alexnet: 34 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 35 | std=[0.229, 0.224, 0.225]) 36 | else: 37 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 38 | return transforms.Compose([ 39 | transforms.Resize((resize_size, resize_size)), 40 | transforms.RandomCrop(crop_size), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | normalize 44 | ]) 45 | 46 | def image_test(resize_size=256, crop_size=224, alexnet=False): 47 | if not alexnet: 48 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 49 | std=[0.229, 0.224, 0.225]) 50 | else: 51 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 52 | return transforms.Compose([ 53 | transforms.Resize((resize_size, resize_size)), 54 | transforms.CenterCrop(crop_size), 55 | transforms.ToTensor(), 56 | normalize 57 | ]) 58 | 59 | def data_load(args): 60 | ## prepare data 61 | dsets = {} 62 | dset_loaders = {} 63 | train_bs = args.batch_size 64 | txt_tar = open(args.t_dset_path).readlines() 65 | txt_test = open(args.test_dset_path).readlines() 66 | 67 | if not args.da == 'uda': 68 | label_map_s = {} 69 | for i in range(len(args.src_classes)): 70 | label_map_s[args.src_classes[i]] = i 71 | 72 | new_tar = [] 73 | for i in range(len(txt_tar)): 74 | rec = txt_tar[i] 75 | reci = rec.strip().split(' ') 76 | if int(reci[1]) in args.tar_classes: 77 | if int(reci[1]) in args.src_classes: 78 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 79 | new_tar.append(line) 80 | else: 81 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n' 82 | new_tar.append(line) 83 | txt_tar = new_tar.copy() 84 | txt_test = txt_tar.copy() 85 | 86 | dsets["target"] = ImageList_idx(txt_tar, transform=image_train()) 87 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 88 | dsets["test"] = ImageList_idx(txt_test, transform=image_test()) 89 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False) 90 | 91 | return dset_loaders 92 | 93 | def cal_acc(loader, netF, netB, netC, flag=False): 94 | start_test = True 95 | with torch.no_grad(): 96 | iter_test = iter(loader) 97 | for i in range(len(loader)): 98 | data = iter_test.next() 99 | inputs = data[0] 100 | labels = data[1] 101 | inputs = inputs.cuda() 102 | outputs = netC(netB(netF(inputs))) 103 | if start_test: 104 | all_output = outputs.float().cpu() 105 | all_label = labels.float() 106 | start_test = False 107 | else: 108 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 109 | all_label = torch.cat((all_label, labels.float()), 0) 110 | _, predict = torch.max(all_output, 1) 111 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 112 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item() 113 | 114 | if flag: 115 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 116 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100 117 | aacc = acc.mean() 118 | aa = [str(np.round(i, 2)) for i in acc] 119 | acc = ' '.join(aa) 120 | return aacc, acc 121 | else: 122 | return accuracy*100, mean_ent 123 | 124 | def train_target(args): 125 | dset_loaders = data_load(args) 126 | ## set base network 127 | if args.net[0:3] == 'res': 128 | netF = network.ResBase(res_name=args.net).cuda() 129 | elif args.net[0:3] == 'vgg': 130 | netF = network.VGGBase(vgg_name=args.net).cuda() 131 | 132 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 133 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 134 | 135 | modelpath = args.output_dir_src + '/source_F.pt' 136 | netF.load_state_dict(torch.load(modelpath)) 137 | modelpath = args.output_dir_src + '/source_B.pt' 138 | netB.load_state_dict(torch.load(modelpath)) 139 | modelpath = args.output_dir_src + '/source_C.pt' 140 | netC.load_state_dict(torch.load(modelpath)) 141 | netC.eval() 142 | for k, v in netC.named_parameters(): 143 | v.requires_grad = False 144 | 145 | param_group = [] 146 | for k, v in netF.named_parameters(): 147 | if args.lr_decay1 > 0: 148 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] 149 | else: 150 | v.requires_grad = False 151 | for k, v in netB.named_parameters(): 152 | if args.lr_decay2 > 0: 153 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] 154 | else: 155 | v.requires_grad = False 156 | 157 | optimizer = optim.SGD(param_group) 158 | optimizer = op_copy(optimizer) 159 | 160 | max_iter = args.max_epoch * len(dset_loaders["target"]) 161 | interval_iter = max_iter // args.interval 162 | iter_num = 0 163 | 164 | while iter_num < max_iter: 165 | try: 166 | inputs_test, _, tar_idx = iter_test.next() 167 | except: 168 | iter_test = iter(dset_loaders["target"]) 169 | inputs_test, _, tar_idx = iter_test.next() 170 | 171 | if inputs_test.size(0) == 1: 172 | continue 173 | 174 | if iter_num % interval_iter == 0 and args.cls_par > 0: 175 | netF.eval() 176 | netB.eval() 177 | mem_label = obtain_label(dset_loaders['test'], netF, netB, netC, args) 178 | mem_label = torch.from_numpy(mem_label).cuda() 179 | netF.train() 180 | netB.train() 181 | 182 | inputs_test = inputs_test.cuda() 183 | 184 | iter_num += 1 185 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 186 | 187 | features_test = netB(netF(inputs_test)) 188 | outputs_test = netC(features_test) 189 | 190 | if args.cls_par > 0: 191 | pred = mem_label[tar_idx] 192 | classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred) 193 | classifier_loss *= args.cls_par 194 | if iter_num < interval_iter and args.dset == "VISDA-C": 195 | classifier_loss *= 0 196 | else: 197 | classifier_loss = torch.tensor(0.0).cuda() 198 | 199 | if args.ent: 200 | softmax_out = nn.Softmax(dim=1)(outputs_test) 201 | entropy_loss = torch.mean(loss.Entropy(softmax_out)) 202 | if args.gent: 203 | msoftmax = softmax_out.mean(dim=0) 204 | gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) 205 | entropy_loss -= gentropy_loss 206 | im_loss = entropy_loss * args.ent_par 207 | classifier_loss += im_loss 208 | 209 | optimizer.zero_grad() 210 | classifier_loss.backward() 211 | optimizer.step() 212 | 213 | if iter_num % interval_iter == 0 or iter_num == max_iter: 214 | netF.eval() 215 | netB.eval() 216 | if args.dset=='VISDA-C': 217 | acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True) 218 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list 219 | else: 220 | acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) 221 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te) 222 | 223 | args.out_file.write(log_str + '\n') 224 | args.out_file.flush() 225 | print(log_str+'\n') 226 | netF.train() 227 | netB.train() 228 | 229 | if args.issave: 230 | torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) 231 | torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) 232 | torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) 233 | 234 | return netF, netB, netC 235 | 236 | def print_args(args): 237 | s = "==========================================\n" 238 | for arg, content in args.__dict__.items(): 239 | s += "{}:{}\n".format(arg, content) 240 | return s 241 | 242 | def obtain_label(loader, netF, netB, netC, args): 243 | start_test = True 244 | with torch.no_grad(): 245 | iter_test = iter(loader) 246 | for _ in range(len(loader)): 247 | data = iter_test.next() 248 | inputs = data[0] 249 | labels = data[1] 250 | inputs = inputs.cuda() 251 | feas = netB(netF(inputs)) 252 | outputs = netC(feas) 253 | if start_test: 254 | all_fea = feas.float().cpu() 255 | all_output = outputs.float().cpu() 256 | all_label = labels.float() 257 | start_test = False 258 | else: 259 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0) 260 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 261 | all_label = torch.cat((all_label, labels.float()), 0) 262 | 263 | all_output = nn.Softmax(dim=1)(all_output) 264 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) 265 | unknown_weight = 1 - ent / np.log(args.class_num) 266 | _, predict = torch.max(all_output, 1) 267 | 268 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 269 | if args.distance == 'cosine': 270 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) 271 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() 272 | 273 | all_fea = all_fea.float().cpu().numpy() 274 | K = all_output.size(1) 275 | aff = all_output.float().cpu().numpy() 276 | 277 | for _ in range(2): 278 | initc = aff.transpose().dot(all_fea) 279 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 280 | cls_count = np.eye(K)[predict].sum(axis=0) 281 | labelset = np.where(cls_count>args.threshold) 282 | labelset = labelset[0] 283 | 284 | dd = cdist(all_fea, initc[labelset], args.distance) 285 | pred_label = dd.argmin(axis=1) 286 | predict = labelset[pred_label] 287 | 288 | aff = np.eye(K)[predict] 289 | 290 | acc = np.sum(predict == all_label.float().numpy()) / len(all_fea) 291 | log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100) 292 | 293 | args.out_file.write(log_str + '\n') 294 | args.out_file.flush() 295 | print(log_str+'\n') 296 | 297 | return predict.astype('int') 298 | 299 | 300 | if __name__ == "__main__": 301 | parser = argparse.ArgumentParser(description='SHOT') 302 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 303 | parser.add_argument('--s', type=int, default=0, help="source") 304 | parser.add_argument('--t', type=int, default=1, help="target") 305 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations") 306 | parser.add_argument('--interval', type=int, default=15) 307 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 308 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 309 | parser.add_argument('--dset', type=str, default='office-home', choices=['VISDA-C', 'office', 'office-home', 'office-caltech']) 310 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 311 | parser.add_argument('--net', type=str, default='resnet50', help="alexnet, vgg16, resnet50, res101") 312 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 313 | 314 | parser.add_argument('--gent', type=bool, default=True) 315 | parser.add_argument('--ent', type=bool, default=True) 316 | parser.add_argument('--threshold', type=int, default=0) 317 | parser.add_argument('--cls_par', type=float, default=0.3) 318 | parser.add_argument('--ent_par', type=float, default=1.0) 319 | parser.add_argument('--lr_decay1', type=float, default=0.1) 320 | parser.add_argument('--lr_decay2', type=float, default=1.0) 321 | 322 | parser.add_argument('--bottleneck', type=int, default=256) 323 | parser.add_argument('--epsilon', type=float, default=1e-5) 324 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 325 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 326 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"]) 327 | parser.add_argument('--output', type=str, default='san') 328 | parser.add_argument('--output_src', type=str, default='san') 329 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda']) 330 | parser.add_argument('--issave', type=bool, default=True) 331 | args = parser.parse_args() 332 | 333 | if args.dset == 'office-home': 334 | names = ['Art', 'Clipart', 'Product', 'RealWorld'] 335 | args.class_num = 65 336 | if args.dset == 'office': 337 | names = ['amazon', 'dslr', 'webcam'] 338 | args.class_num = 31 339 | if args.dset == 'VISDA-C': 340 | names = ['train', 'validation'] 341 | args.class_num = 12 342 | if args.dset == 'office-caltech': 343 | names = ['amazon', 'caltech', 'dslr', 'webcam'] 344 | args.class_num = 10 345 | 346 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 347 | SEED = args.seed 348 | torch.manual_seed(SEED) 349 | torch.cuda.manual_seed(SEED) 350 | np.random.seed(SEED) 351 | random.seed(SEED) 352 | # torch.backends.cudnn.deterministic = True 353 | 354 | for i in range(len(names)): 355 | if i == args.s: 356 | continue 357 | args.t = i 358 | 359 | folder = './data/' 360 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 361 | args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 362 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 363 | 364 | if args.dset == 'office-home': 365 | if args.da == 'pda': 366 | args.class_num = 65 367 | args.src_classes = [i for i in range(65)] 368 | args.tar_classes = [i for i in range(25)] 369 | 370 | args.output_dir_src = osp.join(args.output_src, args.da, args.dset, names[args.s][0].upper()) 371 | args.output_dir = osp.join(args.output, args.da, args.dset, names[args.s][0].upper()+names[args.t][0].upper()) 372 | args.name = names[args.s][0].upper()+names[args.t][0].upper() 373 | 374 | if not osp.exists(args.output_dir): 375 | os.system('mkdir -p ' + args.output_dir) 376 | if not osp.exists(args.output_dir): 377 | os.mkdir(args.output_dir) 378 | 379 | args.savename = 'par_' + str(args.cls_par) 380 | if args.da == 'pda': 381 | args.gent = '' 382 | args.savename = 'par_' + str(args.cls_par) + '_thr' + str(args.threshold) 383 | args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w') 384 | args.out_file.write(print_args(args)+'\n') 385 | args.out_file.flush() 386 | train_target(args) -------------------------------------------------------------------------------- /object/image_target_oda.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList, ImageList_idx 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from scipy.spatial.distance import cdist 16 | from sklearn.metrics import confusion_matrix 17 | from sklearn.cluster import KMeans 18 | 19 | def op_copy(optimizer): 20 | for param_group in optimizer.param_groups: 21 | param_group['lr0'] = param_group['lr'] 22 | return optimizer 23 | 24 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 25 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 26 | for param_group in optimizer.param_groups: 27 | param_group['lr'] = param_group['lr0'] * decay 28 | param_group['weight_decay'] = 1e-3 29 | param_group['momentum'] = 0.9 30 | param_group['nesterov'] = True 31 | return optimizer 32 | 33 | def image_train(resize_size=256, crop_size=224, alexnet=False): 34 | if not alexnet: 35 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 36 | std=[0.229, 0.224, 0.225]) 37 | else: 38 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 39 | return transforms.Compose([ 40 | transforms.Resize((resize_size, resize_size)), 41 | transforms.RandomCrop(crop_size), 42 | transforms.RandomHorizontalFlip(), 43 | transforms.ToTensor(), 44 | normalize 45 | ]) 46 | 47 | def image_test(resize_size=256, crop_size=224, alexnet=False): 48 | if not alexnet: 49 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 50 | std=[0.229, 0.224, 0.225]) 51 | else: 52 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 53 | return transforms.Compose([ 54 | transforms.Resize((resize_size, resize_size)), 55 | transforms.CenterCrop(crop_size), 56 | transforms.ToTensor(), 57 | normalize 58 | ]) 59 | 60 | def data_load(args): 61 | ## prepare data 62 | dsets = {} 63 | dset_loaders = {} 64 | train_bs = args.batch_size 65 | txt_src = open(args.s_dset_path).readlines() 66 | txt_tar = open(args.t_dset_path).readlines() 67 | txt_test = open(args.test_dset_path).readlines() 68 | 69 | if not args.da == 'uda': 70 | label_map_s = {} 71 | for i in range(len(args.src_classes)): 72 | label_map_s[args.src_classes[i]] = i 73 | 74 | new_tar = [] 75 | for i in range(len(txt_tar)): 76 | rec = txt_tar[i] 77 | reci = rec.strip().split(' ') 78 | if int(reci[1]) in args.tar_classes: 79 | if int(reci[1]) in args.src_classes: 80 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 81 | new_tar.append(line) 82 | else: 83 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n' 84 | new_tar.append(line) 85 | txt_tar = new_tar.copy() 86 | txt_test = txt_tar.copy() 87 | 88 | dsets["target"] = ImageList_idx(txt_tar, transform=image_train()) 89 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 90 | dsets["test"] = ImageList(txt_test, transform=image_test()) 91 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False) 92 | 93 | return dset_loaders 94 | 95 | def cal_acc(loader, netF, netB, netC, flag=False, threshold=0.1): 96 | start_test = True 97 | with torch.no_grad(): 98 | iter_test = iter(loader) 99 | for i in range(len(loader)): 100 | data = iter_test.next() 101 | inputs = data[0] 102 | labels = data[1] 103 | inputs = inputs.cuda() 104 | outputs = netC(netB(netF(inputs))) 105 | if start_test: 106 | all_output = outputs.float().cpu() 107 | all_label = labels.float() 108 | start_test = False 109 | else: 110 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 111 | all_label = torch.cat((all_label, labels.float()), 0) 112 | _, predict = torch.max(all_output, 1) 113 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 114 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item() 115 | 116 | if flag: 117 | all_output = nn.Softmax(dim=1)(all_output) 118 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) / np.log(args.class_num) 119 | 120 | from sklearn.cluster import KMeans 121 | kmeans = KMeans(2, random_state=0).fit(ent.reshape(-1,1)) 122 | labels = kmeans.predict(ent.reshape(-1,1)) 123 | 124 | idx = np.where(labels==1)[0] 125 | iidx = 0 126 | if ent[idx].mean() > ent.mean(): 127 | iidx = 1 128 | predict[np.where(labels==iidx)[0]] = args.class_num 129 | 130 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 131 | matrix = matrix[np.unique(all_label).astype(int),:] 132 | 133 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100 134 | unknown_acc = acc[-1:].item() 135 | return np.mean(acc[:-1]), np.mean(acc), unknown_acc 136 | else: 137 | return accuracy*100, mean_ent 138 | 139 | def print_args(args): 140 | s = "==========================================\n" 141 | for arg, content in args.__dict__.items(): 142 | s += "{}:{}\n".format(arg, content) 143 | return s 144 | 145 | def train_target(args): 146 | dset_loaders = data_load(args) 147 | ## set base network 148 | if args.net[0:3] == 'res': 149 | netF = network.ResBase(res_name=args.net).cuda() 150 | elif args.net[0:3] == 'vgg': 151 | netF = network.VGGBase(vgg_name=args.net).cuda() 152 | 153 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 154 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 155 | 156 | args.modelpath = args.output_dir_src + '/source_F.pt' 157 | netF.load_state_dict(torch.load(args.modelpath)) 158 | args.modelpath = args.output_dir_src + '/source_B.pt' 159 | netB.load_state_dict(torch.load(args.modelpath)) 160 | args.modelpath = args.output_dir_src + '/source_C.pt' 161 | netC.load_state_dict(torch.load(args.modelpath)) 162 | netC.eval() 163 | for k, v in netC.named_parameters(): 164 | v.requires_grad = False 165 | 166 | param_group = [] 167 | for k, v in netF.named_parameters(): 168 | if args.lr_decay1 > 0: 169 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] 170 | else: 171 | v.requires_grad = False 172 | for k, v in netB.named_parameters(): 173 | if args.lr_decay2 > 0: 174 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] 175 | else: 176 | v.requires_grad = False 177 | 178 | optimizer = optim.SGD(param_group) 179 | optimizer = op_copy(optimizer) 180 | 181 | tt = 0 182 | iter_num = 0 183 | max_iter = args.max_epoch * len(dset_loaders["target"]) 184 | interval_iter = max_iter // args.interval 185 | 186 | while iter_num < max_iter: 187 | try: 188 | inputs_test, _, tar_idx = iter_test.next() 189 | except: 190 | iter_test = iter(dset_loaders["target"]) 191 | inputs_test, _, tar_idx = iter_test.next() 192 | 193 | if inputs_test.size(0) == 1: 194 | continue 195 | 196 | if iter_num % interval_iter == 0: 197 | netF.eval() 198 | netB.eval() 199 | mem_label, ENT_THRESHOLD = obtain_label(dset_loaders['test'], netF, netB, netC, args) 200 | mem_label = torch.from_numpy(mem_label).cuda() 201 | netF.train() 202 | netB.train() 203 | 204 | inputs_test = inputs_test.cuda() 205 | 206 | iter_num += 1 207 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 208 | 209 | pred = mem_label[tar_idx] 210 | features_test = netB(netF(inputs_test)) 211 | outputs_test = netC(features_test) 212 | 213 | softmax_out = nn.Softmax(dim=1)(outputs_test) 214 | outputs_test_known = outputs_test[pred < args.class_num, :] 215 | pred = pred[pred < args.class_num] 216 | 217 | if len(pred) == 0: 218 | print(tt) 219 | del features_test 220 | del outputs_test 221 | tt += 1 222 | continue 223 | 224 | if args.cls_par > 0: 225 | classifier_loss = nn.CrossEntropyLoss()(outputs_test_known, pred) 226 | classifier_loss *= args.cls_par 227 | else: 228 | classifier_loss = torch.tensor(0.0).cuda() 229 | 230 | if args.ent: 231 | softmax_out_known = nn.Softmax(dim=1)(outputs_test_known) 232 | entropy_loss = torch.mean(loss.Entropy(softmax_out_known)) 233 | if args.gent: 234 | msoftmax = softmax_out.mean(dim=0) 235 | gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) 236 | entropy_loss -= gentropy_loss 237 | classifier_loss += entropy_loss * args.ent_par 238 | 239 | optimizer.zero_grad() 240 | classifier_loss.backward() 241 | optimizer.step() 242 | 243 | if iter_num % interval_iter == 0 or iter_num == max_iter: 244 | netF.eval() 245 | netB.eval() 246 | acc_os1, acc_os2, acc_unknown = cal_acc(dset_loaders['test'], netF, netB, netC, True, ENT_THRESHOLD) 247 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format(args.name, iter_num, max_iter, acc_os2, acc_os1, acc_unknown) 248 | args.out_file.write(log_str + '\n') 249 | args.out_file.flush() 250 | print(log_str+'\n') 251 | netF.train() 252 | netB.train() 253 | 254 | if args.issave: 255 | torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) 256 | torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) 257 | torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) 258 | 259 | return netF, netB, netC 260 | 261 | def obtain_label(loader, netF, netB, netC, args): 262 | start_test = True 263 | with torch.no_grad(): 264 | iter_test = iter(loader) 265 | for _ in range(len(loader)): 266 | data = iter_test.next() 267 | inputs = data[0] 268 | labels = data[1] 269 | inputs = inputs.cuda() 270 | feas = netB(netF(inputs)) 271 | outputs = netC(feas) 272 | if start_test: 273 | all_fea = feas.float().cpu() 274 | all_output = outputs.float().cpu() 275 | all_label = labels.float() 276 | start_test = False 277 | else: 278 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0) 279 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 280 | all_label = torch.cat((all_label, labels.float()), 0) 281 | 282 | all_output = nn.Softmax(dim=1)(all_output) 283 | _, predict = torch.max(all_output, 1) 284 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 285 | if args.distance == 'cosine': 286 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) 287 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() 288 | 289 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) / np.log(args.class_num) 290 | ent = ent.float().cpu() 291 | 292 | from sklearn.cluster import KMeans 293 | kmeans = KMeans(2, random_state=0).fit(ent.reshape(-1,1)) 294 | labels = kmeans.predict(ent.reshape(-1,1)) 295 | 296 | idx = np.where(labels==1)[0] 297 | iidx = 0 298 | if ent[idx].mean() > ent.mean(): 299 | iidx = 1 300 | known_idx = np.where(kmeans.labels_ != iidx)[0] 301 | 302 | all_fea = all_fea[known_idx,:] 303 | all_output = all_output[known_idx,:] 304 | predict = predict[known_idx] 305 | all_label_idx = all_label[known_idx] 306 | ENT_THRESHOLD = (kmeans.cluster_centers_).mean() 307 | 308 | all_fea = all_fea.float().cpu().numpy() 309 | K = all_output.size(1) 310 | aff = all_output.float().cpu().numpy() 311 | initc = aff.transpose().dot(all_fea) 312 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 313 | cls_count = np.eye(K)[predict].sum(axis=0) 314 | labelset = np.where(cls_count>args.threshold) 315 | labelset = labelset[0] 316 | 317 | dd = cdist(all_fea, initc[labelset], args.distance) 318 | pred_label = dd.argmin(axis=1) 319 | pred_label = labelset[pred_label] 320 | 321 | for round in range(1): 322 | aff = np.eye(K)[pred_label] 323 | initc = aff.transpose().dot(all_fea) 324 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 325 | dd = cdist(all_fea, initc[labelset], args.distance) 326 | pred_label = dd.argmin(axis=1) 327 | pred_label = labelset[pred_label] 328 | 329 | guess_label = args.class_num * np.ones(len(all_label), ) 330 | guess_label[known_idx] = pred_label 331 | 332 | acc = np.sum(guess_label == all_label.float().numpy()) / len(all_label_idx) 333 | log_str = 'Threshold = {:.2f}, Accuracy = {:.2f}% -> {:.2f}%'.format(ENT_THRESHOLD, accuracy*100, acc*100) 334 | 335 | return guess_label.astype('int'), ENT_THRESHOLD 336 | 337 | 338 | if __name__ == "__main__": 339 | parser = argparse.ArgumentParser(description='SHOT') 340 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 341 | parser.add_argument('--s', type=int, default=0, help="source") 342 | parser.add_argument('--t', type=int, default=1, help="target") 343 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations") 344 | parser.add_argument('--interval', type=int, default=15) 345 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 346 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 347 | parser.add_argument('--dset', type=str, default='office-home', choices=['office-home']) 348 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 349 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101") 350 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 351 | 352 | parser.add_argument('--gent', type=bool, default=True) 353 | parser.add_argument('--ent', type=bool, default=True) 354 | parser.add_argument('--threshold', type=int, default=0) 355 | parser.add_argument('--cls_par', type=float, default=0.3) 356 | parser.add_argument('--ent_par', type=float, default=1.0) 357 | parser.add_argument('--lr_decay1', type=float, default=0.1) 358 | parser.add_argument('--lr_decay2', type=float, default=1.0) 359 | 360 | parser.add_argument('--bottleneck', type=int, default=256) 361 | parser.add_argument('--epsilon', type=float, default=1e-5) 362 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 363 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 364 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"]) 365 | parser.add_argument('--output', type=str, default='san') 366 | parser.add_argument('--output_src', type=str, default='san') 367 | parser.add_argument('--da', type=str, default='oda', choices=['oda']) 368 | parser.add_argument('--issave', type=bool, default=True) 369 | args = parser.parse_args() 370 | 371 | if args.dset == 'office-home': 372 | names = ['Art', 'Clipart', 'Product', 'RealWorld'] 373 | args.class_num = 65 374 | 375 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 376 | SEED = args.seed 377 | torch.manual_seed(SEED) 378 | torch.cuda.manual_seed(SEED) 379 | np.random.seed(SEED) 380 | random.seed(SEED) 381 | # torch.backends.cudnn.deterministic = True 382 | 383 | for i in range(len(names)): 384 | if i == args.s: 385 | continue 386 | args.t = i 387 | 388 | folder = './data/' 389 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 390 | args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 391 | args.test_dset_path = args.t_dset_path 392 | 393 | if args.dset == 'office-home': 394 | if args.da == 'oda': 395 | args.class_num = 25 396 | args.src_classes = [i for i in range(25)] 397 | args.tar_classes = [i for i in range(65)] 398 | 399 | args.output_dir_src = osp.join(args.output_src, args.da, args.dset, names[args.s][0].upper()) 400 | args.output_dir = osp.join(args.output, args.da, args.dset, names[args.s][0].upper()+names[args.t][0].upper()) 401 | args.name = names[args.s][0].upper()+names[args.t][0].upper() 402 | 403 | if not osp.exists(args.output_dir): 404 | os.system('mkdir -p ' + args.output_dir) 405 | if not osp.exists(args.output_dir): 406 | os.mkdir(args.output_dir) 407 | 408 | args.savename = 'par_' + str(args.cls_par) 409 | args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w') 410 | args.out_file.write(print_args(args)+'\n') 411 | args.out_file.flush() 412 | train_target(args) -------------------------------------------------------------------------------- /object/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import math 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | def Entropy(input_): 10 | bs = input_.size(0) 11 | epsilon = 1e-5 12 | entropy = -input_ * torch.log(input_ + epsilon) 13 | entropy = torch.sum(entropy, dim=1) 14 | return entropy 15 | 16 | def grl_hook(coeff): 17 | def fun1(grad): 18 | return -coeff*grad.clone() 19 | return fun1 20 | 21 | def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None): 22 | softmax_output = input_list[1].detach() 23 | feature = input_list[0] 24 | if random_layer is None: 25 | op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1)) 26 | ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1))) 27 | else: 28 | random_out = random_layer.forward([feature, softmax_output]) 29 | ad_out = ad_net(random_out.view(-1, random_out.size(1))) 30 | batch_size = softmax_output.size(0) // 2 31 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 32 | if entropy is not None: 33 | entropy.register_hook(grl_hook(coeff)) 34 | entropy = 1.0+torch.exp(-entropy) 35 | source_mask = torch.ones_like(entropy) 36 | source_mask[feature.size(0)//2:] = 0 37 | source_weight = entropy*source_mask 38 | target_mask = torch.ones_like(entropy) 39 | target_mask[0:feature.size(0)//2] = 0 40 | target_weight = entropy*target_mask 41 | weight = source_weight / torch.sum(source_weight).detach().item() + \ 42 | target_weight / torch.sum(target_weight).detach().item() 43 | return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item() 44 | else: 45 | return nn.BCELoss()(ad_out, dc_target) 46 | 47 | def DANN(features, ad_net): 48 | ad_out = ad_net(features) 49 | batch_size = ad_out.size(0) // 2 50 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 51 | return nn.BCELoss()(ad_out, dc_target) 52 | 53 | 54 | class CrossEntropyLabelSmooth(nn.Module): 55 | """Cross entropy loss with label smoothing regularizer. 56 | Reference: 57 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 58 | Equation: y = (1 - epsilon) * y + epsilon / K. 59 | Args: 60 | num_classes (int): number of classes. 61 | epsilon (float): weight. 62 | """ 63 | 64 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True): 65 | super(CrossEntropyLabelSmooth, self).__init__() 66 | self.num_classes = num_classes 67 | self.epsilon = epsilon 68 | self.use_gpu = use_gpu 69 | self.reduction = reduction 70 | self.logsoftmax = nn.LogSoftmax(dim=1) 71 | 72 | def forward(self, inputs, targets): 73 | """ 74 | Args: 75 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 76 | targets: ground truth labels with shape (num_classes) 77 | """ 78 | log_probs = self.logsoftmax(inputs) 79 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 80 | if self.use_gpu: targets = targets.cuda() 81 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 82 | loss = (- targets * log_probs).sum(dim=1) 83 | if self.reduction: 84 | return loss.mean() 85 | else: 86 | return loss 87 | return loss -------------------------------------------------------------------------------- /object/network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | from torchvision import models 6 | from torch.autograd import Variable 7 | import math 8 | import torch.nn.utils.weight_norm as weightNorm 9 | from collections import OrderedDict 10 | 11 | def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0): 12 | return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha*iter_num / max_iter)) - (high - low) + low) 13 | 14 | def init_weights(m): 15 | classname = m.__class__.__name__ 16 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1: 17 | nn.init.kaiming_uniform_(m.weight) 18 | nn.init.zeros_(m.bias) 19 | elif classname.find('BatchNorm') != -1: 20 | nn.init.normal_(m.weight, 1.0, 0.02) 21 | nn.init.zeros_(m.bias) 22 | elif classname.find('Linear') != -1: 23 | nn.init.xavier_normal_(m.weight) 24 | nn.init.zeros_(m.bias) 25 | 26 | vgg_dict = {"vgg11":models.vgg11, "vgg13":models.vgg13, "vgg16":models.vgg16, "vgg19":models.vgg19, 27 | "vgg11bn":models.vgg11_bn, "vgg13bn":models.vgg13_bn, "vgg16bn":models.vgg16_bn, "vgg19bn":models.vgg19_bn} 28 | class VGGBase(nn.Module): 29 | def __init__(self, vgg_name): 30 | super(VGGBase, self).__init__() 31 | model_vgg = vgg_dict[vgg_name](pretrained=True) 32 | self.features = model_vgg.features 33 | self.classifier = nn.Sequential() 34 | for i in range(6): 35 | self.classifier.add_module("classifier"+str(i), model_vgg.classifier[i]) 36 | self.in_features = model_vgg.classifier[6].in_features 37 | 38 | def forward(self, x): 39 | x = self.features(x) 40 | x = x.view(x.size(0), -1) 41 | x = self.classifier(x) 42 | return x 43 | 44 | res_dict = {"resnet18":models.resnet18, "resnet34":models.resnet34, "resnet50":models.resnet50, 45 | "resnet101":models.resnet101, "resnet152":models.resnet152, "resnext50":models.resnext50_32x4d, "resnext101":models.resnext101_32x8d} 46 | 47 | class ResBase(nn.Module): 48 | def __init__(self, res_name): 49 | super(ResBase, self).__init__() 50 | model_resnet = res_dict[res_name](pretrained=True) 51 | self.conv1 = model_resnet.conv1 52 | self.bn1 = model_resnet.bn1 53 | self.relu = model_resnet.relu 54 | self.maxpool = model_resnet.maxpool 55 | self.layer1 = model_resnet.layer1 56 | self.layer2 = model_resnet.layer2 57 | self.layer3 = model_resnet.layer3 58 | self.layer4 = model_resnet.layer4 59 | self.avgpool = model_resnet.avgpool 60 | self.in_features = model_resnet.fc.in_features 61 | 62 | def forward(self, x): 63 | x = self.conv1(x) 64 | x = self.bn1(x) 65 | x = self.relu(x) 66 | x = self.maxpool(x) 67 | x = self.layer1(x) 68 | x = self.layer2(x) 69 | x = self.layer3(x) 70 | x = self.layer4(x) 71 | x = self.avgpool(x) 72 | x = x.view(x.size(0), -1) 73 | return x 74 | 75 | class feat_bottleneck(nn.Module): 76 | def __init__(self, feature_dim, bottleneck_dim=256, type="ori"): 77 | super(feat_bottleneck, self).__init__() 78 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.dropout = nn.Dropout(p=0.5) 81 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim) 82 | self.bottleneck.apply(init_weights) 83 | self.type = type 84 | 85 | def forward(self, x): 86 | x = self.bottleneck(x) 87 | if self.type == "bn": 88 | x = self.bn(x) 89 | return x 90 | 91 | class feat_classifier(nn.Module): 92 | def __init__(self, class_num, bottleneck_dim=256, type="linear"): 93 | super(feat_classifier, self).__init__() 94 | self.type = type 95 | if type == 'wn': 96 | self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight") 97 | self.fc.apply(init_weights) 98 | else: 99 | self.fc = nn.Linear(bottleneck_dim, class_num) 100 | self.fc.apply(init_weights) 101 | 102 | def forward(self, x): 103 | x = self.fc(x) 104 | return x 105 | 106 | class feat_classifier_two(nn.Module): 107 | def __init__(self, class_num, input_dim, bottleneck_dim=256): 108 | super(feat_classifier_two, self).__init__() 109 | self.type = type 110 | self.fc0 = nn.Linear(input_dim, bottleneck_dim) 111 | self.fc0.apply(init_weights) 112 | self.fc1 = nn.Linear(bottleneck_dim, class_num) 113 | self.fc1.apply(init_weights) 114 | 115 | def forward(self, x): 116 | x = self.fc0(x) 117 | x = self.fc1(x) 118 | return x 119 | 120 | class Res50(nn.Module): 121 | def __init__(self): 122 | super(Res50, self).__init__() 123 | model_resnet = models.resnet50(pretrained=True) 124 | self.conv1 = model_resnet.conv1 125 | self.bn1 = model_resnet.bn1 126 | self.relu = model_resnet.relu 127 | self.maxpool = model_resnet.maxpool 128 | self.layer1 = model_resnet.layer1 129 | self.layer2 = model_resnet.layer2 130 | self.layer3 = model_resnet.layer3 131 | self.layer4 = model_resnet.layer4 132 | self.avgpool = model_resnet.avgpool 133 | self.in_features = model_resnet.fc.in_features 134 | self.fc = model_resnet.fc 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | x = self.layer1(x) 142 | x = self.layer2(x) 143 | x = self.layer3(x) 144 | x = self.layer4(x) 145 | x = self.avgpool(x) 146 | x = x.view(x.size(0), -1) 147 | y = self.fc(x) 148 | return x, y -------------------------------------------------------------------------------- /object/run.sh: -------------------------------------------------------------------------------- 1 | # Table 3 A->D,W 2 | ~/anaconda3/envs/pytorch/bin/python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office --max_epoch 100 --s 0 3 | ~/anaconda3/envs/pytorch/bin/python image_target.py --cls_par 0.3 --da uda --dset office --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ 4 | 5 | # Table 4 A->C,P,R 6 | ~/anaconda3/envs/pytorch/bin/python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office-home --max_epoch 50 --s 0 7 | ~/anaconda3/envs/pytorch/bin/python image_target.py --cls_par 0.3 --da uda --dset office-home --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ 8 | 9 | # Table 5 VisDA-C 10 | ~/anaconda3/envs/pytorch/bin/python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset VISDA-C --net resnet101 --lr 1e-3 --max_epoch 10 --s 0 11 | ~/anaconda3/envs/pytorch/bin/python image_target.py --cls_par 0.3 --da uda --dset VISDA-C --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ --net resnet101 --lr 1e-3 12 | 13 | # Table 7 A->C,P,R (PDA) 14 | ~/anaconda3/envs/pytorch/bin/python image_source.py --trte val --output ckps/source/ --da pda --gpu_id 0 --dset office-home --max_epoch 50 --s 0 15 | ~/anaconda3/envs/pytorch/bin/python image_target.py --cls_par 0.3 --threshold 10 --da pda --dset office-home --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ 16 | 17 | # Table 7 A->C,P,R (ODA) 18 | ~/anaconda3/envs/pytorch/bin/python image_source.py --trte val --output ckps/source/ --da oda --gpu_id 0 --dset office-home --max_epoch 50 --s 0 19 | ~/anaconda3/envs/pytorch/bin/python image_target_oda.py --cls_par 0.3 --da oda --dset office-home --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ 20 | 21 | 22 | # Table 8 C,D,W->A (MSDA) 23 | ~/anaconda3/envs/pytorch/bin/python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office-caltech --net resnet101 --max_epoch 100 --s 1 24 | ~/anaconda3/envs/pytorch/bin/python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office-caltech --net resnet101 --max_epoch 100 --s 2 25 | ~/anaconda3/envs/pytorch/bin/python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office-caltech --net resnet101 --max_epoch 100 --s 3 26 | 27 | ~/anaconda3/envs/pytorch/bin/python image_target.py --cls_par 0.3 --da uda --dset office-caltech --net resnet101 --gpu_id 0 --s 1 --output_src ckps/source/ --output ckps/target/ 28 | ~/anaconda3/envs/pytorch/bin/python image_target.py --cls_par 0.3 --da uda --dset office-caltech --net resnet101 --gpu_id 0 --s 2 --output_src ckps/source/ --output ckps/target/ 29 | ~/anaconda3/envs/pytorch/bin/python image_target.py --cls_par 0.3 --da uda --dset office-caltech --net resnet101 --gpu_id 0 --s 3 --output_src ckps/source/ --output ckps/target/ 30 | 31 | ~/anaconda3/envs/pytorch/bin/python image_multisource.py --cls_par 0.3 --da uda --dset office-caltech --gpu_id 0 --t 0 --output_src ckps/source/ --output ckps/target/ 32 | 33 | # Table 8 A->(C,D,W)(MTDA) 34 | ~/anaconda3/envs/pytorch/bin/python image_multitarget.py --cls_par 0.3 --da uda --dset office-caltech --net resnet101 --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ 35 | 36 | 37 | # Table 9 ImageNet->Caltech(PDA) 38 | ~/anaconda3/envs/pytorch/bin/python image_pretrained.py --gpu_id 0 --output ckps/target --cls_par 0.3 39 | 40 | -------------------------------------------------------------------------------- /pretrained-models.md: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1Hn3MXbwQF-A6UTBZG3L3ZBiwSrxctB35?usp=sharing 2 | 3 | All the pre-trained source models are provided here. 4 | 5 | | source (resnet50)\ seed | 2019 | 2020 | 2021 | 6 | | :---------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | 7 | | Amazon | [pretrained source model](https://drive.google.com/drive/folders/1tR5nzN6GSSzwJBQpeMNdwlFmgudf1LMt) | [pretrained source model](https://drive.google.com/drive/folders/12Tt9xjoCPoouNvxyaYefjA8utu_Ns15o) | [pretrained source model](https://drive.google.com/drive/folders/1Ky-dryAkIFanjZG8zvpwtFKqcKfXoJwX) | 8 | | Dslr | [pretrained source model](https://drive.google.com/drive/folders/1gyRALSpKlPPBtj8fpk722s3JTEjD2JR_) | [pretrained source model](https://drive.google.com/drive/folders/1EO2ZN4fuWEM5uH0yowZpxqIbAm3G0Qgf) | [pretrained source model](https://drive.google.com/drive/folders/1cPUKwimnK4dfT4K-FbjgRnLsD1iqjysA) | 9 | | Webcam | [pretrained source model](https://drive.google.com/drive/folders/1P4GH-BOoFoVWRqrV2h5sZhSvwmetVuWU) | [pretrained source model](https://drive.google.com/drive/folders/1EVxejkRJAvgdR_PleWo6WC42yVusAbae) | [pretrained source model](https://drive.google.com/drive/folders/1US_yJD0dDubyjKT2vXOIVyZXkQJ3WbU7) | 10 | 11 | | source (resnet50)\ seed | 2019 | 2020 | 2021 | 12 | | :---------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | 13 | | Art | [pretrained source model](https://drive.google.com/drive/folders/1t44cr406AKNwhMg0TmHWyXKoUyOAZhKL) | [pretrained source model](https://drive.google.com/drive/folders/1stSi6Lx5T-PRp-Hxc29wdvUdlXjPMiDZ) | [pretrained source model](https://drive.google.com/drive/folders/1xEkqFyTDoj5rBmf-XnnE5uuiNMVdz28W) | 14 | | Clipart | [pretrained source model](https://drive.google.com/drive/folders/1K7NXVqKwCG0HZlYPGLqa0Klpa47HmXcT) | [pretrained source model](https://drive.google.com/drive/folders/1mZK0v1XlocKWezvd5bdrM28_6hI5A543) | [pretrained source model](https://drive.google.com/drive/folders/1HrsVlb5KnBmxGQ-ZyepNKGyZr4wRSzyq) | 15 | | Product | [pretrained source model](https://drive.google.com/drive/folders/18-JZDSyrahcSx4IV-H2lPiLptgWmOfFy) | [pretrained source model](https://drive.google.com/drive/folders/1Liyf9VGepW2ulBp7EHWjN6ho14S-Q_3i) | [pretrained source model](https://drive.google.com/drive/folders/1ej81eKV8gBfos4byUjg9TZSAez13RkbO) | 16 | | RealWorld | [pretrained source model](https://drive.google.com/drive/folders/1f_s4i3l1HZl2HrovSQBJpwamFqsZjBEE) | [pretrained source model](https://drive.google.com/drive/folders/1jn-pi2IIWIbVQ_cMXCKd3UwuCiDrQnz3) | [pretrained source model](https://drive.google.com/drive/folders/18RTAq7wlyloZn00QR03DQvmKG0VX63rg) | 17 | 18 | | source (resnet101)\ seed | 2019 | 2020 | 2021 | 19 | | :----------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | 20 | | train | [pretrained source model](https://drive.google.com/drive/folders/1Dev9TFgdyw1hcc8F9ngjJ6omHuRpWliK) | [pretrained source model](https://drive.google.com/drive/folders/1AeTt5sPbo-7oNX5u7Jbm8LSf3Pp8buTd) | [pretrained source model](https://drive.google.com/drive/folders/1JGQGCQwLLI5A2FNAHeEJgtjqAXkrco6d) | 21 | 22 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Official implementation for **SHOT** 2 | 3 | ## [**[ICML-2020] Do We Really Need to Access the Source Data? Source Hypothesis Transfer for Unsupervised Domain Adaptation**](http://proceedings.mlr.press/v119/liang20a.html) 4 | 5 | 6 | 7 | - **2022/6/6 We correct a bug in the pseudo-labeling function (def obtain_label), many thanks to @TomSheng21.** 8 | - **2022/2/8 We upload the pretrained source models via Google drive in [pretrained-models.md](./pretrained-models.md).** 9 | 10 | 11 | 12 | ### Attention-v2: ***we release the code of our recent black-box UDA method (DINE, https://arxiv.org/pdf/2104.01539.pdf) in the following repository (https://github.com/tim-learn/DINE).*** 13 | 14 | #### Attention: ***The code of our stronger TPAMI extension (SHOT++, https://arxiv.org/pdf/2012.07297.pdf) has been released in a new repository (https://github.com/tim-learn/SHOT-plus).*** 15 | 16 | 17 | 18 | ### Results: 19 | 20 | #### **Note that we update the code and further consider the standard learning rate scheduler like DANN and report new results in the final camera ready version.** Please refer [results.md](./results.md) for the detailed results on various datasets. 21 | 22 | *We have updated the results for **Digits**. Now the results of SHOT-IM for **Digits** are stable and promising. (Thanks to @wengzejia1 for pointing the bugs in **uda_digit.py**).* 23 | 24 | 25 | ### Framework: 26 | 27 | 28 | 29 | ### Prerequisites: 30 | - python == 3.6.8 31 | - pytorch ==1.1.0 32 | - torchvision == 0.3.0 33 | - numpy, scipy, sklearn, PIL, argparse, tqdm 34 | 35 | ### Dataset: 36 | 37 | - Please manually download the datasets [Office](https://drive.google.com/file/d/0B4IapRTv9pJ1WGZVd1VDMmhwdlE/view), [Office-Home](https://drive.google.com/file/d/0B81rNlvomiwed0V1YUxQdC1uOTg/view), [VisDA-C](https://github.com/VisionLearningGroup/taskcv-2017-public/tree/master/classification), [Office-Caltech](http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar) from the official websites, and modify the path of images in each '.txt' under the folder './object/data/'. [**How to generate such txt files could be found in https://github.com/tim-learn/Generate_list **] 38 | 39 | - Concerning the **Digits** dsatasets, the code will automatically download three digit datasets (i.e., MNIST, USPS, and SVHN) in './digit/data/'. 40 | 41 | 42 | ### Training: 43 | 1. ##### Unsupervised Closed-set Domain Adaptation (UDA) on the Digits dataset 44 | - MNIST -> USPS (**m2u**) SHOT (**cls_par = 0.1**) and SHOT-IM (**cls_par = 0.0**) 45 | ```python 46 | cd digit/ 47 | python uda_digit.py --dset m2u --gpu_id 0 --output ckps_digits --cls_par 0.0 48 | python uda_digit.py --dset m2u --gpu_id 0 --output ckps_digits --cls_par 0.1 49 | ``` 50 | 51 | 2. ##### Unsupervised Closed-set Domain Adaptation (UDA) on the Office/ Office-Home dataset 52 | - Train model on the source domain **A** (**s = 0**) 53 | ```python 54 | cd object/ 55 | python image_source.py --trte val --da uda --output ckps/source/ --gpu_id 0 --dset office --max_epoch 100 --s 0 56 | ``` 57 | 58 | - Adaptation to other target domains **D and W**, respectively 59 | ```python 60 | python image_target.py --cls_par 0.3 --da uda --output_src ckps/source/ --output ckps/target/ --gpu_id 0 --dset office --s 0 61 | ``` 62 | 63 | 3. ##### Unsupervised Closed-set Domain Adaptation (UDA) on the VisDA-C dataset 64 | - Synthetic-to-real 65 | ```python 66 | cd object/ 67 | python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset VISDA-C --net resnet101 --lr 1e-3 --max_epoch 10 --s 0 68 | python image_target.py --cls_par 0.3 --da uda --dset VISDA-C --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ --net resnet101 --lr 1e-3 69 | ``` 70 | 71 | 4. ##### Unsupervised Partial-set Domain Adaptation (PDA) on the Office-Home dataset 72 | - Train model on the source domain **A** (**s = 0**) 73 | ```python 74 | cd object/ 75 | python image_source.py --trte val --da pda --output ckps/source/ --gpu_id 0 --dset office-home --max_epoch 50 --s 0 76 | ``` 77 | 78 | - Adaptation to other target domains **C and P and R**, respectively 79 | ```python 80 | python image_target.py --cls_par 0.3 --threshold 10 --da pda --dset office-home --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ 81 | ``` 82 | 83 | 5. ##### Unsupervised Open-set Domain Adaptation (ODA) on the Office-Home dataset 84 | - Train model on the source domain **A** (**s = 0**) 85 | ```python 86 | cd object/ 87 | python image_source.py --trte val --da oda --output ckps/source/ --gpu_id 0 --dset office-home --max_epoch 50 --s 0 88 | ``` 89 | 90 | - Adaptation to other target domains **C and P and R**, respectively 91 | ```python 92 | python image_target_oda.py --cls_par 0.3 --da oda --dset office-home --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ 93 | ``` 94 | 95 | 6. ##### Unsupervised Multi-source Domain Adaptation (MSDA) on the Office-Caltech dataset 96 | - Train model on the source domains **A** (**s = 0**), **C** (**s = 1**), **D** (**s = 2**), respectively 97 | ```python 98 | cd object/ 99 | python image_source.py --trte val --da uda --output ckps/source/ --gpu_id 0 --dset office-caltech --max_epoch 100 --s 0 100 | python image_source.py --trte val --da uda --output ckps/source/ --gpu_id 0 --dset office-caltech --max_epoch 100 --s 1 101 | python image_source.py --trte val --da uda --output ckps/source/ --gpu_id 0 --dset office-caltech --max_epoch 100 --s 2 102 | ``` 103 | 104 | - Adaptation to the target domain **W** (**t = 3**) 105 | ```python 106 | python image_target.py --cls_par 0.3 --da uda --output_src ckps/source/ --output ckps/target/ --gpu_id 0 --dset office --s 0 107 | python image_target.py --cls_par 0.3 --da uda --output_src ckps/source/ --output ckps/target/ --gpu_id 0 --dset office --s 1 108 | python image_target.py --cls_par 0.3 --da uda --output_src ckps/source/ --output ckps/target/ --gpu_id 0 --dset office --s 2 109 | python image_multisource.py --cls_par 0.0 --da uda --dset office-caltech --gpu_id 0 --t 3 --output_src ckps/source/ --output ckps/target/ 110 | ``` 111 | 112 | 7. ##### Unsupervised Multi-target Domain Adaptation (MTDA) on the Office-Caltech dataset 113 | - Train model on the source domain **A** (**s = 0**) 114 | ```python 115 | cd object/ 116 | python image_source.py --trte val --da uda --output ckps/source/ --gpu_id 0 --dset office-caltech --max_epoch 100 --s 0 117 | ``` 118 | 119 | - Adaptation to multiple target domains **C and P and R** at the same time 120 | ```python 121 | python image_multitarget.py --cls_par 0.3 --da uda --dset office-caltech --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ 122 | ``` 123 | 124 | 8. ##### Unsupervised Partial Domain Adaptation (PDA) on the ImageNet-Caltech dataset without source training by ourselves (using the downloaded Pytorch ResNet50 model directly) 125 | - ImageNet -> Caltech (84 classes) [following the protocol in [PADA](https://github.com/thuml/PADA/tree/master/pytorch/data/imagenet-caltech)] 126 | ```python 127 | cd object/ 128 | python image_pretrained.py --gpu_id 0 --output ckps/target/ --cls_par 0.3 129 | ``` 130 | 131 | **Please refer *./object/run.sh*** for all the settings for different methods and scenarios. 132 | 133 | ### Citation 134 | 135 | If you find this code useful for your research, please cite our papers 136 | 137 | ``` 138 | @inproceedings{liang2020we, 139 | title={Do We Really Need to Access the Source Data? Source Hypothesis Transfer for Unsupervised Domain Adaptation}, 140 | author={Liang, Jian and Hu, Dapeng and Feng, Jiashi}, 141 | booktitle={International Conference on Machine Learning (ICML)}, 142 | pages={6028--6039}, 143 | year={2020} 144 | } 145 | 146 | @article{liang2021source, 147 | title={Source Data-absent Unsupervised Domain Adaptation through Hypothesis Transfer and Labeling Transfer}, 148 | author={Liang, Jian and Hu, Dapeng and Wang, Yunbo and He, Ran and Feng, Jiashi}, 149 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, 150 | year={2021}, 151 | note={In Press} 152 | } 153 | ``` 154 | 155 | 156 | ### Contact 157 | 158 | - [liangjian92@gmail.com](mailto:liangjian92@gmail.com) 159 | - [dapeng.hu@u.nus.edu](mailto:dapeng.hu@u.nus.edu) 160 | - [elefjia@nus.edu.sg](mailto:elefjia@nus.edu.sg) 161 | -------------------------------------------------------------------------------- /results.md: -------------------------------------------------------------------------------- 1 | Code for our ICML-2020 paper [**Do We Really Need to Access the Source Data? Source Hypothesis Transfer for Unsupervised Domain Adaptation**](https://arxiv.org/abs/2002.08546). 2 | 3 | ### Framework: 4 | 5 | 6 | 7 | ### Results: 8 | 9 | #### Table 2 [UDA results on Digits] 10 | 11 | | Methods | S->M | U->M | M->U | Avg. | 12 | | -------------- | ---- | ---- | ---- | ---- | 13 | | srconly (2019) | 71.5 | 85.5 | 82.5 | | 14 | | srconly (2020) | 69.2 | 89.8 | 77.6 | | 15 | | srconly (2021) | 69.7 | 88.7 | 79.0 | | 16 | | srconly (Avg.) | 70.2 | 88.0 | 79.7 | 79.3 | 17 | | SHOT-IM (2019) | 98.9 | 98.6 | 97.8 | | 18 | | SHOT-IM (2020) | 99.0 | 97.8 | 97.7 | | 19 | | SHOT-IM (2021) | 98.9 | 97.6 | 97.7 | | 20 | | SHOT-IM (Avg.) | 99.0 | 97.6 | 97.7 | 98.2 | 21 | | SHOT (2019) | 98.8 | 98.6 | 98.0 | | 22 | | SHOT (2020) | 99.0 | 97.6 | 97.8 | | 23 | | SHOT (2021) | 99.0 | 97.7 | 97.7 | | 24 | | SHOT (Avg.) | 98.9 | 98.0 | 97.9 | 98.3 | 25 | | Oracle (2019) | 99.2 | 99.2 | 97.1 | | 26 | | Oracle (2020) | 99.2 | 99.2 | 97.0 | | 27 | | Oracle (2021) | 99.3 | 99.3 | 97.0 | | 28 | | Oracle (Avg.) | 99.2 | 99.2 | 97.0 | 98.8 | 29 | 30 | #### Table 3 [UDA results on Office] 31 | 32 | | Methods | A->D | A->W | D->A | D->W | W->A | W->D | Avg. | 33 | | -------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | 34 | | srconly (2019) | 79.9 | 77.5 | 58.9 | 95.0 | 64.6 | 98.4 | | 35 | | srconly (2020) | 81.5 | 75.8 | 61.6 | 96.0 | 63.3 | 99.0 | | 36 | | srconly (2021) | 80.9 | 77.5 | 60.2 | 94.8 | 62.9 | 98.8 | | 37 | | srconly (Avg.) | 80.8 | 76.9 | 60.3 | 95.3 | 63.6 | 98.7 | 79.3 | 38 | | SHOT-IM (2019) | 88.8 | 90.7 | 71.7 | 98.5 | 71.7 | 99.8 | | 39 | | SHOT-IM (2020) | 92.6 | 92.2 | 72.4 | 98.4 | 71.1 | 100. | | 40 | | SHOT-IM (2021) | 90.6 | 90.7 | 73.3 | 98.0 | 71.2 | 99.8 | | 41 | | SHOT-IM (Avg.) | 90.6 | 91.2 | 72.5 | 98.3 | 71.4 | 99.9 | 87.3 | 42 | | SHOT (2019) | 93.4 | 88.8 | 74.9 | 98.5 | 74.6 | 99.8 | | 43 | | SHOT (2020) | 95.0 | 92.0 | 75.7 | 98.6 | 73.7 | 100. | | 44 | | SHOT (2021) | 93.8 | 89.7 | 73.6 | 98.2 | 74.6 | 99.8 | | 45 | | SHOT (Avg.) | 94.0 | 90.1 | 74.7 | 98.4 | 74.3 | 99.9 | 88.6 | 46 | 47 | #### Table 4 [UDA results on Office-Home] 48 | 49 | | Methods |Ar->Cl|Ar->Pr|Ar->Re|Cl->Ar|Cl->Pr|Cl->Re|Pr->Ar|Pr->Cl|Pr->Re|Re->Ar|Re->Cl|Re->Pr| Avg. | 50 | | -------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | 51 | | srconly (2019) | 45.2 | 67.2 | 75.0 | 52.4 | 62.6 | 64.9 | 52.4 | 40.6 | 73.0 | 65.0 | 43.8 | 78.1 | | 52 | | srconly (2020) | 44.2 | 67.2 | 74.4 | 52.3 | 63.1 | 64.5 | 53.1 | 41.0 | 73.7 | 65.3 | 46.8 | 77.9 | | 53 | | srconly (2021) | 44.5 | 67.7 | 74.8 | 53.4 | 62.4 | 64.9 | 53.4 | 40.4 | 72.9 | 65.7 | 45.8 | 78.1 | | 54 | | srconly (Avg.) | 44.6 | 67.3 | 74.8 | 52.7 | 62.7 | 64.8 | 53.0 | 40.6 | 73.2 | 65.3 | 45.4 | 78.0 | 60.2 | 55 | | SHOT-IM (2019) | 56.5 | 77.1 | 80.8 | 67.7 | 73.3 | 75.1 | 65.5 | 54.5 | 80.6 | 73.4 | 57.2 | 84.0 | | 56 | | SHOT-IM (2020) | 54.7 | 76.3 | 80.2 | 66.8 | 75.8 | 76.2 | 65.6 | 53.9 | 80.7 | 73.6 | 58.3 | 83.5 | | 57 | | SHOT-IM (2021) | 54.9 | 76.4 | 80.1 | 66.2 | 73.8 | 75.0 | 65.7 | 56.1 | 80.7 | 74.2 | 59.6 | 82.9 | | 58 | | SHOT-IM (Avg.) | 55.4 | 76.6 | 80.4 | 66.9 | 74.3 | 75.4 | 65.6 | 54.8 | 80.7 | 73.7 | 58.4 | 83.4 | 70.5 | 59 | | SHOT (2019) | 57.3 | 79.3 | 81.8 | 68.1 | 77.1 | 78.0 | 67.8 | 55.0 | 82.5 | 73.2 | 58.5 | 84.1 | | 60 | | SHOT (2020) | 57.1 | 77.5 | 81.6 | 68.4 | 78.2 | 77.9 | 67.0 | 55.6 | 82.4 | 73.6 | 60.2 | 84.6 | | 61 | | SHOT (2021) | 57.0 | 77.6 | 81.0 | 67.5 | 79.2 | 78.3 | 67.3 | 54.1 | 81.6 | 73.0 | 57.8 | 84.2 | | 62 | | SHOT (Avg.) | 57.1 | 78.1 | 81.5 | 68.0 | 78.2 | 78.1 | 67.4 | 54.9 | 82.2 | 73.3 | 58.8 | 84.3 | 71.8 | 63 | 64 | #### Table 5 [UDA results on VisDA-C] 65 | 66 | | Methods | plane | bcycl | bus | car | horse | knife | mcycl | person | plant | sktbrd | train | truck | Per-class | 67 | | -------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | 68 | | srconly (2019) | 57.1 | 20.5 | 48.6 | 60.8 | 66.2 | 3.6 | 80.7 | 23.9 | 38.5 | 31.0 | 87.0 | 10.7 | | 69 | | srconly (2020) | 65.1 | 18.9 | 57.2 | 66.9 | 69.9 | 11.0| 84.7 | 23.9 | 69.4 | 34.0 | 83.8 | 9.3 | | 70 | | srconly (2021) | 60.5 | 25.5 | 47.0 | 75.2 | 61.3 | 4.2 | 81.1 | 21.9 | 63.9 | 26.7 | 83.1 | 4.0 | | 71 | | srconly (Avg.) | 60.9 | 21.6 | 50.9 | 67.6 | 65.8 | 6.3 | 82.2 | 23.2 | 57.3 | 30.6 | 84.6 | 8.0 | 46.6 | 72 | | SHOT-IM (2019) | 94.3 | 86.6 | 78.1 | 54.0 | 91.0 | 92.3 | 79.1 | 78.9 | 88.4 | 86.0 | 88.0 | 50.7 | | 73 | | SHOT-IM (2020) | 93.4 | 87.1 | 80.4 | 51.7 | 91.5 | 92.9 | 80.0 | 78.0 | 89.6 | 85.1 | 87.2 | 51.3 | | 74 | | SHOT-IM (2021) | 93.5 | 85.7 | 77.6 | 46.3 | 90.5 | 95.1 | 77.9 | 78.1 | 89.7 | 85.0 | 88.5 | 51.2 | | 75 | | SHOT-IM (Avg.) | 93.7 | 86.4 | 78.7 | 50.7 | 91.0 | 93.5 | 79.0 | 78.3 | 89.2 | 85.4 | 87.9 | 51.1 | 80.4 | 76 | | SHOT (2019) | 93.8 | 89.0 | 81.4 | 57.0 | 93.4 | 94.7 | 81.3 | 80.3 | 90.5 | 89.1 | 85.3 | 58.4 | | 77 | | SHOT (2020) | 94.5 | 87.3 | 80.0 | 57.1 | 93.1 | 94.5 | 82.0 | 80.7 | 91.7 | 89.4 | 87.0 | 58.3 | | 78 | | SHOT (2021) | 94.7 | 89.1 | 78.7 | 57.8 | 92.8 | 95.5 | 78.8 | 79.9 | 92.4 | 89.0 | 86.6 | 57.9 | | 79 | | SHOT (Avg.) | 94.3 | 88.5 | 80.1 | 57.3 | 93.1 | 94.9 | 80.7 | 80.3 | 91.5 | 89.1 | 86.3 | 58.2 | 82.9 | 80 | 81 | #### Table 7 [PDA/ ODA results on Office-Home] 82 | 83 | | Methods@PDA |Ar->Cl|Ar->Pr|Ar->Re|Cl->Ar|Cl->Pr|Cl->Re|Pr->Ar|Pr->Cl|Pr->Re|Re->Ar|Re->Cl|Re->Pr| Avg. | 84 | | -------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | 85 | | srconly (2019) | 46.0 | 69.7 | 80.7 | 56.3 | 60.4 | 66.9 | 60.2 | 40.6 | 76.0 | 70.8 | 48.6 | 78.5 | | 86 | | srconly (2020) | 45.1 | 71.0 | 80.8 | 55.7 | 61.8 | 66.4 | 61.4 | 39.7 | 76.1 | 70.6 | 49.7 | 76.3 | | 87 | | srconly (2021) | 44.5 | 70.5 | 81.3 | 56.8 | 60.2 | 65.2 | 61.2 | 40.0 | 76.5 | 70.9 | 47.2 | 77.2 | | 88 | | srconly (Avg.) | 45.2 | 70.4 | 81.0 | 56.2 | 60.8 | 66.2 | 60.9 | 40.1 | 76.2 | 70.8 | 48.5 | 77.3 | 62.8 | 89 | | SHOT-IM (2019) | 57.5 | 86.2 | 88.2 | 69.3 | 73.6 | 79.9 | 79.7 | 62.2 | 89.0 | 80.8 | 66.6 | 91.0 | | 90 | | SHOT-IM (2020) | 61.2 | 82.0 | 87.8 | 73.3 | 74.4 | 80.6 | 74.1 | 58.8 | 90.0 | 81.7 | 70.8 | 87.1 | | 91 | | SHOT-IM (2021) | 55.0 | 82.6 | 90.3 | 74.5 | 74.0 | 76.5 | 74.4 | 60.8 | 91.4 | 83.0 | 67.5 | 87.3 | | 92 | | SHOT-IM (Avg.) | 57.9 | 83.6 | 88.8 | 72.4 | 74.0 | 79.0 | 76.1 | 60.6 | 90.1 | 81.9 | 68.3 | 88.5 | 76.8 | 93 | | SHOT (2019) | 65.0 | 85.0 | 93.3 | 75.7 | 79.3 | 88.9 | 80.5 | 65.3 | 90.1 | 80.9 | 67.0 | 86.3 | | 94 | | SHOT (2020) | 64.1 | 82.0 | 92.7 | 77.6 | 74.8 | 90.7 | 80.0 | 63.5 | 88.4 | 79.9 | 66.8 | 85.0 | | 95 | | SHOT (2021) | 65.2 | 88.7 | 92.2 | 75.7 | 78.8 | 86.8 | 78.5 | 64.1 | 90.1 | 80.9 | 65.3 | 86.0 | | 96 | | SHOT (Avg.) | 64.8 | 85.2 | 92.7 | 76.3 | 77.6 | 88.8 | 79.7 | 64.3 | 89.5 | 80.6 | 66.4 | 85.8 | 79.3 | 97 | 98 | 99 | | Methods@ODA |Ar->Cl|Ar->Pr|Ar->Re|Cl->Ar|Cl->Pr|Cl->Re|Pr->Ar|Pr->Cl|Pr->Re|Re->Ar|Re->Cl|Re->Pr| Avg. | 100 | | -------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | 101 | | srconly (2019) | 37.4 | 54.7 | 69.9 | 34.2 | 44.3 | 49.7 | 37.7 | 30.1 | 56.2 | 50.6 | 35.2 | 61.6 | | 102 | | srconly (2020) | 36.4 | 55.0 | 69.0 | 33.3 | 44.7 | 47.8 | 34.6 | 29.2 | 55.7 | 53.2 | 36.0 | 62.4 | | 103 | | srconly (2021) | 35.1 | 54.8 | 68.4 | 33.8 | 44.1 | 50.1 | 38.2 | 28.1 | 58.3 | 50.3 | 34.1 | 62.9 | | 104 | | srconly (Avg.) | 36.3 | 54.8 | 69.1 | 33.8 | 44.4 | 49.2 | 36.8 | 29.2 | 56.8 | 51.4 | 35.1 | 62.3 | 46.6 | 105 | | SHOT-IM (2019) | 61.6 | 80.1 | 84.4 | 61.8 | 74.0 | 81.9 | 63.6 | 58.5 | 83.1 | 68.4 | 63.7 | 82.2 | | 106 | | SHOT-IM (2020) | 63.4 | 76.0 | 83.2 | 61.4 | 74.3 | 78.7 | 63.8 | 59.6 | 83.1 | 70.0 | 61.8 | 82.7 | | 107 | | SHOT-IM (2021) | 62.4 | 77.3 | 84.1 | 59.6 | 71.9 | 77.7 | 66.7 | 58.0 | 83.0 | 68.9 | 60.6 | 81.6 | | 108 | | SHOT-IM (Avg.) | 62.5 | 77.8 | 83.9 | 60.9 | 73.4 | 79.4 | 64.7 | 58.7 | 83.1 | 69.1 | 62.0 | 82.1 | 71.5 | 109 | | SHOT (2019) | 63.9 | 80.6 | 85.6 | 63.6 | 77.1 | 83.2 | 64.9 | 58.3 | 83.2 | 69.7 | 65.2 | 82.8 | | 110 | | SHOT (2020) | 64.0 | 80.4 | 84.7 | 63.4 | 75.3 | 81.6 | 65.1 | 60.9 | 82.8 | 69.9 | 64.4 | 82.4 | | 111 | | SHOT (2021) | 65.6 | 80.2 | 83.8 | 62.2 | 73.7 | 78.8 | 65.9 | 58.8 | 83.9 | 69.2 | 64.1 | 81.7 | | 112 | | SHOT (Avg.) | 64.5 | 80.4 | 84.7 | 63.1 | 75.4 | 81.2 | 65.3 | 59.3 | 83.3 | 69.6 | 64.6 | 82.3 | 72.8 | 113 | 114 | #### Table 8 [MSDA/ MTDA results on Office-Caltech] 115 | 116 | | Methods@MSDA | ->A | ->C | ->D | ->W | Avg. | 117 | | -------------- | ---- | ---- | ---- | ---- | ---- | 118 | | srconly (2019) | 95.2 | 93.9 | 99.4 | 98.0 | | 119 | | srconly (2020) | 95.4 | 93.5 | 98.7 | 98.6 | | 120 | | srconly (2021) | 95.6 | 93.7 | 98.7 | 98.3 | | 121 | | srconly (Avg.) | 95.4 | 93.7 | 98.9 | 98.3 | 96.6 | 122 | | SHOT-IM (2019) | 95.8 | 96.0 | 99.4 | 99.7 | | 123 | | SHOT-IM (2020) | 96.5 | 95.9 | 97.5 | 99.7 | | 124 | | SHOT-IM (2021) | 96.4 | 96.3 | 98.7 | 99.7 | | 125 | | SHOT-IM (Avg.) | 96.2 | 96.1 | 98.5 | 99.7 | 97.6 | 126 | | SHOT (2019) | 96.2 | 95.9 | 98.7 | 99.7 | | 127 | | SHOT (2020) | 96.5 | 96.1 | 98.7 | 99.7 || 128 | | SHOT (2021) | 96.6 | 96.6 | 98.1 | 99.7 | | 129 | | SHOT (Avg.) | 96.4 | 96.2 | 98.5 | 99.7 | 97.7 | 130 | 131 | | Methods@MTDA | A-> | C-> | D-> | W-> | Avg. | 132 | | -------------- | ---- | ---- | ---- | ---- | ---- | 133 | | srconly (2019) | 90.4 | 95.9 | 90.3 | 90.6 | | 134 | | srconly (2020) | 91.2 | 95.9 | 90.2 | 91.1 | | 135 | | srconly (2021) | 90.5 | 96.5 | 90.2 | 91.1 | | 136 | | srconly (Avg.) | 90.7 | 96.1 | 90.2 | 90.9 | 92.0 | 137 | | SHOT-IM (2019) | 96.6 | 97.5 | 96.3 | 96.0 | | 138 | | SHOT-IM (2020) | 95.1 | 96.7 | 96.3 | 96.4 | | 139 | | SHOT-IM (2021) | 95.4 | 97.3 | 96.3 | 96.0 | | 140 | | SHOT-IM (Avg.) | 95.7 | 97.2 | 96.3 | 96.1 | 96.3 | 141 | | SHOT (2019) | 96.6 | 97.5 | 96.4 | 96.0 | | 142 | | SHOT (2020) | 95.4 | 97.0 | 96.5 | 96.7 | | 143 | | SHOT (2021) | 96.6 | 97.5 | 96.0 | 96.0 | | 144 | | SHOT (Avg.) | 96.2 | 97.3 | 96.3 | 96.2 | 96.5 | 145 | 146 | 147 | #### Table 9 [PDA results on ImageNet->Caltech] 148 | 149 | | Methods@PDA | 2019 | 2020 | 2021 | Avg. | 150 | | -------------- | ---- | ---- | ---- | ---- | 151 | | srconly | 69.7 | 69.7 | 69.7 | 69.7 | 152 | | SHOT-IM | 81.1 | 82.2 | 81.8 | 81.7 | 153 | | SHOT | 83.2 | 83.3 | 83.4 | 83.3 | 154 | 155 | 156 | 157 | 158 | ### Citation 159 | 160 | If you find this code useful for your research, please cite our paper 161 | 162 | > @inproceedings{liang2020shot, 163 | >     title={Do We Really Need to Access the Source Data? Source Hypothesis Transfer for Unsupervised Domain Adaptation}, 164 | >     author={Liang, Jian and Hu, Dapeng and Feng, Jiashi}, 165 | >     booktitle={International Conference on Machine Learning (ICML)}, 166 | >     pages={xx-xx}, 167 | >     month = {July}, 168 | >     year={2020} 169 | > } 170 | 171 | ### Contact 172 | 173 | - [liangjian92@gmail.com](mailto:liangjian92@gmail.com) 174 | - [dapeng.hu@u.nus.edu](mailto:dapeng.hu@u.nus.edu) 175 | - [elefjia@nus.edu.sg](mailto:elefjia@nus.edu.sg) 176 | --------------------------------------------------------------------------------