├── figs ├── shot.jpg ├── shot_plus.jpg └── shot_ssl.jpg ├── supp └── shot++_supp.pdf ├── code ├── digit │ ├── data_load │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── mnist.cpython-36.pyc │ │ │ ├── mnist.cpython-37.pyc │ │ │ ├── svhn.cpython-36.pyc │ │ │ ├── svhn.cpython-37.pyc │ │ │ ├── usps.cpython-36.pyc │ │ │ ├── usps.cpython-37.pyc │ │ │ ├── utils.cpython-36.pyc │ │ │ ├── utils.cpython-37.pyc │ │ │ ├── vision.cpython-36.pyc │ │ │ ├── vision.cpython-37.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── __init__.cpython-37.pyc │ │ ├── vision.py │ │ ├── utils.py │ │ ├── usps.py │ │ └── svhn.py │ ├── run_digit.sh │ ├── loss.py │ ├── rotation.py │ └── network.py ├── ssda │ ├── run_ssda.sh │ ├── rotation.py │ ├── loss.py │ ├── data_list.py │ └── image_source.py ├── pda │ ├── run_pda.sh │ ├── rotation.py │ ├── loss.py │ └── data_list.py ├── msda │ ├── rotation.py │ ├── loss.py │ ├── run_msda.sh │ ├── data_list.py │ ├── image_ms.py │ └── image_source.py ├── uda │ ├── rotation.py │ ├── loss.py │ ├── run_uda.sh │ ├── data_list.py │ └── image_source.py └── data │ ├── ssda │ └── office-home │ │ ├── labeled_target_images_Art_1.txt │ │ ├── labeled_target_images_Clipart_1.txt │ │ ├── labeled_target_images_Product_1.txt │ │ └── labeled_target_images_Real_1.txt │ ├── pacs │ └── photo_crossval_kfold.txt │ └── office-caltech │ └── dslr_list.txt ├── LICENSE └── README.md /figs/shot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/figs/shot.jpg -------------------------------------------------------------------------------- /figs/shot_plus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/figs/shot_plus.jpg -------------------------------------------------------------------------------- /figs/shot_ssl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/figs/shot_ssl.jpg -------------------------------------------------------------------------------- /supp/shot++_supp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/supp/shot++_supp.pdf -------------------------------------------------------------------------------- /code/digit/data_load/__init__.py: -------------------------------------------------------------------------------- 1 | from .svhn import * 2 | from .mnist import * 3 | from .usps import * -------------------------------------------------------------------------------- /code/digit/data_load/__pycache__/mnist.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/mnist.cpython-36.pyc -------------------------------------------------------------------------------- /code/digit/data_load/__pycache__/mnist.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/mnist.cpython-37.pyc -------------------------------------------------------------------------------- /code/digit/data_load/__pycache__/svhn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/svhn.cpython-36.pyc -------------------------------------------------------------------------------- /code/digit/data_load/__pycache__/svhn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/svhn.cpython-37.pyc -------------------------------------------------------------------------------- /code/digit/data_load/__pycache__/usps.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/usps.cpython-36.pyc -------------------------------------------------------------------------------- /code/digit/data_load/__pycache__/usps.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/usps.cpython-37.pyc -------------------------------------------------------------------------------- /code/digit/data_load/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /code/digit/data_load/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /code/digit/data_load/__pycache__/vision.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/vision.cpython-36.pyc -------------------------------------------------------------------------------- /code/digit/data_load/__pycache__/vision.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/vision.cpython-37.pyc -------------------------------------------------------------------------------- /code/digit/data_load/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /code/digit/data_load/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 tim-learn 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /code/ssda/run_ssda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # id seed 4 | for ((s=0;s<=3;s++)) 5 | do 6 | python image_source.py --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset office-home --max_epoch 50 --s $s 7 | 8 | for ((t=0;t<=3;t++)) 9 | do 10 | if [ $s -eq $t ];then 11 | echo "skipped" 12 | else 13 | echo "okay" 14 | python image_target.py --ssl 0.0 --cls_par 0.0 --ent '' --gpu_id $1 --s $s --t $t --output_src "ckps/s"$2 --output "ckps/st"$2 --seed $2 --dset office-home 15 | python image_mixmatch.py --ps 0.0 --cls_par 0.0 --gpu_id $1 --s $s --t $t --output_tar "ckps/st"$2 --output "ckps/mm_st"$2 --seed $2 --dset office-home --max_epoch 50 16 | 17 | python image_target.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --t $t --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-home 18 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --t $t --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --max_epoch 50 19 | 20 | python image_target.py --ssl 0.2 --cls_par 0.1 --gpu_id $1 --s $s --t $t --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-home 21 | python image_mixmatch.py --ps 0.0 --ssl 0.2 --cls_par 0.1 --gpu_id $1 --s $s --t $t --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --max_epoch 50 22 | 23 | fi 24 | done 25 | done -------------------------------------------------------------------------------- /code/digit/run_digit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python uda_digit.py --dset m2u --gpu_id $1 --seed $2 --cls_par 0.0 --output ckps_digits 4 | python uda_digit.py --dset u2m --gpu_id $1 --seed $2 --cls_par 0.0 --output ckps_digits 5 | python uda_digit.py --dset s2m --gpu_id $1 --seed $2 --cls_par 0.0 --output ckps_digits 6 | 7 | python uda_digit.py --dset m2u --gpu_id $1 --seed $2 --cls_par 0.1 --ssl 0.2 --output ckps_digits 8 | python uda_digit.py --dset u2m --gpu_id $1 --seed $2 --cls_par 0.1 --ssl 0.2 --output ckps_digits 9 | python uda_digit.py --dset s2m --gpu_id $1 --seed $2 --cls_par 0.1 --ssl 0.2 --output ckps_digits 10 | 11 | python digit_mixmatch.py --dset m2u --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --model source 12 | python digit_mixmatch.py --dset u2m --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --model source 13 | python digit_mixmatch.py --dset s2m --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --model source 14 | 15 | python digit_mixmatch.py --dset m2u --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --cls_par 0.0 16 | python digit_mixmatch.py --dset u2m --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --cls_par 0.0 17 | python digit_mixmatch.py --dset s2m --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --cls_par 0.0 18 | 19 | python digit_mixmatch.py --dset m2u --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --cls_par 0.1 --ssl 0.2 20 | python digit_mixmatch.py --dset u2m --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --cls_par 0.1 --ssl 0.2 21 | python digit_mixmatch.py --dset s2m --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --cls_par 0.1 --ssl 0.2 -------------------------------------------------------------------------------- /code/digit/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import math 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | def Entropy(input_): 10 | bs = input_.size(0) 11 | epsilon = 1e-5 12 | entropy = -input_ * torch.log(input_ + epsilon) 13 | entropy = torch.sum(entropy, dim=1) 14 | return entropy 15 | 16 | class CrossEntropyLabelSmooth(nn.Module): 17 | """Cross entropy loss with label smoothing regularizer. 18 | Reference: 19 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 20 | Equation: y = (1 - epsilon) * y + epsilon / K. 21 | Args: 22 | num_classes (int): number of classes. 23 | epsilon (float): weight. 24 | """ 25 | 26 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, size_average=True): 27 | super(CrossEntropyLabelSmooth, self).__init__() 28 | self.num_classes = num_classes 29 | self.epsilon = epsilon 30 | self.use_gpu = use_gpu 31 | self.size_average = size_average 32 | self.logsoftmax = nn.LogSoftmax(dim=1) 33 | 34 | def forward(self, inputs, targets): 35 | """ 36 | Args: 37 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 38 | targets: ground truth labels with shape (num_classes) 39 | """ 40 | log_probs = self.logsoftmax(inputs) 41 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 42 | if self.use_gpu: targets = targets.cuda() 43 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 44 | if self.size_average: 45 | loss = (- targets * log_probs).mean(0).sum() 46 | else: 47 | loss = (- targets * log_probs).sum(1) 48 | return loss -------------------------------------------------------------------------------- /code/digit/rotation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from torchvision import datasets 4 | import numpy as np 5 | 6 | # Assumes that tensor is (nchannels, height, width) 7 | def tensor_rot_90(x): 8 | return x.flip(2).transpose(1, 2) 9 | 10 | def tensor_rot_180(x): 11 | return x.flip(2).flip(1) 12 | 13 | def tensor_rot_270(x): 14 | return x.transpose(1, 2).flip(2) 15 | 16 | def rotate_single_with_label(img, label): 17 | if label == 1: 18 | img = tensor_rot_90(img) 19 | elif label == 2: 20 | img = tensor_rot_180(img) 21 | elif label == 3: 22 | img = tensor_rot_270(img) 23 | return img 24 | 25 | def rotate_batch_with_labels(batch, labels): 26 | images = [] 27 | for img, label in zip(batch, labels): 28 | img = rotate_single_with_label(img, label) 29 | images.append(img.unsqueeze(0)) 30 | return torch.cat(images) 31 | 32 | def rotate_batch(batch, label='rand'): 33 | if label == 'rand': 34 | labels = torch.randint(4, (len(batch),), dtype=torch.long) 35 | else: 36 | assert isinstance(label, int) 37 | labels = torch.zeros((len(batch),), dtype=torch.long) + label 38 | return rotate_batch_with_labels(batch, labels), labels 39 | 40 | class RotateImageFolder(datasets.ImageFolder): 41 | def __init__(self, traindir, train_transform, original=True, rotation=True, rotation_transform=None): 42 | super(RotateImageFolder, self).__init__(traindir, train_transform) 43 | self.original = original 44 | self.rotation = rotation 45 | self.rotation_transform = rotation_transform 46 | 47 | def __getitem__(self, index): 48 | path, target = self.imgs[index] 49 | img_input = self.loader(path) 50 | 51 | if self.transform is not None: 52 | img = self.transform(img_input) 53 | else: 54 | img = img_input 55 | 56 | results = [] 57 | if self.original: 58 | results.append(img) 59 | results.append(target) 60 | if self.rotation: 61 | if self.rotation_transform is not None: 62 | img = self.rotation_transform(img_input) 63 | target_ssh = np.random.randint(0, 4, 1)[0] 64 | img_ssh = rotate_single_with_label(img, target_ssh) 65 | results.append(img_ssh) 66 | results.append(target_ssh) 67 | return results 68 | 69 | def switch_mode(self, original, rotation): 70 | self.original = original 71 | self.rotation = rotation 72 | -------------------------------------------------------------------------------- /code/pda/run_pda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python image_pretrained.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --seed $2 --output "ckps/s"$2 4 | python image_pretrained.py --ssl 0.0 --cls_par 0.3 --gpu_id $1 --seed $2 --output "ckps/s"$2 5 | python image_pretrained.py --ssl 0.6 --cls_par 0.3 --gpu_id $1 --seed $2 --output "ckps/s"$2 6 | 7 | for ((s=0;s<=3;s++)) 8 | do 9 | python image_source.py --da pda --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset office-home --max_epoch 50 --s $s 10 | python image_mixmatch.py --model source --da pda --ps 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/s"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --max_epoch 50 11 | 12 | python image_target.py --da pda --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-home 13 | python image_mixmatch.py --da pda --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --max_epoch 50 14 | 15 | python image_target.py --da pda --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-home 16 | python image_mixmatch.py --da pda --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --max_epoch 50 17 | 18 | done 19 | 20 | 21 | for ((s=0;s<=1;s++)) 22 | do 23 | python image_source.py --da pda --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset VISDA-C --max_epoch 10 --s $s 24 | python image_mixmatch.py --model source --da pda --ps 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/s"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10 25 | 26 | python image_target.py --da pda --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset VISDA-C 27 | python image_mixmatch.py --da pda --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10 28 | 29 | python image_target.py --da pda --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset VISDA-C 30 | python image_mixmatch.py --da pda --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10 31 | 32 | done -------------------------------------------------------------------------------- /code/msda/rotation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from torchvision import datasets 4 | import numpy as np 5 | 6 | # Assumes that tensor is (nchannels, height, width) 7 | def tensor_rot_90(x): 8 | return x.flip(2).transpose(1, 2) 9 | 10 | def tensor_rot_180(x): 11 | return x.flip(2).flip(1) 12 | 13 | def tensor_rot_270(x): 14 | return x.transpose(1, 2).flip(2) 15 | 16 | def rotate_single_with_label(img, label): 17 | if label == 1: 18 | img = tensor_rot_90(img) 19 | elif label == 2: 20 | img = tensor_rot_180(img) 21 | elif label == 3: 22 | img = tensor_rot_270(img) 23 | return img 24 | 25 | def rotate_batch_with_labels(batch, labels): 26 | images = [] 27 | for img, label in zip(batch, labels): 28 | img = rotate_single_with_label(img, label) 29 | images.append(img.unsqueeze(0)) 30 | return torch.cat(images) 31 | 32 | 33 | def rotate_single_with_label2(img, label): 34 | if label == 1: 35 | img = tensor_rot_180(img) 36 | return img 37 | 38 | def rotate_batch_with_labels2(batch, labels): 39 | images = [] 40 | for img, label in zip(batch, labels): 41 | img = rotate_single_with_label2(img, label) 42 | images.append(img.unsqueeze(0)) 43 | return torch.cat(images) 44 | 45 | 46 | def rotate_batch(batch, label='rand'): 47 | if label == 'rand': 48 | labels = torch.randint(4, (len(batch),), dtype=torch.long) 49 | else: 50 | assert isinstance(label, int) 51 | labels = torch.zeros((len(batch),), dtype=torch.long) + label 52 | return rotate_batch_with_labels(batch, labels), labels 53 | 54 | class RotateImageFolder(datasets.ImageFolder): 55 | def __init__(self, traindir, train_transform, original=True, rotation=True, rotation_transform=None): 56 | super(RotateImageFolder, self).__init__(traindir, train_transform) 57 | self.original = original 58 | self.rotation = rotation 59 | self.rotation_transform = rotation_transform 60 | 61 | def __getitem__(self, index): 62 | path, target = self.imgs[index] 63 | img_input = self.loader(path) 64 | 65 | if self.transform is not None: 66 | img = self.transform(img_input) 67 | else: 68 | img = img_input 69 | 70 | results = [] 71 | if self.original: 72 | results.append(img) 73 | results.append(target) 74 | if self.rotation: 75 | if self.rotation_transform is not None: 76 | img = self.rotation_transform(img_input) 77 | target_ssh = np.random.randint(0, 4, 1)[0] 78 | img_ssh = rotate_single_with_label(img, target_ssh) 79 | results.append(img_ssh) 80 | results.append(target_ssh) 81 | return results 82 | 83 | def switch_mode(self, original, rotation): 84 | self.original = original 85 | self.rotation = rotation 86 | -------------------------------------------------------------------------------- /code/pda/rotation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from torchvision import datasets 4 | import numpy as np 5 | 6 | # Assumes that tensor is (nchannels, height, width) 7 | def tensor_rot_90(x): 8 | return x.flip(2).transpose(1, 2) 9 | 10 | def tensor_rot_180(x): 11 | return x.flip(2).flip(1) 12 | 13 | def tensor_rot_270(x): 14 | return x.transpose(1, 2).flip(2) 15 | 16 | def rotate_single_with_label(img, label): 17 | if label == 1: 18 | img = tensor_rot_90(img) 19 | elif label == 2: 20 | img = tensor_rot_180(img) 21 | elif label == 3: 22 | img = tensor_rot_270(img) 23 | return img 24 | 25 | def rotate_batch_with_labels(batch, labels): 26 | images = [] 27 | for img, label in zip(batch, labels): 28 | img = rotate_single_with_label(img, label) 29 | images.append(img.unsqueeze(0)) 30 | return torch.cat(images) 31 | 32 | 33 | def rotate_single_with_label2(img, label): 34 | if label == 1: 35 | img = tensor_rot_180(img) 36 | return img 37 | 38 | def rotate_batch_with_labels2(batch, labels): 39 | images = [] 40 | for img, label in zip(batch, labels): 41 | img = rotate_single_with_label2(img, label) 42 | images.append(img.unsqueeze(0)) 43 | return torch.cat(images) 44 | 45 | 46 | def rotate_batch(batch, label='rand'): 47 | if label == 'rand': 48 | labels = torch.randint(4, (len(batch),), dtype=torch.long) 49 | else: 50 | assert isinstance(label, int) 51 | labels = torch.zeros((len(batch),), dtype=torch.long) + label 52 | return rotate_batch_with_labels(batch, labels), labels 53 | 54 | class RotateImageFolder(datasets.ImageFolder): 55 | def __init__(self, traindir, train_transform, original=True, rotation=True, rotation_transform=None): 56 | super(RotateImageFolder, self).__init__(traindir, train_transform) 57 | self.original = original 58 | self.rotation = rotation 59 | self.rotation_transform = rotation_transform 60 | 61 | def __getitem__(self, index): 62 | path, target = self.imgs[index] 63 | img_input = self.loader(path) 64 | 65 | if self.transform is not None: 66 | img = self.transform(img_input) 67 | else: 68 | img = img_input 69 | 70 | results = [] 71 | if self.original: 72 | results.append(img) 73 | results.append(target) 74 | if self.rotation: 75 | if self.rotation_transform is not None: 76 | img = self.rotation_transform(img_input) 77 | target_ssh = np.random.randint(0, 4, 1)[0] 78 | img_ssh = rotate_single_with_label(img, target_ssh) 79 | results.append(img_ssh) 80 | results.append(target_ssh) 81 | return results 82 | 83 | def switch_mode(self, original, rotation): 84 | self.original = original 85 | self.rotation = rotation 86 | -------------------------------------------------------------------------------- /code/ssda/rotation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from torchvision import datasets 4 | import numpy as np 5 | 6 | # Assumes that tensor is (nchannels, height, width) 7 | def tensor_rot_90(x): 8 | return x.flip(2).transpose(1, 2) 9 | 10 | def tensor_rot_180(x): 11 | return x.flip(2).flip(1) 12 | 13 | def tensor_rot_270(x): 14 | return x.transpose(1, 2).flip(2) 15 | 16 | def rotate_single_with_label(img, label): 17 | if label == 1: 18 | img = tensor_rot_90(img) 19 | elif label == 2: 20 | img = tensor_rot_180(img) 21 | elif label == 3: 22 | img = tensor_rot_270(img) 23 | return img 24 | 25 | def rotate_batch_with_labels(batch, labels): 26 | images = [] 27 | for img, label in zip(batch, labels): 28 | img = rotate_single_with_label(img, label) 29 | images.append(img.unsqueeze(0)) 30 | return torch.cat(images) 31 | 32 | 33 | def rotate_single_with_label2(img, label): 34 | if label == 1: 35 | img = tensor_rot_180(img) 36 | return img 37 | 38 | def rotate_batch_with_labels2(batch, labels): 39 | images = [] 40 | for img, label in zip(batch, labels): 41 | img = rotate_single_with_label2(img, label) 42 | images.append(img.unsqueeze(0)) 43 | return torch.cat(images) 44 | 45 | 46 | def rotate_batch(batch, label='rand'): 47 | if label == 'rand': 48 | labels = torch.randint(4, (len(batch),), dtype=torch.long) 49 | else: 50 | assert isinstance(label, int) 51 | labels = torch.zeros((len(batch),), dtype=torch.long) + label 52 | return rotate_batch_with_labels(batch, labels), labels 53 | 54 | class RotateImageFolder(datasets.ImageFolder): 55 | def __init__(self, traindir, train_transform, original=True, rotation=True, rotation_transform=None): 56 | super(RotateImageFolder, self).__init__(traindir, train_transform) 57 | self.original = original 58 | self.rotation = rotation 59 | self.rotation_transform = rotation_transform 60 | 61 | def __getitem__(self, index): 62 | path, target = self.imgs[index] 63 | img_input = self.loader(path) 64 | 65 | if self.transform is not None: 66 | img = self.transform(img_input) 67 | else: 68 | img = img_input 69 | 70 | results = [] 71 | if self.original: 72 | results.append(img) 73 | results.append(target) 74 | if self.rotation: 75 | if self.rotation_transform is not None: 76 | img = self.rotation_transform(img_input) 77 | target_ssh = np.random.randint(0, 4, 1)[0] 78 | img_ssh = rotate_single_with_label(img, target_ssh) 79 | results.append(img_ssh) 80 | results.append(target_ssh) 81 | return results 82 | 83 | def switch_mode(self, original, rotation): 84 | self.original = original 85 | self.rotation = rotation 86 | -------------------------------------------------------------------------------- /code/uda/rotation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from torchvision import datasets 4 | import numpy as np 5 | 6 | # Assumes that tensor is (nchannels, height, width) 7 | def tensor_rot_90(x): 8 | return x.flip(2).transpose(1, 2) 9 | 10 | def tensor_rot_180(x): 11 | return x.flip(2).flip(1) 12 | 13 | def tensor_rot_270(x): 14 | return x.transpose(1, 2).flip(2) 15 | 16 | def rotate_single_with_label(img, label): 17 | if label == 1: 18 | img = tensor_rot_90(img) 19 | elif label == 2: 20 | img = tensor_rot_180(img) 21 | elif label == 3: 22 | img = tensor_rot_270(img) 23 | return img 24 | 25 | def rotate_batch_with_labels(batch, labels): 26 | images = [] 27 | for img, label in zip(batch, labels): 28 | img = rotate_single_with_label(img, label) 29 | images.append(img.unsqueeze(0)) 30 | return torch.cat(images) 31 | 32 | 33 | def rotate_single_with_label2(img, label): 34 | if label == 1: 35 | img = tensor_rot_180(img) 36 | return img 37 | 38 | def rotate_batch_with_labels2(batch, labels): 39 | images = [] 40 | for img, label in zip(batch, labels): 41 | img = rotate_single_with_label2(img, label) 42 | images.append(img.unsqueeze(0)) 43 | return torch.cat(images) 44 | 45 | 46 | def rotate_batch(batch, label='rand'): 47 | if label == 'rand': 48 | labels = torch.randint(4, (len(batch),), dtype=torch.long) 49 | else: 50 | assert isinstance(label, int) 51 | labels = torch.zeros((len(batch),), dtype=torch.long) + label 52 | return rotate_batch_with_labels(batch, labels), labels 53 | 54 | class RotateImageFolder(datasets.ImageFolder): 55 | def __init__(self, traindir, train_transform, original=True, rotation=True, rotation_transform=None): 56 | super(RotateImageFolder, self).__init__(traindir, train_transform) 57 | self.original = original 58 | self.rotation = rotation 59 | self.rotation_transform = rotation_transform 60 | 61 | def __getitem__(self, index): 62 | path, target = self.imgs[index] 63 | img_input = self.loader(path) 64 | 65 | if self.transform is not None: 66 | img = self.transform(img_input) 67 | else: 68 | img = img_input 69 | 70 | results = [] 71 | if self.original: 72 | results.append(img) 73 | results.append(target) 74 | if self.rotation: 75 | if self.rotation_transform is not None: 76 | img = self.rotation_transform(img_input) 77 | target_ssh = np.random.randint(0, 4, 1)[0] 78 | img_ssh = rotate_single_with_label(img, target_ssh) 79 | results.append(img_ssh) 80 | results.append(target_ssh) 81 | return results 82 | 83 | def switch_mode(self, original, rotation): 84 | self.original = original 85 | self.rotation = rotation 86 | -------------------------------------------------------------------------------- /code/digit/data_load/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class VisionDataset(data.Dataset): 7 | _repr_indent = 4 8 | 9 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 10 | if isinstance(root, torch._six.string_classes): 11 | root = os.path.expanduser(root) 12 | self.root = root 13 | 14 | has_transforms = transforms is not None 15 | has_separate_transform = transform is not None or target_transform is not None 16 | if has_transforms and has_separate_transform: 17 | raise ValueError("Only transforms or transform/target_transform can " 18 | "be passed as argument") 19 | 20 | # for backwards-compatibility 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | 24 | if has_separate_transform: 25 | transforms = StandardTransform(transform, target_transform) 26 | self.transforms = transforms 27 | 28 | def __getitem__(self, index): 29 | raise NotImplementedError 30 | 31 | def __len__(self): 32 | raise NotImplementedError 33 | 34 | def __repr__(self): 35 | head = "Dataset " + self.__class__.__name__ 36 | body = ["Number of datapoints: {}".format(self.__len__())] 37 | if self.root is not None: 38 | body.append("Root location: {}".format(self.root)) 39 | body += self.extra_repr().splitlines() 40 | if hasattr(self, "transforms") and self.transforms is not None: 41 | body += [repr(self.transforms)] 42 | lines = [head] + [" " * self._repr_indent + line for line in body] 43 | return '\n'.join(lines) 44 | 45 | def _format_transform_repr(self, transform, head): 46 | lines = transform.__repr__().splitlines() 47 | return (["{}{}".format(head, lines[0])] + 48 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 49 | 50 | def extra_repr(self): 51 | return "" 52 | 53 | 54 | class StandardTransform(object): 55 | def __init__(self, transform=None, target_transform=None): 56 | self.transform = transform 57 | self.target_transform = target_transform 58 | 59 | def __call__(self, input, target): 60 | if self.transform is not None: 61 | input = self.transform(input) 62 | if self.target_transform is not None: 63 | target = self.target_transform(target) 64 | return input, target 65 | 66 | def _format_transform_repr(self, transform, head): 67 | lines = transform.__repr__().splitlines() 68 | return (["{}{}".format(head, lines[0])] + 69 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 70 | 71 | def __repr__(self): 72 | body = [self.__class__.__name__] 73 | if self.transform is not None: 74 | body += self._format_transform_repr(self.transform, 75 | "Transform: ") 76 | if self.target_transform is not None: 77 | body += self._format_transform_repr(self.target_transform, 78 | "Target transform: ") 79 | 80 | return '\n'.join(body) -------------------------------------------------------------------------------- /code/digit/network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | from torchvision import models 6 | from torch.autograd import Variable 7 | import math 8 | import pdb 9 | import torch.nn.utils.weight_norm as weightNorm 10 | from collections import OrderedDict 11 | 12 | def init_weights(m): 13 | classname = m.__class__.__name__ 14 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1: 15 | nn.init.kaiming_uniform_(m.weight) 16 | nn.init.zeros_(m.bias) 17 | elif classname.find('BatchNorm') != -1: 18 | nn.init.normal_(m.weight, 1.0, 0.02) 19 | nn.init.zeros_(m.bias) 20 | elif classname.find('Linear') != -1: 21 | nn.init.xavier_normal_(m.weight) 22 | nn.init.zeros_(m.bias) 23 | 24 | class feat_bottleneck(nn.Module): 25 | def __init__(self, feature_dim, bottleneck_dim=256, type="ori"): 26 | super(feat_bottleneck, self).__init__() 27 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.dropout = nn.Dropout(p=0.5) 30 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim) 31 | self.bottleneck.apply(init_weights) 32 | self.type = type 33 | 34 | def forward(self, x): 35 | x = self.bottleneck(x) 36 | if self.type == "bn": 37 | x = self.bn(x) 38 | x = self.dropout(x) 39 | return x 40 | 41 | class feat_classifier(nn.Module): 42 | def __init__(self, class_num, bottleneck_dim=256, type="linear"): 43 | super(feat_classifier, self).__init__() 44 | if type == "linear": 45 | self.fc = nn.Linear(bottleneck_dim, class_num) 46 | else: 47 | self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight") 48 | self.fc.apply(init_weights) 49 | 50 | def forward(self, x): 51 | x = self.fc(x) 52 | return x 53 | 54 | class DTNBase(nn.Module): 55 | def __init__(self): 56 | super(DTNBase, self).__init__() 57 | self.conv_params = nn.Sequential( 58 | nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2), 59 | nn.BatchNorm2d(64), 60 | nn.Dropout2d(0.1), 61 | nn.ReLU(), 62 | nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2), 63 | nn.BatchNorm2d(128), 64 | nn.Dropout2d(0.3), 65 | nn.ReLU(), 66 | nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2), 67 | nn.BatchNorm2d(256), 68 | nn.Dropout2d(0.5), 69 | nn.ReLU() 70 | ) 71 | self.in_features = 256*4*4 72 | 73 | def forward(self, x): 74 | x = self.conv_params(x) 75 | x = x.view(x.size(0), -1) 76 | return x 77 | 78 | class LeNetBase(nn.Module): 79 | def __init__(self): 80 | super(LeNetBase, self).__init__() 81 | self.conv_params = nn.Sequential( 82 | nn.Conv2d(1, 20, kernel_size=5), 83 | nn.MaxPool2d(2), 84 | nn.ReLU(), 85 | nn.Conv2d(20, 50, kernel_size=5), 86 | nn.Dropout2d(p=0.5), 87 | nn.MaxPool2d(2), 88 | nn.ReLU(), 89 | ) 90 | self.in_features = 50*4*4 91 | 92 | def forward(self, x): 93 | x = self.conv_params(x) 94 | x = x.view(x.size(0), -1) 95 | return x -------------------------------------------------------------------------------- /code/pda/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import math 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | def Entropy(input_): 10 | bs = input_.size(0) 11 | epsilon = 1e-5 12 | entropy = -input_ * torch.log(input_ + epsilon) 13 | entropy = torch.sum(entropy, dim=1) 14 | return entropy 15 | 16 | def grl_hook(coeff): 17 | def fun1(grad): 18 | return -coeff*grad.clone() 19 | return fun1 20 | 21 | def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None): 22 | softmax_output = input_list[1].detach() 23 | feature = input_list[0] 24 | if random_layer is None: 25 | op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1)) 26 | ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1))) 27 | else: 28 | random_out = random_layer.forward([feature, softmax_output]) 29 | ad_out = ad_net(random_out.view(-1, random_out.size(1))) 30 | batch_size = softmax_output.size(0) // 2 31 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 32 | if entropy is not None: 33 | entropy.register_hook(grl_hook(coeff)) 34 | entropy = 1.0+torch.exp(-entropy) 35 | source_mask = torch.ones_like(entropy) 36 | source_mask[feature.size(0)//2:] = 0 37 | source_weight = entropy*source_mask 38 | target_mask = torch.ones_like(entropy) 39 | target_mask[0:feature.size(0)//2] = 0 40 | target_weight = entropy*target_mask 41 | weight = source_weight / torch.sum(source_weight).detach().item() + \ 42 | target_weight / torch.sum(target_weight).detach().item() 43 | return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item() 44 | else: 45 | return nn.BCELoss()(ad_out, dc_target) 46 | 47 | def DANN(features, ad_net): 48 | ad_out = ad_net(features) 49 | batch_size = ad_out.size(0) // 2 50 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 51 | return nn.BCELoss()(ad_out, dc_target) 52 | 53 | 54 | class CrossEntropyLabelSmooth(nn.Module): 55 | """Cross entropy loss with label smoothing regularizer. 56 | Reference: 57 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 58 | Equation: y = (1 - epsilon) * y + epsilon / K. 59 | Args: 60 | num_classes (int): number of classes. 61 | epsilon (float): weight. 62 | """ 63 | 64 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True): 65 | super(CrossEntropyLabelSmooth, self).__init__() 66 | self.num_classes = num_classes 67 | self.epsilon = epsilon 68 | self.use_gpu = use_gpu 69 | self.reduction = reduction 70 | self.logsoftmax = nn.LogSoftmax(dim=1) 71 | 72 | def forward(self, inputs, targets): 73 | """ 74 | Args: 75 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 76 | targets: ground truth labels with shape (num_classes) 77 | """ 78 | log_probs = self.logsoftmax(inputs) 79 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 80 | if self.use_gpu: targets = targets.cuda() 81 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 82 | loss = (- targets * log_probs).sum(dim=1) 83 | if self.reduction: 84 | return loss.mean() 85 | else: 86 | return loss 87 | return loss -------------------------------------------------------------------------------- /code/uda/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import math 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | def Entropy(input_): 10 | bs = input_.size(0) 11 | epsilon = 1e-5 12 | entropy = -input_ * torch.log(input_ + epsilon) 13 | entropy = torch.sum(entropy, dim=1) 14 | return entropy 15 | 16 | def grl_hook(coeff): 17 | def fun1(grad): 18 | return -coeff*grad.clone() 19 | return fun1 20 | 21 | def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None): 22 | softmax_output = input_list[1].detach() 23 | feature = input_list[0] 24 | if random_layer is None: 25 | op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1)) 26 | ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1))) 27 | else: 28 | random_out = random_layer.forward([feature, softmax_output]) 29 | ad_out = ad_net(random_out.view(-1, random_out.size(1))) 30 | batch_size = softmax_output.size(0) // 2 31 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 32 | if entropy is not None: 33 | entropy.register_hook(grl_hook(coeff)) 34 | entropy = 1.0+torch.exp(-entropy) 35 | source_mask = torch.ones_like(entropy) 36 | source_mask[feature.size(0)//2:] = 0 37 | source_weight = entropy*source_mask 38 | target_mask = torch.ones_like(entropy) 39 | target_mask[0:feature.size(0)//2] = 0 40 | target_weight = entropy*target_mask 41 | weight = source_weight / torch.sum(source_weight).detach().item() + \ 42 | target_weight / torch.sum(target_weight).detach().item() 43 | return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item() 44 | else: 45 | return nn.BCELoss()(ad_out, dc_target) 46 | 47 | def DANN(features, ad_net): 48 | ad_out = ad_net(features) 49 | batch_size = ad_out.size(0) // 2 50 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 51 | return nn.BCELoss()(ad_out, dc_target) 52 | 53 | 54 | class CrossEntropyLabelSmooth(nn.Module): 55 | """Cross entropy loss with label smoothing regularizer. 56 | Reference: 57 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 58 | Equation: y = (1 - epsilon) * y + epsilon / K. 59 | Args: 60 | num_classes (int): number of classes. 61 | epsilon (float): weight. 62 | """ 63 | 64 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True): 65 | super(CrossEntropyLabelSmooth, self).__init__() 66 | self.num_classes = num_classes 67 | self.epsilon = epsilon 68 | self.use_gpu = use_gpu 69 | self.reduction = reduction 70 | self.logsoftmax = nn.LogSoftmax(dim=1) 71 | 72 | def forward(self, inputs, targets): 73 | """ 74 | Args: 75 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 76 | targets: ground truth labels with shape (num_classes) 77 | """ 78 | log_probs = self.logsoftmax(inputs) 79 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 80 | if self.use_gpu: targets = targets.cuda() 81 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 82 | loss = (- targets * log_probs).sum(dim=1) 83 | if self.reduction: 84 | return loss.mean() 85 | else: 86 | return loss 87 | return loss -------------------------------------------------------------------------------- /code/msda/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import math 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | def Entropy(input_): 10 | bs = input_.size(0) 11 | epsilon = 1e-5 12 | entropy = -input_ * torch.log(input_ + epsilon) 13 | entropy = torch.sum(entropy, dim=1) 14 | return entropy 15 | 16 | def grl_hook(coeff): 17 | def fun1(grad): 18 | return -coeff*grad.clone() 19 | return fun1 20 | 21 | def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None): 22 | softmax_output = input_list[1].detach() 23 | feature = input_list[0] 24 | if random_layer is None: 25 | op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1)) 26 | ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1))) 27 | else: 28 | random_out = random_layer.forward([feature, softmax_output]) 29 | ad_out = ad_net(random_out.view(-1, random_out.size(1))) 30 | batch_size = softmax_output.size(0) // 2 31 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 32 | if entropy is not None: 33 | entropy.register_hook(grl_hook(coeff)) 34 | entropy = 1.0+torch.exp(-entropy) 35 | source_mask = torch.ones_like(entropy) 36 | source_mask[feature.size(0)//2:] = 0 37 | source_weight = entropy*source_mask 38 | target_mask = torch.ones_like(entropy) 39 | target_mask[0:feature.size(0)//2] = 0 40 | target_weight = entropy*target_mask 41 | weight = source_weight / torch.sum(source_weight).detach().item() + \ 42 | target_weight / torch.sum(target_weight).detach().item() 43 | return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item() 44 | else: 45 | return nn.BCELoss()(ad_out, dc_target) 46 | 47 | def DANN(features, ad_net): 48 | ad_out = ad_net(features) 49 | batch_size = ad_out.size(0) // 2 50 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 51 | return nn.BCELoss()(ad_out, dc_target) 52 | 53 | 54 | class CrossEntropyLabelSmooth(nn.Module): 55 | """Cross entropy loss with label smoothing regularizer. 56 | Reference: 57 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 58 | Equation: y = (1 - epsilon) * y + epsilon / K. 59 | Args: 60 | num_classes (int): number of classes. 61 | epsilon (float): weight. 62 | """ 63 | 64 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True): 65 | super(CrossEntropyLabelSmooth, self).__init__() 66 | self.num_classes = num_classes 67 | self.epsilon = epsilon 68 | self.use_gpu = use_gpu 69 | self.reduction = reduction 70 | self.logsoftmax = nn.LogSoftmax(dim=1) 71 | 72 | def forward(self, inputs, targets): 73 | """ 74 | Args: 75 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 76 | targets: ground truth labels with shape (num_classes) 77 | """ 78 | log_probs = self.logsoftmax(inputs) 79 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 80 | if self.use_gpu: targets = targets.cuda() 81 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 82 | loss = (- targets * log_probs).sum(dim=1) 83 | if self.reduction: 84 | return loss.mean() 85 | else: 86 | return loss 87 | return loss -------------------------------------------------------------------------------- /code/ssda/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import math 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | def Entropy(input_): 10 | bs = input_.size(0) 11 | epsilon = 1e-5 12 | entropy = -input_ * torch.log(input_ + epsilon) 13 | entropy = torch.sum(entropy, dim=1) 14 | return entropy 15 | 16 | def grl_hook(coeff): 17 | def fun1(grad): 18 | return -coeff*grad.clone() 19 | return fun1 20 | 21 | def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None): 22 | softmax_output = input_list[1].detach() 23 | feature = input_list[0] 24 | if random_layer is None: 25 | op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1)) 26 | ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1))) 27 | else: 28 | random_out = random_layer.forward([feature, softmax_output]) 29 | ad_out = ad_net(random_out.view(-1, random_out.size(1))) 30 | batch_size = softmax_output.size(0) // 2 31 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 32 | if entropy is not None: 33 | entropy.register_hook(grl_hook(coeff)) 34 | entropy = 1.0+torch.exp(-entropy) 35 | source_mask = torch.ones_like(entropy) 36 | source_mask[feature.size(0)//2:] = 0 37 | source_weight = entropy*source_mask 38 | target_mask = torch.ones_like(entropy) 39 | target_mask[0:feature.size(0)//2] = 0 40 | target_weight = entropy*target_mask 41 | weight = source_weight / torch.sum(source_weight).detach().item() + \ 42 | target_weight / torch.sum(target_weight).detach().item() 43 | return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item() 44 | else: 45 | return nn.BCELoss()(ad_out, dc_target) 46 | 47 | def DANN(features, ad_net): 48 | ad_out = ad_net(features) 49 | batch_size = ad_out.size(0) // 2 50 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 51 | return nn.BCELoss()(ad_out, dc_target) 52 | 53 | 54 | class CrossEntropyLabelSmooth(nn.Module): 55 | """Cross entropy loss with label smoothing regularizer. 56 | Reference: 57 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 58 | Equation: y = (1 - epsilon) * y + epsilon / K. 59 | Args: 60 | num_classes (int): number of classes. 61 | epsilon (float): weight. 62 | """ 63 | 64 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True): 65 | super(CrossEntropyLabelSmooth, self).__init__() 66 | self.num_classes = num_classes 67 | self.epsilon = epsilon 68 | self.use_gpu = use_gpu 69 | self.reduction = reduction 70 | self.logsoftmax = nn.LogSoftmax(dim=1) 71 | 72 | def forward(self, inputs, targets): 73 | """ 74 | Args: 75 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 76 | targets: ground truth labels with shape (num_classes) 77 | """ 78 | log_probs = self.logsoftmax(inputs) 79 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 80 | if self.use_gpu: targets = targets.cuda() 81 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 82 | loss = (- targets * log_probs).sum(dim=1) 83 | if self.reduction: 84 | return loss.mean() 85 | else: 86 | return loss 87 | return loss -------------------------------------------------------------------------------- /code/msda/run_msda.sh: -------------------------------------------------------------------------------- 1 | # !/bin/sh 2 | 3 | 4 | for ((s=0;s<=3;s++)) 5 | do 6 | python image_source.py --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset office-caltech --net resnet101 --max_epoch 100 --s $s 7 | python image_mixmatch.py --ps 0.0 --cls_par 0.0 --model source --gpu_id $1 --s $s --output_tar "ckps/s"$2 --output "ckps/mm"$2 --seed $2 --dset office-caltech --net resnet101 --max_epoch 100 8 | 9 | python image_target.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-caltech --net resnet101 10 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-caltech --net resnet101 --max_epoch 100 11 | 12 | python image_target.py --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-caltech --net resnet101 13 | python image_mixmatch.py --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-caltech --net resnet101 --max_epoch 100 14 | 15 | done 16 | 17 | 18 | for ((s=0;s<=3;s++)) 19 | do 20 | python image_source.py --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset office-home --net resnet50 --max_epoch 50 --s $s 21 | python image_mixmatch.py --ps 0.0 --cls_par 0.0 --model source --gpu_id $1 --s $s --output_tar "ckps/s"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --net resnet50 --max_epoch 50 22 | 23 | python image_target.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-home --net resnet50 24 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --net resnet50 --max_epoch 50 25 | 26 | python image_target.py --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-home --net resnet50 27 | python image_mixmatch.py --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --net resnet50 --max_epoch 50 28 | 29 | done 30 | 31 | 32 | for ((s=0;s<=3;s++)) 33 | do 34 | python image_source.py --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset pacs --net resnet18 --max_epoch 100 --s $s 35 | python image_mixmatch.py --ps 0.0 --cls_par 0.0 --model source --gpu_id $1 --s $s --output_tar "ckps/s"$2 --output "ckps/mm"$2 --seed $2 --dset pacs --net resnet18 --max_epoch 100 36 | 37 | python image_target.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset pacs --net resnet18 38 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset pacs --net resnet18 --max_epoch 100 39 | 40 | python image_target.py --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset pacs --net resnet18 41 | python image_mixmatch.py --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset pacs --net resnet18 --max_epoch 100 42 | 43 | done 44 | 45 | 46 | 47 | 48 | for ((y=2019;y<=2021;y++)) 49 | do 50 | for ((t=0;t<=3;t++)) 51 | do 52 | python image_ms.py --output "san_ms/s"$y --output_src "ckps/s"$y --output_tar "ckps/t"$y --output_mm "ckps/mm"$y --dset office-caltech --t $t --cls_par 0.0 --ssl 0.0 53 | python image_ms.py --output "san_ms/s"$y --output_src "ckps/s"$y --output_tar "ckps/t"$y --output_mm "ckps/mm"$y --dset office-caltech --t $t --cls_par 0.3 --ssl 0.6 54 | 55 | python image_ms.py --output "san_ms/s"$y --output_src "ckps/s"$y --output_tar "ckps/t"$y --output_mm "ckps/mm"$y --dset pacs --net resnet18 --t $t --cls_par 0.0 --ssl 0.0 56 | python image_ms.py --output "san_ms/s"$y --output_src "ckps/s"$y --output_tar "ckps/t"$y --output_mm "ckps/mm"$y --dset pacs --net resnet18 --t $t --cls_par 0.3 --ssl 0.6 57 | 58 | python image_ms.py --output "san_ms/s"$y --output_src "ckps/s"$y --output_tar "ckps/t"$y --output_mm "ckps/mm"$y --dset office-home --net resnet50 --t $t --cls_par 0.0 --ssl 0.0 59 | python image_ms.py --output "san_ms/s"$y --output_src "ckps/s"$y --output_tar "ckps/t"$y --output_mm "ckps/mm"$y --dset office-home --net resnet50 --t $t --cls_par 0.3 --ssl 0.6 60 | 61 | done 62 | done -------------------------------------------------------------------------------- /code/data/ssda/office-home/labeled_target_images_Art_1.txt: -------------------------------------------------------------------------------- 1 | /Checkpoint/liangjian/da_dataset/office_home/Art/Alarm_Clock/00045.jpg 0 2 | /Checkpoint/liangjian/da_dataset/office_home/Art/Backpack/00019.jpg 1 3 | /Checkpoint/liangjian/da_dataset/office_home/Art/Batteries/00023.jpg 2 4 | /Checkpoint/liangjian/da_dataset/office_home/Art/Bed/00033.jpg 3 5 | /Checkpoint/liangjian/da_dataset/office_home/Art/Bike/00020.jpg 4 6 | /Checkpoint/liangjian/da_dataset/office_home/Art/Bottle/00044.jpg 5 7 | /Checkpoint/liangjian/da_dataset/office_home/Art/Bucket/00022.jpg 6 8 | /Checkpoint/liangjian/da_dataset/office_home/Art/Calculator/00009.jpg 7 9 | /Checkpoint/liangjian/da_dataset/office_home/Art/Calendar/00005.jpg 8 10 | /Checkpoint/liangjian/da_dataset/office_home/Art/Candles/00038.jpg 9 11 | /Checkpoint/liangjian/da_dataset/office_home/Art/Chair/00008.jpg 10 12 | /Checkpoint/liangjian/da_dataset/office_home/Art/Clipboards/00016.jpg 11 13 | /Checkpoint/liangjian/da_dataset/office_home/Art/Computer/00004.jpg 12 14 | /Checkpoint/liangjian/da_dataset/office_home/Art/Couch/00011.jpg 13 15 | /Checkpoint/liangjian/da_dataset/office_home/Art/Curtains/00012.jpg 14 16 | /Checkpoint/liangjian/da_dataset/office_home/Art/Desk_Lamp/00016.jpg 15 17 | /Checkpoint/liangjian/da_dataset/office_home/Art/Drill/00004.jpg 16 18 | /Checkpoint/liangjian/da_dataset/office_home/Art/Eraser/00016.jpg 17 19 | /Checkpoint/liangjian/da_dataset/office_home/Art/Exit_Sign/00016.jpg 18 20 | /Checkpoint/liangjian/da_dataset/office_home/Art/Fan/00024.jpg 19 21 | /Checkpoint/liangjian/da_dataset/office_home/Art/File_Cabinet/00002.jpg 20 22 | /Checkpoint/liangjian/da_dataset/office_home/Art/Flipflops/00018.jpg 21 23 | /Checkpoint/liangjian/da_dataset/office_home/Art/Flowers/00064.jpg 22 24 | /Checkpoint/liangjian/da_dataset/office_home/Art/Folder/00006.jpg 23 25 | /Checkpoint/liangjian/da_dataset/office_home/Art/Fork/00046.jpg 24 26 | /Checkpoint/liangjian/da_dataset/office_home/Art/Glasses/00015.jpg 25 27 | /Checkpoint/liangjian/da_dataset/office_home/Art/Hammer/00021.jpg 26 28 | /Checkpoint/liangjian/da_dataset/office_home/Art/Helmet/00074.jpg 27 29 | /Checkpoint/liangjian/da_dataset/office_home/Art/Kettle/00017.jpg 28 30 | /Checkpoint/liangjian/da_dataset/office_home/Art/Keyboard/00008.jpg 29 31 | /Checkpoint/liangjian/da_dataset/office_home/Art/Knives/00037.jpg 30 32 | /Checkpoint/liangjian/da_dataset/office_home/Art/Lamp_Shade/00029.jpg 31 33 | /Checkpoint/liangjian/da_dataset/office_home/Art/Laptop/00005.jpg 32 34 | /Checkpoint/liangjian/da_dataset/office_home/Art/Marker/00004.jpg 33 35 | /Checkpoint/liangjian/da_dataset/office_home/Art/Monitor/00017.jpg 34 36 | /Checkpoint/liangjian/da_dataset/office_home/Art/Mop/00004.jpg 35 37 | /Checkpoint/liangjian/da_dataset/office_home/Art/Mouse/00002.jpg 36 38 | /Checkpoint/liangjian/da_dataset/office_home/Art/Mug/00033.jpg 37 39 | /Checkpoint/liangjian/da_dataset/office_home/Art/Notebook/00016.jpg 38 40 | /Checkpoint/liangjian/da_dataset/office_home/Art/Oven/00016.jpg 39 41 | /Checkpoint/liangjian/da_dataset/office_home/Art/Pan/00018.jpg 40 42 | /Checkpoint/liangjian/da_dataset/office_home/Art/Paper_Clip/00016.jpg 41 43 | /Checkpoint/liangjian/da_dataset/office_home/Art/Pen/00012.jpg 42 44 | /Checkpoint/liangjian/da_dataset/office_home/Art/Pencil/00005.jpg 43 45 | /Checkpoint/liangjian/da_dataset/office_home/Art/Postit_Notes/00017.jpg 44 46 | /Checkpoint/liangjian/da_dataset/office_home/Art/Printer/00009.jpg 45 47 | /Checkpoint/liangjian/da_dataset/office_home/Art/Push_Pin/00024.jpg 46 48 | /Checkpoint/liangjian/da_dataset/office_home/Art/Radio/00036.jpg 47 49 | /Checkpoint/liangjian/da_dataset/office_home/Art/Refrigerator/00022.jpg 48 50 | /Checkpoint/liangjian/da_dataset/office_home/Art/Ruler/00010.jpg 49 51 | /Checkpoint/liangjian/da_dataset/office_home/Art/Scissors/00019.jpg 50 52 | /Checkpoint/liangjian/da_dataset/office_home/Art/Screwdriver/00007.jpg 51 53 | /Checkpoint/liangjian/da_dataset/office_home/Art/Shelf/00008.jpg 52 54 | /Checkpoint/liangjian/da_dataset/office_home/Art/Sink/00003.jpg 53 55 | /Checkpoint/liangjian/da_dataset/office_home/Art/Sneakers/00017.jpg 54 56 | /Checkpoint/liangjian/da_dataset/office_home/Art/Soda/00001.jpg 55 57 | /Checkpoint/liangjian/da_dataset/office_home/Art/Speaker/00005.jpg 56 58 | /Checkpoint/liangjian/da_dataset/office_home/Art/Spoon/00045.jpg 57 59 | /Checkpoint/liangjian/da_dataset/office_home/Art/Table/00005.jpg 58 60 | /Checkpoint/liangjian/da_dataset/office_home/Art/Telephone/00020.jpg 59 61 | /Checkpoint/liangjian/da_dataset/office_home/Art/ToothBrush/00026.jpg 60 62 | /Checkpoint/liangjian/da_dataset/office_home/Art/Toys/00017.jpg 61 63 | /Checkpoint/liangjian/da_dataset/office_home/Art/Trash_Can/00019.jpg 62 64 | /Checkpoint/liangjian/da_dataset/office_home/Art/TV/00010.jpg 63 65 | /Checkpoint/liangjian/da_dataset/office_home/Art/Webcam/00016.jpg 64 66 | -------------------------------------------------------------------------------- /code/uda/run_uda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # id seed 4 | 5 | for ((s=0;s<=2;s++)) 6 | python image_source.py --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset office --max_epoch 100 --s $s 7 | python image_mixmatch.py --ps 0.0 --cls_par 0.0 --model source --gpu_id $1 --s $s --output_tar "ckps/s"$2 --output "ckps/mm"$2 --seed $2 --dset office --max_epoch 100 8 | 9 | python image_target.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office 10 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office --max_epoch 100 11 | 12 | python image_target.py --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office 13 | python image_mixmatch.py --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office --max_epoch 100 14 | done 15 | 16 | 17 | for ((s=0;s<=3;s++)) 18 | do 19 | python image_source.py --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset office-home --max_epoch 50 --s $s 20 | python image_mixmatch.py --ps 0.0 --cls_par 0.0 --model source --gpu_id $1 --s $s --output_tar "ckps/s"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --max_epoch 50 21 | 22 | python image_target.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-home 23 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --max_epoch 50 24 | 25 | python image_target.py --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-home 26 | python image_mixmatch.py --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --max_epoch 50 27 | done 28 | 29 | python image_source.py --gpu_id $1 --seed $2 --output "ckps/res101/s"$2 --dset VISDA-C --max_epoch 10 --s 0 --net resnet101 30 | python image_mixmatch.py --ps 0.0 --cls_par 0.0 --model source --gpu_id $1 --s 0 --t 1 --output_tar "ckps/res101/s"$2 --output "ckps/res101/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10 --net resnet101 31 | 32 | python image_target.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s 0 --t 1 --output_src "ckps/res101/s"$2 --output "ckps/res101/t"$2 --seed $2 --dset VISDA-C --net resnet101 33 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s 0 --t 1 --output_tar "ckps/res101/t"$2 --output "ckps/res101/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10 --net resnet101 34 | 35 | python image_target.py --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s 0 --t 1 --output_src "ckps/res101/s"$2 --output "ckps/res101/t"$2 --seed $2 --dset VISDA-C --net resnet101 36 | python image_mixmatch.py --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s 0 --t 1 --output_tar "ckps/res101/t"$2 --output "ckps/res101/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10 --net resnet101 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | python image_source.py --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset VISDA-C --max_epoch 10 --s 0 45 | python image_mixmatch.py --ps 0.0 --cls_par 0.0 --model source --gpu_id $1 --s 0 --t 1 --output_tar "ckps/s"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10 46 | 47 | python image_target.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s 0 --t 1 --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset VISDA-C 48 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s 0 --t 1 --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10 49 | 50 | python image_target.py --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s 0 --t 1 --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset VISDA-C 51 | python image_mixmatch.py --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s 0 --t 1 --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10 52 | 53 | python image_target.py --gent '' --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s 0 --t 1 --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset VISDA-C 54 | python image_mixmatch.py --gent '' --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s 0 --t 1 --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10 55 | 56 | python image_target.py --ssl 0.0 --cls_par 0.3 --gpu_id $1 --s 0 --t 1 --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset VISDA-C 57 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.3 --gpu_id $1 --s 0 --t 1 --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10 58 | 59 | python image_target.py --ssl 0.6 --cls_par 0.0 --gpu_id $1 --s 0 --t 1 --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset VISDA-C 60 | python image_mixmatch.py --ps 0.0 --ssl 0.6 --cls_par 0.0 --gpu_id $1 --s 0 --t 1 --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10 61 | -------------------------------------------------------------------------------- /code/data/ssda/office-home/labeled_target_images_Clipart_1.txt: -------------------------------------------------------------------------------- 1 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Alarm_Clock/00032.jpg 0 2 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Backpack/00052.jpg 1 3 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Batteries/00043.jpg 2 4 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Bed/00073.jpg 3 5 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Bike/00023.jpg 4 6 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Bottle/00059.jpg 5 7 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Bucket/00006.jpg 6 8 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Calculator/00007.jpg 7 9 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Calendar/00009.jpg 8 10 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Candles/00004.jpg 9 11 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Chair/00060.jpg 10 12 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Clipboards/00033.jpg 11 13 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Computer/00036.jpg 12 14 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Couch/00063.jpg 13 15 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Curtains/00032.jpg 14 16 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Desk_Lamp/00004.jpg 15 17 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Drill/00035.jpg 16 18 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Eraser/00015.jpg 17 19 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Exit_Sign/00032.jpg 18 20 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Fan/00027.jpg 19 21 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/File_Cabinet/00030.jpg 20 22 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Flipflops/00012.jpg 21 23 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Flowers/00039.jpg 22 24 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Folder/00095.jpg 23 25 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Fork/00018.jpg 24 26 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Glasses/00014.jpg 25 27 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Hammer/00098.jpg 26 28 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Helmet/00051.jpg 27 29 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Kettle/00027.jpg 28 30 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Keyboard/00053.jpg 29 31 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Knives/00037.jpg 30 32 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Lamp_Shade/00031.jpg 31 33 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Laptop/00061.jpg 32 34 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Marker/00065.jpg 33 35 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Monitor/00089.jpg 34 36 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Mop/00032.jpg 35 37 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Mouse/00044.jpg 36 38 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Mug/00088.jpg 37 39 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Notebook/00065.jpg 38 40 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Oven/00008.jpg 39 41 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Pan/00035.jpg 40 42 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Paper_Clip/00015.jpg 41 43 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Pen/00004.jpg 42 44 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Pencil/00072.jpg 43 45 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Postit_Notes/00002.jpg 44 46 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Printer/00056.jpg 45 47 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Push_Pin/00025.jpg 46 48 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Radio/00030.jpg 47 49 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Refrigerator/00017.jpg 48 50 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Ruler/00050.jpg 49 51 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Scissors/00073.jpg 50 52 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Screwdriver/00062.jpg 51 53 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Shelf/00021.jpg 52 54 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Sink/00023.jpg 53 55 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Sneakers/00019.jpg 54 56 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Soda/00029.jpg 55 57 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Speaker/00037.jpg 56 58 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Spoon/00046.jpg 57 59 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Table/00041.jpg 58 60 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Telephone/00048.jpg 59 61 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/ToothBrush/00007.jpg 60 62 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Toys/00021.jpg 61 63 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Trash_Can/00003.jpg 62 64 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/TV/00012.jpg 63 65 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Webcam/00006.jpg 64 66 | -------------------------------------------------------------------------------- /code/data/ssda/office-home/labeled_target_images_Product_1.txt: -------------------------------------------------------------------------------- 1 | /Checkpoint/liangjian/da_dataset/office_home/Product/Alarm_Clock/00003.jpg 0 2 | /Checkpoint/liangjian/da_dataset/office_home/Product/Backpack/00095.jpg 1 3 | /Checkpoint/liangjian/da_dataset/office_home/Product/Batteries/00050.jpg 2 4 | /Checkpoint/liangjian/da_dataset/office_home/Product/Bed/00008.jpg 3 5 | /Checkpoint/liangjian/da_dataset/office_home/Product/Bike/00004.jpg 4 6 | /Checkpoint/liangjian/da_dataset/office_home/Product/Bottle/00028.jpg 5 7 | /Checkpoint/liangjian/da_dataset/office_home/Product/Bucket/00019.jpg 6 8 | /Checkpoint/liangjian/da_dataset/office_home/Product/Calculator/00014.jpg 7 9 | /Checkpoint/liangjian/da_dataset/office_home/Product/Calendar/00067.jpg 8 10 | /Checkpoint/liangjian/da_dataset/office_home/Product/Candles/00052.jpg 9 11 | /Checkpoint/liangjian/da_dataset/office_home/Product/Chair/00038.jpg 10 12 | /Checkpoint/liangjian/da_dataset/office_home/Product/Clipboards/00049.jpg 11 13 | /Checkpoint/liangjian/da_dataset/office_home/Product/Computer/00081.jpg 12 14 | /Checkpoint/liangjian/da_dataset/office_home/Product/Couch/00039.jpg 13 15 | /Checkpoint/liangjian/da_dataset/office_home/Product/Curtains/00058.jpg 14 16 | /Checkpoint/liangjian/da_dataset/office_home/Product/Desk_Lamp/00036.jpg 15 17 | /Checkpoint/liangjian/da_dataset/office_home/Product/Drill/00063.jpg 16 18 | /Checkpoint/liangjian/da_dataset/office_home/Product/Eraser/00027.jpg 17 19 | /Checkpoint/liangjian/da_dataset/office_home/Product/Exit_Sign/00004.jpg 18 20 | /Checkpoint/liangjian/da_dataset/office_home/Product/Fan/00045.jpg 19 21 | /Checkpoint/liangjian/da_dataset/office_home/Product/File_Cabinet/00058.jpg 20 22 | /Checkpoint/liangjian/da_dataset/office_home/Product/Flipflops/00039.jpg 21 23 | /Checkpoint/liangjian/da_dataset/office_home/Product/Flowers/00079.jpg 22 24 | /Checkpoint/liangjian/da_dataset/office_home/Product/Folder/00034.jpg 23 25 | /Checkpoint/liangjian/da_dataset/office_home/Product/Fork/00022.jpg 24 26 | /Checkpoint/liangjian/da_dataset/office_home/Product/Glasses/00053.jpg 25 27 | /Checkpoint/liangjian/da_dataset/office_home/Product/Hammer/00005.jpg 26 28 | /Checkpoint/liangjian/da_dataset/office_home/Product/Helmet/00046.jpg 27 29 | /Checkpoint/liangjian/da_dataset/office_home/Product/Kettle/00062.jpg 28 30 | /Checkpoint/liangjian/da_dataset/office_home/Product/Keyboard/00033.jpg 29 31 | /Checkpoint/liangjian/da_dataset/office_home/Product/Knives/00039.jpg 30 32 | /Checkpoint/liangjian/da_dataset/office_home/Product/Lamp_Shade/00017.jpg 31 33 | /Checkpoint/liangjian/da_dataset/office_home/Product/Laptop/00048.jpg 32 34 | /Checkpoint/liangjian/da_dataset/office_home/Product/Marker/00003.jpg 33 35 | /Checkpoint/liangjian/da_dataset/office_home/Product/Monitor/00013.jpg 34 36 | /Checkpoint/liangjian/da_dataset/office_home/Product/Mop/00036.jpg 35 37 | /Checkpoint/liangjian/da_dataset/office_home/Product/Mouse/00023.jpg 36 38 | /Checkpoint/liangjian/da_dataset/office_home/Product/Mug/00039.jpg 37 39 | /Checkpoint/liangjian/da_dataset/office_home/Product/Notebook/00010.jpg 38 40 | /Checkpoint/liangjian/da_dataset/office_home/Product/Oven/00016.jpg 39 41 | /Checkpoint/liangjian/da_dataset/office_home/Product/Pan/00053.jpg 40 42 | /Checkpoint/liangjian/da_dataset/office_home/Product/Paper_Clip/00002.jpg 41 43 | /Checkpoint/liangjian/da_dataset/office_home/Product/Pen/00048.jpg 42 44 | /Checkpoint/liangjian/da_dataset/office_home/Product/Pencil/00002.jpg 43 45 | /Checkpoint/liangjian/da_dataset/office_home/Product/Postit_Notes/00001.jpg 44 46 | /Checkpoint/liangjian/da_dataset/office_home/Product/Printer/00018.jpg 45 47 | /Checkpoint/liangjian/da_dataset/office_home/Product/Push_Pin/00039.jpg 46 48 | /Checkpoint/liangjian/da_dataset/office_home/Product/Radio/00011.jpg 47 49 | /Checkpoint/liangjian/da_dataset/office_home/Product/Refrigerator/00024.jpg 48 50 | /Checkpoint/liangjian/da_dataset/office_home/Product/Ruler/00001.jpg 49 51 | /Checkpoint/liangjian/da_dataset/office_home/Product/Scissors/00072.jpg 50 52 | /Checkpoint/liangjian/da_dataset/office_home/Product/Screwdriver/00015.jpg 51 53 | /Checkpoint/liangjian/da_dataset/office_home/Product/Shelf/00037.jpg 52 54 | /Checkpoint/liangjian/da_dataset/office_home/Product/Sink/00024.jpg 53 55 | /Checkpoint/liangjian/da_dataset/office_home/Product/Sneakers/00061.jpg 54 56 | /Checkpoint/liangjian/da_dataset/office_home/Product/Soda/00041.jpg 55 57 | /Checkpoint/liangjian/da_dataset/office_home/Product/Speaker/00071.jpg 56 58 | /Checkpoint/liangjian/da_dataset/office_home/Product/Spoon/00038.jpg 57 59 | /Checkpoint/liangjian/da_dataset/office_home/Product/Table/00043.jpg 58 60 | /Checkpoint/liangjian/da_dataset/office_home/Product/Telephone/00011.jpg 59 61 | /Checkpoint/liangjian/da_dataset/office_home/Product/ToothBrush/00002.jpg 60 62 | /Checkpoint/liangjian/da_dataset/office_home/Product/Toys/00037.jpg 61 63 | /Checkpoint/liangjian/da_dataset/office_home/Product/Trash_Can/00090.jpg 62 64 | /Checkpoint/liangjian/da_dataset/office_home/Product/TV/00029.jpg 63 65 | /Checkpoint/liangjian/da_dataset/office_home/Product/Webcam/00064.jpg 64 66 | -------------------------------------------------------------------------------- /code/data/ssda/office-home/labeled_target_images_Real_1.txt: -------------------------------------------------------------------------------- 1 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Alarm_Clock/00085.jpg 0 2 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Backpack/00062.jpg 1 3 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Batteries/00041.jpg 2 4 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Bed/00055.jpg 3 5 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Bike/00056.jpg 4 6 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Bottle/00078.jpg 5 7 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Bucket/00064.jpg 6 8 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Calculator/00044.jpg 7 9 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Calendar/00064.jpg 8 10 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Candles/00097.jpg 9 11 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Chair/00069.jpg 10 12 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Clipboards/00040.jpg 11 13 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Computer/00004.jpg 12 14 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Couch/00037.jpg 13 15 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Curtains/00001.jpg 14 16 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Desk_Lamp/00003.jpg 15 17 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Drill/00038.jpg 16 18 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Eraser/00040.jpg 17 19 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Exit_Sign/00014.jpg 18 20 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Fan/00050.jpg 19 21 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/File_Cabinet/00032.jpg 20 22 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Flipflops/00077.jpg 21 23 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Flowers/00028.jpg 22 24 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Folder/00020.jpg 23 25 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Fork/00011.jpg 24 26 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Glasses/00031.jpg 25 27 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Hammer/00022.jpg 26 28 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Helmet/00034.jpg 27 29 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Kettle/00055.jpg 28 30 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Keyboard/00064.jpg 29 31 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Knives/00010.jpg 30 32 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Lamp_Shade/00026.jpg 31 33 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Laptop/00051.jpg 32 34 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Marker/00009.jpg 33 35 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Monitor/00055.jpg 34 36 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Mop/00044.jpg 35 37 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Mouse/00028.jpg 36 38 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Mug/00042.jpg 37 39 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Notebook/00013.jpg 38 40 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Oven/00036.jpg 39 41 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Pan/00011.jpg 40 42 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Paper_Clip/00034.jpg 41 43 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Pen/00064.jpg 42 44 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Pencil/00046.jpg 43 45 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Postit_Notes/00064.jpg 44 46 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Printer/00027.jpg 45 47 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Push_Pin/00034.jpg 46 48 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Radio/00003.jpg 47 49 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Refrigerator/00039.jpg 48 50 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Ruler/00035.jpg 49 51 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Scissors/00036.jpg 50 52 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Screwdriver/00038.jpg 51 53 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Shelf/00018.jpg 52 54 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Sink/00063.jpg 53 55 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Sneakers/00072.jpg 54 56 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Soda/00052.jpg 55 57 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Speaker/00014.jpg 56 58 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Spoon/00019.jpg 57 59 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Table/00025.jpg 58 60 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Telephone/00041.jpg 59 61 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/ToothBrush/00051.jpg 60 62 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Toys/00002.jpg 61 63 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Trash_Can/00077.jpg 62 64 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/TV/00015.jpg 63 65 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Webcam/00033.jpg 64 66 | -------------------------------------------------------------------------------- /code/msda/data_list.py: -------------------------------------------------------------------------------- 1 | #from __future__ import print_function, division 2 | 3 | import torch 4 | import numpy as np 5 | import random 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | import os 9 | import os.path 10 | 11 | import cv2 12 | import torchvision 13 | 14 | def make_dataset(image_list, labels): 15 | if labels: 16 | len_ = len(image_list) 17 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)] 18 | else: 19 | if len(image_list[0].split()) > 2: 20 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list] 21 | else: 22 | images = [(val.split()[0], int(val.split()[1])) for val in image_list] 23 | return images 24 | 25 | 26 | def rgb_loader(path): 27 | with open(path, 'rb') as f: 28 | with Image.open(f) as img: 29 | return img.convert('RGB') 30 | 31 | def l_loader(path): 32 | with open(path, 'rb') as f: 33 | with Image.open(f) as img: 34 | return img.convert('L') 35 | 36 | class ImageList(Dataset): 37 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 38 | imgs = make_dataset(image_list, labels) 39 | if len(imgs) == 0: 40 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 41 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 42 | 43 | self.imgs = imgs 44 | self.transform = transform 45 | self.target_transform = target_transform 46 | if mode == 'RGB': 47 | self.loader = rgb_loader 48 | elif mode == 'L': 49 | self.loader = l_loader 50 | 51 | def __getitem__(self, index): 52 | path, target = self.imgs[index] 53 | img = self.loader(path) 54 | if self.transform is not None: 55 | img = self.transform(img) 56 | if self.target_transform is not None: 57 | target = self.target_transform(target) 58 | 59 | return img, target 60 | 61 | def __len__(self): 62 | return len(self.imgs) 63 | 64 | 65 | class ImageList_twice(Dataset): 66 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 67 | imgs = make_dataset(image_list, labels) 68 | if len(imgs) == 0: 69 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 70 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 71 | 72 | self.imgs = imgs 73 | self.transform = transform 74 | self.target_transform = target_transform 75 | if mode == 'RGB': 76 | self.loader = rgb_loader 77 | elif mode == 'L': 78 | self.loader = l_loader 79 | 80 | def __getitem__(self, index): 81 | path, target = self.imgs[index] 82 | img = self.loader(path) 83 | if self.target_transform is not None: 84 | target = self.target_transform(target) 85 | if self.transform is not None: 86 | if type(self.transform).__name__=='list': 87 | img = [t(img) for t in self.transform] 88 | else: 89 | img = self.transform(img) 90 | 91 | return img, target, index 92 | 93 | def __len__(self): 94 | return len(self.imgs) 95 | 96 | class ImageList_idx(Dataset): 97 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 98 | imgs = make_dataset(image_list, labels) 99 | if len(imgs) == 0: 100 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 101 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 102 | 103 | self.imgs = imgs 104 | self.transform = transform 105 | self.target_transform = target_transform 106 | if mode == 'RGB': 107 | self.loader = rgb_loader 108 | elif mode == 'L': 109 | self.loader = l_loader 110 | 111 | def __getitem__(self, index): 112 | path, target = self.imgs[index] 113 | img = self.loader(path) 114 | if self.transform is not None: 115 | img = self.transform(img) 116 | if self.target_transform is not None: 117 | target = self.target_transform(target) 118 | 119 | return img, target, index 120 | 121 | def __len__(self): 122 | return len(self.imgs) 123 | 124 | 125 | class ImageValueList(Dataset): 126 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, 127 | loader=rgb_loader): 128 | imgs = make_dataset(image_list, labels) 129 | if len(imgs) == 0: 130 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 131 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 132 | 133 | self.imgs = imgs 134 | self.values = [1.0] * len(imgs) 135 | self.transform = transform 136 | self.target_transform = target_transform 137 | self.loader = loader 138 | 139 | def set_values(self, values): 140 | self.values = values 141 | 142 | def __getitem__(self, index): 143 | path, target = self.imgs[index] 144 | img = self.loader(path) 145 | if self.transform is not None: 146 | img = self.transform(img) 147 | if self.target_transform is not None: 148 | target = self.target_transform(target) 149 | 150 | return img, target 151 | 152 | def __len__(self): 153 | return len(self.imgs) 154 | 155 | class alexnetlist(Dataset): 156 | 157 | def __init__(self, list, training=True): 158 | self.images = [] 159 | self.labels = [] 160 | self.multi_scale = [256, 257] 161 | self.output_size = [227, 227] 162 | self.training = training 163 | self.mean_color=[104.006, 116.668, 122.678] 164 | 165 | list_file = open(list) 166 | lines = list_file.readlines() 167 | for line in lines: 168 | fields = line.split() 169 | self.images.append(fields[0]) 170 | self.labels.append(int(fields[1])) 171 | 172 | def __len__(self): 173 | return len(self.images) 174 | 175 | def __getitem__(self, index): 176 | image_path = self.images[index] 177 | label = self.labels[index] 178 | img = cv2.imread(image_path) 179 | if type(img) == None: 180 | print('Error: Image at {} not found.'.format(image_path)) 181 | 182 | if self.training and np.random.random() < 0.5: 183 | img = cv2.flip(img, 1) 184 | new_size = np.random.randint(self.multi_scale[0], self.multi_scale[1], 1)[0] 185 | 186 | img = cv2.resize(img, (new_size, new_size)) 187 | img = img.astype(np.float32) 188 | 189 | # cropping 190 | if self.training: 191 | diff = new_size - self.output_size[0] 192 | offset_x = np.random.randint(0, diff, 1)[0] 193 | offset_y = np.random.randint(0, diff, 1)[0] 194 | else: 195 | offset_x = img.shape[0]//2 - self.output_size[0] // 2 196 | offset_y = img.shape[1]//2 - self.output_size[1] // 2 197 | 198 | img = img[offset_x:(offset_x+self.output_size[0]), 199 | offset_y:(offset_y+self.output_size[1])] 200 | 201 | # substract mean 202 | img -= np.array(self.mean_color) 203 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 204 | 205 | # ToTensor transform cv2 HWC->CHW, only byteTensor will be div by 255. 206 | tensor = torchvision.transforms.ToTensor() 207 | img = tensor(img) 208 | # img = np.transpose(img, (2, 0, 1)) 209 | 210 | return img, label -------------------------------------------------------------------------------- /code/pda/data_list.py: -------------------------------------------------------------------------------- 1 | #from __future__ import print_function, division 2 | 3 | import torch 4 | import numpy as np 5 | import random 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | import os 9 | import os.path 10 | 11 | import cv2 12 | import torchvision 13 | 14 | def make_dataset(image_list, labels): 15 | if labels: 16 | len_ = len(image_list) 17 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)] 18 | else: 19 | if len(image_list[0].split()) > 2: 20 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list] 21 | else: 22 | images = [(val.split()[0], int(val.split()[1])) for val in image_list] 23 | return images 24 | 25 | 26 | def rgb_loader(path): 27 | with open(path, 'rb') as f: 28 | with Image.open(f) as img: 29 | return img.convert('RGB') 30 | 31 | def l_loader(path): 32 | with open(path, 'rb') as f: 33 | with Image.open(f) as img: 34 | return img.convert('L') 35 | 36 | class ImageList(Dataset): 37 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 38 | imgs = make_dataset(image_list, labels) 39 | if len(imgs) == 0: 40 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 41 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 42 | 43 | self.imgs = imgs 44 | self.transform = transform 45 | self.target_transform = target_transform 46 | if mode == 'RGB': 47 | self.loader = rgb_loader 48 | elif mode == 'L': 49 | self.loader = l_loader 50 | 51 | def __getitem__(self, index): 52 | path, target = self.imgs[index] 53 | img = self.loader(path) 54 | if self.transform is not None: 55 | img = self.transform(img) 56 | if self.target_transform is not None: 57 | target = self.target_transform(target) 58 | 59 | return img, target 60 | 61 | def __len__(self): 62 | return len(self.imgs) 63 | 64 | 65 | class ImageList_twice(Dataset): 66 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 67 | imgs = make_dataset(image_list, labels) 68 | if len(imgs) == 0: 69 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 70 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 71 | 72 | self.imgs = imgs 73 | self.transform = transform 74 | self.target_transform = target_transform 75 | if mode == 'RGB': 76 | self.loader = rgb_loader 77 | elif mode == 'L': 78 | self.loader = l_loader 79 | 80 | def __getitem__(self, index): 81 | path, target = self.imgs[index] 82 | img = self.loader(path) 83 | if self.target_transform is not None: 84 | target = self.target_transform(target) 85 | if self.transform is not None: 86 | if type(self.transform).__name__=='list': 87 | img = [t(img) for t in self.transform] 88 | else: 89 | img = self.transform(img) 90 | 91 | return img, target, index 92 | 93 | def __len__(self): 94 | return len(self.imgs) 95 | 96 | class ImageList_idx(Dataset): 97 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 98 | imgs = make_dataset(image_list, labels) 99 | if len(imgs) == 0: 100 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 101 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 102 | 103 | self.imgs = imgs 104 | self.transform = transform 105 | self.target_transform = target_transform 106 | if mode == 'RGB': 107 | self.loader = rgb_loader 108 | elif mode == 'L': 109 | self.loader = l_loader 110 | 111 | def __getitem__(self, index): 112 | path, target = self.imgs[index] 113 | img = self.loader(path) 114 | if self.transform is not None: 115 | img = self.transform(img) 116 | if self.target_transform is not None: 117 | target = self.target_transform(target) 118 | 119 | return img, target, index 120 | 121 | def __len__(self): 122 | return len(self.imgs) 123 | 124 | 125 | class ImageValueList(Dataset): 126 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, 127 | loader=rgb_loader): 128 | imgs = make_dataset(image_list, labels) 129 | if len(imgs) == 0: 130 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 131 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 132 | 133 | self.imgs = imgs 134 | self.values = [1.0] * len(imgs) 135 | self.transform = transform 136 | self.target_transform = target_transform 137 | self.loader = loader 138 | 139 | def set_values(self, values): 140 | self.values = values 141 | 142 | def __getitem__(self, index): 143 | path, target = self.imgs[index] 144 | img = self.loader(path) 145 | if self.transform is not None: 146 | img = self.transform(img) 147 | if self.target_transform is not None: 148 | target = self.target_transform(target) 149 | 150 | return img, target 151 | 152 | def __len__(self): 153 | return len(self.imgs) 154 | 155 | class alexnetlist(Dataset): 156 | 157 | def __init__(self, list, training=True): 158 | self.images = [] 159 | self.labels = [] 160 | self.multi_scale = [256, 257] 161 | self.output_size = [227, 227] 162 | self.training = training 163 | self.mean_color=[104.006, 116.668, 122.678] 164 | 165 | list_file = open(list) 166 | lines = list_file.readlines() 167 | for line in lines: 168 | fields = line.split() 169 | self.images.append(fields[0]) 170 | self.labels.append(int(fields[1])) 171 | 172 | def __len__(self): 173 | return len(self.images) 174 | 175 | def __getitem__(self, index): 176 | image_path = self.images[index] 177 | label = self.labels[index] 178 | img = cv2.imread(image_path) 179 | if type(img) == None: 180 | print('Error: Image at {} not found.'.format(image_path)) 181 | 182 | if self.training and np.random.random() < 0.5: 183 | img = cv2.flip(img, 1) 184 | new_size = np.random.randint(self.multi_scale[0], self.multi_scale[1], 1)[0] 185 | 186 | img = cv2.resize(img, (new_size, new_size)) 187 | img = img.astype(np.float32) 188 | 189 | # cropping 190 | if self.training: 191 | diff = new_size - self.output_size[0] 192 | offset_x = np.random.randint(0, diff, 1)[0] 193 | offset_y = np.random.randint(0, diff, 1)[0] 194 | else: 195 | offset_x = img.shape[0]//2 - self.output_size[0] // 2 196 | offset_y = img.shape[1]//2 - self.output_size[1] // 2 197 | 198 | img = img[offset_x:(offset_x+self.output_size[0]), 199 | offset_y:(offset_y+self.output_size[1])] 200 | 201 | # substract mean 202 | img -= np.array(self.mean_color) 203 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 204 | 205 | # ToTensor transform cv2 HWC->CHW, only byteTensor will be div by 255. 206 | tensor = torchvision.transforms.ToTensor() 207 | img = tensor(img) 208 | # img = np.transpose(img, (2, 0, 1)) 209 | 210 | return img, label -------------------------------------------------------------------------------- /code/ssda/data_list.py: -------------------------------------------------------------------------------- 1 | #from __future__ import print_function, division 2 | 3 | import torch 4 | import numpy as np 5 | import random 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | import os 9 | import os.path 10 | 11 | import cv2 12 | import torchvision 13 | 14 | def make_dataset(image_list, labels): 15 | if labels: 16 | len_ = len(image_list) 17 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)] 18 | else: 19 | if len(image_list[0].split()) > 2: 20 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list] 21 | else: 22 | images = [(val.split()[0], int(val.split()[1])) for val in image_list] 23 | return images 24 | 25 | 26 | def rgb_loader(path): 27 | with open(path, 'rb') as f: 28 | with Image.open(f) as img: 29 | return img.convert('RGB') 30 | 31 | def l_loader(path): 32 | with open(path, 'rb') as f: 33 | with Image.open(f) as img: 34 | return img.convert('L') 35 | 36 | class ImageList(Dataset): 37 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 38 | imgs = make_dataset(image_list, labels) 39 | if len(imgs) == 0: 40 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 41 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 42 | 43 | self.imgs = imgs 44 | self.transform = transform 45 | self.target_transform = target_transform 46 | if mode == 'RGB': 47 | self.loader = rgb_loader 48 | elif mode == 'L': 49 | self.loader = l_loader 50 | 51 | def __getitem__(self, index): 52 | path, target = self.imgs[index] 53 | img = self.loader(path) 54 | if self.transform is not None: 55 | img = self.transform(img) 56 | if self.target_transform is not None: 57 | target = self.target_transform(target) 58 | 59 | return img, target 60 | 61 | def __len__(self): 62 | return len(self.imgs) 63 | 64 | 65 | class ImageList_twice(Dataset): 66 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 67 | imgs = make_dataset(image_list, labels) 68 | if len(imgs) == 0: 69 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 70 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 71 | 72 | self.imgs = imgs 73 | self.transform = transform 74 | self.target_transform = target_transform 75 | if mode == 'RGB': 76 | self.loader = rgb_loader 77 | elif mode == 'L': 78 | self.loader = l_loader 79 | 80 | def __getitem__(self, index): 81 | path, target = self.imgs[index] 82 | img = self.loader(path) 83 | if self.target_transform is not None: 84 | target = self.target_transform(target) 85 | if self.transform is not None: 86 | if type(self.transform).__name__=='list': 87 | img = [t(img) for t in self.transform] 88 | else: 89 | img = self.transform(img) 90 | 91 | return img, target, index 92 | 93 | def __len__(self): 94 | return len(self.imgs) 95 | 96 | class ImageList_idx(Dataset): 97 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 98 | imgs = make_dataset(image_list, labels) 99 | if len(imgs) == 0: 100 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 101 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 102 | 103 | self.imgs = imgs 104 | self.transform = transform 105 | self.target_transform = target_transform 106 | if mode == 'RGB': 107 | self.loader = rgb_loader 108 | elif mode == 'L': 109 | self.loader = l_loader 110 | 111 | def __getitem__(self, index): 112 | path, target = self.imgs[index] 113 | img = self.loader(path) 114 | if self.transform is not None: 115 | img = self.transform(img) 116 | if self.target_transform is not None: 117 | target = self.target_transform(target) 118 | 119 | return img, target, index 120 | 121 | def __len__(self): 122 | return len(self.imgs) 123 | 124 | 125 | class ImageValueList(Dataset): 126 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, 127 | loader=rgb_loader): 128 | imgs = make_dataset(image_list, labels) 129 | if len(imgs) == 0: 130 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 131 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 132 | 133 | self.imgs = imgs 134 | self.values = [1.0] * len(imgs) 135 | self.transform = transform 136 | self.target_transform = target_transform 137 | self.loader = loader 138 | 139 | def set_values(self, values): 140 | self.values = values 141 | 142 | def __getitem__(self, index): 143 | path, target = self.imgs[index] 144 | img = self.loader(path) 145 | if self.transform is not None: 146 | img = self.transform(img) 147 | if self.target_transform is not None: 148 | target = self.target_transform(target) 149 | 150 | return img, target 151 | 152 | def __len__(self): 153 | return len(self.imgs) 154 | 155 | class alexnetlist(Dataset): 156 | 157 | def __init__(self, list, training=True): 158 | self.images = [] 159 | self.labels = [] 160 | self.multi_scale = [256, 257] 161 | self.output_size = [227, 227] 162 | self.training = training 163 | self.mean_color=[104.006, 116.668, 122.678] 164 | 165 | list_file = open(list) 166 | lines = list_file.readlines() 167 | for line in lines: 168 | fields = line.split() 169 | self.images.append(fields[0]) 170 | self.labels.append(int(fields[1])) 171 | 172 | def __len__(self): 173 | return len(self.images) 174 | 175 | def __getitem__(self, index): 176 | image_path = self.images[index] 177 | label = self.labels[index] 178 | img = cv2.imread(image_path) 179 | if type(img) == None: 180 | print('Error: Image at {} not found.'.format(image_path)) 181 | 182 | if self.training and np.random.random() < 0.5: 183 | img = cv2.flip(img, 1) 184 | new_size = np.random.randint(self.multi_scale[0], self.multi_scale[1], 1)[0] 185 | 186 | img = cv2.resize(img, (new_size, new_size)) 187 | img = img.astype(np.float32) 188 | 189 | # cropping 190 | if self.training: 191 | diff = new_size - self.output_size[0] 192 | offset_x = np.random.randint(0, diff, 1)[0] 193 | offset_y = np.random.randint(0, diff, 1)[0] 194 | else: 195 | offset_x = img.shape[0]//2 - self.output_size[0] // 2 196 | offset_y = img.shape[1]//2 - self.output_size[1] // 2 197 | 198 | img = img[offset_x:(offset_x+self.output_size[0]), 199 | offset_y:(offset_y+self.output_size[1])] 200 | 201 | # substract mean 202 | img -= np.array(self.mean_color) 203 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 204 | 205 | # ToTensor transform cv2 HWC->CHW, only byteTensor will be div by 255. 206 | tensor = torchvision.transforms.ToTensor() 207 | img = tensor(img) 208 | # img = np.transpose(img, (2, 0, 1)) 209 | 210 | return img, label -------------------------------------------------------------------------------- /code/uda/data_list.py: -------------------------------------------------------------------------------- 1 | #from __future__ import print_function, division 2 | 3 | import torch 4 | import numpy as np 5 | import random 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | import os 9 | import os.path 10 | 11 | import cv2 12 | import torchvision 13 | 14 | def make_dataset(image_list, labels): 15 | if labels: 16 | len_ = len(image_list) 17 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)] 18 | else: 19 | if len(image_list[0].split()) > 2: 20 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list] 21 | else: 22 | images = [(val.split()[0], int(val.split()[1])) for val in image_list] 23 | return images 24 | 25 | 26 | def rgb_loader(path): 27 | with open(path, 'rb') as f: 28 | with Image.open(f) as img: 29 | return img.convert('RGB') 30 | 31 | def l_loader(path): 32 | with open(path, 'rb') as f: 33 | with Image.open(f) as img: 34 | return img.convert('L') 35 | 36 | class ImageList(Dataset): 37 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 38 | imgs = make_dataset(image_list, labels) 39 | if len(imgs) == 0: 40 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 41 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 42 | 43 | self.imgs = imgs 44 | self.transform = transform 45 | self.target_transform = target_transform 46 | if mode == 'RGB': 47 | self.loader = rgb_loader 48 | elif mode == 'L': 49 | self.loader = l_loader 50 | 51 | def __getitem__(self, index): 52 | path, target = self.imgs[index] 53 | img = self.loader(path) 54 | if self.transform is not None: 55 | img = self.transform(img) 56 | if self.target_transform is not None: 57 | target = self.target_transform(target) 58 | 59 | return img, target 60 | 61 | def __len__(self): 62 | return len(self.imgs) 63 | 64 | 65 | class ImageList_twice(Dataset): 66 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 67 | imgs = make_dataset(image_list, labels) 68 | if len(imgs) == 0: 69 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 70 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 71 | 72 | self.imgs = imgs 73 | self.transform = transform 74 | self.target_transform = target_transform 75 | if mode == 'RGB': 76 | self.loader = rgb_loader 77 | elif mode == 'L': 78 | self.loader = l_loader 79 | 80 | def __getitem__(self, index): 81 | path, target = self.imgs[index] 82 | img = self.loader(path) 83 | if self.target_transform is not None: 84 | target = self.target_transform(target) 85 | if self.transform is not None: 86 | if type(self.transform).__name__=='list': 87 | img = [t(img) for t in self.transform] 88 | else: 89 | img = self.transform(img) 90 | 91 | return img, target, index 92 | 93 | def __len__(self): 94 | return len(self.imgs) 95 | 96 | class ImageList_idx(Dataset): 97 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 98 | imgs = make_dataset(image_list, labels) 99 | if len(imgs) == 0: 100 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 101 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 102 | 103 | self.imgs = imgs 104 | self.transform = transform 105 | self.target_transform = target_transform 106 | if mode == 'RGB': 107 | self.loader = rgb_loader 108 | elif mode == 'L': 109 | self.loader = l_loader 110 | 111 | def __getitem__(self, index): 112 | path, target = self.imgs[index] 113 | img = self.loader(path) 114 | if self.transform is not None: 115 | img = self.transform(img) 116 | if self.target_transform is not None: 117 | target = self.target_transform(target) 118 | 119 | return img, target, index 120 | 121 | def __len__(self): 122 | return len(self.imgs) 123 | 124 | 125 | class ImageValueList(Dataset): 126 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, 127 | loader=rgb_loader): 128 | imgs = make_dataset(image_list, labels) 129 | if len(imgs) == 0: 130 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 131 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 132 | 133 | self.imgs = imgs 134 | self.values = [1.0] * len(imgs) 135 | self.transform = transform 136 | self.target_transform = target_transform 137 | self.loader = loader 138 | 139 | def set_values(self, values): 140 | self.values = values 141 | 142 | def __getitem__(self, index): 143 | path, target = self.imgs[index] 144 | img = self.loader(path) 145 | if self.transform is not None: 146 | img = self.transform(img) 147 | if self.target_transform is not None: 148 | target = self.target_transform(target) 149 | 150 | return img, target 151 | 152 | def __len__(self): 153 | return len(self.imgs) 154 | 155 | class alexnetlist(Dataset): 156 | 157 | def __init__(self, list, training=True): 158 | self.images = [] 159 | self.labels = [] 160 | self.multi_scale = [256, 257] 161 | self.output_size = [227, 227] 162 | self.training = training 163 | self.mean_color=[104.006, 116.668, 122.678] 164 | 165 | list_file = open(list) 166 | lines = list_file.readlines() 167 | for line in lines: 168 | fields = line.split() 169 | self.images.append(fields[0]) 170 | self.labels.append(int(fields[1])) 171 | 172 | def __len__(self): 173 | return len(self.images) 174 | 175 | def __getitem__(self, index): 176 | image_path = self.images[index] 177 | label = self.labels[index] 178 | img = cv2.imread(image_path) 179 | if type(img) == None: 180 | print('Error: Image at {} not found.'.format(image_path)) 181 | 182 | if self.training and np.random.random() < 0.5: 183 | img = cv2.flip(img, 1) 184 | new_size = np.random.randint(self.multi_scale[0], self.multi_scale[1], 1)[0] 185 | 186 | img = cv2.resize(img, (new_size, new_size)) 187 | img = img.astype(np.float32) 188 | 189 | # cropping 190 | if self.training: 191 | diff = new_size - self.output_size[0] 192 | offset_x = np.random.randint(0, diff, 1)[0] 193 | offset_y = np.random.randint(0, diff, 1)[0] 194 | else: 195 | offset_x = img.shape[0]//2 - self.output_size[0] // 2 196 | offset_y = img.shape[1]//2 - self.output_size[1] // 2 197 | 198 | img = img[offset_x:(offset_x+self.output_size[0]), 199 | offset_y:(offset_y+self.output_size[1])] 200 | 201 | # substract mean 202 | img -= np.array(self.mean_color) 203 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 204 | 205 | # ToTensor transform cv2 HWC->CHW, only byteTensor will be div by 255. 206 | tensor = torchvision.transforms.ToTensor() 207 | img = tensor(img) 208 | # img = np.transpose(img, (2, 0, 1)) 209 | 210 | return img, label -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Official implementation for **SHOT++** 2 | 3 | ## [**[TPAMI-2021] Source Data-absent Unsupervised Domain Adaptation through Hypothesis Transfer and Labeling Transfer**](https://ieeexplore.ieee.org/abstract/document/9512429/) 4 | 5 | ### Framework: 6 | 7 | 1. train on the source domain; (Section 3.1) 8 | 2. **hypothesis transfer with information maximization and *self-supervised learning*; (Section 3.2 & Section 3.3)** 9 | (note that SHOT here means results after step 2, which contains an additional rotation-driven self-supervised objective compared with the original SHOT in ICML 2020) 10 | 11 | 12 | 13 | 3. **labeling transfer with semi-supervised learning. (Section 3.4)** 14 | (note that SHOT++ has an extra semi-supervised learning step via MixMatch) 15 | 16 | 17 | 18 | ### Prerequisites: 19 | - python == 3.6.8 20 | - pytorch ==1.1.0 21 | - torchvision == 0.3.0 22 | - numpy, scipy, sklearn, PIL, argparse, tqdm 23 | 24 | ### Dataset: 25 | 26 | - Please manually download the datasets [Office](https://drive.google.com/file/d/0B4IapRTv9pJ1WGZVd1VDMmhwdlE/view), [Office-Home](https://drive.google.com/file/d/0B81rNlvomiwed0V1YUxQdC1uOTg/view), [VisDA-C](https://github.com/VisionLearningGroup/taskcv-2017-public/tree/master/classification), [Office-Caltech](http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar) from the official websites, and modify the path of images in each '.txt' under the folder './object/data/'. [**How to generate such txt files could be found in https://github.com/tim-learn/Generate_list **] 27 | 28 | - Concerning the **Digits** dsatasets, the code will automatically download three digit datasets (i.e., MNIST, USPS, and SVHN) in './digit/data/'. 29 | 30 | ### Training: 31 | 1. ##### Unsupervised Closed-set Domain Adaptation (UDA) on the Digits dataset 32 | - MNIST -> USPS (**m2u**) 33 | ```python 34 | cd digit/ 35 | python uda_digit.py --gpu_id 0 --seed 2021 --dset m2u --output ckps_digits --cls_par 0.1 --ssl 0.2 36 | python digit_mixmatch.py --gpu_id 0 --seed 2021 --dset m2u --output ckps_mm --output_tar ckps_digits --cls_par 0.1 --ssl 0.2 --alpha 0.1 37 | ``` 38 | 2. ##### Unsupervised Closed-set Domain Adaptation (UDA) on the Office/ Office-Home dataset 39 | - Train model on the source domain **A** (**s = 0**) [--max_epoch 50 for Office-Home] 40 | ```python 41 | cd uda/ 42 | python image_source.py --gpu_id 0 --seed 2021 --trte val --da uda --output ckps/source/ --dset office --max_epoch 100 --s 0 43 | ``` 44 | - Adaptation to other target domains **D and W** (hypothesis transfer) 45 | ```python 46 | python image_target.py --gpu_id 0 --seed 2021 --da uda --output ckps/target/ --dset office --s 0 --cls_par 0.3 --ssl 0.6 47 | ``` 48 | - Adaptation to other target domains **D and W** (following labeling transfer) [--max_epoch 50 for Office-Home] 49 | ```python 50 | python image_mixmatch.py --gpu_id 0 --seed 2021 --da uda --dset office --max_epoch 100 --s 0 --output_tar ckps/target/ --output ckps/mixmatch/ --cls_par 0.3 --ssl 0.6 --choice ent --ps 0.0 51 | ``` 52 | 53 | 3. ##### Unsupervised Closed-set Domain Adaptation (UDA) on the VISDA-C dataset 54 | - Train model on the Synthetic domain [--max_epoch 10 --lr 1e-3] 55 | ```python 56 | cd uda/ 57 | python image_source.py --gpu_id 0 --seed 2021 --trte val --da uda --output ckps/source/ --dset VISDA-C --net resnet101 --lr 1e-3 --max_epoch 10 --s 0 58 | ``` 59 | - Adaptation to the real domain (hypothesis transfer) 60 | ```python 61 | python image_target.py --gpu_id 0 --seed 2021 --da uda --output ckps/target/ --dset VISDA-C --s 0 --net resnet101 --cls_par 0.3 --ssl 0.6 62 | ``` 63 | - Adaptation to the real domain (following labeling transfer) 64 | ```python 65 | python image_mixmatch.py --gpu_id 0 --seed 2021 --da uda --dset VISDA-C --max_epoch 10 --s 0 --output_tar ckps/target/ --output ckps/mixmatch/ --net resnet101 --cls_par 0.3 --ssl 0.6 --choice ent --ps 0.0 66 | ``` 67 | 68 | 4. ##### Unsupervised Partial-set Domain Adaptation (PDA) on the Office-Home dataset 69 | - Train model on the source domain **A** (**s = 0**) 70 | ```python 71 | cd pda/ 72 | python image_source.py --gpu_id 0 --seed 2021 --trte val --da pda --output ckps/source/ --dset office-home --max_epoch 50 --s 0 73 | ``` 74 | 75 | - Adaptation to other target domains (hypothesis transfer) 76 | ```python 77 | python image_target.py --gpu_id 0 --seed 2021 --da pda --dset office-home --s 0 --output_src ckps/source/ --output ckps/target/ --cls_par 0.3 --ssl 0.6 78 | ``` 79 | 80 | - Adaptation to the real domain (following labeling transfer) 81 | ```python 82 | python image_mixmatch.py --gpu_id 0 --seed 2021 --da pda --dset office-home --max_epoch 50 --s 0 --output_tar ckps/target/ --output ckps/mixmatch/ --cls_par 0.3 --ssl 0.6 --choice ent --ps 0.0 83 | ``` 84 | 85 | 5. ##### Unsupervised Multi-source Domain Adaptation (MSDA) on the Office-Home dataset 86 | - Train model on the source domains **Ar** (**s = 0**), **Cl** (**s = 1**), **Pr** (**s = 2**), respectively 87 | ```python 88 | cd msda/ 89 | python image_source.py --gpu_id 0 --seed 2021 --trte val --da uda --dset office-home --output ckps/source/ --net resnet50 --max_epoch 50 --s 0 90 | python image_source.py --gpu_id 0 --seed 2021 --trte val --da uda --dset office-home --output ckps/source/ --net resnet50 --max_epoch 50 --s 1 91 | python image_source.py --gpu_id 0 --seed 2021 --trte val --da uda --dset office-home --output ckps/source/ --net resnet50 --max_epoch 50 --s 2 92 | ``` 93 | 94 | - Adaptation to the target domain (hypothesis transfer) 95 | ```python 96 | python image_target.py --gpu_id 0 --seed 2021 --cls_par 0.3 --ssl 0.6 --da uda --dset office-home --output_src ckps/source/ --output ckps/target/ --net resnet50 --s 0 97 | python image_target.py --gpu_id 0 --seed 2021 --cls_par 0.3 --ssl 0.6 --da uda --dset office-home --output_src ckps/source/ --output ckps/target/ --net resnet50 --s 1 98 | python image_target.py --gpu_id 0 --seed 2021 --cls_par 0.3 --ssl 0.6 --da uda --dset office-home --output_src ckps/source/ --output ckps/target/ --net resnet50 --s 2 99 | ``` 100 | 101 | - Adaptation to the target domain (labeling transfer) 102 | ```python 103 | python image_mixmatch.py --gpu_id 0 --seed 2021 --da uda --dset office-home --max_epoch 50 --output_tar ckps/target/ --output ckps/mixmatch/ --cls_par 0.3 --ssl 0.6 --choice ent --ps 0.0 --net resnet50 --s 0 104 | python image_mixmatch.py --gpu_id 0 --seed 2021 --da uda --dset office-home --max_epoch 50 --output_tar ckps/target/ --output ckps/mixmatch/ --cls_par 0.3 --ssl 0.6 --choice ent --ps 0.0 --net resnet50 --s 1 105 | python image_mixmatch.py --gpu_id 0 --seed 2021 --da uda --dset office-home --max_epoch 50 --output_tar ckps/target/ --output ckps/mixmatch/ --cls_par 0.3 --ssl 0.6 --choice ent --ps 0.0 --net resnet50 --s 2 106 | ``` 107 | 108 | - Combine domain-spetific scores together 109 | ```python 110 | python image_ms.py --gpu_id 0 --seed 2021 --cls_par 0.3 --ssl 0.6 --da uda --dset office-home --output_src ckps/source/ --output ckps/target/ --output_mm ckps/mixmatch/ --net resnet50 --t 3 111 | ``` 112 | 113 | 6. ##### Semi-supervised Domain Adaptation (SSDA) on the Office-Home dataset 114 | - Train model on the source domain **Ar** (**s = 0**) 115 | ```python 116 | cd ssda/ 117 | python image_source.py --gpu_id 0 --seed 2021 --output ckps/source/ --dset office-home --max_epoch 50 --s 0 118 | ``` 119 | 120 | - Adaptation to the target domain **Cl** (**t = 1**) [hypothesis transfer] 121 | ```python 122 | python image_target.py --gpu_id 0 --seed 2021 --cls_par 0.1 --ssl 0.2 --output_src ckps/source --output ckps/target --dset office-home --s 0 --t 1 123 | ``` 124 | 125 | - Adaptation to the target domain **Cl** (**t = 1**) [labeling transfer] 126 | ```python 127 | python image_mixmatch.py --gpu_id 0 --seed 2021 --ps 0.0 --cls_par 0.1 --ssl 0.2 --output_tar ckps/target --output ckps/mixmatch --dset office-home --max_epoch 50 --s 0 --t 1 128 | ``` 129 | 130 | **Please refer *./xxda/run_xxda.sh*** for all the settings for different methods and scenarios. 131 | 132 | ### Citation 133 | 134 | If you find this code useful for your research, please cite our papers 135 | ``` 136 | @article{liang2021source, 137 | title={Source Data-absent Unsupervised Domain Adaptation through Hypothesis Transfer and Labeling Transfer}, 138 | author={Liang, Jian and Hu, Dapeng and Wang, Yunbo and He, Ran and Feng, Jiashi}, 139 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, 140 | year={2021}, 141 | note={In Press} 142 | } 143 | 144 | @inproceedings{liang2020we, 145 | title={Do We Really Need to Access the Source Data? Source Hypothesis Transfer for Unsupervised Domain Adaptation}, 146 | author={Liang, Jian and Hu, Dapeng and Feng, Jiashi}, 147 | booktitle={International Conference on Machine Learning (ICML)}, 148 | pages={6028--6039}, 149 | year={2020} 150 | } 151 | ``` 152 | 153 | ### Contact 154 | 155 | - [liangjian92@gmail.com](mailto:liangjian92@gmail.com) 156 | - [dapeng.hu@u.nus.edu](mailto:dapeng.hu@u.nus.edu) 157 | - [elefjia@nus.edu.sg](mailto:elefjia@nus.edu.sg) 158 | -------------------------------------------------------------------------------- /code/digit/data_load/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import hashlib 4 | import gzip 5 | import errno 6 | import tarfile 7 | import zipfile 8 | 9 | import torch 10 | from torch.utils.model_zoo import tqdm 11 | 12 | 13 | def gen_bar_updater(): 14 | pbar = tqdm(total=None) 15 | 16 | def bar_update(count, block_size, total_size): 17 | if pbar.total is None and total_size: 18 | pbar.total = total_size 19 | progress_bytes = count * block_size 20 | pbar.update(progress_bytes - pbar.n) 21 | 22 | return bar_update 23 | 24 | 25 | def calculate_md5(fpath, chunk_size=1024 * 1024): 26 | md5 = hashlib.md5() 27 | with open(fpath, 'rb') as f: 28 | for chunk in iter(lambda: f.read(chunk_size), b''): 29 | md5.update(chunk) 30 | return md5.hexdigest() 31 | 32 | 33 | def check_md5(fpath, md5, **kwargs): 34 | return md5 == calculate_md5(fpath, **kwargs) 35 | 36 | 37 | def check_integrity(fpath, md5=None): 38 | if not os.path.isfile(fpath): 39 | return False 40 | if md5 is None: 41 | return True 42 | return check_md5(fpath, md5) 43 | 44 | 45 | def download_url(url, root, filename=None, md5=None): 46 | """Download a file from a url and place it in root. 47 | 48 | Args: 49 | url (str): URL to download file from 50 | root (str): Directory to place downloaded file in 51 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 52 | md5 (str, optional): MD5 checksum of the download. If None, do not check 53 | """ 54 | import urllib 55 | 56 | root = os.path.expanduser(root) 57 | if not filename: 58 | filename = os.path.basename(url) 59 | fpath = os.path.join(root, filename) 60 | 61 | os.makedirs(root, exist_ok=True) 62 | 63 | # check if file is already present locally 64 | if check_integrity(fpath, md5): 65 | print('Using downloaded and verified file: ' + fpath) 66 | else: # download the file 67 | try: 68 | print('Downloading ' + url + ' to ' + fpath) 69 | urllib.request.urlretrieve( 70 | url, fpath, 71 | reporthook=gen_bar_updater() 72 | ) 73 | except (urllib.error.URLError, IOError) as e: 74 | if url[:5] == 'https': 75 | url = url.replace('https:', 'http:') 76 | print('Failed download. Trying https -> http instead.' 77 | ' Downloading ' + url + ' to ' + fpath) 78 | urllib.request.urlretrieve( 79 | url, fpath, 80 | reporthook=gen_bar_updater() 81 | ) 82 | else: 83 | raise e 84 | # check integrity of downloaded file 85 | if not check_integrity(fpath, md5): 86 | raise RuntimeError("File not found or corrupted.") 87 | 88 | 89 | def list_dir(root, prefix=False): 90 | """List all directories at a given root 91 | 92 | Args: 93 | root (str): Path to directory whose folders need to be listed 94 | prefix (bool, optional): If true, prepends the path to each result, otherwise 95 | only returns the name of the directories found 96 | """ 97 | root = os.path.expanduser(root) 98 | directories = list( 99 | filter( 100 | lambda p: os.path.isdir(os.path.join(root, p)), 101 | os.listdir(root) 102 | ) 103 | ) 104 | 105 | if prefix is True: 106 | directories = [os.path.join(root, d) for d in directories] 107 | 108 | return directories 109 | 110 | 111 | def list_files(root, suffix, prefix=False): 112 | """List all files ending with a suffix at a given root 113 | 114 | Args: 115 | root (str): Path to directory whose folders need to be listed 116 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 117 | It uses the Python "str.endswith" method and is passed directly 118 | prefix (bool, optional): If true, prepends the path to each result, otherwise 119 | only returns the name of the files found 120 | """ 121 | root = os.path.expanduser(root) 122 | files = list( 123 | filter( 124 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 125 | os.listdir(root) 126 | ) 127 | ) 128 | 129 | if prefix is True: 130 | files = [os.path.join(root, d) for d in files] 131 | 132 | return files 133 | 134 | 135 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 136 | """Download a Google Drive file from and place it in root. 137 | 138 | Args: 139 | file_id (str): id of file to be downloaded 140 | root (str): Directory to place downloaded file in 141 | filename (str, optional): Name to save the file under. If None, use the id of the file. 142 | md5 (str, optional): MD5 checksum of the download. If None, do not check 143 | """ 144 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 145 | import requests 146 | url = "https://docs.google.com/uc?export=download" 147 | 148 | root = os.path.expanduser(root) 149 | if not filename: 150 | filename = file_id 151 | fpath = os.path.join(root, filename) 152 | 153 | os.makedirs(root, exist_ok=True) 154 | 155 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 156 | print('Using downloaded and verified file: ' + fpath) 157 | else: 158 | session = requests.Session() 159 | 160 | response = session.get(url, params={'id': file_id}, stream=True) 161 | token = _get_confirm_token(response) 162 | 163 | if token: 164 | params = {'id': file_id, 'confirm': token} 165 | response = session.get(url, params=params, stream=True) 166 | 167 | _save_response_content(response, fpath) 168 | 169 | 170 | def _get_confirm_token(response): 171 | for key, value in response.cookies.items(): 172 | if key.startswith('download_warning'): 173 | return value 174 | 175 | return None 176 | 177 | 178 | def _save_response_content(response, destination, chunk_size=32768): 179 | with open(destination, "wb") as f: 180 | pbar = tqdm(total=None) 181 | progress = 0 182 | for chunk in response.iter_content(chunk_size): 183 | if chunk: # filter out keep-alive new chunks 184 | f.write(chunk) 185 | progress += len(chunk) 186 | pbar.update(progress - pbar.n) 187 | pbar.close() 188 | 189 | 190 | def _is_tarxz(filename): 191 | return filename.endswith(".tar.xz") 192 | 193 | 194 | def _is_tar(filename): 195 | return filename.endswith(".tar") 196 | 197 | 198 | def _is_targz(filename): 199 | return filename.endswith(".tar.gz") 200 | 201 | 202 | def _is_tgz(filename): 203 | return filename.endswith(".tgz") 204 | 205 | 206 | def _is_gzip(filename): 207 | return filename.endswith(".gz") and not filename.endswith(".tar.gz") 208 | 209 | 210 | def _is_zip(filename): 211 | return filename.endswith(".zip") 212 | 213 | 214 | def extract_archive(from_path, to_path=None, remove_finished=False): 215 | if to_path is None: 216 | to_path = os.path.dirname(from_path) 217 | 218 | if _is_tar(from_path): 219 | with tarfile.open(from_path, 'r') as tar: 220 | tar.extractall(path=to_path) 221 | elif _is_targz(from_path) or _is_tgz(from_path): 222 | with tarfile.open(from_path, 'r:gz') as tar: 223 | tar.extractall(path=to_path) 224 | elif _is_tarxz(from_path): 225 | with tarfile.open(from_path, 'r:xz') as tar: 226 | tar.extractall(path=to_path) 227 | elif _is_gzip(from_path): 228 | to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) 229 | with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: 230 | out_f.write(zip_f.read()) 231 | elif _is_zip(from_path): 232 | with zipfile.ZipFile(from_path, 'r') as z: 233 | z.extractall(to_path) 234 | else: 235 | raise ValueError("Extraction of {} not supported".format(from_path)) 236 | 237 | if remove_finished: 238 | os.remove(from_path) 239 | 240 | 241 | def download_and_extract_archive(url, download_root, extract_root=None, filename=None, 242 | md5=None, remove_finished=False): 243 | download_root = os.path.expanduser(download_root) 244 | if extract_root is None: 245 | extract_root = download_root 246 | if not filename: 247 | filename = os.path.basename(url) 248 | 249 | download_url(url, download_root, filename, md5) 250 | 251 | archive = os.path.join(download_root, filename) 252 | print("Extracting {} to {}".format(archive, extract_root)) 253 | extract_archive(archive, extract_root, remove_finished) 254 | 255 | 256 | def iterable_to_str(iterable): 257 | return "'" + "', '".join([str(item) for item in iterable]) + "'" 258 | 259 | 260 | def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None): 261 | if not isinstance(value, torch._six.string_classes): 262 | if arg is None: 263 | msg = "Expected type str, but got type {type}." 264 | else: 265 | msg = "Expected type str for argument {arg}, but got type {type}." 266 | msg = msg.format(type=type(value), arg=arg) 267 | raise ValueError(msg) 268 | 269 | if valid_values is None: 270 | return value 271 | 272 | if value not in valid_values: 273 | if custom_msg is not None: 274 | msg = custom_msg 275 | else: 276 | msg = ("Unknown value '{value}' for argument {arg}. " 277 | "Valid values are {{{valid_values}}}.") 278 | msg = msg.format(value=value, arg=arg, 279 | valid_values=iterable_to_str(valid_values)) 280 | raise ValueError(msg) 281 | 282 | return value -------------------------------------------------------------------------------- /code/digit/data_load/usps.py: -------------------------------------------------------------------------------- 1 | """Dataset setting and data loader for USPS. 2 | Modified from 3 | https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py 4 | """ 5 | 6 | import gzip 7 | import os 8 | import pickle 9 | import urllib 10 | from PIL import Image 11 | 12 | import numpy as np 13 | import torch 14 | import torch.utils.data as data 15 | from torch.utils.data.sampler import WeightedRandomSampler 16 | from torchvision import datasets, transforms 17 | 18 | 19 | class USPS(data.Dataset): 20 | """USPS Dataset. 21 | Args: 22 | root (string): Root directory of dataset where dataset file exist. 23 | train (bool, optional): If True, resample from dataset randomly. 24 | download (bool, optional): If true, downloads the dataset 25 | from the internet and puts it in root directory. 26 | If dataset is already downloaded, it is not downloaded again. 27 | transform (callable, optional): A function/transform that takes in 28 | an PIL image and returns a transformed version. 29 | E.g, ``transforms.RandomCrop`` 30 | """ 31 | 32 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl" 33 | 34 | def __init__(self, root, train=True, transform=None, download=False): 35 | """Init USPS dataset.""" 36 | # init params 37 | self.root = os.path.expanduser(root) 38 | self.filename = "usps_28x28.pkl" 39 | self.train = train 40 | # Num of Train = 7438, Num ot Test 1860 41 | self.transform = transform 42 | self.dataset_size = None 43 | 44 | # download dataset. 45 | if download: 46 | self.download() 47 | if not self._check_exists(): 48 | raise RuntimeError("Dataset not found." + 49 | " You can use download=True to download it") 50 | 51 | self.train_data, self.train_labels = self.load_samples() 52 | if self.train: 53 | total_num_samples = self.train_labels.shape[0] 54 | indices = np.arange(total_num_samples) 55 | self.train_data = self.train_data[indices[0:self.dataset_size], ::] 56 | self.train_labels = self.train_labels[indices[0:self.dataset_size]] 57 | self.train_data *= 255.0 58 | self.train_data = np.squeeze(self.train_data).astype(np.uint8) 59 | 60 | def __getitem__(self, index): 61 | """Get images and target for data loader. 62 | Args: 63 | index (int): Index 64 | Returns: 65 | tuple: (image, target) where target is index of the target class. 66 | """ 67 | img, label = self.train_data[index], self.train_labels[index] 68 | img = Image.fromarray(img, mode='L') 69 | img = img.copy() 70 | if self.transform is not None: 71 | img = self.transform(img) 72 | return img, label.astype("int64") 73 | 74 | def __len__(self): 75 | """Return size of dataset.""" 76 | return len(self.train_data) 77 | 78 | def _check_exists(self): 79 | """Check if dataset is download and in right place.""" 80 | return os.path.exists(os.path.join(self.root, self.filename)) 81 | 82 | def download(self): 83 | """Download dataset.""" 84 | filename = os.path.join(self.root, self.filename) 85 | dirname = os.path.dirname(filename) 86 | if not os.path.isdir(dirname): 87 | os.makedirs(dirname) 88 | if os.path.isfile(filename): 89 | return 90 | print("Download %s to %s" % (self.url, os.path.abspath(filename))) 91 | urllib.request.urlretrieve(self.url, filename) 92 | print("[DONE]") 93 | return 94 | 95 | def load_samples(self): 96 | """Load sample images from dataset.""" 97 | filename = os.path.join(self.root, self.filename) 98 | f = gzip.open(filename, "rb") 99 | data_set = pickle.load(f, encoding="bytes") 100 | f.close() 101 | if self.train: 102 | images = data_set[0][0] 103 | labels = data_set[0][1] 104 | self.dataset_size = labels.shape[0] 105 | else: 106 | images = data_set[1][0] 107 | labels = data_set[1][1] 108 | self.dataset_size = labels.shape[0] 109 | return images, labels 110 | 111 | 112 | class USPS_idx(data.Dataset): 113 | """USPS Dataset. 114 | Args: 115 | root (string): Root directory of dataset where dataset file exist. 116 | train (bool, optional): If True, resample from dataset randomly. 117 | download (bool, optional): If true, downloads the dataset 118 | from the internet and puts it in root directory. 119 | If dataset is already downloaded, it is not downloaded again. 120 | transform (callable, optional): A function/transform that takes in 121 | an PIL image and returns a transformed version. 122 | E.g, ``transforms.RandomCrop`` 123 | """ 124 | 125 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl" 126 | 127 | def __init__(self, root, train=True, transform=None, download=False): 128 | """Init USPS dataset.""" 129 | # init params 130 | self.root = os.path.expanduser(root) 131 | self.filename = "usps_28x28.pkl" 132 | self.train = train 133 | # Num of Train = 7438, Num ot Test 1860 134 | self.transform = transform 135 | self.dataset_size = None 136 | 137 | # download dataset. 138 | if download: 139 | self.download() 140 | if not self._check_exists(): 141 | raise RuntimeError("Dataset not found." + 142 | " You can use download=True to download it") 143 | 144 | self.train_data, self.train_labels = self.load_samples() 145 | if self.train: 146 | total_num_samples = self.train_labels.shape[0] 147 | indices = np.arange(total_num_samples) 148 | self.train_data = self.train_data[indices[0:self.dataset_size], ::] 149 | self.train_labels = self.train_labels[indices[0:self.dataset_size]] 150 | self.train_data *= 255.0 151 | self.train_data = np.squeeze(self.train_data).astype(np.uint8) 152 | 153 | def __getitem__(self, index): 154 | """Get images and target for data loader. 155 | Args: 156 | index (int): Index 157 | Returns: 158 | tuple: (image, target) where target is index of the target class. 159 | """ 160 | img, label = self.train_data[index], self.train_labels[index] 161 | img = Image.fromarray(img, mode='L') 162 | img = img.copy() 163 | if self.transform is not None: 164 | img = self.transform(img) 165 | return img, label.astype("int64"), index 166 | 167 | def __len__(self): 168 | """Return size of dataset.""" 169 | return len(self.train_data) 170 | 171 | def _check_exists(self): 172 | """Check if dataset is download and in right place.""" 173 | return os.path.exists(os.path.join(self.root, self.filename)) 174 | 175 | def download(self): 176 | """Download dataset.""" 177 | filename = os.path.join(self.root, self.filename) 178 | dirname = os.path.dirname(filename) 179 | if not os.path.isdir(dirname): 180 | os.makedirs(dirname) 181 | if os.path.isfile(filename): 182 | return 183 | print("Download %s to %s" % (self.url, os.path.abspath(filename))) 184 | urllib.request.urlretrieve(self.url, filename) 185 | print("[DONE]") 186 | return 187 | 188 | def load_samples(self): 189 | """Load sample images from dataset.""" 190 | filename = os.path.join(self.root, self.filename) 191 | f = gzip.open(filename, "rb") 192 | data_set = pickle.load(f, encoding="bytes") 193 | f.close() 194 | if self.train: 195 | images = data_set[0][0] 196 | labels = data_set[0][1] 197 | self.dataset_size = labels.shape[0] 198 | else: 199 | images = data_set[1][0] 200 | labels = data_set[1][1] 201 | self.dataset_size = labels.shape[0] 202 | return images, labels 203 | 204 | 205 | class USPS_twice(USPS): 206 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl" 207 | 208 | def __init__(self, root, train=True, transform=None, download=False): 209 | """Init USPS dataset.""" 210 | # init params 211 | self.root = os.path.expanduser(root) 212 | self.filename = "usps_28x28.pkl" 213 | self.train = train 214 | # Num of Train = 7438, Num ot Test 1860 215 | self.transform = transform 216 | self.dataset_size = None 217 | 218 | # download dataset. 219 | if download: 220 | self.download() 221 | if not self._check_exists(): 222 | raise RuntimeError("Dataset not found." + 223 | " You can use download=True to download it") 224 | 225 | self.train_data, self.train_labels = self.load_samples() 226 | if self.train: 227 | total_num_samples = self.train_labels.shape[0] 228 | indices = np.arange(total_num_samples) 229 | self.train_data = self.train_data[indices[0:self.dataset_size], ::] 230 | self.train_labels = self.train_labels[indices[0:self.dataset_size]] 231 | self.train_data *= 255.0 232 | self.train_data = np.squeeze(self.train_data).astype(np.uint8) 233 | 234 | def __getitem__(self, index): 235 | """Get images and target for data loader. 236 | Args: 237 | index (int): Index 238 | Returns: 239 | tuple: (image, target) where target is index of the target class. 240 | """ 241 | img, label = self.train_data[index], self.train_labels[index] 242 | img = Image.fromarray(img, mode='L') 243 | img = img.copy() 244 | if self.transform is not None: 245 | img = [self.transform(img), self.transform(img)] 246 | return img, label.astype("int64"), index 247 | 248 | def __len__(self): 249 | """Return size of dataset.""" 250 | return len(self.train_data) 251 | -------------------------------------------------------------------------------- /code/msda/image_ms.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList, ImageList_idx 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from loss import CrossEntropyLabelSmooth 16 | from scipy.spatial.distance import cdist 17 | from sklearn.metrics import confusion_matrix 18 | 19 | def image_train(resize_size=256, crop_size=224, alexnet=False): 20 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 21 | std=[0.229, 0.224, 0.225]) 22 | return transforms.Compose([ 23 | transforms.Resize((resize_size, resize_size)), 24 | transforms.RandomCrop(crop_size), 25 | transforms.RandomHorizontalFlip(), 26 | transforms.ToTensor(), 27 | normalize 28 | ]) 29 | 30 | def image_test(resize_size=256, crop_size=224, alexnet=False): 31 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 32 | std=[0.229, 0.224, 0.225]) 33 | return transforms.Compose([ 34 | transforms.Resize((resize_size, resize_size)), 35 | transforms.CenterCrop(crop_size), 36 | transforms.ToTensor(), 37 | normalize 38 | ]) 39 | 40 | def data_load(args): 41 | ## prepare data 42 | dsets = {} 43 | dset_loaders = {} 44 | train_bs = args.batch_size 45 | txt_test = open(args.test_dset_path).readlines() 46 | 47 | dsets["test"] = ImageList_idx(txt_test, transform=image_test()) 48 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, 49 | shuffle=False, num_workers=args.worker, drop_last=False) 50 | 51 | return dset_loaders 52 | 53 | def cal_acc(loader, netF, netB, netC): 54 | start_test = True 55 | with torch.no_grad(): 56 | iter_test = iter(loader) 57 | for i in range(len(loader)): 58 | data = iter_test.next() 59 | inputs = data[0] 60 | labels = data[1] 61 | inputs = inputs.cuda() 62 | outputs = netC(netB(netF(inputs))) 63 | if start_test: 64 | all_output = outputs.float().cpu() 65 | all_label = labels.float() 66 | start_test = False 67 | else: 68 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 69 | all_label = torch.cat((all_label, labels.float()), 0) 70 | _, predict = torch.max(all_output, 1) 71 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 72 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item() 73 | 74 | return accuracy, all_label, nn.Softmax(dim=1)(all_output) 75 | 76 | def print_args(args): 77 | s = "==========================================\n" 78 | for arg, content in args.__dict__.items(): 79 | s += "{}:{}\n".format(arg, content) 80 | return s 81 | 82 | def test_target_srconly(args): 83 | dset_loaders = data_load(args) 84 | ## set base network 85 | if args.net[0:3] == 'res': 86 | netF = network.ResBase(res_name=args.net).cuda() 87 | elif args.net[0:3] == 'vgg': 88 | netF = network.VGGBase(vgg_name=args.net).cuda() 89 | 90 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 91 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 92 | 93 | args.modelpath = args.output_dir_src + '/source_F.pt' 94 | netF.load_state_dict(torch.load(args.modelpath)) 95 | args.modelpath = args.output_dir_src + '/source_B.pt' 96 | netB.load_state_dict(torch.load(args.modelpath)) 97 | args.modelpath = args.output_dir_src + '/source_C.pt' 98 | netC.load_state_dict(torch.load(args.modelpath)) 99 | netF.eval() 100 | netB.eval() 101 | netC.eval() 102 | 103 | acc, y, py = cal_acc(dset_loaders['test'], netF, netB, netC) 104 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc*100) 105 | args.out_file.write(log_str) 106 | args.out_file.flush() 107 | print(log_str) 108 | 109 | return y, py 110 | 111 | def test_target(args): 112 | dset_loaders = data_load(args) 113 | ## set base network 114 | if args.net[0:3] == 'res': 115 | netF = network.ResBase(res_name=args.net).cuda() 116 | elif args.net[0:3] == 'vgg': 117 | netF = network.VGGBase(vgg_name=args.net).cuda() 118 | 119 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 120 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 121 | 122 | args.modelpath = args.output_dir_ori + "/target_F_" + args.savename + ".pt" 123 | netF.load_state_dict(torch.load(args.modelpath)) 124 | args.modelpath = args.output_dir_ori + "/target_B_" + args.savename + ".pt" 125 | netB.load_state_dict(torch.load(args.modelpath)) 126 | args.modelpath = args.output_dir_ori + "/target_C_" + args.savename + ".pt" 127 | netC.load_state_dict(torch.load(args.modelpath)) 128 | netF.eval() 129 | netB.eval() 130 | netC.eval() 131 | 132 | acc, y, py = cal_acc(dset_loaders['test'], netF, netB, netC) 133 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc*100) 134 | args.out_file.write(log_str) 135 | args.out_file.flush() 136 | print(log_str) 137 | 138 | return y, py 139 | 140 | if __name__ == "__main__": 141 | parser = argparse.ArgumentParser(description='SHOT++') 142 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 143 | parser.add_argument('--s', type=int, default=0, help="source") 144 | parser.add_argument('--t', type=int, default=1, help="target") 145 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations") 146 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 147 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 148 | parser.add_argument('--dset', type=str, default='office-caltech', choices=['pacs', 'office-home', 'office-caltech']) 149 | parser.add_argument('--net', type=str, default='resnet101', help="alexnet, vgg16, resnet50") 150 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 151 | 152 | parser.add_argument('--gent', type=bool, default=True) 153 | parser.add_argument('--ent', type=bool, default=True) 154 | parser.add_argument('--threshold', type=int, default=0) 155 | parser.add_argument('--cls_par', type=float, default=0.3) 156 | 157 | parser.add_argument('--bottleneck', type=int, default=256) 158 | parser.add_argument('--epsilon', type=float, default=1e-5) 159 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 160 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 161 | parser.add_argument('--output', type=str, default='san') 162 | parser.add_argument('--output_src', type=str, default='san') 163 | parser.add_argument('--output_tar', type=str, default='san') 164 | parser.add_argument('--output_mm', type=str, default='san') 165 | parser.add_argument('--da', type=str, default='uda', choices=['uda']) 166 | parser.add_argument('--ssl', type=float, default=0.0) 167 | parser.add_argument('--savename', type=str, default='san') 168 | args = parser.parse_args() 169 | 170 | if args.dset == 'office-home': 171 | names = ['Art', 'Clipart', 'Product', 'RealWorld'] 172 | args.class_num = 65 173 | elif args.dset == 'office-caltech': 174 | names = ['amazon', 'caltech', 'dslr', 'webcam'] 175 | args.class_num = 10 176 | elif args.dset == 'pacs': 177 | names = ['art_painting', 'cartoon', 'photo', 'sketch'] 178 | args.class_num = 7 179 | 180 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 181 | SEED = args.seed 182 | torch.manual_seed(SEED) 183 | torch.cuda.manual_seed(SEED) 184 | np.random.seed(SEED) 185 | random.seed(SEED) 186 | # torch.backends.cudnn.deterministic = True 187 | 188 | args.output_dir = osp.join(args.output, args.da, args.dset, names[args.t][0].upper()) 189 | if not osp.exists(args.output_dir): 190 | os.system('mkdir -p ' + args.output_dir) 191 | if not osp.exists(args.output_dir): 192 | os.mkdir(args.output_dir) 193 | 194 | args.savename = 'par_' + str(args.cls_par) 195 | if args.ssl > 0: 196 | args.savename += ('_ssl_' + str(args.ssl)) 197 | 198 | args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w') 199 | args.out_file.write(print_args(args)+'\n') 200 | args.out_file.flush() 201 | 202 | score_srconly = 0 203 | score_srconly_mm = 0 204 | score = 0 205 | score_mm = 0 206 | 207 | for i in range(len(names)): 208 | if i == args.t: 209 | continue 210 | args.s = i 211 | 212 | folder = '../data/' 213 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 214 | args.name = names[args.s][0].upper()+names[args.t][0].upper() 215 | 216 | if args.dset == 'pacs': 217 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_test_kfold.txt' 218 | 219 | args.output_dir_src = osp.join(args.output_src, args.da, args.dset, names[args.s][0].upper()) 220 | label, output_srconly = test_target_srconly(args) 221 | score_srconly += output_srconly 222 | 223 | args.output_dir_ori = osp.join(args.output_tar, args.da, args.dset, names[args.s][0].upper()+names[args.t][0].upper()) 224 | _, output = test_target(args) 225 | score += output 226 | 227 | args.output_dir = osp.join(args.output_mm, args.da, args.dset, names[args.s][0].upper()+names[args.t][0].upper()) 228 | mm = np.load(args.output_dir + '/ps_0.0_' + args.savename + '.npz') 229 | score_mm += torch.from_numpy(mm['score']) 230 | 231 | args.output_dir = osp.join(args.output_mm, args.da, args.dset, names[args.s][0].upper()+names[args.t][0].upper()) 232 | mm = np.load(args.output_dir + '/ps_0.0_srconly.npz') 233 | score_srconly_mm += torch.from_numpy(mm['score']) 234 | 235 | _, predict = torch.max(score_srconly, 1) 236 | acc = torch.sum(torch.squeeze(predict).float() == label).item() / float(label.size()[0]) 237 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(names[args.t][0].upper(), acc*100) 238 | args.out_file.write(log_str) 239 | args.out_file.flush() 240 | print(log_str) 241 | 242 | _, predict = torch.max(score_srconly_mm, 1) 243 | acc = torch.sum(torch.squeeze(predict).float() == label).item() / float(label.size()[0]) 244 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(names[args.t][0].upper(), acc*100) 245 | args.out_file.write(log_str) 246 | args.out_file.flush() 247 | print(log_str) 248 | 249 | _, predict = torch.max(score, 1) 250 | acc = torch.sum(torch.squeeze(predict).float() == label).item() / float(label.size()[0]) 251 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(names[args.t][0].upper(), acc*100) 252 | args.out_file.write(log_str) 253 | args.out_file.flush() 254 | print(log_str) 255 | 256 | _, predict = torch.max(score_mm, 1) 257 | acc = torch.sum(torch.squeeze(predict).float() == label).item() / float(label.size()[0]) 258 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(names[args.t][0].upper(), acc*100) 259 | args.out_file.write(log_str) 260 | args.out_file.flush() 261 | print(log_str) 262 | 263 | args.out_file.close() -------------------------------------------------------------------------------- /code/ssda/image_source.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network 11 | import loss 12 | from torch.utils.data import DataLoader 13 | from data_list import ImageList 14 | import random, pdb, math, copy 15 | from tqdm import tqdm 16 | from loss import CrossEntropyLabelSmooth 17 | from scipy.spatial.distance import cdist 18 | from sklearn.metrics import confusion_matrix 19 | from sklearn.cluster import KMeans 20 | 21 | def op_copy(optimizer): 22 | for param_group in optimizer.param_groups: 23 | param_group['lr0'] = param_group['lr'] 24 | return optimizer 25 | 26 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 27 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 28 | for param_group in optimizer.param_groups: 29 | param_group['lr'] = param_group['lr0'] * decay 30 | param_group['weight_decay'] = 1e-3 31 | param_group['momentum'] = 0.9 32 | param_group['nesterov'] = True 33 | return optimizer 34 | 35 | def image_train(resize_size=256, crop_size=224, alexnet=False): 36 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 37 | std=[0.229, 0.224, 0.225]) 38 | return transforms.Compose([ 39 | transforms.Resize((resize_size, resize_size)), 40 | transforms.RandomCrop(crop_size), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | normalize 44 | ]) 45 | 46 | def image_test(resize_size=256, crop_size=224, alexnet=False): 47 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 48 | std=[0.229, 0.224, 0.225]) 49 | return transforms.Compose([ 50 | transforms.Resize((resize_size, resize_size)), 51 | transforms.CenterCrop(crop_size), 52 | transforms.ToTensor(), 53 | normalize 54 | ]) 55 | 56 | def data_load(args): 57 | ## prepare data 58 | dsets = {} 59 | dset_loaders = {} 60 | train_bs = args.batch_size 61 | txt_src = open(args.s_dset_path).readlines() 62 | txt_test = open(args.test_dset_path).readlines() 63 | 64 | dsize = len(txt_src) 65 | tr_size = int(0.9*dsize) 66 | tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size]) 67 | 68 | dsets["source_tr"] = ImageList(tr_txt, transform=image_train()) 69 | dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 70 | dsets["source_te"] = ImageList(te_txt, transform=image_test()) 71 | dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 72 | dsets["test"] = ImageList(txt_test, transform=image_test()) 73 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=True, num_workers=args.worker, drop_last=False) 74 | 75 | return dset_loaders 76 | 77 | def cal_acc(loader, netF, netB, netC, flag=False): 78 | start_test = True 79 | with torch.no_grad(): 80 | iter_test = iter(loader) 81 | for i in range(len(loader)): 82 | data = iter_test.next() 83 | inputs = data[0] 84 | labels = data[1] 85 | inputs = inputs.cuda() 86 | outputs = netC(netB(netF(inputs))) 87 | if start_test: 88 | all_output = outputs.float().cpu() 89 | all_label = labels.float() 90 | start_test = False 91 | else: 92 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 93 | all_label = torch.cat((all_label, labels.float()), 0) 94 | 95 | all_output = nn.Softmax(dim=1)(all_output) 96 | _, predict = torch.max(all_output, 1) 97 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 98 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item() 99 | 100 | return accuracy*100, mean_ent 101 | 102 | def train_source(args): 103 | dset_loaders = data_load(args) 104 | ## set base network 105 | if args.net[0:3] == 'res': 106 | netF = network.ResBase(res_name=args.net).cuda() 107 | elif args.net[0:3] == 'vgg': 108 | netF = network.VGGBase(vgg_name=args.net).cuda() 109 | 110 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 111 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 112 | 113 | param_group = [] 114 | learning_rate = args.lr 115 | for k, v in netF.named_parameters(): 116 | param_group += [{'params': v, 'lr': learning_rate*0.1}] 117 | for k, v in netB.named_parameters(): 118 | param_group += [{'params': v, 'lr': learning_rate}] 119 | for k, v in netC.named_parameters(): 120 | param_group += [{'params': v, 'lr': learning_rate}] 121 | optimizer = optim.SGD(param_group) 122 | optimizer = op_copy(optimizer) 123 | 124 | acc_init = 0 125 | max_iter = args.max_epoch * len(dset_loaders["source_tr"]) 126 | interval_iter = max_iter // 10 127 | iter_num = 0 128 | 129 | netF.train() 130 | netB.train() 131 | netC.train() 132 | 133 | while iter_num < max_iter: 134 | try: 135 | inputs_source, labels_source = iter_source.next() 136 | except: 137 | iter_source = iter(dset_loaders["source_tr"]) 138 | inputs_source, labels_source = iter_source.next() 139 | 140 | if inputs_source.size(0) == 1: 141 | continue 142 | 143 | iter_num += 1 144 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 145 | 146 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda() 147 | outputs_source = netC(netB(netF(inputs_source))) 148 | classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source) 149 | 150 | optimizer.zero_grad() 151 | classifier_loss.backward() 152 | optimizer.step() 153 | 154 | if iter_num % interval_iter == 0 or iter_num == max_iter: 155 | netF.eval() 156 | netB.eval() 157 | netC.eval() 158 | 159 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC, False) 160 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te) 161 | 162 | args.out_file.write(log_str + '\n') 163 | args.out_file.flush() 164 | print(log_str+'\n') 165 | 166 | if acc_s_te >= acc_init: 167 | acc_init = acc_s_te 168 | best_netF = netF.state_dict() 169 | best_netB = netB.state_dict() 170 | best_netC = netC.state_dict() 171 | 172 | netF.train() 173 | netB.train() 174 | netC.train() 175 | 176 | torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt")) 177 | torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt")) 178 | torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt")) 179 | 180 | return netF, netB, netC 181 | 182 | def test_target(args): 183 | dset_loaders = data_load(args) 184 | ## set base network 185 | if args.net[0:3] == 'res': 186 | netF = network.ResBase(res_name=args.net).cuda() 187 | elif args.net[0:3] == 'vgg': 188 | netF = network.VGGBase(vgg_name=args.net).cuda() 189 | 190 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 191 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 192 | 193 | args.modelpath = args.output_dir_src + '/source_F.pt' 194 | netF.load_state_dict(torch.load(args.modelpath)) 195 | args.modelpath = args.output_dir_src + '/source_B.pt' 196 | netB.load_state_dict(torch.load(args.modelpath)) 197 | args.modelpath = args.output_dir_src + '/source_C.pt' 198 | netC.load_state_dict(torch.load(args.modelpath)) 199 | netF.eval() 200 | netB.eval() 201 | netC.eval() 202 | 203 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) 204 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc) 205 | 206 | args.out_file.write(log_str) 207 | args.out_file.flush() 208 | print(log_str) 209 | 210 | def print_args(args): 211 | s = "==========================================\n" 212 | for arg, content in args.__dict__.items(): 213 | s += "{}:{}\n".format(arg, content) 214 | return s 215 | 216 | if __name__ == "__main__": 217 | parser = argparse.ArgumentParser(description='SHOT++') 218 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 219 | parser.add_argument('--s', type=int, default=0, help="source") 220 | parser.add_argument('--t', type=int, default=1, help="target") 221 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations") 222 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 223 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 224 | parser.add_argument('--dset', type=str, default='office-home', choices=['office-home']) 225 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 226 | parser.add_argument('--net', type=str, default='vgg16', help="resnet34, vgg16, resnet50, resnet101") 227 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 228 | parser.add_argument('--bottleneck', type=int, default=256) 229 | parser.add_argument('--epsilon', type=float, default=1e-5) 230 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 231 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 232 | parser.add_argument('--smooth', type=float, default=0.1) 233 | parser.add_argument('--output', type=str, default='san') 234 | parser.add_argument('--da', type=str, default='ssda') 235 | parser.add_argument('--shot', type=int, default=1, choices=[1, 3]) 236 | 237 | args = parser.parse_args() 238 | 239 | if args.dset == 'office-home': 240 | names = ['Art', 'Clipart', 'Product', 'Real'] 241 | args.class_num = 65 242 | 243 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 244 | SEED = args.seed 245 | torch.manual_seed(SEED) 246 | torch.cuda.manual_seed(SEED) 247 | np.random.seed(SEED) 248 | random.seed(SEED) 249 | # torch.backends.cudnn.deterministic = True 250 | 251 | folder = '../data/ssda/' 252 | args.s_dset_path = folder + args.dset + '/labeled_source_images_' + names[args.s] + '.txt' 253 | args.test_dset_path = folder + args.dset + '/unlabeled_target_images_' + names[args.t] + '_' + str(args.shot) + '.txt' 254 | 255 | args.output_dir_src = osp.join(args.output, args.da, args.dset, names[args.s][0].upper()) 256 | args.name_src = names[args.s][0].upper() 257 | if not osp.exists(args.output_dir_src): 258 | os.system('mkdir -p ' + args.output_dir_src) 259 | if not osp.exists(args.output_dir_src): 260 | os.mkdir(args.output_dir_src) 261 | 262 | args.out_file = open(osp.join(args.output_dir_src, 'log.txt'), 'w') 263 | args.out_file.write(print_args(args)+'\n') 264 | args.out_file.flush() 265 | train_source(args) 266 | 267 | args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w') 268 | for i in range(len(names)): 269 | if i == args.s: 270 | continue 271 | args.t = i 272 | args.name = names[args.s][0].upper() + names[args.t][0].upper() 273 | 274 | args.s_dset_path = folder + args.dset + '/labeled_source_images_' + names[args.s] + '.txt' 275 | args.test_dset_path = folder + args.dset + '/unlabeled_target_images_' + names[args.t] + '_' + str(args.shot) + '.txt' 276 | test_target(args) 277 | args.out_file.close() -------------------------------------------------------------------------------- /code/data/pacs/photo_crossval_kfold.txt: -------------------------------------------------------------------------------- 1 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0001.jpg 1 2 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0002.jpg 1 3 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0003.jpg 1 4 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0004.jpg 1 5 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0005.jpg 1 6 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0006.jpg 1 7 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0007.jpg 1 8 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0009.jpg 1 9 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0010.jpg 1 10 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0011.jpg 1 11 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0012.jpg 1 12 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0013.jpg 1 13 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0014.jpg 1 14 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0015.jpg 1 15 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0016.jpg 1 16 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0017.jpg 1 17 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0018.jpg 1 18 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0020.jpg 1 19 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0021.jpg 1 20 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0001.jpg 2 21 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0002.jpg 2 22 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0003.jpg 2 23 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0004.jpg 2 24 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0005.jpg 2 25 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0006.jpg 2 26 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0007.jpg 2 27 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0008.jpg 2 28 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0009.jpg 2 29 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0010.jpg 2 30 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0011.jpg 2 31 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0012.jpg 2 32 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0013.jpg 2 33 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0014.jpg 2 34 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0015.jpg 2 35 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0016.jpg 2 36 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0017.jpg 2 37 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0018.jpg 2 38 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0019.jpg 2 39 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0020.jpg 2 40 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0021.jpg 2 41 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0001.jpg 3 42 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0002.jpg 3 43 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0003.jpg 3 44 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0004.jpg 3 45 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0005.jpg 3 46 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0006.jpg 3 47 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0007.jpg 3 48 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0008.jpg 3 49 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0009.jpg 3 50 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0010.jpg 3 51 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0011.jpg 3 52 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0012.jpg 3 53 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0013.jpg 3 54 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0014.jpg 3 55 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0015.jpg 3 56 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0016.jpg 3 57 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0017.jpg 3 58 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0018.jpg 3 59 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0019.jpg 3 60 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0001.jpg 4 61 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0002.jpg 4 62 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0003.jpg 4 63 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0004.jpg 4 64 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0005.jpg 4 65 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0006.jpg 4 66 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0007.jpg 4 67 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0008.jpg 4 68 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0009.jpg 4 69 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0010.jpg 4 70 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0012.jpg 4 71 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0013.jpg 4 72 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0016.jpg 4 73 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0018.jpg 4 74 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0019.jpg 4 75 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0020.jpg 4 76 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0021.jpg 4 77 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0022.jpg 4 78 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0023.jpg 4 79 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0002.jpg 5 80 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0003.jpg 5 81 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0007.jpg 5 82 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0008.jpg 5 83 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0009.jpg 5 84 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0010.jpg 5 85 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0012.jpg 5 86 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0013.jpg 5 87 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0022.jpg 5 88 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0025.jpg 5 89 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0028.jpg 5 90 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0029.jpg 5 91 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0030.jpg 5 92 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0033.jpg 5 93 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0037.jpg 5 94 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0038.jpg 5 95 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0041.jpg 5 96 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0042.jpg 5 97 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0047.jpg 5 98 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0048.jpg 5 99 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_010.jpg 6 100 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_011.jpg 6 101 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_012.jpg 6 102 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_013.jpg 6 103 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_014.jpg 6 104 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_015.jpg 6 105 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_016.jpg 6 106 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_017.jpg 6 107 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_018.jpg 6 108 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_021.jpg 6 109 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_019.jpg 6 110 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_022.jpg 6 111 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_020.jpg 6 112 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_023.jpg 6 113 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_024.jpg 6 114 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_026.jpg 6 115 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_025.jpg 6 116 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_027.jpg 6 117 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_028.jpg 6 118 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_029.jpg 6 119 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_031.jpg 6 120 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_239.jpg 6 121 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_240.jpg 6 122 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_241.jpg 6 123 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_242.jpg 6 124 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_248.jpg 6 125 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_246.jpg 6 126 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_247.jpg 6 127 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_244.jpg 6 128 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0001.jpg 0 129 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0002.jpg 0 130 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0003.jpg 0 131 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0004.jpg 0 132 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0005.jpg 0 133 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0006.jpg 0 134 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0007.jpg 0 135 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0008.jpg 0 136 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0009.jpg 0 137 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0010.jpg 0 138 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0011.jpg 0 139 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0012.jpg 0 140 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0013.jpg 0 141 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0014.jpg 0 142 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0015.jpg 0 143 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0016.jpg 0 144 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0017.jpg 0 145 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0018.jpg 0 146 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0019.jpg 0 147 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0020.jpg 0 148 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0021.jpg 0 149 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0022.jpg 0 150 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0023.jpg 0 151 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0024.jpg 0 152 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0025.jpg 0 153 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0026.jpg 0 154 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0027.jpg 0 155 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0028.jpg 0 156 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0029.jpg 0 157 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0030.jpg 0 158 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0031.jpg 0 159 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0032.jpg 0 160 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0033.jpg 0 161 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0034.jpg 0 162 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0035.jpg 0 163 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0036.jpg 0 164 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0037.jpg 0 165 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0038.jpg 0 166 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0039.jpg 0 167 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0040.jpg 0 168 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0041.jpg 0 169 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0042.jpg 0 170 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0043.jpg 0 171 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0044.jpg 0 172 | -------------------------------------------------------------------------------- /code/msda/image_source.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network 11 | import loss 12 | from torch.utils.data import DataLoader 13 | from data_list import ImageList 14 | import random, pdb, math, copy 15 | from tqdm import tqdm 16 | from loss import CrossEntropyLabelSmooth 17 | from scipy.spatial.distance import cdist 18 | from sklearn.metrics import confusion_matrix 19 | from sklearn.cluster import KMeans 20 | 21 | def op_copy(optimizer): 22 | for param_group in optimizer.param_groups: 23 | param_group['lr0'] = param_group['lr'] 24 | return optimizer 25 | 26 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 27 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 28 | for param_group in optimizer.param_groups: 29 | param_group['lr'] = param_group['lr0'] * decay 30 | param_group['weight_decay'] = 1e-3 31 | param_group['momentum'] = 0.9 32 | param_group['nesterov'] = True 33 | return optimizer 34 | 35 | def image_train(resize_size=256, crop_size=224, alexnet=False): 36 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 37 | std=[0.229, 0.224, 0.225]) 38 | return transforms.Compose([ 39 | transforms.Resize((resize_size, resize_size)), 40 | transforms.RandomCrop(crop_size), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | normalize 44 | ]) 45 | 46 | def image_test(resize_size=256, crop_size=224, alexnet=False): 47 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 48 | std=[0.229, 0.224, 0.225]) 49 | return transforms.Compose([ 50 | transforms.Resize((resize_size, resize_size)), 51 | transforms.CenterCrop(crop_size), 52 | transforms.ToTensor(), 53 | normalize 54 | ]) 55 | 56 | def data_load(args): 57 | ## prepare data 58 | dsets = {} 59 | dset_loaders = {} 60 | train_bs = args.batch_size 61 | txt_src = open(args.s_dset_path).readlines() 62 | txt_test = open(args.test_dset_path).readlines() 63 | 64 | dsize = len(txt_src) 65 | tr_size = int(0.9*dsize) 66 | tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size]) 67 | 68 | dsets["source_tr"] = ImageList(tr_txt, transform=image_train()) 69 | dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, 70 | num_workers=args.worker, drop_last=False) 71 | dsets["source_te"] = ImageList(te_txt, transform=image_test()) 72 | dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs*3, shuffle=False, 73 | num_workers=args.worker, drop_last=False) 74 | dsets["test"] = ImageList(txt_test, transform=image_test()) 75 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, 76 | num_workers=args.worker, drop_last=False) 77 | 78 | return dset_loaders 79 | 80 | def cal_acc(loader, netF, netB, netC): 81 | start_test = True 82 | with torch.no_grad(): 83 | iter_test = iter(loader) 84 | for i in range(len(loader)): 85 | data = iter_test.next() 86 | inputs = data[0] 87 | labels = data[1] 88 | inputs = inputs.cuda() 89 | outputs = netC(netB(netF(inputs))) 90 | if start_test: 91 | all_output = outputs.float().cpu() 92 | all_label = labels.float() 93 | start_test = False 94 | else: 95 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 96 | all_label = torch.cat((all_label, labels.float()), 0) 97 | 98 | all_output = nn.Softmax(dim=1)(all_output) 99 | _, predict = torch.max(all_output, 1) 100 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 101 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item() 102 | 103 | return accuracy*100, mean_ent 104 | 105 | def train_source(args): 106 | dset_loaders = data_load(args) 107 | ## set base network 108 | if args.net[0:3] == 'res': 109 | netF = network.ResBase(res_name=args.net).cuda() 110 | elif args.net[0:3] == 'vgg': 111 | netF = network.VGGBase(vgg_name=args.net).cuda() 112 | 113 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 114 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 115 | 116 | param_group = [] 117 | learning_rate = args.lr 118 | for k, v in netF.named_parameters(): 119 | param_group += [{'params': v, 'lr': learning_rate*0.1}] 120 | for k, v in netB.named_parameters(): 121 | param_group += [{'params': v, 'lr': learning_rate}] 122 | for k, v in netC.named_parameters(): 123 | param_group += [{'params': v, 'lr': learning_rate}] 124 | optimizer = optim.SGD(param_group) 125 | optimizer = op_copy(optimizer) 126 | 127 | acc_init = 0 128 | max_iter = args.max_epoch * len(dset_loaders["source_tr"]) 129 | interval_iter = max_iter // 10 130 | iter_num = 0 131 | 132 | netF.train() 133 | netB.train() 134 | netC.train() 135 | 136 | while iter_num < max_iter: 137 | try: 138 | inputs_source, labels_source = iter_source.next() 139 | except: 140 | iter_source = iter(dset_loaders["source_tr"]) 141 | inputs_source, labels_source = iter_source.next() 142 | 143 | if inputs_source.size(0) == 1: 144 | continue 145 | 146 | iter_num += 1 147 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 148 | 149 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda() 150 | outputs_source = netC(netB(netF(inputs_source))) 151 | classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source) 152 | 153 | optimizer.zero_grad() 154 | classifier_loss.backward() 155 | optimizer.step() 156 | 157 | if iter_num % interval_iter == 0 or iter_num == max_iter: 158 | netF.eval() 159 | netB.eval() 160 | netC.eval() 161 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC) 162 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te) 163 | args.out_file.write(log_str + '\n') 164 | args.out_file.flush() 165 | print(log_str+'\n') 166 | 167 | if acc_s_te >= acc_init: 168 | acc_init = acc_s_te 169 | best_netF = netF.state_dict() 170 | best_netB = netB.state_dict() 171 | best_netC = netC.state_dict() 172 | 173 | netF.train() 174 | netB.train() 175 | netC.train() 176 | 177 | torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt")) 178 | torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt")) 179 | torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt")) 180 | 181 | return netF, netB, netC 182 | 183 | def test_target(args): 184 | dset_loaders = data_load(args) 185 | ## set base network 186 | if args.net[0:3] == 'res': 187 | netF = network.ResBase(res_name=args.net).cuda() 188 | elif args.net[0:3] == 'vgg': 189 | netF = network.VGGBase(vgg_name=args.net).cuda() 190 | 191 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 192 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 193 | 194 | args.modelpath = args.output_dir_src + '/source_F.pt' 195 | netF.load_state_dict(torch.load(args.modelpath)) 196 | args.modelpath = args.output_dir_src + '/source_B.pt' 197 | netB.load_state_dict(torch.load(args.modelpath)) 198 | args.modelpath = args.output_dir_src + '/source_C.pt' 199 | netC.load_state_dict(torch.load(args.modelpath)) 200 | netF.eval() 201 | netB.eval() 202 | netC.eval() 203 | 204 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC) 205 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc) 206 | 207 | args.out_file.write(log_str) 208 | args.out_file.flush() 209 | print(log_str) 210 | 211 | def print_args(args): 212 | s = "==========================================\n" 213 | for arg, content in args.__dict__.items(): 214 | s += "{}:{}\n".format(arg, content) 215 | return s 216 | 217 | if __name__ == "__main__": 218 | parser = argparse.ArgumentParser(description='SHOT') 219 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 220 | parser.add_argument('--s', type=int, default=0, help="source") 221 | parser.add_argument('--t', type=int, default=1, help="target") 222 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations") 223 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 224 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 225 | parser.add_argument('--dset', type=str, default='office-home', choices=['pacs', 'office-home', 'office-caltech']) 226 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 227 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet18, resnet34, resnet50, resnet101") 228 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 229 | parser.add_argument('--bottleneck', type=int, default=256) 230 | parser.add_argument('--epsilon', type=float, default=1e-5) 231 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 232 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 233 | parser.add_argument('--smooth', type=float, default=0.1) 234 | parser.add_argument('--output', type=str, default='san') 235 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda']) 236 | args = parser.parse_args() 237 | 238 | if args.dset == 'office-home': 239 | names = ['Art', 'Clipart', 'Product', 'RealWorld'] 240 | args.class_num = 65 241 | elif args.dset == 'office-caltech': 242 | names = ['amazon', 'caltech', 'dslr', 'webcam'] 243 | args.class_num = 10 244 | elif args.dset == 'pacs': 245 | names = ['art_painting', 'cartoon', 'photo', 'sketch'] 246 | args.class_num = 7 247 | 248 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 249 | SEED = args.seed 250 | torch.manual_seed(SEED) 251 | torch.cuda.manual_seed(SEED) 252 | np.random.seed(SEED) 253 | random.seed(SEED) 254 | # torch.backends.cudnn.deterministic = True 255 | 256 | folder = '/Checkpoint/liangjian/tran/data/' 257 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 258 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 259 | 260 | if args.dset == 'pacs': 261 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_test_kfold.txt' 262 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_test_kfold.txt' 263 | 264 | args.output_dir_src = osp.join(args.output, args.da, args.dset, names[args.s][0].upper()) 265 | args.name_src = names[args.s][0].upper() 266 | if not osp.exists(args.output_dir_src): 267 | os.system('mkdir -p ' + args.output_dir_src) 268 | if not osp.exists(args.output_dir_src): 269 | os.mkdir(args.output_dir_src) 270 | 271 | args.out_file = open(osp.join(args.output_dir_src, 'log.txt'), 'w') 272 | args.out_file.write(print_args(args)+'\n') 273 | args.out_file.flush() 274 | train_source(args) 275 | 276 | args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w') 277 | for i in range(len(names)): 278 | if i == args.s: 279 | continue 280 | args.t = i 281 | args.name = names[args.s][0].upper() + names[args.t][0].upper() 282 | 283 | folder = '../data/' 284 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 285 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 286 | 287 | if args.dset == 'pacs': 288 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_test_kfold.txt' 289 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_test_kfold.txt' 290 | 291 | test_target(args) 292 | 293 | args.out_file.close() -------------------------------------------------------------------------------- /code/digit/data_load/svhn.py: -------------------------------------------------------------------------------- 1 | from .vision import VisionDataset 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | from .utils import download_url, check_integrity, verify_str_arg 7 | 8 | 9 | class SVHN(VisionDataset): 10 | """`SVHN `_ Dataset. 11 | Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset, 12 | we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which 13 | expect the class labels to be in the range `[0, C-1]` 14 | 15 | .. warning:: 16 | 17 | This class needs `scipy `_ to load data from `.mat` format. 18 | 19 | Args: 20 | root (string): Root directory of dataset where directory 21 | ``SVHN`` exists. 22 | split (string): One of {'train', 'test', 'extra'}. 23 | Accordingly dataset is selected. 'extra' is Extra training set. 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | download (bool, optional): If true, downloads the dataset from the internet and 29 | puts it in root directory. If dataset is already downloaded, it is not 30 | downloaded again. 31 | 32 | """ 33 | 34 | split_list = { 35 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 36 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 37 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", 38 | "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], 39 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 40 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]} 41 | 42 | def __init__(self, root, split='train', transform=None, target_transform=None, 43 | download=False): 44 | super(SVHN, self).__init__(root, transform=transform, 45 | target_transform=target_transform) 46 | self.split = verify_str_arg(split, "split", tuple(self.split_list.keys())) 47 | self.url = self.split_list[split][0] 48 | self.filename = self.split_list[split][1] 49 | self.file_md5 = self.split_list[split][2] 50 | 51 | if download: 52 | self.download() 53 | 54 | if not self._check_integrity(): 55 | raise RuntimeError('Dataset not found or corrupted.' + 56 | ' You can use download=True to download it') 57 | 58 | # import here rather than at top of file because this is 59 | # an optional dependency for torchvision 60 | import scipy.io as sio 61 | 62 | # reading(loading) mat file as array 63 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename)) 64 | 65 | self.data = loaded_mat['X'] 66 | # loading from the .mat file gives an np array of type np.uint8 67 | # converting to np.int64, so that we have a LongTensor after 68 | # the conversion from the numpy array 69 | # the squeeze is needed to obtain a 1D tensor 70 | self.labels = loaded_mat['y'].astype(np.int64).squeeze() 71 | 72 | # the svhn dataset assigns the class label "10" to the digit 0 73 | # this makes it inconsistent with several loss functions 74 | # which expect the class labels to be in the range [0, C-1] 75 | np.place(self.labels, self.labels == 10, 0) 76 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 77 | 78 | def __getitem__(self, index): 79 | """ 80 | Args: 81 | index (int): Index 82 | 83 | Returns: 84 | tuple: (image, target) where target is index of the target class. 85 | """ 86 | img, target = self.data[index], int(self.labels[index]) 87 | 88 | # doing this so that it is consistent with all other datasets 89 | # to return a PIL Image 90 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 91 | 92 | if self.transform is not None: 93 | img = self.transform(img) 94 | 95 | if self.target_transform is not None: 96 | target = self.target_transform(target) 97 | 98 | return img, target 99 | 100 | def __len__(self): 101 | return len(self.data) 102 | 103 | def _check_integrity(self): 104 | root = self.root 105 | md5 = self.split_list[self.split][2] 106 | fpath = os.path.join(root, self.filename) 107 | return check_integrity(fpath, md5) 108 | 109 | def download(self): 110 | md5 = self.split_list[self.split][2] 111 | download_url(self.url, self.root, self.filename, md5) 112 | 113 | def extra_repr(self): 114 | return "Split: {split}".format(**self.__dict__) 115 | 116 | class SVHN_idx(VisionDataset): 117 | """`SVHN `_ Dataset. 118 | Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset, 119 | we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which 120 | expect the class labels to be in the range `[0, C-1]` 121 | 122 | .. warning:: 123 | 124 | This class needs `scipy `_ to load data from `.mat` format. 125 | 126 | Args: 127 | root (string): Root directory of dataset where directory 128 | ``SVHN`` exists. 129 | split (string): One of {'train', 'test', 'extra'}. 130 | Accordingly dataset is selected. 'extra' is Extra training set. 131 | transform (callable, optional): A function/transform that takes in an PIL image 132 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 133 | target_transform (callable, optional): A function/transform that takes in the 134 | target and transforms it. 135 | download (bool, optional): If true, downloads the dataset from the internet and 136 | puts it in root directory. If dataset is already downloaded, it is not 137 | downloaded again. 138 | 139 | """ 140 | 141 | split_list = { 142 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 143 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 144 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", 145 | "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], 146 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 147 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]} 148 | 149 | def __init__(self, root, split='train', transform=None, target_transform=None, 150 | download=False): 151 | super(SVHN_idx, self).__init__(root, transform=transform, 152 | target_transform=target_transform) 153 | self.split = verify_str_arg(split, "split", tuple(self.split_list.keys())) 154 | self.url = self.split_list[split][0] 155 | self.filename = self.split_list[split][1] 156 | self.file_md5 = self.split_list[split][2] 157 | 158 | if download: 159 | self.download() 160 | 161 | if not self._check_integrity(): 162 | raise RuntimeError('Dataset not found or corrupted.' + 163 | ' You can use download=True to download it') 164 | 165 | # import here rather than at top of file because this is 166 | # an optional dependency for torchvision 167 | import scipy.io as sio 168 | 169 | # reading(loading) mat file as array 170 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename)) 171 | 172 | self.data = loaded_mat['X'] 173 | # loading from the .mat file gives an np array of type np.uint8 174 | # converting to np.int64, so that we have a LongTensor after 175 | # the conversion from the numpy array 176 | # the squeeze is needed to obtain a 1D tensor 177 | self.labels = loaded_mat['y'].astype(np.int64).squeeze() 178 | 179 | # the svhn dataset assigns the class label "10" to the digit 0 180 | # this makes it inconsistent with several loss functions 181 | # which expect the class labels to be in the range [0, C-1] 182 | np.place(self.labels, self.labels == 10, 0) 183 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 184 | 185 | def __getitem__(self, index): 186 | """ 187 | Args: 188 | index (int): Index 189 | 190 | Returns: 191 | tuple: (image, target) where target is index of the target class. 192 | """ 193 | img, target = self.data[index], int(self.labels[index]) 194 | 195 | # doing this so that it is consistent with all other datasets 196 | # to return a PIL Image 197 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 198 | 199 | if self.transform is not None: 200 | img = self.transform(img) 201 | 202 | if self.target_transform is not None: 203 | target = self.target_transform(target) 204 | 205 | return img, target, index 206 | 207 | def __len__(self): 208 | return len(self.data) 209 | 210 | def _check_integrity(self): 211 | root = self.root 212 | md5 = self.split_list[self.split][2] 213 | fpath = os.path.join(root, self.filename) 214 | return check_integrity(fpath, md5) 215 | 216 | def download(self): 217 | md5 = self.split_list[self.split][2] 218 | download_url(self.url, self.root, self.filename, md5) 219 | 220 | def extra_repr(self): 221 | return "Split: {split}".format(**self.__dict__) 222 | 223 | 224 | class SVHN_twice(SVHN): 225 | split_list = { 226 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 227 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 228 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", 229 | "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], 230 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 231 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]} 232 | 233 | def __init__(self, root, split='train', transform=None, target_transform=None, 234 | download=False): 235 | super(SVHN_idx, self).__init__(root, transform=transform, 236 | target_transform=target_transform) 237 | self.split = verify_str_arg(split, "split", tuple(self.split_list.keys())) 238 | self.url = self.split_list[split][0] 239 | self.filename = self.split_list[split][1] 240 | self.file_md5 = self.split_list[split][2] 241 | 242 | if download: 243 | self.download() 244 | 245 | if not self._check_integrity(): 246 | raise RuntimeError('Dataset not found or corrupted.' + 247 | ' You can use download=True to download it') 248 | 249 | # import here rather than at top of file because this is 250 | # an optional dependency for torchvision 251 | import scipy.io as sio 252 | 253 | # reading(loading) mat file as array 254 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename)) 255 | 256 | self.data = loaded_mat['X'] 257 | # loading from the .mat file gives an np array of type np.uint8 258 | # converting to np.int64, so that we have a LongTensor after 259 | # the conversion from the numpy array 260 | # the squeeze is needed to obtain a 1D tensor 261 | self.labels = loaded_mat['y'].astype(np.int64).squeeze() 262 | 263 | # the svhn dataset assigns the class label "10" to the digit 0 264 | # this makes it inconsistent with several loss functions 265 | # which expect the class labels to be in the range [0, C-1] 266 | np.place(self.labels, self.labels == 10, 0) 267 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 268 | 269 | def __getitem__(self, index): 270 | """ 271 | Args: 272 | index (int): Index 273 | 274 | Returns: 275 | tuple: (image, target) where target is index of the target class. 276 | """ 277 | img, target = self.data[index], int(self.labels[index]) 278 | 279 | # doing this so that it is consistent with all other datasets 280 | # to return a PIL Image 281 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 282 | 283 | if self.transform is not None: 284 | img = [self.transform(img), self.transform(img)] 285 | 286 | if self.target_transform is not None: 287 | target = self.target_transform(target) 288 | 289 | return img, target, index 290 | 291 | def __len__(self): 292 | return len(self.data) -------------------------------------------------------------------------------- /code/data/office-caltech/dslr_list.txt: -------------------------------------------------------------------------------- 1 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0001.jpg 0 2 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0002.jpg 0 3 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0003.jpg 0 4 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0004.jpg 0 5 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0005.jpg 0 6 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0006.jpg 0 7 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0007.jpg 0 8 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0008.jpg 0 9 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0009.jpg 0 10 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0010.jpg 0 11 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0011.jpg 0 12 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0012.jpg 0 13 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0001.jpg 1 14 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0002.jpg 1 15 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0003.jpg 1 16 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0004.jpg 1 17 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0005.jpg 1 18 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0006.jpg 1 19 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0007.jpg 1 20 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0008.jpg 1 21 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0009.jpg 1 22 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0010.jpg 1 23 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0011.jpg 1 24 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0012.jpg 1 25 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0013.jpg 1 26 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0014.jpg 1 27 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0015.jpg 1 28 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0016.jpg 1 29 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0017.jpg 1 30 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0018.jpg 1 31 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0019.jpg 1 32 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0020.jpg 1 33 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0021.jpg 1 34 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0001.jpg 2 35 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0002.jpg 2 36 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0003.jpg 2 37 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0004.jpg 2 38 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0005.jpg 2 39 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0006.jpg 2 40 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0007.jpg 2 41 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0008.jpg 2 42 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0009.jpg 2 43 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0010.jpg 2 44 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0011.jpg 2 45 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0012.jpg 2 46 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0001.jpg 3 47 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0002.jpg 3 48 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0003.jpg 3 49 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0004.jpg 3 50 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0005.jpg 3 51 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0006.jpg 3 52 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0007.jpg 3 53 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0008.jpg 3 54 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0009.jpg 3 55 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0010.jpg 3 56 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0011.jpg 3 57 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0012.jpg 3 58 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0013.jpg 3 59 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0001.jpg 4 60 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0002.jpg 4 61 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0003.jpg 4 62 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0004.jpg 4 63 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0005.jpg 4 64 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0006.jpg 4 65 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0007.jpg 4 66 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0008.jpg 4 67 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0009.jpg 4 68 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0010.jpg 4 69 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0001.jpg 5 70 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0002.jpg 5 71 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0003.jpg 5 72 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0004.jpg 5 73 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0005.jpg 5 74 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0006.jpg 5 75 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0007.jpg 5 76 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0008.jpg 5 77 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0009.jpg 5 78 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0010.jpg 5 79 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0011.jpg 5 80 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0012.jpg 5 81 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0013.jpg 5 82 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0014.jpg 5 83 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0015.jpg 5 84 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0016.jpg 5 85 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0017.jpg 5 86 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0018.jpg 5 87 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0019.jpg 5 88 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0020.jpg 5 89 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0021.jpg 5 90 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0022.jpg 5 91 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0023.jpg 5 92 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0024.jpg 5 93 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0001.jpg 6 94 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0002.jpg 6 95 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0003.jpg 6 96 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0004.jpg 6 97 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0005.jpg 6 98 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0006.jpg 6 99 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0007.jpg 6 100 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0008.jpg 6 101 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0009.jpg 6 102 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0010.jpg 6 103 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0011.jpg 6 104 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0012.jpg 6 105 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0013.jpg 6 106 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0014.jpg 6 107 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0015.jpg 6 108 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0016.jpg 6 109 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0017.jpg 6 110 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0018.jpg 6 111 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0019.jpg 6 112 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0020.jpg 6 113 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0021.jpg 6 114 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0022.jpg 6 115 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0001.jpg 7 116 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0002.jpg 7 117 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0003.jpg 7 118 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0004.jpg 7 119 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0005.jpg 7 120 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0006.jpg 7 121 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0007.jpg 7 122 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0008.jpg 7 123 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0009.jpg 7 124 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0010.jpg 7 125 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0011.jpg 7 126 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0012.jpg 7 127 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mug/frame_0001.jpg 8 128 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mug/frame_0002.jpg 8 129 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mug/frame_0003.jpg 8 130 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mug/frame_0004.jpg 8 131 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mug/frame_0005.jpg 8 132 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mug/frame_0006.jpg 8 133 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mug/frame_0007.jpg 8 134 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mug/frame_0008.jpg 8 135 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0001.jpg 9 136 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0002.jpg 9 137 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0003.jpg 9 138 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0004.jpg 9 139 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0005.jpg 9 140 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0006.jpg 9 141 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0007.jpg 9 142 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0008.jpg 9 143 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0009.jpg 9 144 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0010.jpg 9 145 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0011.jpg 9 146 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0012.jpg 9 147 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0013.jpg 9 148 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0014.jpg 9 149 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0015.jpg 9 150 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0016.jpg 9 151 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0017.jpg 9 152 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0018.jpg 9 153 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0019.jpg 9 154 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0020.jpg 9 155 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0021.jpg 9 156 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0022.jpg 9 157 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0023.jpg 9 158 | -------------------------------------------------------------------------------- /code/uda/image_source.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network 11 | import loss 12 | from torch.utils.data import DataLoader 13 | from data_list import ImageList 14 | import random, pdb, math, copy 15 | from tqdm import tqdm 16 | from loss import CrossEntropyLabelSmooth 17 | from scipy.spatial.distance import cdist 18 | from sklearn.metrics import confusion_matrix 19 | from sklearn.cluster import KMeans 20 | 21 | def op_copy(optimizer): 22 | for param_group in optimizer.param_groups: 23 | param_group['lr0'] = param_group['lr'] 24 | return optimizer 25 | 26 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 27 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 28 | for param_group in optimizer.param_groups: 29 | param_group['lr'] = param_group['lr0'] * decay 30 | param_group['weight_decay'] = 1e-3 31 | param_group['momentum'] = 0.9 32 | param_group['nesterov'] = True 33 | return optimizer 34 | 35 | def image_train(resize_size=256, crop_size=224, alexnet=False): 36 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 37 | std=[0.229, 0.224, 0.225]) 38 | return transforms.Compose([ 39 | transforms.Resize((resize_size, resize_size)), 40 | transforms.RandomCrop(crop_size), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | normalize 44 | ]) 45 | 46 | def image_test(resize_size=256, crop_size=224, alexnet=False): 47 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 48 | std=[0.229, 0.224, 0.225]) 49 | return transforms.Compose([ 50 | transforms.Resize((resize_size, resize_size)), 51 | transforms.CenterCrop(crop_size), 52 | transforms.ToTensor(), 53 | normalize 54 | ]) 55 | 56 | def data_load(args): 57 | ## prepare data 58 | dsets = {} 59 | dset_loaders = {} 60 | train_bs = args.batch_size 61 | txt_src = open(args.s_dset_path).readlines() 62 | txt_test = open(args.test_dset_path).readlines() 63 | 64 | dsize = len(txt_src) 65 | tr_size = int(0.9*dsize) 66 | # print(dsize, tr_size, dsize - tr_size) 67 | tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size]) 68 | 69 | dsets["source_tr"] = ImageList(tr_txt, transform=image_train()) 70 | dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, 71 | num_workers=args.worker, drop_last=False) 72 | dsets["source_te"] = ImageList(te_txt, transform=image_test()) 73 | dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs*3, shuffle=False, 74 | num_workers=args.worker, drop_last=False) 75 | dsets["test"] = ImageList(txt_test, transform=image_test()) 76 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, 77 | num_workers=args.worker, drop_last=False) 78 | 79 | return dset_loaders 80 | 81 | def cal_acc(loader, netF, netB, netC, flag=False): 82 | start_test = True 83 | with torch.no_grad(): 84 | iter_test = iter(loader) 85 | for i in range(len(loader)): 86 | data = iter_test.next() 87 | inputs = data[0] 88 | labels = data[1] 89 | inputs = inputs.cuda() 90 | outputs = netC(netB(netF(inputs))) 91 | if start_test: 92 | all_output = outputs.float().cpu() 93 | all_label = labels.float() 94 | start_test = False 95 | else: 96 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 97 | all_label = torch.cat((all_label, labels.float()), 0) 98 | 99 | all_output = nn.Softmax(dim=1)(all_output) 100 | _, predict = torch.max(all_output, 1) 101 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 102 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item() 103 | 104 | if flag: 105 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 106 | matrix = matrix[np.unique(all_label).astype(int),:] 107 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100 108 | aacc = acc.mean() 109 | aa = [str(np.round(i, 2)) for i in acc] 110 | acc = ' '.join(aa) 111 | return aacc, acc 112 | else: 113 | return accuracy*100, mean_ent 114 | 115 | def train_source(args): 116 | dset_loaders = data_load(args) 117 | ## set base network 118 | if args.net[0:3] == 'res': 119 | netF = network.ResBase(res_name=args.net).cuda() 120 | elif args.net[0:3] == 'vgg': 121 | netF = network.VGGBase(vgg_name=args.net).cuda() 122 | 123 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 124 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 125 | 126 | param_group = [] 127 | learning_rate = args.lr 128 | for k, v in netF.named_parameters(): 129 | param_group += [{'params': v, 'lr': learning_rate*0.1}] 130 | for k, v in netB.named_parameters(): 131 | param_group += [{'params': v, 'lr': learning_rate}] 132 | for k, v in netC.named_parameters(): 133 | param_group += [{'params': v, 'lr': learning_rate}] 134 | optimizer = optim.SGD(param_group) 135 | optimizer = op_copy(optimizer) 136 | 137 | acc_init = 0 138 | max_iter = args.max_epoch * len(dset_loaders["source_tr"]) 139 | interval_iter = max_iter // 10 140 | iter_num = 0 141 | 142 | netF.train() 143 | netB.train() 144 | netC.train() 145 | 146 | while iter_num < max_iter: 147 | try: 148 | inputs_source, labels_source = iter_source.next() 149 | except: 150 | iter_source = iter(dset_loaders["source_tr"]) 151 | inputs_source, labels_source = iter_source.next() 152 | 153 | if inputs_source.size(0) == 1: 154 | continue 155 | 156 | iter_num += 1 157 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 158 | 159 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda() 160 | outputs_source = netC(netB(netF(inputs_source))) 161 | classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source) 162 | 163 | optimizer.zero_grad() 164 | classifier_loss.backward() 165 | optimizer.step() 166 | 167 | if iter_num % interval_iter == 0 or iter_num == max_iter: 168 | netF.eval() 169 | netB.eval() 170 | netC.eval() 171 | if args.dset=='VISDA-C': 172 | acc_s_te, acc_list = cal_acc(dset_loaders['source_te'], netF, netB, netC, True) 173 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te) + '\n' + acc_list 174 | else: 175 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC, False) 176 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te) 177 | args.out_file.write(log_str + '\n') 178 | args.out_file.flush() 179 | print(log_str+'\n') 180 | 181 | if acc_s_te >= acc_init: 182 | acc_init = acc_s_te 183 | best_netF = netF.state_dict() 184 | best_netB = netB.state_dict() 185 | best_netC = netC.state_dict() 186 | 187 | netF.train() 188 | netB.train() 189 | netC.train() 190 | 191 | torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt")) 192 | torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt")) 193 | torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt")) 194 | 195 | return netF, netB, netC 196 | 197 | def test_target(args): 198 | dset_loaders = data_load(args) 199 | ## set base network 200 | if args.net[0:3] == 'res': 201 | netF = network.ResBase(res_name=args.net).cuda() 202 | elif args.net[0:3] == 'vgg': 203 | netF = network.VGGBase(vgg_name=args.net).cuda() 204 | 205 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 206 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 207 | 208 | args.modelpath = args.output_dir_src + '/source_F.pt' 209 | netF.load_state_dict(torch.load(args.modelpath)) 210 | args.modelpath = args.output_dir_src + '/source_B.pt' 211 | netB.load_state_dict(torch.load(args.modelpath)) 212 | args.modelpath = args.output_dir_src + '/source_C.pt' 213 | netC.load_state_dict(torch.load(args.modelpath)) 214 | netF.eval() 215 | netB.eval() 216 | netC.eval() 217 | 218 | if args.dset=='VISDA-C': 219 | acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True) 220 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc) + '\n' + acc_list 221 | else: 222 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) 223 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc) 224 | 225 | args.out_file.write(log_str) 226 | args.out_file.flush() 227 | print(log_str) 228 | 229 | def print_args(args): 230 | s = "==========================================\n" 231 | for arg, content in args.__dict__.items(): 232 | s += "{}:{}\n".format(arg, content) 233 | return s 234 | 235 | if __name__ == "__main__": 236 | parser = argparse.ArgumentParser(description='SHOT++') 237 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 238 | parser.add_argument('--s', type=int, default=0, help="source") 239 | parser.add_argument('--t', type=int, default=1, help="target") 240 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations") 241 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 242 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 243 | parser.add_argument('--dset', type=str, default='office-home', choices=['VISDA-C', 'office', 'office-home']) 244 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 245 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet18, resnet34, resnet50, resnet101") 246 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 247 | parser.add_argument('--bottleneck', type=int, default=256) 248 | parser.add_argument('--epsilon', type=float, default=1e-5) 249 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 250 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 251 | parser.add_argument('--smooth', type=float, default=0.1) 252 | parser.add_argument('--output', type=str, default='san') 253 | parser.add_argument('--da', type=str, default='uda', choices=['uda']) 254 | args = parser.parse_args() 255 | 256 | if args.dset == 'office-home': 257 | names = ['Art', 'Clipart', 'Product', 'RealWorld'] 258 | args.class_num = 65 259 | elif args.dset == 'office': 260 | names = ['amazon', 'dslr', 'webcam'] 261 | args.class_num = 31 262 | elif args.dset == 'VISDA-C': 263 | names = ['train', 'validation'] 264 | args.class_num = 12 265 | args.lr = 1e-3 266 | 267 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 268 | SEED = args.seed 269 | torch.manual_seed(SEED) 270 | torch.cuda.manual_seed(SEED) 271 | np.random.seed(SEED) 272 | random.seed(SEED) 273 | # torch.backends.cudnn.deterministic = True 274 | 275 | folder = '../data/' 276 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 277 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 278 | 279 | args.output_dir_src = osp.join(args.output, args.da, args.dset, names[args.s][0].upper()) 280 | args.name_src = names[args.s][0].upper() 281 | if not osp.exists(args.output_dir_src): 282 | os.system('mkdir -p ' + args.output_dir_src) 283 | if not osp.exists(args.output_dir_src): 284 | os.mkdir(args.output_dir_src) 285 | 286 | args.out_file = open(osp.join(args.output_dir_src, 'log.txt'), 'w') 287 | args.out_file.write(print_args(args)+'\n') 288 | args.out_file.flush() 289 | train_source(args) 290 | 291 | args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w') 292 | for i in range(len(names)): 293 | if i == args.s: 294 | continue 295 | args.t = i 296 | args.name = names[args.s][0].upper() + names[args.t][0].upper() 297 | 298 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 299 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 300 | 301 | test_target(args) 302 | 303 | args.out_file.close() --------------------------------------------------------------------------------