├── figs
├── shot.jpg
├── shot_plus.jpg
└── shot_ssl.jpg
├── supp
└── shot++_supp.pdf
├── code
├── digit
│ ├── data_load
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── mnist.cpython-36.pyc
│ │ │ ├── mnist.cpython-37.pyc
│ │ │ ├── svhn.cpython-36.pyc
│ │ │ ├── svhn.cpython-37.pyc
│ │ │ ├── usps.cpython-36.pyc
│ │ │ ├── usps.cpython-37.pyc
│ │ │ ├── utils.cpython-36.pyc
│ │ │ ├── utils.cpython-37.pyc
│ │ │ ├── vision.cpython-36.pyc
│ │ │ ├── vision.cpython-37.pyc
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ └── __init__.cpython-37.pyc
│ │ ├── vision.py
│ │ ├── utils.py
│ │ ├── usps.py
│ │ └── svhn.py
│ ├── run_digit.sh
│ ├── loss.py
│ ├── rotation.py
│ └── network.py
├── ssda
│ ├── run_ssda.sh
│ ├── rotation.py
│ ├── loss.py
│ ├── data_list.py
│ └── image_source.py
├── pda
│ ├── run_pda.sh
│ ├── rotation.py
│ ├── loss.py
│ └── data_list.py
├── msda
│ ├── rotation.py
│ ├── loss.py
│ ├── run_msda.sh
│ ├── data_list.py
│ ├── image_ms.py
│ └── image_source.py
├── uda
│ ├── rotation.py
│ ├── loss.py
│ ├── run_uda.sh
│ ├── data_list.py
│ └── image_source.py
└── data
│ ├── ssda
│ └── office-home
│ │ ├── labeled_target_images_Art_1.txt
│ │ ├── labeled_target_images_Clipart_1.txt
│ │ ├── labeled_target_images_Product_1.txt
│ │ └── labeled_target_images_Real_1.txt
│ ├── pacs
│ └── photo_crossval_kfold.txt
│ └── office-caltech
│ └── dslr_list.txt
├── LICENSE
└── README.md
/figs/shot.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/figs/shot.jpg
--------------------------------------------------------------------------------
/figs/shot_plus.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/figs/shot_plus.jpg
--------------------------------------------------------------------------------
/figs/shot_ssl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/figs/shot_ssl.jpg
--------------------------------------------------------------------------------
/supp/shot++_supp.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/supp/shot++_supp.pdf
--------------------------------------------------------------------------------
/code/digit/data_load/__init__.py:
--------------------------------------------------------------------------------
1 | from .svhn import *
2 | from .mnist import *
3 | from .usps import *
--------------------------------------------------------------------------------
/code/digit/data_load/__pycache__/mnist.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/mnist.cpython-36.pyc
--------------------------------------------------------------------------------
/code/digit/data_load/__pycache__/mnist.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/mnist.cpython-37.pyc
--------------------------------------------------------------------------------
/code/digit/data_load/__pycache__/svhn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/svhn.cpython-36.pyc
--------------------------------------------------------------------------------
/code/digit/data_load/__pycache__/svhn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/svhn.cpython-37.pyc
--------------------------------------------------------------------------------
/code/digit/data_load/__pycache__/usps.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/usps.cpython-36.pyc
--------------------------------------------------------------------------------
/code/digit/data_load/__pycache__/usps.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/usps.cpython-37.pyc
--------------------------------------------------------------------------------
/code/digit/data_load/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/code/digit/data_load/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/code/digit/data_load/__pycache__/vision.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/vision.cpython-36.pyc
--------------------------------------------------------------------------------
/code/digit/data_load/__pycache__/vision.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/vision.cpython-37.pyc
--------------------------------------------------------------------------------
/code/digit/data_load/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/code/digit/data_load/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tim-learn/SHOT-plus/HEAD/code/digit/data_load/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 tim-learn
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/code/ssda/run_ssda.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # id seed
4 | for ((s=0;s<=3;s++))
5 | do
6 | python image_source.py --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset office-home --max_epoch 50 --s $s
7 |
8 | for ((t=0;t<=3;t++))
9 | do
10 | if [ $s -eq $t ];then
11 | echo "skipped"
12 | else
13 | echo "okay"
14 | python image_target.py --ssl 0.0 --cls_par 0.0 --ent '' --gpu_id $1 --s $s --t $t --output_src "ckps/s"$2 --output "ckps/st"$2 --seed $2 --dset office-home
15 | python image_mixmatch.py --ps 0.0 --cls_par 0.0 --gpu_id $1 --s $s --t $t --output_tar "ckps/st"$2 --output "ckps/mm_st"$2 --seed $2 --dset office-home --max_epoch 50
16 |
17 | python image_target.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --t $t --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-home
18 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --t $t --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --max_epoch 50
19 |
20 | python image_target.py --ssl 0.2 --cls_par 0.1 --gpu_id $1 --s $s --t $t --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-home
21 | python image_mixmatch.py --ps 0.0 --ssl 0.2 --cls_par 0.1 --gpu_id $1 --s $s --t $t --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --max_epoch 50
22 |
23 | fi
24 | done
25 | done
--------------------------------------------------------------------------------
/code/digit/run_digit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | python uda_digit.py --dset m2u --gpu_id $1 --seed $2 --cls_par 0.0 --output ckps_digits
4 | python uda_digit.py --dset u2m --gpu_id $1 --seed $2 --cls_par 0.0 --output ckps_digits
5 | python uda_digit.py --dset s2m --gpu_id $1 --seed $2 --cls_par 0.0 --output ckps_digits
6 |
7 | python uda_digit.py --dset m2u --gpu_id $1 --seed $2 --cls_par 0.1 --ssl 0.2 --output ckps_digits
8 | python uda_digit.py --dset u2m --gpu_id $1 --seed $2 --cls_par 0.1 --ssl 0.2 --output ckps_digits
9 | python uda_digit.py --dset s2m --gpu_id $1 --seed $2 --cls_par 0.1 --ssl 0.2 --output ckps_digits
10 |
11 | python digit_mixmatch.py --dset m2u --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --model source
12 | python digit_mixmatch.py --dset u2m --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --model source
13 | python digit_mixmatch.py --dset s2m --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --model source
14 |
15 | python digit_mixmatch.py --dset m2u --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --cls_par 0.0
16 | python digit_mixmatch.py --dset u2m --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --cls_par 0.0
17 | python digit_mixmatch.py --dset s2m --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --cls_par 0.0
18 |
19 | python digit_mixmatch.py --dset m2u --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --cls_par 0.1 --ssl 0.2
20 | python digit_mixmatch.py --dset u2m --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --cls_par 0.1 --ssl 0.2
21 | python digit_mixmatch.py --dset s2m --gpu_id $1 --seed $2 --output ckps_mm --output_tar ckps_digits --cls_par 0.1 --ssl 0.2
--------------------------------------------------------------------------------
/code/digit/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import math
6 | import torch.nn.functional as F
7 | import pdb
8 |
9 | def Entropy(input_):
10 | bs = input_.size(0)
11 | epsilon = 1e-5
12 | entropy = -input_ * torch.log(input_ + epsilon)
13 | entropy = torch.sum(entropy, dim=1)
14 | return entropy
15 |
16 | class CrossEntropyLabelSmooth(nn.Module):
17 | """Cross entropy loss with label smoothing regularizer.
18 | Reference:
19 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
20 | Equation: y = (1 - epsilon) * y + epsilon / K.
21 | Args:
22 | num_classes (int): number of classes.
23 | epsilon (float): weight.
24 | """
25 |
26 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, size_average=True):
27 | super(CrossEntropyLabelSmooth, self).__init__()
28 | self.num_classes = num_classes
29 | self.epsilon = epsilon
30 | self.use_gpu = use_gpu
31 | self.size_average = size_average
32 | self.logsoftmax = nn.LogSoftmax(dim=1)
33 |
34 | def forward(self, inputs, targets):
35 | """
36 | Args:
37 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
38 | targets: ground truth labels with shape (num_classes)
39 | """
40 | log_probs = self.logsoftmax(inputs)
41 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
42 | if self.use_gpu: targets = targets.cuda()
43 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
44 | if self.size_average:
45 | loss = (- targets * log_probs).mean(0).sum()
46 | else:
47 | loss = (- targets * log_probs).sum(1)
48 | return loss
--------------------------------------------------------------------------------
/code/digit/rotation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | from torchvision import datasets
4 | import numpy as np
5 |
6 | # Assumes that tensor is (nchannels, height, width)
7 | def tensor_rot_90(x):
8 | return x.flip(2).transpose(1, 2)
9 |
10 | def tensor_rot_180(x):
11 | return x.flip(2).flip(1)
12 |
13 | def tensor_rot_270(x):
14 | return x.transpose(1, 2).flip(2)
15 |
16 | def rotate_single_with_label(img, label):
17 | if label == 1:
18 | img = tensor_rot_90(img)
19 | elif label == 2:
20 | img = tensor_rot_180(img)
21 | elif label == 3:
22 | img = tensor_rot_270(img)
23 | return img
24 |
25 | def rotate_batch_with_labels(batch, labels):
26 | images = []
27 | for img, label in zip(batch, labels):
28 | img = rotate_single_with_label(img, label)
29 | images.append(img.unsqueeze(0))
30 | return torch.cat(images)
31 |
32 | def rotate_batch(batch, label='rand'):
33 | if label == 'rand':
34 | labels = torch.randint(4, (len(batch),), dtype=torch.long)
35 | else:
36 | assert isinstance(label, int)
37 | labels = torch.zeros((len(batch),), dtype=torch.long) + label
38 | return rotate_batch_with_labels(batch, labels), labels
39 |
40 | class RotateImageFolder(datasets.ImageFolder):
41 | def __init__(self, traindir, train_transform, original=True, rotation=True, rotation_transform=None):
42 | super(RotateImageFolder, self).__init__(traindir, train_transform)
43 | self.original = original
44 | self.rotation = rotation
45 | self.rotation_transform = rotation_transform
46 |
47 | def __getitem__(self, index):
48 | path, target = self.imgs[index]
49 | img_input = self.loader(path)
50 |
51 | if self.transform is not None:
52 | img = self.transform(img_input)
53 | else:
54 | img = img_input
55 |
56 | results = []
57 | if self.original:
58 | results.append(img)
59 | results.append(target)
60 | if self.rotation:
61 | if self.rotation_transform is not None:
62 | img = self.rotation_transform(img_input)
63 | target_ssh = np.random.randint(0, 4, 1)[0]
64 | img_ssh = rotate_single_with_label(img, target_ssh)
65 | results.append(img_ssh)
66 | results.append(target_ssh)
67 | return results
68 |
69 | def switch_mode(self, original, rotation):
70 | self.original = original
71 | self.rotation = rotation
72 |
--------------------------------------------------------------------------------
/code/pda/run_pda.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | python image_pretrained.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --seed $2 --output "ckps/s"$2
4 | python image_pretrained.py --ssl 0.0 --cls_par 0.3 --gpu_id $1 --seed $2 --output "ckps/s"$2
5 | python image_pretrained.py --ssl 0.6 --cls_par 0.3 --gpu_id $1 --seed $2 --output "ckps/s"$2
6 |
7 | for ((s=0;s<=3;s++))
8 | do
9 | python image_source.py --da pda --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset office-home --max_epoch 50 --s $s
10 | python image_mixmatch.py --model source --da pda --ps 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/s"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --max_epoch 50
11 |
12 | python image_target.py --da pda --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-home
13 | python image_mixmatch.py --da pda --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --max_epoch 50
14 |
15 | python image_target.py --da pda --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-home
16 | python image_mixmatch.py --da pda --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --max_epoch 50
17 |
18 | done
19 |
20 |
21 | for ((s=0;s<=1;s++))
22 | do
23 | python image_source.py --da pda --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset VISDA-C --max_epoch 10 --s $s
24 | python image_mixmatch.py --model source --da pda --ps 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/s"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10
25 |
26 | python image_target.py --da pda --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset VISDA-C
27 | python image_mixmatch.py --da pda --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10
28 |
29 | python image_target.py --da pda --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset VISDA-C
30 | python image_mixmatch.py --da pda --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10
31 |
32 | done
--------------------------------------------------------------------------------
/code/msda/rotation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | from torchvision import datasets
4 | import numpy as np
5 |
6 | # Assumes that tensor is (nchannels, height, width)
7 | def tensor_rot_90(x):
8 | return x.flip(2).transpose(1, 2)
9 |
10 | def tensor_rot_180(x):
11 | return x.flip(2).flip(1)
12 |
13 | def tensor_rot_270(x):
14 | return x.transpose(1, 2).flip(2)
15 |
16 | def rotate_single_with_label(img, label):
17 | if label == 1:
18 | img = tensor_rot_90(img)
19 | elif label == 2:
20 | img = tensor_rot_180(img)
21 | elif label == 3:
22 | img = tensor_rot_270(img)
23 | return img
24 |
25 | def rotate_batch_with_labels(batch, labels):
26 | images = []
27 | for img, label in zip(batch, labels):
28 | img = rotate_single_with_label(img, label)
29 | images.append(img.unsqueeze(0))
30 | return torch.cat(images)
31 |
32 |
33 | def rotate_single_with_label2(img, label):
34 | if label == 1:
35 | img = tensor_rot_180(img)
36 | return img
37 |
38 | def rotate_batch_with_labels2(batch, labels):
39 | images = []
40 | for img, label in zip(batch, labels):
41 | img = rotate_single_with_label2(img, label)
42 | images.append(img.unsqueeze(0))
43 | return torch.cat(images)
44 |
45 |
46 | def rotate_batch(batch, label='rand'):
47 | if label == 'rand':
48 | labels = torch.randint(4, (len(batch),), dtype=torch.long)
49 | else:
50 | assert isinstance(label, int)
51 | labels = torch.zeros((len(batch),), dtype=torch.long) + label
52 | return rotate_batch_with_labels(batch, labels), labels
53 |
54 | class RotateImageFolder(datasets.ImageFolder):
55 | def __init__(self, traindir, train_transform, original=True, rotation=True, rotation_transform=None):
56 | super(RotateImageFolder, self).__init__(traindir, train_transform)
57 | self.original = original
58 | self.rotation = rotation
59 | self.rotation_transform = rotation_transform
60 |
61 | def __getitem__(self, index):
62 | path, target = self.imgs[index]
63 | img_input = self.loader(path)
64 |
65 | if self.transform is not None:
66 | img = self.transform(img_input)
67 | else:
68 | img = img_input
69 |
70 | results = []
71 | if self.original:
72 | results.append(img)
73 | results.append(target)
74 | if self.rotation:
75 | if self.rotation_transform is not None:
76 | img = self.rotation_transform(img_input)
77 | target_ssh = np.random.randint(0, 4, 1)[0]
78 | img_ssh = rotate_single_with_label(img, target_ssh)
79 | results.append(img_ssh)
80 | results.append(target_ssh)
81 | return results
82 |
83 | def switch_mode(self, original, rotation):
84 | self.original = original
85 | self.rotation = rotation
86 |
--------------------------------------------------------------------------------
/code/pda/rotation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | from torchvision import datasets
4 | import numpy as np
5 |
6 | # Assumes that tensor is (nchannels, height, width)
7 | def tensor_rot_90(x):
8 | return x.flip(2).transpose(1, 2)
9 |
10 | def tensor_rot_180(x):
11 | return x.flip(2).flip(1)
12 |
13 | def tensor_rot_270(x):
14 | return x.transpose(1, 2).flip(2)
15 |
16 | def rotate_single_with_label(img, label):
17 | if label == 1:
18 | img = tensor_rot_90(img)
19 | elif label == 2:
20 | img = tensor_rot_180(img)
21 | elif label == 3:
22 | img = tensor_rot_270(img)
23 | return img
24 |
25 | def rotate_batch_with_labels(batch, labels):
26 | images = []
27 | for img, label in zip(batch, labels):
28 | img = rotate_single_with_label(img, label)
29 | images.append(img.unsqueeze(0))
30 | return torch.cat(images)
31 |
32 |
33 | def rotate_single_with_label2(img, label):
34 | if label == 1:
35 | img = tensor_rot_180(img)
36 | return img
37 |
38 | def rotate_batch_with_labels2(batch, labels):
39 | images = []
40 | for img, label in zip(batch, labels):
41 | img = rotate_single_with_label2(img, label)
42 | images.append(img.unsqueeze(0))
43 | return torch.cat(images)
44 |
45 |
46 | def rotate_batch(batch, label='rand'):
47 | if label == 'rand':
48 | labels = torch.randint(4, (len(batch),), dtype=torch.long)
49 | else:
50 | assert isinstance(label, int)
51 | labels = torch.zeros((len(batch),), dtype=torch.long) + label
52 | return rotate_batch_with_labels(batch, labels), labels
53 |
54 | class RotateImageFolder(datasets.ImageFolder):
55 | def __init__(self, traindir, train_transform, original=True, rotation=True, rotation_transform=None):
56 | super(RotateImageFolder, self).__init__(traindir, train_transform)
57 | self.original = original
58 | self.rotation = rotation
59 | self.rotation_transform = rotation_transform
60 |
61 | def __getitem__(self, index):
62 | path, target = self.imgs[index]
63 | img_input = self.loader(path)
64 |
65 | if self.transform is not None:
66 | img = self.transform(img_input)
67 | else:
68 | img = img_input
69 |
70 | results = []
71 | if self.original:
72 | results.append(img)
73 | results.append(target)
74 | if self.rotation:
75 | if self.rotation_transform is not None:
76 | img = self.rotation_transform(img_input)
77 | target_ssh = np.random.randint(0, 4, 1)[0]
78 | img_ssh = rotate_single_with_label(img, target_ssh)
79 | results.append(img_ssh)
80 | results.append(target_ssh)
81 | return results
82 |
83 | def switch_mode(self, original, rotation):
84 | self.original = original
85 | self.rotation = rotation
86 |
--------------------------------------------------------------------------------
/code/ssda/rotation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | from torchvision import datasets
4 | import numpy as np
5 |
6 | # Assumes that tensor is (nchannels, height, width)
7 | def tensor_rot_90(x):
8 | return x.flip(2).transpose(1, 2)
9 |
10 | def tensor_rot_180(x):
11 | return x.flip(2).flip(1)
12 |
13 | def tensor_rot_270(x):
14 | return x.transpose(1, 2).flip(2)
15 |
16 | def rotate_single_with_label(img, label):
17 | if label == 1:
18 | img = tensor_rot_90(img)
19 | elif label == 2:
20 | img = tensor_rot_180(img)
21 | elif label == 3:
22 | img = tensor_rot_270(img)
23 | return img
24 |
25 | def rotate_batch_with_labels(batch, labels):
26 | images = []
27 | for img, label in zip(batch, labels):
28 | img = rotate_single_with_label(img, label)
29 | images.append(img.unsqueeze(0))
30 | return torch.cat(images)
31 |
32 |
33 | def rotate_single_with_label2(img, label):
34 | if label == 1:
35 | img = tensor_rot_180(img)
36 | return img
37 |
38 | def rotate_batch_with_labels2(batch, labels):
39 | images = []
40 | for img, label in zip(batch, labels):
41 | img = rotate_single_with_label2(img, label)
42 | images.append(img.unsqueeze(0))
43 | return torch.cat(images)
44 |
45 |
46 | def rotate_batch(batch, label='rand'):
47 | if label == 'rand':
48 | labels = torch.randint(4, (len(batch),), dtype=torch.long)
49 | else:
50 | assert isinstance(label, int)
51 | labels = torch.zeros((len(batch),), dtype=torch.long) + label
52 | return rotate_batch_with_labels(batch, labels), labels
53 |
54 | class RotateImageFolder(datasets.ImageFolder):
55 | def __init__(self, traindir, train_transform, original=True, rotation=True, rotation_transform=None):
56 | super(RotateImageFolder, self).__init__(traindir, train_transform)
57 | self.original = original
58 | self.rotation = rotation
59 | self.rotation_transform = rotation_transform
60 |
61 | def __getitem__(self, index):
62 | path, target = self.imgs[index]
63 | img_input = self.loader(path)
64 |
65 | if self.transform is not None:
66 | img = self.transform(img_input)
67 | else:
68 | img = img_input
69 |
70 | results = []
71 | if self.original:
72 | results.append(img)
73 | results.append(target)
74 | if self.rotation:
75 | if self.rotation_transform is not None:
76 | img = self.rotation_transform(img_input)
77 | target_ssh = np.random.randint(0, 4, 1)[0]
78 | img_ssh = rotate_single_with_label(img, target_ssh)
79 | results.append(img_ssh)
80 | results.append(target_ssh)
81 | return results
82 |
83 | def switch_mode(self, original, rotation):
84 | self.original = original
85 | self.rotation = rotation
86 |
--------------------------------------------------------------------------------
/code/uda/rotation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | from torchvision import datasets
4 | import numpy as np
5 |
6 | # Assumes that tensor is (nchannels, height, width)
7 | def tensor_rot_90(x):
8 | return x.flip(2).transpose(1, 2)
9 |
10 | def tensor_rot_180(x):
11 | return x.flip(2).flip(1)
12 |
13 | def tensor_rot_270(x):
14 | return x.transpose(1, 2).flip(2)
15 |
16 | def rotate_single_with_label(img, label):
17 | if label == 1:
18 | img = tensor_rot_90(img)
19 | elif label == 2:
20 | img = tensor_rot_180(img)
21 | elif label == 3:
22 | img = tensor_rot_270(img)
23 | return img
24 |
25 | def rotate_batch_with_labels(batch, labels):
26 | images = []
27 | for img, label in zip(batch, labels):
28 | img = rotate_single_with_label(img, label)
29 | images.append(img.unsqueeze(0))
30 | return torch.cat(images)
31 |
32 |
33 | def rotate_single_with_label2(img, label):
34 | if label == 1:
35 | img = tensor_rot_180(img)
36 | return img
37 |
38 | def rotate_batch_with_labels2(batch, labels):
39 | images = []
40 | for img, label in zip(batch, labels):
41 | img = rotate_single_with_label2(img, label)
42 | images.append(img.unsqueeze(0))
43 | return torch.cat(images)
44 |
45 |
46 | def rotate_batch(batch, label='rand'):
47 | if label == 'rand':
48 | labels = torch.randint(4, (len(batch),), dtype=torch.long)
49 | else:
50 | assert isinstance(label, int)
51 | labels = torch.zeros((len(batch),), dtype=torch.long) + label
52 | return rotate_batch_with_labels(batch, labels), labels
53 |
54 | class RotateImageFolder(datasets.ImageFolder):
55 | def __init__(self, traindir, train_transform, original=True, rotation=True, rotation_transform=None):
56 | super(RotateImageFolder, self).__init__(traindir, train_transform)
57 | self.original = original
58 | self.rotation = rotation
59 | self.rotation_transform = rotation_transform
60 |
61 | def __getitem__(self, index):
62 | path, target = self.imgs[index]
63 | img_input = self.loader(path)
64 |
65 | if self.transform is not None:
66 | img = self.transform(img_input)
67 | else:
68 | img = img_input
69 |
70 | results = []
71 | if self.original:
72 | results.append(img)
73 | results.append(target)
74 | if self.rotation:
75 | if self.rotation_transform is not None:
76 | img = self.rotation_transform(img_input)
77 | target_ssh = np.random.randint(0, 4, 1)[0]
78 | img_ssh = rotate_single_with_label(img, target_ssh)
79 | results.append(img_ssh)
80 | results.append(target_ssh)
81 | return results
82 |
83 | def switch_mode(self, original, rotation):
84 | self.original = original
85 | self.rotation = rotation
86 |
--------------------------------------------------------------------------------
/code/digit/data_load/vision.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data as data
4 |
5 |
6 | class VisionDataset(data.Dataset):
7 | _repr_indent = 4
8 |
9 | def __init__(self, root, transforms=None, transform=None, target_transform=None):
10 | if isinstance(root, torch._six.string_classes):
11 | root = os.path.expanduser(root)
12 | self.root = root
13 |
14 | has_transforms = transforms is not None
15 | has_separate_transform = transform is not None or target_transform is not None
16 | if has_transforms and has_separate_transform:
17 | raise ValueError("Only transforms or transform/target_transform can "
18 | "be passed as argument")
19 |
20 | # for backwards-compatibility
21 | self.transform = transform
22 | self.target_transform = target_transform
23 |
24 | if has_separate_transform:
25 | transforms = StandardTransform(transform, target_transform)
26 | self.transforms = transforms
27 |
28 | def __getitem__(self, index):
29 | raise NotImplementedError
30 |
31 | def __len__(self):
32 | raise NotImplementedError
33 |
34 | def __repr__(self):
35 | head = "Dataset " + self.__class__.__name__
36 | body = ["Number of datapoints: {}".format(self.__len__())]
37 | if self.root is not None:
38 | body.append("Root location: {}".format(self.root))
39 | body += self.extra_repr().splitlines()
40 | if hasattr(self, "transforms") and self.transforms is not None:
41 | body += [repr(self.transforms)]
42 | lines = [head] + [" " * self._repr_indent + line for line in body]
43 | return '\n'.join(lines)
44 |
45 | def _format_transform_repr(self, transform, head):
46 | lines = transform.__repr__().splitlines()
47 | return (["{}{}".format(head, lines[0])] +
48 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
49 |
50 | def extra_repr(self):
51 | return ""
52 |
53 |
54 | class StandardTransform(object):
55 | def __init__(self, transform=None, target_transform=None):
56 | self.transform = transform
57 | self.target_transform = target_transform
58 |
59 | def __call__(self, input, target):
60 | if self.transform is not None:
61 | input = self.transform(input)
62 | if self.target_transform is not None:
63 | target = self.target_transform(target)
64 | return input, target
65 |
66 | def _format_transform_repr(self, transform, head):
67 | lines = transform.__repr__().splitlines()
68 | return (["{}{}".format(head, lines[0])] +
69 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
70 |
71 | def __repr__(self):
72 | body = [self.__class__.__name__]
73 | if self.transform is not None:
74 | body += self._format_transform_repr(self.transform,
75 | "Transform: ")
76 | if self.target_transform is not None:
77 | body += self._format_transform_repr(self.target_transform,
78 | "Target transform: ")
79 |
80 | return '\n'.join(body)
--------------------------------------------------------------------------------
/code/digit/network.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torchvision
5 | from torchvision import models
6 | from torch.autograd import Variable
7 | import math
8 | import pdb
9 | import torch.nn.utils.weight_norm as weightNorm
10 | from collections import OrderedDict
11 |
12 | def init_weights(m):
13 | classname = m.__class__.__name__
14 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
15 | nn.init.kaiming_uniform_(m.weight)
16 | nn.init.zeros_(m.bias)
17 | elif classname.find('BatchNorm') != -1:
18 | nn.init.normal_(m.weight, 1.0, 0.02)
19 | nn.init.zeros_(m.bias)
20 | elif classname.find('Linear') != -1:
21 | nn.init.xavier_normal_(m.weight)
22 | nn.init.zeros_(m.bias)
23 |
24 | class feat_bottleneck(nn.Module):
25 | def __init__(self, feature_dim, bottleneck_dim=256, type="ori"):
26 | super(feat_bottleneck, self).__init__()
27 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True)
28 | self.relu = nn.ReLU(inplace=True)
29 | self.dropout = nn.Dropout(p=0.5)
30 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim)
31 | self.bottleneck.apply(init_weights)
32 | self.type = type
33 |
34 | def forward(self, x):
35 | x = self.bottleneck(x)
36 | if self.type == "bn":
37 | x = self.bn(x)
38 | x = self.dropout(x)
39 | return x
40 |
41 | class feat_classifier(nn.Module):
42 | def __init__(self, class_num, bottleneck_dim=256, type="linear"):
43 | super(feat_classifier, self).__init__()
44 | if type == "linear":
45 | self.fc = nn.Linear(bottleneck_dim, class_num)
46 | else:
47 | self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight")
48 | self.fc.apply(init_weights)
49 |
50 | def forward(self, x):
51 | x = self.fc(x)
52 | return x
53 |
54 | class DTNBase(nn.Module):
55 | def __init__(self):
56 | super(DTNBase, self).__init__()
57 | self.conv_params = nn.Sequential(
58 | nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
59 | nn.BatchNorm2d(64),
60 | nn.Dropout2d(0.1),
61 | nn.ReLU(),
62 | nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
63 | nn.BatchNorm2d(128),
64 | nn.Dropout2d(0.3),
65 | nn.ReLU(),
66 | nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
67 | nn.BatchNorm2d(256),
68 | nn.Dropout2d(0.5),
69 | nn.ReLU()
70 | )
71 | self.in_features = 256*4*4
72 |
73 | def forward(self, x):
74 | x = self.conv_params(x)
75 | x = x.view(x.size(0), -1)
76 | return x
77 |
78 | class LeNetBase(nn.Module):
79 | def __init__(self):
80 | super(LeNetBase, self).__init__()
81 | self.conv_params = nn.Sequential(
82 | nn.Conv2d(1, 20, kernel_size=5),
83 | nn.MaxPool2d(2),
84 | nn.ReLU(),
85 | nn.Conv2d(20, 50, kernel_size=5),
86 | nn.Dropout2d(p=0.5),
87 | nn.MaxPool2d(2),
88 | nn.ReLU(),
89 | )
90 | self.in_features = 50*4*4
91 |
92 | def forward(self, x):
93 | x = self.conv_params(x)
94 | x = x.view(x.size(0), -1)
95 | return x
--------------------------------------------------------------------------------
/code/pda/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import math
6 | import torch.nn.functional as F
7 | import pdb
8 |
9 | def Entropy(input_):
10 | bs = input_.size(0)
11 | epsilon = 1e-5
12 | entropy = -input_ * torch.log(input_ + epsilon)
13 | entropy = torch.sum(entropy, dim=1)
14 | return entropy
15 |
16 | def grl_hook(coeff):
17 | def fun1(grad):
18 | return -coeff*grad.clone()
19 | return fun1
20 |
21 | def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None):
22 | softmax_output = input_list[1].detach()
23 | feature = input_list[0]
24 | if random_layer is None:
25 | op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1))
26 | ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1)))
27 | else:
28 | random_out = random_layer.forward([feature, softmax_output])
29 | ad_out = ad_net(random_out.view(-1, random_out.size(1)))
30 | batch_size = softmax_output.size(0) // 2
31 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda()
32 | if entropy is not None:
33 | entropy.register_hook(grl_hook(coeff))
34 | entropy = 1.0+torch.exp(-entropy)
35 | source_mask = torch.ones_like(entropy)
36 | source_mask[feature.size(0)//2:] = 0
37 | source_weight = entropy*source_mask
38 | target_mask = torch.ones_like(entropy)
39 | target_mask[0:feature.size(0)//2] = 0
40 | target_weight = entropy*target_mask
41 | weight = source_weight / torch.sum(source_weight).detach().item() + \
42 | target_weight / torch.sum(target_weight).detach().item()
43 | return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item()
44 | else:
45 | return nn.BCELoss()(ad_out, dc_target)
46 |
47 | def DANN(features, ad_net):
48 | ad_out = ad_net(features)
49 | batch_size = ad_out.size(0) // 2
50 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda()
51 | return nn.BCELoss()(ad_out, dc_target)
52 |
53 |
54 | class CrossEntropyLabelSmooth(nn.Module):
55 | """Cross entropy loss with label smoothing regularizer.
56 | Reference:
57 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
58 | Equation: y = (1 - epsilon) * y + epsilon / K.
59 | Args:
60 | num_classes (int): number of classes.
61 | epsilon (float): weight.
62 | """
63 |
64 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True):
65 | super(CrossEntropyLabelSmooth, self).__init__()
66 | self.num_classes = num_classes
67 | self.epsilon = epsilon
68 | self.use_gpu = use_gpu
69 | self.reduction = reduction
70 | self.logsoftmax = nn.LogSoftmax(dim=1)
71 |
72 | def forward(self, inputs, targets):
73 | """
74 | Args:
75 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
76 | targets: ground truth labels with shape (num_classes)
77 | """
78 | log_probs = self.logsoftmax(inputs)
79 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
80 | if self.use_gpu: targets = targets.cuda()
81 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
82 | loss = (- targets * log_probs).sum(dim=1)
83 | if self.reduction:
84 | return loss.mean()
85 | else:
86 | return loss
87 | return loss
--------------------------------------------------------------------------------
/code/uda/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import math
6 | import torch.nn.functional as F
7 | import pdb
8 |
9 | def Entropy(input_):
10 | bs = input_.size(0)
11 | epsilon = 1e-5
12 | entropy = -input_ * torch.log(input_ + epsilon)
13 | entropy = torch.sum(entropy, dim=1)
14 | return entropy
15 |
16 | def grl_hook(coeff):
17 | def fun1(grad):
18 | return -coeff*grad.clone()
19 | return fun1
20 |
21 | def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None):
22 | softmax_output = input_list[1].detach()
23 | feature = input_list[0]
24 | if random_layer is None:
25 | op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1))
26 | ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1)))
27 | else:
28 | random_out = random_layer.forward([feature, softmax_output])
29 | ad_out = ad_net(random_out.view(-1, random_out.size(1)))
30 | batch_size = softmax_output.size(0) // 2
31 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda()
32 | if entropy is not None:
33 | entropy.register_hook(grl_hook(coeff))
34 | entropy = 1.0+torch.exp(-entropy)
35 | source_mask = torch.ones_like(entropy)
36 | source_mask[feature.size(0)//2:] = 0
37 | source_weight = entropy*source_mask
38 | target_mask = torch.ones_like(entropy)
39 | target_mask[0:feature.size(0)//2] = 0
40 | target_weight = entropy*target_mask
41 | weight = source_weight / torch.sum(source_weight).detach().item() + \
42 | target_weight / torch.sum(target_weight).detach().item()
43 | return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item()
44 | else:
45 | return nn.BCELoss()(ad_out, dc_target)
46 |
47 | def DANN(features, ad_net):
48 | ad_out = ad_net(features)
49 | batch_size = ad_out.size(0) // 2
50 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda()
51 | return nn.BCELoss()(ad_out, dc_target)
52 |
53 |
54 | class CrossEntropyLabelSmooth(nn.Module):
55 | """Cross entropy loss with label smoothing regularizer.
56 | Reference:
57 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
58 | Equation: y = (1 - epsilon) * y + epsilon / K.
59 | Args:
60 | num_classes (int): number of classes.
61 | epsilon (float): weight.
62 | """
63 |
64 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True):
65 | super(CrossEntropyLabelSmooth, self).__init__()
66 | self.num_classes = num_classes
67 | self.epsilon = epsilon
68 | self.use_gpu = use_gpu
69 | self.reduction = reduction
70 | self.logsoftmax = nn.LogSoftmax(dim=1)
71 |
72 | def forward(self, inputs, targets):
73 | """
74 | Args:
75 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
76 | targets: ground truth labels with shape (num_classes)
77 | """
78 | log_probs = self.logsoftmax(inputs)
79 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
80 | if self.use_gpu: targets = targets.cuda()
81 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
82 | loss = (- targets * log_probs).sum(dim=1)
83 | if self.reduction:
84 | return loss.mean()
85 | else:
86 | return loss
87 | return loss
--------------------------------------------------------------------------------
/code/msda/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import math
6 | import torch.nn.functional as F
7 | import pdb
8 |
9 | def Entropy(input_):
10 | bs = input_.size(0)
11 | epsilon = 1e-5
12 | entropy = -input_ * torch.log(input_ + epsilon)
13 | entropy = torch.sum(entropy, dim=1)
14 | return entropy
15 |
16 | def grl_hook(coeff):
17 | def fun1(grad):
18 | return -coeff*grad.clone()
19 | return fun1
20 |
21 | def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None):
22 | softmax_output = input_list[1].detach()
23 | feature = input_list[0]
24 | if random_layer is None:
25 | op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1))
26 | ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1)))
27 | else:
28 | random_out = random_layer.forward([feature, softmax_output])
29 | ad_out = ad_net(random_out.view(-1, random_out.size(1)))
30 | batch_size = softmax_output.size(0) // 2
31 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda()
32 | if entropy is not None:
33 | entropy.register_hook(grl_hook(coeff))
34 | entropy = 1.0+torch.exp(-entropy)
35 | source_mask = torch.ones_like(entropy)
36 | source_mask[feature.size(0)//2:] = 0
37 | source_weight = entropy*source_mask
38 | target_mask = torch.ones_like(entropy)
39 | target_mask[0:feature.size(0)//2] = 0
40 | target_weight = entropy*target_mask
41 | weight = source_weight / torch.sum(source_weight).detach().item() + \
42 | target_weight / torch.sum(target_weight).detach().item()
43 | return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item()
44 | else:
45 | return nn.BCELoss()(ad_out, dc_target)
46 |
47 | def DANN(features, ad_net):
48 | ad_out = ad_net(features)
49 | batch_size = ad_out.size(0) // 2
50 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda()
51 | return nn.BCELoss()(ad_out, dc_target)
52 |
53 |
54 | class CrossEntropyLabelSmooth(nn.Module):
55 | """Cross entropy loss with label smoothing regularizer.
56 | Reference:
57 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
58 | Equation: y = (1 - epsilon) * y + epsilon / K.
59 | Args:
60 | num_classes (int): number of classes.
61 | epsilon (float): weight.
62 | """
63 |
64 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True):
65 | super(CrossEntropyLabelSmooth, self).__init__()
66 | self.num_classes = num_classes
67 | self.epsilon = epsilon
68 | self.use_gpu = use_gpu
69 | self.reduction = reduction
70 | self.logsoftmax = nn.LogSoftmax(dim=1)
71 |
72 | def forward(self, inputs, targets):
73 | """
74 | Args:
75 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
76 | targets: ground truth labels with shape (num_classes)
77 | """
78 | log_probs = self.logsoftmax(inputs)
79 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
80 | if self.use_gpu: targets = targets.cuda()
81 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
82 | loss = (- targets * log_probs).sum(dim=1)
83 | if self.reduction:
84 | return loss.mean()
85 | else:
86 | return loss
87 | return loss
--------------------------------------------------------------------------------
/code/ssda/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import math
6 | import torch.nn.functional as F
7 | import pdb
8 |
9 | def Entropy(input_):
10 | bs = input_.size(0)
11 | epsilon = 1e-5
12 | entropy = -input_ * torch.log(input_ + epsilon)
13 | entropy = torch.sum(entropy, dim=1)
14 | return entropy
15 |
16 | def grl_hook(coeff):
17 | def fun1(grad):
18 | return -coeff*grad.clone()
19 | return fun1
20 |
21 | def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None):
22 | softmax_output = input_list[1].detach()
23 | feature = input_list[0]
24 | if random_layer is None:
25 | op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1))
26 | ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1)))
27 | else:
28 | random_out = random_layer.forward([feature, softmax_output])
29 | ad_out = ad_net(random_out.view(-1, random_out.size(1)))
30 | batch_size = softmax_output.size(0) // 2
31 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda()
32 | if entropy is not None:
33 | entropy.register_hook(grl_hook(coeff))
34 | entropy = 1.0+torch.exp(-entropy)
35 | source_mask = torch.ones_like(entropy)
36 | source_mask[feature.size(0)//2:] = 0
37 | source_weight = entropy*source_mask
38 | target_mask = torch.ones_like(entropy)
39 | target_mask[0:feature.size(0)//2] = 0
40 | target_weight = entropy*target_mask
41 | weight = source_weight / torch.sum(source_weight).detach().item() + \
42 | target_weight / torch.sum(target_weight).detach().item()
43 | return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item()
44 | else:
45 | return nn.BCELoss()(ad_out, dc_target)
46 |
47 | def DANN(features, ad_net):
48 | ad_out = ad_net(features)
49 | batch_size = ad_out.size(0) // 2
50 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda()
51 | return nn.BCELoss()(ad_out, dc_target)
52 |
53 |
54 | class CrossEntropyLabelSmooth(nn.Module):
55 | """Cross entropy loss with label smoothing regularizer.
56 | Reference:
57 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
58 | Equation: y = (1 - epsilon) * y + epsilon / K.
59 | Args:
60 | num_classes (int): number of classes.
61 | epsilon (float): weight.
62 | """
63 |
64 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True):
65 | super(CrossEntropyLabelSmooth, self).__init__()
66 | self.num_classes = num_classes
67 | self.epsilon = epsilon
68 | self.use_gpu = use_gpu
69 | self.reduction = reduction
70 | self.logsoftmax = nn.LogSoftmax(dim=1)
71 |
72 | def forward(self, inputs, targets):
73 | """
74 | Args:
75 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
76 | targets: ground truth labels with shape (num_classes)
77 | """
78 | log_probs = self.logsoftmax(inputs)
79 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
80 | if self.use_gpu: targets = targets.cuda()
81 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
82 | loss = (- targets * log_probs).sum(dim=1)
83 | if self.reduction:
84 | return loss.mean()
85 | else:
86 | return loss
87 | return loss
--------------------------------------------------------------------------------
/code/msda/run_msda.sh:
--------------------------------------------------------------------------------
1 | # !/bin/sh
2 |
3 |
4 | for ((s=0;s<=3;s++))
5 | do
6 | python image_source.py --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset office-caltech --net resnet101 --max_epoch 100 --s $s
7 | python image_mixmatch.py --ps 0.0 --cls_par 0.0 --model source --gpu_id $1 --s $s --output_tar "ckps/s"$2 --output "ckps/mm"$2 --seed $2 --dset office-caltech --net resnet101 --max_epoch 100
8 |
9 | python image_target.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-caltech --net resnet101
10 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-caltech --net resnet101 --max_epoch 100
11 |
12 | python image_target.py --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-caltech --net resnet101
13 | python image_mixmatch.py --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-caltech --net resnet101 --max_epoch 100
14 |
15 | done
16 |
17 |
18 | for ((s=0;s<=3;s++))
19 | do
20 | python image_source.py --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset office-home --net resnet50 --max_epoch 50 --s $s
21 | python image_mixmatch.py --ps 0.0 --cls_par 0.0 --model source --gpu_id $1 --s $s --output_tar "ckps/s"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --net resnet50 --max_epoch 50
22 |
23 | python image_target.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-home --net resnet50
24 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --net resnet50 --max_epoch 50
25 |
26 | python image_target.py --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-home --net resnet50
27 | python image_mixmatch.py --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --net resnet50 --max_epoch 50
28 |
29 | done
30 |
31 |
32 | for ((s=0;s<=3;s++))
33 | do
34 | python image_source.py --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset pacs --net resnet18 --max_epoch 100 --s $s
35 | python image_mixmatch.py --ps 0.0 --cls_par 0.0 --model source --gpu_id $1 --s $s --output_tar "ckps/s"$2 --output "ckps/mm"$2 --seed $2 --dset pacs --net resnet18 --max_epoch 100
36 |
37 | python image_target.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset pacs --net resnet18
38 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset pacs --net resnet18 --max_epoch 100
39 |
40 | python image_target.py --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset pacs --net resnet18
41 | python image_mixmatch.py --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset pacs --net resnet18 --max_epoch 100
42 |
43 | done
44 |
45 |
46 |
47 |
48 | for ((y=2019;y<=2021;y++))
49 | do
50 | for ((t=0;t<=3;t++))
51 | do
52 | python image_ms.py --output "san_ms/s"$y --output_src "ckps/s"$y --output_tar "ckps/t"$y --output_mm "ckps/mm"$y --dset office-caltech --t $t --cls_par 0.0 --ssl 0.0
53 | python image_ms.py --output "san_ms/s"$y --output_src "ckps/s"$y --output_tar "ckps/t"$y --output_mm "ckps/mm"$y --dset office-caltech --t $t --cls_par 0.3 --ssl 0.6
54 |
55 | python image_ms.py --output "san_ms/s"$y --output_src "ckps/s"$y --output_tar "ckps/t"$y --output_mm "ckps/mm"$y --dset pacs --net resnet18 --t $t --cls_par 0.0 --ssl 0.0
56 | python image_ms.py --output "san_ms/s"$y --output_src "ckps/s"$y --output_tar "ckps/t"$y --output_mm "ckps/mm"$y --dset pacs --net resnet18 --t $t --cls_par 0.3 --ssl 0.6
57 |
58 | python image_ms.py --output "san_ms/s"$y --output_src "ckps/s"$y --output_tar "ckps/t"$y --output_mm "ckps/mm"$y --dset office-home --net resnet50 --t $t --cls_par 0.0 --ssl 0.0
59 | python image_ms.py --output "san_ms/s"$y --output_src "ckps/s"$y --output_tar "ckps/t"$y --output_mm "ckps/mm"$y --dset office-home --net resnet50 --t $t --cls_par 0.3 --ssl 0.6
60 |
61 | done
62 | done
--------------------------------------------------------------------------------
/code/data/ssda/office-home/labeled_target_images_Art_1.txt:
--------------------------------------------------------------------------------
1 | /Checkpoint/liangjian/da_dataset/office_home/Art/Alarm_Clock/00045.jpg 0
2 | /Checkpoint/liangjian/da_dataset/office_home/Art/Backpack/00019.jpg 1
3 | /Checkpoint/liangjian/da_dataset/office_home/Art/Batteries/00023.jpg 2
4 | /Checkpoint/liangjian/da_dataset/office_home/Art/Bed/00033.jpg 3
5 | /Checkpoint/liangjian/da_dataset/office_home/Art/Bike/00020.jpg 4
6 | /Checkpoint/liangjian/da_dataset/office_home/Art/Bottle/00044.jpg 5
7 | /Checkpoint/liangjian/da_dataset/office_home/Art/Bucket/00022.jpg 6
8 | /Checkpoint/liangjian/da_dataset/office_home/Art/Calculator/00009.jpg 7
9 | /Checkpoint/liangjian/da_dataset/office_home/Art/Calendar/00005.jpg 8
10 | /Checkpoint/liangjian/da_dataset/office_home/Art/Candles/00038.jpg 9
11 | /Checkpoint/liangjian/da_dataset/office_home/Art/Chair/00008.jpg 10
12 | /Checkpoint/liangjian/da_dataset/office_home/Art/Clipboards/00016.jpg 11
13 | /Checkpoint/liangjian/da_dataset/office_home/Art/Computer/00004.jpg 12
14 | /Checkpoint/liangjian/da_dataset/office_home/Art/Couch/00011.jpg 13
15 | /Checkpoint/liangjian/da_dataset/office_home/Art/Curtains/00012.jpg 14
16 | /Checkpoint/liangjian/da_dataset/office_home/Art/Desk_Lamp/00016.jpg 15
17 | /Checkpoint/liangjian/da_dataset/office_home/Art/Drill/00004.jpg 16
18 | /Checkpoint/liangjian/da_dataset/office_home/Art/Eraser/00016.jpg 17
19 | /Checkpoint/liangjian/da_dataset/office_home/Art/Exit_Sign/00016.jpg 18
20 | /Checkpoint/liangjian/da_dataset/office_home/Art/Fan/00024.jpg 19
21 | /Checkpoint/liangjian/da_dataset/office_home/Art/File_Cabinet/00002.jpg 20
22 | /Checkpoint/liangjian/da_dataset/office_home/Art/Flipflops/00018.jpg 21
23 | /Checkpoint/liangjian/da_dataset/office_home/Art/Flowers/00064.jpg 22
24 | /Checkpoint/liangjian/da_dataset/office_home/Art/Folder/00006.jpg 23
25 | /Checkpoint/liangjian/da_dataset/office_home/Art/Fork/00046.jpg 24
26 | /Checkpoint/liangjian/da_dataset/office_home/Art/Glasses/00015.jpg 25
27 | /Checkpoint/liangjian/da_dataset/office_home/Art/Hammer/00021.jpg 26
28 | /Checkpoint/liangjian/da_dataset/office_home/Art/Helmet/00074.jpg 27
29 | /Checkpoint/liangjian/da_dataset/office_home/Art/Kettle/00017.jpg 28
30 | /Checkpoint/liangjian/da_dataset/office_home/Art/Keyboard/00008.jpg 29
31 | /Checkpoint/liangjian/da_dataset/office_home/Art/Knives/00037.jpg 30
32 | /Checkpoint/liangjian/da_dataset/office_home/Art/Lamp_Shade/00029.jpg 31
33 | /Checkpoint/liangjian/da_dataset/office_home/Art/Laptop/00005.jpg 32
34 | /Checkpoint/liangjian/da_dataset/office_home/Art/Marker/00004.jpg 33
35 | /Checkpoint/liangjian/da_dataset/office_home/Art/Monitor/00017.jpg 34
36 | /Checkpoint/liangjian/da_dataset/office_home/Art/Mop/00004.jpg 35
37 | /Checkpoint/liangjian/da_dataset/office_home/Art/Mouse/00002.jpg 36
38 | /Checkpoint/liangjian/da_dataset/office_home/Art/Mug/00033.jpg 37
39 | /Checkpoint/liangjian/da_dataset/office_home/Art/Notebook/00016.jpg 38
40 | /Checkpoint/liangjian/da_dataset/office_home/Art/Oven/00016.jpg 39
41 | /Checkpoint/liangjian/da_dataset/office_home/Art/Pan/00018.jpg 40
42 | /Checkpoint/liangjian/da_dataset/office_home/Art/Paper_Clip/00016.jpg 41
43 | /Checkpoint/liangjian/da_dataset/office_home/Art/Pen/00012.jpg 42
44 | /Checkpoint/liangjian/da_dataset/office_home/Art/Pencil/00005.jpg 43
45 | /Checkpoint/liangjian/da_dataset/office_home/Art/Postit_Notes/00017.jpg 44
46 | /Checkpoint/liangjian/da_dataset/office_home/Art/Printer/00009.jpg 45
47 | /Checkpoint/liangjian/da_dataset/office_home/Art/Push_Pin/00024.jpg 46
48 | /Checkpoint/liangjian/da_dataset/office_home/Art/Radio/00036.jpg 47
49 | /Checkpoint/liangjian/da_dataset/office_home/Art/Refrigerator/00022.jpg 48
50 | /Checkpoint/liangjian/da_dataset/office_home/Art/Ruler/00010.jpg 49
51 | /Checkpoint/liangjian/da_dataset/office_home/Art/Scissors/00019.jpg 50
52 | /Checkpoint/liangjian/da_dataset/office_home/Art/Screwdriver/00007.jpg 51
53 | /Checkpoint/liangjian/da_dataset/office_home/Art/Shelf/00008.jpg 52
54 | /Checkpoint/liangjian/da_dataset/office_home/Art/Sink/00003.jpg 53
55 | /Checkpoint/liangjian/da_dataset/office_home/Art/Sneakers/00017.jpg 54
56 | /Checkpoint/liangjian/da_dataset/office_home/Art/Soda/00001.jpg 55
57 | /Checkpoint/liangjian/da_dataset/office_home/Art/Speaker/00005.jpg 56
58 | /Checkpoint/liangjian/da_dataset/office_home/Art/Spoon/00045.jpg 57
59 | /Checkpoint/liangjian/da_dataset/office_home/Art/Table/00005.jpg 58
60 | /Checkpoint/liangjian/da_dataset/office_home/Art/Telephone/00020.jpg 59
61 | /Checkpoint/liangjian/da_dataset/office_home/Art/ToothBrush/00026.jpg 60
62 | /Checkpoint/liangjian/da_dataset/office_home/Art/Toys/00017.jpg 61
63 | /Checkpoint/liangjian/da_dataset/office_home/Art/Trash_Can/00019.jpg 62
64 | /Checkpoint/liangjian/da_dataset/office_home/Art/TV/00010.jpg 63
65 | /Checkpoint/liangjian/da_dataset/office_home/Art/Webcam/00016.jpg 64
66 |
--------------------------------------------------------------------------------
/code/uda/run_uda.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # id seed
4 |
5 | for ((s=0;s<=2;s++))
6 | python image_source.py --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset office --max_epoch 100 --s $s
7 | python image_mixmatch.py --ps 0.0 --cls_par 0.0 --model source --gpu_id $1 --s $s --output_tar "ckps/s"$2 --output "ckps/mm"$2 --seed $2 --dset office --max_epoch 100
8 |
9 | python image_target.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office
10 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office --max_epoch 100
11 |
12 | python image_target.py --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office
13 | python image_mixmatch.py --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office --max_epoch 100
14 | done
15 |
16 |
17 | for ((s=0;s<=3;s++))
18 | do
19 | python image_source.py --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset office-home --max_epoch 50 --s $s
20 | python image_mixmatch.py --ps 0.0 --cls_par 0.0 --model source --gpu_id $1 --s $s --output_tar "ckps/s"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --max_epoch 50
21 |
22 | python image_target.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-home
23 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --max_epoch 50
24 |
25 | python image_target.py --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset office-home
26 | python image_mixmatch.py --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s $s --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset office-home --max_epoch 50
27 | done
28 |
29 | python image_source.py --gpu_id $1 --seed $2 --output "ckps/res101/s"$2 --dset VISDA-C --max_epoch 10 --s 0 --net resnet101
30 | python image_mixmatch.py --ps 0.0 --cls_par 0.0 --model source --gpu_id $1 --s 0 --t 1 --output_tar "ckps/res101/s"$2 --output "ckps/res101/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10 --net resnet101
31 |
32 | python image_target.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s 0 --t 1 --output_src "ckps/res101/s"$2 --output "ckps/res101/t"$2 --seed $2 --dset VISDA-C --net resnet101
33 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s 0 --t 1 --output_tar "ckps/res101/t"$2 --output "ckps/res101/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10 --net resnet101
34 |
35 | python image_target.py --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s 0 --t 1 --output_src "ckps/res101/s"$2 --output "ckps/res101/t"$2 --seed $2 --dset VISDA-C --net resnet101
36 | python image_mixmatch.py --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s 0 --t 1 --output_tar "ckps/res101/t"$2 --output "ckps/res101/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10 --net resnet101
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 | python image_source.py --gpu_id $1 --seed $2 --output "ckps/s"$2 --dset VISDA-C --max_epoch 10 --s 0
45 | python image_mixmatch.py --ps 0.0 --cls_par 0.0 --model source --gpu_id $1 --s 0 --t 1 --output_tar "ckps/s"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10
46 |
47 | python image_target.py --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s 0 --t 1 --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset VISDA-C
48 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s 0 --t 1 --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10
49 |
50 | python image_target.py --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s 0 --t 1 --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset VISDA-C
51 | python image_mixmatch.py --ps 0.0 --ssl 0.6 --cls_par 0.3 --gpu_id $1 --s 0 --t 1 --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10
52 |
53 | python image_target.py --gent '' --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s 0 --t 1 --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset VISDA-C
54 | python image_mixmatch.py --gent '' --ps 0.0 --ssl 0.0 --cls_par 0.0 --gpu_id $1 --s 0 --t 1 --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10
55 |
56 | python image_target.py --ssl 0.0 --cls_par 0.3 --gpu_id $1 --s 0 --t 1 --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset VISDA-C
57 | python image_mixmatch.py --ps 0.0 --ssl 0.0 --cls_par 0.3 --gpu_id $1 --s 0 --t 1 --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10
58 |
59 | python image_target.py --ssl 0.6 --cls_par 0.0 --gpu_id $1 --s 0 --t 1 --output_src "ckps/s"$2 --output "ckps/t"$2 --seed $2 --dset VISDA-C
60 | python image_mixmatch.py --ps 0.0 --ssl 0.6 --cls_par 0.0 --gpu_id $1 --s 0 --t 1 --output_tar "ckps/t"$2 --output "ckps/mm"$2 --seed $2 --dset VISDA-C --max_epoch 10
61 |
--------------------------------------------------------------------------------
/code/data/ssda/office-home/labeled_target_images_Clipart_1.txt:
--------------------------------------------------------------------------------
1 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Alarm_Clock/00032.jpg 0
2 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Backpack/00052.jpg 1
3 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Batteries/00043.jpg 2
4 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Bed/00073.jpg 3
5 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Bike/00023.jpg 4
6 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Bottle/00059.jpg 5
7 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Bucket/00006.jpg 6
8 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Calculator/00007.jpg 7
9 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Calendar/00009.jpg 8
10 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Candles/00004.jpg 9
11 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Chair/00060.jpg 10
12 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Clipboards/00033.jpg 11
13 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Computer/00036.jpg 12
14 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Couch/00063.jpg 13
15 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Curtains/00032.jpg 14
16 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Desk_Lamp/00004.jpg 15
17 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Drill/00035.jpg 16
18 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Eraser/00015.jpg 17
19 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Exit_Sign/00032.jpg 18
20 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Fan/00027.jpg 19
21 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/File_Cabinet/00030.jpg 20
22 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Flipflops/00012.jpg 21
23 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Flowers/00039.jpg 22
24 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Folder/00095.jpg 23
25 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Fork/00018.jpg 24
26 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Glasses/00014.jpg 25
27 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Hammer/00098.jpg 26
28 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Helmet/00051.jpg 27
29 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Kettle/00027.jpg 28
30 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Keyboard/00053.jpg 29
31 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Knives/00037.jpg 30
32 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Lamp_Shade/00031.jpg 31
33 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Laptop/00061.jpg 32
34 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Marker/00065.jpg 33
35 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Monitor/00089.jpg 34
36 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Mop/00032.jpg 35
37 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Mouse/00044.jpg 36
38 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Mug/00088.jpg 37
39 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Notebook/00065.jpg 38
40 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Oven/00008.jpg 39
41 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Pan/00035.jpg 40
42 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Paper_Clip/00015.jpg 41
43 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Pen/00004.jpg 42
44 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Pencil/00072.jpg 43
45 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Postit_Notes/00002.jpg 44
46 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Printer/00056.jpg 45
47 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Push_Pin/00025.jpg 46
48 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Radio/00030.jpg 47
49 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Refrigerator/00017.jpg 48
50 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Ruler/00050.jpg 49
51 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Scissors/00073.jpg 50
52 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Screwdriver/00062.jpg 51
53 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Shelf/00021.jpg 52
54 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Sink/00023.jpg 53
55 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Sneakers/00019.jpg 54
56 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Soda/00029.jpg 55
57 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Speaker/00037.jpg 56
58 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Spoon/00046.jpg 57
59 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Table/00041.jpg 58
60 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Telephone/00048.jpg 59
61 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/ToothBrush/00007.jpg 60
62 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Toys/00021.jpg 61
63 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Trash_Can/00003.jpg 62
64 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/TV/00012.jpg 63
65 | /Checkpoint/liangjian/da_dataset/office_home/Clipart/Webcam/00006.jpg 64
66 |
--------------------------------------------------------------------------------
/code/data/ssda/office-home/labeled_target_images_Product_1.txt:
--------------------------------------------------------------------------------
1 | /Checkpoint/liangjian/da_dataset/office_home/Product/Alarm_Clock/00003.jpg 0
2 | /Checkpoint/liangjian/da_dataset/office_home/Product/Backpack/00095.jpg 1
3 | /Checkpoint/liangjian/da_dataset/office_home/Product/Batteries/00050.jpg 2
4 | /Checkpoint/liangjian/da_dataset/office_home/Product/Bed/00008.jpg 3
5 | /Checkpoint/liangjian/da_dataset/office_home/Product/Bike/00004.jpg 4
6 | /Checkpoint/liangjian/da_dataset/office_home/Product/Bottle/00028.jpg 5
7 | /Checkpoint/liangjian/da_dataset/office_home/Product/Bucket/00019.jpg 6
8 | /Checkpoint/liangjian/da_dataset/office_home/Product/Calculator/00014.jpg 7
9 | /Checkpoint/liangjian/da_dataset/office_home/Product/Calendar/00067.jpg 8
10 | /Checkpoint/liangjian/da_dataset/office_home/Product/Candles/00052.jpg 9
11 | /Checkpoint/liangjian/da_dataset/office_home/Product/Chair/00038.jpg 10
12 | /Checkpoint/liangjian/da_dataset/office_home/Product/Clipboards/00049.jpg 11
13 | /Checkpoint/liangjian/da_dataset/office_home/Product/Computer/00081.jpg 12
14 | /Checkpoint/liangjian/da_dataset/office_home/Product/Couch/00039.jpg 13
15 | /Checkpoint/liangjian/da_dataset/office_home/Product/Curtains/00058.jpg 14
16 | /Checkpoint/liangjian/da_dataset/office_home/Product/Desk_Lamp/00036.jpg 15
17 | /Checkpoint/liangjian/da_dataset/office_home/Product/Drill/00063.jpg 16
18 | /Checkpoint/liangjian/da_dataset/office_home/Product/Eraser/00027.jpg 17
19 | /Checkpoint/liangjian/da_dataset/office_home/Product/Exit_Sign/00004.jpg 18
20 | /Checkpoint/liangjian/da_dataset/office_home/Product/Fan/00045.jpg 19
21 | /Checkpoint/liangjian/da_dataset/office_home/Product/File_Cabinet/00058.jpg 20
22 | /Checkpoint/liangjian/da_dataset/office_home/Product/Flipflops/00039.jpg 21
23 | /Checkpoint/liangjian/da_dataset/office_home/Product/Flowers/00079.jpg 22
24 | /Checkpoint/liangjian/da_dataset/office_home/Product/Folder/00034.jpg 23
25 | /Checkpoint/liangjian/da_dataset/office_home/Product/Fork/00022.jpg 24
26 | /Checkpoint/liangjian/da_dataset/office_home/Product/Glasses/00053.jpg 25
27 | /Checkpoint/liangjian/da_dataset/office_home/Product/Hammer/00005.jpg 26
28 | /Checkpoint/liangjian/da_dataset/office_home/Product/Helmet/00046.jpg 27
29 | /Checkpoint/liangjian/da_dataset/office_home/Product/Kettle/00062.jpg 28
30 | /Checkpoint/liangjian/da_dataset/office_home/Product/Keyboard/00033.jpg 29
31 | /Checkpoint/liangjian/da_dataset/office_home/Product/Knives/00039.jpg 30
32 | /Checkpoint/liangjian/da_dataset/office_home/Product/Lamp_Shade/00017.jpg 31
33 | /Checkpoint/liangjian/da_dataset/office_home/Product/Laptop/00048.jpg 32
34 | /Checkpoint/liangjian/da_dataset/office_home/Product/Marker/00003.jpg 33
35 | /Checkpoint/liangjian/da_dataset/office_home/Product/Monitor/00013.jpg 34
36 | /Checkpoint/liangjian/da_dataset/office_home/Product/Mop/00036.jpg 35
37 | /Checkpoint/liangjian/da_dataset/office_home/Product/Mouse/00023.jpg 36
38 | /Checkpoint/liangjian/da_dataset/office_home/Product/Mug/00039.jpg 37
39 | /Checkpoint/liangjian/da_dataset/office_home/Product/Notebook/00010.jpg 38
40 | /Checkpoint/liangjian/da_dataset/office_home/Product/Oven/00016.jpg 39
41 | /Checkpoint/liangjian/da_dataset/office_home/Product/Pan/00053.jpg 40
42 | /Checkpoint/liangjian/da_dataset/office_home/Product/Paper_Clip/00002.jpg 41
43 | /Checkpoint/liangjian/da_dataset/office_home/Product/Pen/00048.jpg 42
44 | /Checkpoint/liangjian/da_dataset/office_home/Product/Pencil/00002.jpg 43
45 | /Checkpoint/liangjian/da_dataset/office_home/Product/Postit_Notes/00001.jpg 44
46 | /Checkpoint/liangjian/da_dataset/office_home/Product/Printer/00018.jpg 45
47 | /Checkpoint/liangjian/da_dataset/office_home/Product/Push_Pin/00039.jpg 46
48 | /Checkpoint/liangjian/da_dataset/office_home/Product/Radio/00011.jpg 47
49 | /Checkpoint/liangjian/da_dataset/office_home/Product/Refrigerator/00024.jpg 48
50 | /Checkpoint/liangjian/da_dataset/office_home/Product/Ruler/00001.jpg 49
51 | /Checkpoint/liangjian/da_dataset/office_home/Product/Scissors/00072.jpg 50
52 | /Checkpoint/liangjian/da_dataset/office_home/Product/Screwdriver/00015.jpg 51
53 | /Checkpoint/liangjian/da_dataset/office_home/Product/Shelf/00037.jpg 52
54 | /Checkpoint/liangjian/da_dataset/office_home/Product/Sink/00024.jpg 53
55 | /Checkpoint/liangjian/da_dataset/office_home/Product/Sneakers/00061.jpg 54
56 | /Checkpoint/liangjian/da_dataset/office_home/Product/Soda/00041.jpg 55
57 | /Checkpoint/liangjian/da_dataset/office_home/Product/Speaker/00071.jpg 56
58 | /Checkpoint/liangjian/da_dataset/office_home/Product/Spoon/00038.jpg 57
59 | /Checkpoint/liangjian/da_dataset/office_home/Product/Table/00043.jpg 58
60 | /Checkpoint/liangjian/da_dataset/office_home/Product/Telephone/00011.jpg 59
61 | /Checkpoint/liangjian/da_dataset/office_home/Product/ToothBrush/00002.jpg 60
62 | /Checkpoint/liangjian/da_dataset/office_home/Product/Toys/00037.jpg 61
63 | /Checkpoint/liangjian/da_dataset/office_home/Product/Trash_Can/00090.jpg 62
64 | /Checkpoint/liangjian/da_dataset/office_home/Product/TV/00029.jpg 63
65 | /Checkpoint/liangjian/da_dataset/office_home/Product/Webcam/00064.jpg 64
66 |
--------------------------------------------------------------------------------
/code/data/ssda/office-home/labeled_target_images_Real_1.txt:
--------------------------------------------------------------------------------
1 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Alarm_Clock/00085.jpg 0
2 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Backpack/00062.jpg 1
3 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Batteries/00041.jpg 2
4 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Bed/00055.jpg 3
5 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Bike/00056.jpg 4
6 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Bottle/00078.jpg 5
7 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Bucket/00064.jpg 6
8 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Calculator/00044.jpg 7
9 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Calendar/00064.jpg 8
10 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Candles/00097.jpg 9
11 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Chair/00069.jpg 10
12 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Clipboards/00040.jpg 11
13 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Computer/00004.jpg 12
14 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Couch/00037.jpg 13
15 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Curtains/00001.jpg 14
16 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Desk_Lamp/00003.jpg 15
17 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Drill/00038.jpg 16
18 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Eraser/00040.jpg 17
19 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Exit_Sign/00014.jpg 18
20 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Fan/00050.jpg 19
21 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/File_Cabinet/00032.jpg 20
22 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Flipflops/00077.jpg 21
23 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Flowers/00028.jpg 22
24 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Folder/00020.jpg 23
25 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Fork/00011.jpg 24
26 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Glasses/00031.jpg 25
27 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Hammer/00022.jpg 26
28 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Helmet/00034.jpg 27
29 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Kettle/00055.jpg 28
30 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Keyboard/00064.jpg 29
31 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Knives/00010.jpg 30
32 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Lamp_Shade/00026.jpg 31
33 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Laptop/00051.jpg 32
34 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Marker/00009.jpg 33
35 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Monitor/00055.jpg 34
36 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Mop/00044.jpg 35
37 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Mouse/00028.jpg 36
38 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Mug/00042.jpg 37
39 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Notebook/00013.jpg 38
40 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Oven/00036.jpg 39
41 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Pan/00011.jpg 40
42 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Paper_Clip/00034.jpg 41
43 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Pen/00064.jpg 42
44 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Pencil/00046.jpg 43
45 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Postit_Notes/00064.jpg 44
46 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Printer/00027.jpg 45
47 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Push_Pin/00034.jpg 46
48 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Radio/00003.jpg 47
49 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Refrigerator/00039.jpg 48
50 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Ruler/00035.jpg 49
51 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Scissors/00036.jpg 50
52 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Screwdriver/00038.jpg 51
53 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Shelf/00018.jpg 52
54 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Sink/00063.jpg 53
55 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Sneakers/00072.jpg 54
56 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Soda/00052.jpg 55
57 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Speaker/00014.jpg 56
58 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Spoon/00019.jpg 57
59 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Table/00025.jpg 58
60 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Telephone/00041.jpg 59
61 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/ToothBrush/00051.jpg 60
62 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Toys/00002.jpg 61
63 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Trash_Can/00077.jpg 62
64 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/TV/00015.jpg 63
65 | /Checkpoint/liangjian/da_dataset/office_home/RealWorld/Webcam/00033.jpg 64
66 |
--------------------------------------------------------------------------------
/code/msda/data_list.py:
--------------------------------------------------------------------------------
1 | #from __future__ import print_function, division
2 |
3 | import torch
4 | import numpy as np
5 | import random
6 | from PIL import Image
7 | from torch.utils.data import Dataset
8 | import os
9 | import os.path
10 |
11 | import cv2
12 | import torchvision
13 |
14 | def make_dataset(image_list, labels):
15 | if labels:
16 | len_ = len(image_list)
17 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
18 | else:
19 | if len(image_list[0].split()) > 2:
20 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
21 | else:
22 | images = [(val.split()[0], int(val.split()[1])) for val in image_list]
23 | return images
24 |
25 |
26 | def rgb_loader(path):
27 | with open(path, 'rb') as f:
28 | with Image.open(f) as img:
29 | return img.convert('RGB')
30 |
31 | def l_loader(path):
32 | with open(path, 'rb') as f:
33 | with Image.open(f) as img:
34 | return img.convert('L')
35 |
36 | class ImageList(Dataset):
37 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
38 | imgs = make_dataset(image_list, labels)
39 | if len(imgs) == 0:
40 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
41 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
42 |
43 | self.imgs = imgs
44 | self.transform = transform
45 | self.target_transform = target_transform
46 | if mode == 'RGB':
47 | self.loader = rgb_loader
48 | elif mode == 'L':
49 | self.loader = l_loader
50 |
51 | def __getitem__(self, index):
52 | path, target = self.imgs[index]
53 | img = self.loader(path)
54 | if self.transform is not None:
55 | img = self.transform(img)
56 | if self.target_transform is not None:
57 | target = self.target_transform(target)
58 |
59 | return img, target
60 |
61 | def __len__(self):
62 | return len(self.imgs)
63 |
64 |
65 | class ImageList_twice(Dataset):
66 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
67 | imgs = make_dataset(image_list, labels)
68 | if len(imgs) == 0:
69 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
70 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
71 |
72 | self.imgs = imgs
73 | self.transform = transform
74 | self.target_transform = target_transform
75 | if mode == 'RGB':
76 | self.loader = rgb_loader
77 | elif mode == 'L':
78 | self.loader = l_loader
79 |
80 | def __getitem__(self, index):
81 | path, target = self.imgs[index]
82 | img = self.loader(path)
83 | if self.target_transform is not None:
84 | target = self.target_transform(target)
85 | if self.transform is not None:
86 | if type(self.transform).__name__=='list':
87 | img = [t(img) for t in self.transform]
88 | else:
89 | img = self.transform(img)
90 |
91 | return img, target, index
92 |
93 | def __len__(self):
94 | return len(self.imgs)
95 |
96 | class ImageList_idx(Dataset):
97 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
98 | imgs = make_dataset(image_list, labels)
99 | if len(imgs) == 0:
100 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
101 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
102 |
103 | self.imgs = imgs
104 | self.transform = transform
105 | self.target_transform = target_transform
106 | if mode == 'RGB':
107 | self.loader = rgb_loader
108 | elif mode == 'L':
109 | self.loader = l_loader
110 |
111 | def __getitem__(self, index):
112 | path, target = self.imgs[index]
113 | img = self.loader(path)
114 | if self.transform is not None:
115 | img = self.transform(img)
116 | if self.target_transform is not None:
117 | target = self.target_transform(target)
118 |
119 | return img, target, index
120 |
121 | def __len__(self):
122 | return len(self.imgs)
123 |
124 |
125 | class ImageValueList(Dataset):
126 | def __init__(self, image_list, labels=None, transform=None, target_transform=None,
127 | loader=rgb_loader):
128 | imgs = make_dataset(image_list, labels)
129 | if len(imgs) == 0:
130 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
131 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
132 |
133 | self.imgs = imgs
134 | self.values = [1.0] * len(imgs)
135 | self.transform = transform
136 | self.target_transform = target_transform
137 | self.loader = loader
138 |
139 | def set_values(self, values):
140 | self.values = values
141 |
142 | def __getitem__(self, index):
143 | path, target = self.imgs[index]
144 | img = self.loader(path)
145 | if self.transform is not None:
146 | img = self.transform(img)
147 | if self.target_transform is not None:
148 | target = self.target_transform(target)
149 |
150 | return img, target
151 |
152 | def __len__(self):
153 | return len(self.imgs)
154 |
155 | class alexnetlist(Dataset):
156 |
157 | def __init__(self, list, training=True):
158 | self.images = []
159 | self.labels = []
160 | self.multi_scale = [256, 257]
161 | self.output_size = [227, 227]
162 | self.training = training
163 | self.mean_color=[104.006, 116.668, 122.678]
164 |
165 | list_file = open(list)
166 | lines = list_file.readlines()
167 | for line in lines:
168 | fields = line.split()
169 | self.images.append(fields[0])
170 | self.labels.append(int(fields[1]))
171 |
172 | def __len__(self):
173 | return len(self.images)
174 |
175 | def __getitem__(self, index):
176 | image_path = self.images[index]
177 | label = self.labels[index]
178 | img = cv2.imread(image_path)
179 | if type(img) == None:
180 | print('Error: Image at {} not found.'.format(image_path))
181 |
182 | if self.training and np.random.random() < 0.5:
183 | img = cv2.flip(img, 1)
184 | new_size = np.random.randint(self.multi_scale[0], self.multi_scale[1], 1)[0]
185 |
186 | img = cv2.resize(img, (new_size, new_size))
187 | img = img.astype(np.float32)
188 |
189 | # cropping
190 | if self.training:
191 | diff = new_size - self.output_size[0]
192 | offset_x = np.random.randint(0, diff, 1)[0]
193 | offset_y = np.random.randint(0, diff, 1)[0]
194 | else:
195 | offset_x = img.shape[0]//2 - self.output_size[0] // 2
196 | offset_y = img.shape[1]//2 - self.output_size[1] // 2
197 |
198 | img = img[offset_x:(offset_x+self.output_size[0]),
199 | offset_y:(offset_y+self.output_size[1])]
200 |
201 | # substract mean
202 | img -= np.array(self.mean_color)
203 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
204 |
205 | # ToTensor transform cv2 HWC->CHW, only byteTensor will be div by 255.
206 | tensor = torchvision.transforms.ToTensor()
207 | img = tensor(img)
208 | # img = np.transpose(img, (2, 0, 1))
209 |
210 | return img, label
--------------------------------------------------------------------------------
/code/pda/data_list.py:
--------------------------------------------------------------------------------
1 | #from __future__ import print_function, division
2 |
3 | import torch
4 | import numpy as np
5 | import random
6 | from PIL import Image
7 | from torch.utils.data import Dataset
8 | import os
9 | import os.path
10 |
11 | import cv2
12 | import torchvision
13 |
14 | def make_dataset(image_list, labels):
15 | if labels:
16 | len_ = len(image_list)
17 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
18 | else:
19 | if len(image_list[0].split()) > 2:
20 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
21 | else:
22 | images = [(val.split()[0], int(val.split()[1])) for val in image_list]
23 | return images
24 |
25 |
26 | def rgb_loader(path):
27 | with open(path, 'rb') as f:
28 | with Image.open(f) as img:
29 | return img.convert('RGB')
30 |
31 | def l_loader(path):
32 | with open(path, 'rb') as f:
33 | with Image.open(f) as img:
34 | return img.convert('L')
35 |
36 | class ImageList(Dataset):
37 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
38 | imgs = make_dataset(image_list, labels)
39 | if len(imgs) == 0:
40 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
41 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
42 |
43 | self.imgs = imgs
44 | self.transform = transform
45 | self.target_transform = target_transform
46 | if mode == 'RGB':
47 | self.loader = rgb_loader
48 | elif mode == 'L':
49 | self.loader = l_loader
50 |
51 | def __getitem__(self, index):
52 | path, target = self.imgs[index]
53 | img = self.loader(path)
54 | if self.transform is not None:
55 | img = self.transform(img)
56 | if self.target_transform is not None:
57 | target = self.target_transform(target)
58 |
59 | return img, target
60 |
61 | def __len__(self):
62 | return len(self.imgs)
63 |
64 |
65 | class ImageList_twice(Dataset):
66 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
67 | imgs = make_dataset(image_list, labels)
68 | if len(imgs) == 0:
69 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
70 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
71 |
72 | self.imgs = imgs
73 | self.transform = transform
74 | self.target_transform = target_transform
75 | if mode == 'RGB':
76 | self.loader = rgb_loader
77 | elif mode == 'L':
78 | self.loader = l_loader
79 |
80 | def __getitem__(self, index):
81 | path, target = self.imgs[index]
82 | img = self.loader(path)
83 | if self.target_transform is not None:
84 | target = self.target_transform(target)
85 | if self.transform is not None:
86 | if type(self.transform).__name__=='list':
87 | img = [t(img) for t in self.transform]
88 | else:
89 | img = self.transform(img)
90 |
91 | return img, target, index
92 |
93 | def __len__(self):
94 | return len(self.imgs)
95 |
96 | class ImageList_idx(Dataset):
97 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
98 | imgs = make_dataset(image_list, labels)
99 | if len(imgs) == 0:
100 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
101 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
102 |
103 | self.imgs = imgs
104 | self.transform = transform
105 | self.target_transform = target_transform
106 | if mode == 'RGB':
107 | self.loader = rgb_loader
108 | elif mode == 'L':
109 | self.loader = l_loader
110 |
111 | def __getitem__(self, index):
112 | path, target = self.imgs[index]
113 | img = self.loader(path)
114 | if self.transform is not None:
115 | img = self.transform(img)
116 | if self.target_transform is not None:
117 | target = self.target_transform(target)
118 |
119 | return img, target, index
120 |
121 | def __len__(self):
122 | return len(self.imgs)
123 |
124 |
125 | class ImageValueList(Dataset):
126 | def __init__(self, image_list, labels=None, transform=None, target_transform=None,
127 | loader=rgb_loader):
128 | imgs = make_dataset(image_list, labels)
129 | if len(imgs) == 0:
130 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
131 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
132 |
133 | self.imgs = imgs
134 | self.values = [1.0] * len(imgs)
135 | self.transform = transform
136 | self.target_transform = target_transform
137 | self.loader = loader
138 |
139 | def set_values(self, values):
140 | self.values = values
141 |
142 | def __getitem__(self, index):
143 | path, target = self.imgs[index]
144 | img = self.loader(path)
145 | if self.transform is not None:
146 | img = self.transform(img)
147 | if self.target_transform is not None:
148 | target = self.target_transform(target)
149 |
150 | return img, target
151 |
152 | def __len__(self):
153 | return len(self.imgs)
154 |
155 | class alexnetlist(Dataset):
156 |
157 | def __init__(self, list, training=True):
158 | self.images = []
159 | self.labels = []
160 | self.multi_scale = [256, 257]
161 | self.output_size = [227, 227]
162 | self.training = training
163 | self.mean_color=[104.006, 116.668, 122.678]
164 |
165 | list_file = open(list)
166 | lines = list_file.readlines()
167 | for line in lines:
168 | fields = line.split()
169 | self.images.append(fields[0])
170 | self.labels.append(int(fields[1]))
171 |
172 | def __len__(self):
173 | return len(self.images)
174 |
175 | def __getitem__(self, index):
176 | image_path = self.images[index]
177 | label = self.labels[index]
178 | img = cv2.imread(image_path)
179 | if type(img) == None:
180 | print('Error: Image at {} not found.'.format(image_path))
181 |
182 | if self.training and np.random.random() < 0.5:
183 | img = cv2.flip(img, 1)
184 | new_size = np.random.randint(self.multi_scale[0], self.multi_scale[1], 1)[0]
185 |
186 | img = cv2.resize(img, (new_size, new_size))
187 | img = img.astype(np.float32)
188 |
189 | # cropping
190 | if self.training:
191 | diff = new_size - self.output_size[0]
192 | offset_x = np.random.randint(0, diff, 1)[0]
193 | offset_y = np.random.randint(0, diff, 1)[0]
194 | else:
195 | offset_x = img.shape[0]//2 - self.output_size[0] // 2
196 | offset_y = img.shape[1]//2 - self.output_size[1] // 2
197 |
198 | img = img[offset_x:(offset_x+self.output_size[0]),
199 | offset_y:(offset_y+self.output_size[1])]
200 |
201 | # substract mean
202 | img -= np.array(self.mean_color)
203 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
204 |
205 | # ToTensor transform cv2 HWC->CHW, only byteTensor will be div by 255.
206 | tensor = torchvision.transforms.ToTensor()
207 | img = tensor(img)
208 | # img = np.transpose(img, (2, 0, 1))
209 |
210 | return img, label
--------------------------------------------------------------------------------
/code/ssda/data_list.py:
--------------------------------------------------------------------------------
1 | #from __future__ import print_function, division
2 |
3 | import torch
4 | import numpy as np
5 | import random
6 | from PIL import Image
7 | from torch.utils.data import Dataset
8 | import os
9 | import os.path
10 |
11 | import cv2
12 | import torchvision
13 |
14 | def make_dataset(image_list, labels):
15 | if labels:
16 | len_ = len(image_list)
17 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
18 | else:
19 | if len(image_list[0].split()) > 2:
20 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
21 | else:
22 | images = [(val.split()[0], int(val.split()[1])) for val in image_list]
23 | return images
24 |
25 |
26 | def rgb_loader(path):
27 | with open(path, 'rb') as f:
28 | with Image.open(f) as img:
29 | return img.convert('RGB')
30 |
31 | def l_loader(path):
32 | with open(path, 'rb') as f:
33 | with Image.open(f) as img:
34 | return img.convert('L')
35 |
36 | class ImageList(Dataset):
37 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
38 | imgs = make_dataset(image_list, labels)
39 | if len(imgs) == 0:
40 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
41 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
42 |
43 | self.imgs = imgs
44 | self.transform = transform
45 | self.target_transform = target_transform
46 | if mode == 'RGB':
47 | self.loader = rgb_loader
48 | elif mode == 'L':
49 | self.loader = l_loader
50 |
51 | def __getitem__(self, index):
52 | path, target = self.imgs[index]
53 | img = self.loader(path)
54 | if self.transform is not None:
55 | img = self.transform(img)
56 | if self.target_transform is not None:
57 | target = self.target_transform(target)
58 |
59 | return img, target
60 |
61 | def __len__(self):
62 | return len(self.imgs)
63 |
64 |
65 | class ImageList_twice(Dataset):
66 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
67 | imgs = make_dataset(image_list, labels)
68 | if len(imgs) == 0:
69 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
70 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
71 |
72 | self.imgs = imgs
73 | self.transform = transform
74 | self.target_transform = target_transform
75 | if mode == 'RGB':
76 | self.loader = rgb_loader
77 | elif mode == 'L':
78 | self.loader = l_loader
79 |
80 | def __getitem__(self, index):
81 | path, target = self.imgs[index]
82 | img = self.loader(path)
83 | if self.target_transform is not None:
84 | target = self.target_transform(target)
85 | if self.transform is not None:
86 | if type(self.transform).__name__=='list':
87 | img = [t(img) for t in self.transform]
88 | else:
89 | img = self.transform(img)
90 |
91 | return img, target, index
92 |
93 | def __len__(self):
94 | return len(self.imgs)
95 |
96 | class ImageList_idx(Dataset):
97 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
98 | imgs = make_dataset(image_list, labels)
99 | if len(imgs) == 0:
100 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
101 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
102 |
103 | self.imgs = imgs
104 | self.transform = transform
105 | self.target_transform = target_transform
106 | if mode == 'RGB':
107 | self.loader = rgb_loader
108 | elif mode == 'L':
109 | self.loader = l_loader
110 |
111 | def __getitem__(self, index):
112 | path, target = self.imgs[index]
113 | img = self.loader(path)
114 | if self.transform is not None:
115 | img = self.transform(img)
116 | if self.target_transform is not None:
117 | target = self.target_transform(target)
118 |
119 | return img, target, index
120 |
121 | def __len__(self):
122 | return len(self.imgs)
123 |
124 |
125 | class ImageValueList(Dataset):
126 | def __init__(self, image_list, labels=None, transform=None, target_transform=None,
127 | loader=rgb_loader):
128 | imgs = make_dataset(image_list, labels)
129 | if len(imgs) == 0:
130 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
131 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
132 |
133 | self.imgs = imgs
134 | self.values = [1.0] * len(imgs)
135 | self.transform = transform
136 | self.target_transform = target_transform
137 | self.loader = loader
138 |
139 | def set_values(self, values):
140 | self.values = values
141 |
142 | def __getitem__(self, index):
143 | path, target = self.imgs[index]
144 | img = self.loader(path)
145 | if self.transform is not None:
146 | img = self.transform(img)
147 | if self.target_transform is not None:
148 | target = self.target_transform(target)
149 |
150 | return img, target
151 |
152 | def __len__(self):
153 | return len(self.imgs)
154 |
155 | class alexnetlist(Dataset):
156 |
157 | def __init__(self, list, training=True):
158 | self.images = []
159 | self.labels = []
160 | self.multi_scale = [256, 257]
161 | self.output_size = [227, 227]
162 | self.training = training
163 | self.mean_color=[104.006, 116.668, 122.678]
164 |
165 | list_file = open(list)
166 | lines = list_file.readlines()
167 | for line in lines:
168 | fields = line.split()
169 | self.images.append(fields[0])
170 | self.labels.append(int(fields[1]))
171 |
172 | def __len__(self):
173 | return len(self.images)
174 |
175 | def __getitem__(self, index):
176 | image_path = self.images[index]
177 | label = self.labels[index]
178 | img = cv2.imread(image_path)
179 | if type(img) == None:
180 | print('Error: Image at {} not found.'.format(image_path))
181 |
182 | if self.training and np.random.random() < 0.5:
183 | img = cv2.flip(img, 1)
184 | new_size = np.random.randint(self.multi_scale[0], self.multi_scale[1], 1)[0]
185 |
186 | img = cv2.resize(img, (new_size, new_size))
187 | img = img.astype(np.float32)
188 |
189 | # cropping
190 | if self.training:
191 | diff = new_size - self.output_size[0]
192 | offset_x = np.random.randint(0, diff, 1)[0]
193 | offset_y = np.random.randint(0, diff, 1)[0]
194 | else:
195 | offset_x = img.shape[0]//2 - self.output_size[0] // 2
196 | offset_y = img.shape[1]//2 - self.output_size[1] // 2
197 |
198 | img = img[offset_x:(offset_x+self.output_size[0]),
199 | offset_y:(offset_y+self.output_size[1])]
200 |
201 | # substract mean
202 | img -= np.array(self.mean_color)
203 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
204 |
205 | # ToTensor transform cv2 HWC->CHW, only byteTensor will be div by 255.
206 | tensor = torchvision.transforms.ToTensor()
207 | img = tensor(img)
208 | # img = np.transpose(img, (2, 0, 1))
209 |
210 | return img, label
--------------------------------------------------------------------------------
/code/uda/data_list.py:
--------------------------------------------------------------------------------
1 | #from __future__ import print_function, division
2 |
3 | import torch
4 | import numpy as np
5 | import random
6 | from PIL import Image
7 | from torch.utils.data import Dataset
8 | import os
9 | import os.path
10 |
11 | import cv2
12 | import torchvision
13 |
14 | def make_dataset(image_list, labels):
15 | if labels:
16 | len_ = len(image_list)
17 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
18 | else:
19 | if len(image_list[0].split()) > 2:
20 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
21 | else:
22 | images = [(val.split()[0], int(val.split()[1])) for val in image_list]
23 | return images
24 |
25 |
26 | def rgb_loader(path):
27 | with open(path, 'rb') as f:
28 | with Image.open(f) as img:
29 | return img.convert('RGB')
30 |
31 | def l_loader(path):
32 | with open(path, 'rb') as f:
33 | with Image.open(f) as img:
34 | return img.convert('L')
35 |
36 | class ImageList(Dataset):
37 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
38 | imgs = make_dataset(image_list, labels)
39 | if len(imgs) == 0:
40 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
41 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
42 |
43 | self.imgs = imgs
44 | self.transform = transform
45 | self.target_transform = target_transform
46 | if mode == 'RGB':
47 | self.loader = rgb_loader
48 | elif mode == 'L':
49 | self.loader = l_loader
50 |
51 | def __getitem__(self, index):
52 | path, target = self.imgs[index]
53 | img = self.loader(path)
54 | if self.transform is not None:
55 | img = self.transform(img)
56 | if self.target_transform is not None:
57 | target = self.target_transform(target)
58 |
59 | return img, target
60 |
61 | def __len__(self):
62 | return len(self.imgs)
63 |
64 |
65 | class ImageList_twice(Dataset):
66 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
67 | imgs = make_dataset(image_list, labels)
68 | if len(imgs) == 0:
69 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
70 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
71 |
72 | self.imgs = imgs
73 | self.transform = transform
74 | self.target_transform = target_transform
75 | if mode == 'RGB':
76 | self.loader = rgb_loader
77 | elif mode == 'L':
78 | self.loader = l_loader
79 |
80 | def __getitem__(self, index):
81 | path, target = self.imgs[index]
82 | img = self.loader(path)
83 | if self.target_transform is not None:
84 | target = self.target_transform(target)
85 | if self.transform is not None:
86 | if type(self.transform).__name__=='list':
87 | img = [t(img) for t in self.transform]
88 | else:
89 | img = self.transform(img)
90 |
91 | return img, target, index
92 |
93 | def __len__(self):
94 | return len(self.imgs)
95 |
96 | class ImageList_idx(Dataset):
97 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
98 | imgs = make_dataset(image_list, labels)
99 | if len(imgs) == 0:
100 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
101 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
102 |
103 | self.imgs = imgs
104 | self.transform = transform
105 | self.target_transform = target_transform
106 | if mode == 'RGB':
107 | self.loader = rgb_loader
108 | elif mode == 'L':
109 | self.loader = l_loader
110 |
111 | def __getitem__(self, index):
112 | path, target = self.imgs[index]
113 | img = self.loader(path)
114 | if self.transform is not None:
115 | img = self.transform(img)
116 | if self.target_transform is not None:
117 | target = self.target_transform(target)
118 |
119 | return img, target, index
120 |
121 | def __len__(self):
122 | return len(self.imgs)
123 |
124 |
125 | class ImageValueList(Dataset):
126 | def __init__(self, image_list, labels=None, transform=None, target_transform=None,
127 | loader=rgb_loader):
128 | imgs = make_dataset(image_list, labels)
129 | if len(imgs) == 0:
130 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
131 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
132 |
133 | self.imgs = imgs
134 | self.values = [1.0] * len(imgs)
135 | self.transform = transform
136 | self.target_transform = target_transform
137 | self.loader = loader
138 |
139 | def set_values(self, values):
140 | self.values = values
141 |
142 | def __getitem__(self, index):
143 | path, target = self.imgs[index]
144 | img = self.loader(path)
145 | if self.transform is not None:
146 | img = self.transform(img)
147 | if self.target_transform is not None:
148 | target = self.target_transform(target)
149 |
150 | return img, target
151 |
152 | def __len__(self):
153 | return len(self.imgs)
154 |
155 | class alexnetlist(Dataset):
156 |
157 | def __init__(self, list, training=True):
158 | self.images = []
159 | self.labels = []
160 | self.multi_scale = [256, 257]
161 | self.output_size = [227, 227]
162 | self.training = training
163 | self.mean_color=[104.006, 116.668, 122.678]
164 |
165 | list_file = open(list)
166 | lines = list_file.readlines()
167 | for line in lines:
168 | fields = line.split()
169 | self.images.append(fields[0])
170 | self.labels.append(int(fields[1]))
171 |
172 | def __len__(self):
173 | return len(self.images)
174 |
175 | def __getitem__(self, index):
176 | image_path = self.images[index]
177 | label = self.labels[index]
178 | img = cv2.imread(image_path)
179 | if type(img) == None:
180 | print('Error: Image at {} not found.'.format(image_path))
181 |
182 | if self.training and np.random.random() < 0.5:
183 | img = cv2.flip(img, 1)
184 | new_size = np.random.randint(self.multi_scale[0], self.multi_scale[1], 1)[0]
185 |
186 | img = cv2.resize(img, (new_size, new_size))
187 | img = img.astype(np.float32)
188 |
189 | # cropping
190 | if self.training:
191 | diff = new_size - self.output_size[0]
192 | offset_x = np.random.randint(0, diff, 1)[0]
193 | offset_y = np.random.randint(0, diff, 1)[0]
194 | else:
195 | offset_x = img.shape[0]//2 - self.output_size[0] // 2
196 | offset_y = img.shape[1]//2 - self.output_size[1] // 2
197 |
198 | img = img[offset_x:(offset_x+self.output_size[0]),
199 | offset_y:(offset_y+self.output_size[1])]
200 |
201 | # substract mean
202 | img -= np.array(self.mean_color)
203 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
204 |
205 | # ToTensor transform cv2 HWC->CHW, only byteTensor will be div by 255.
206 | tensor = torchvision.transforms.ToTensor()
207 | img = tensor(img)
208 | # img = np.transpose(img, (2, 0, 1))
209 |
210 | return img, label
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Official implementation for **SHOT++**
2 |
3 | ## [**[TPAMI-2021] Source Data-absent Unsupervised Domain Adaptation through Hypothesis Transfer and Labeling Transfer**](https://ieeexplore.ieee.org/abstract/document/9512429/)
4 |
5 | ### Framework:
6 |
7 | 1. train on the source domain; (Section 3.1)
8 | 2. **hypothesis transfer with information maximization and *self-supervised learning*; (Section 3.2 & Section 3.3)**
9 | (note that SHOT here means results after step 2, which contains an additional rotation-driven self-supervised objective compared with the original SHOT in ICML 2020)
10 |
11 |
12 |
13 | 3. **labeling transfer with semi-supervised learning. (Section 3.4)**
14 | (note that SHOT++ has an extra semi-supervised learning step via MixMatch)
15 |
16 |
17 |
18 | ### Prerequisites:
19 | - python == 3.6.8
20 | - pytorch ==1.1.0
21 | - torchvision == 0.3.0
22 | - numpy, scipy, sklearn, PIL, argparse, tqdm
23 |
24 | ### Dataset:
25 |
26 | - Please manually download the datasets [Office](https://drive.google.com/file/d/0B4IapRTv9pJ1WGZVd1VDMmhwdlE/view), [Office-Home](https://drive.google.com/file/d/0B81rNlvomiwed0V1YUxQdC1uOTg/view), [VisDA-C](https://github.com/VisionLearningGroup/taskcv-2017-public/tree/master/classification), [Office-Caltech](http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar) from the official websites, and modify the path of images in each '.txt' under the folder './object/data/'. [**How to generate such txt files could be found in https://github.com/tim-learn/Generate_list **]
27 |
28 | - Concerning the **Digits** dsatasets, the code will automatically download three digit datasets (i.e., MNIST, USPS, and SVHN) in './digit/data/'.
29 |
30 | ### Training:
31 | 1. ##### Unsupervised Closed-set Domain Adaptation (UDA) on the Digits dataset
32 | - MNIST -> USPS (**m2u**)
33 | ```python
34 | cd digit/
35 | python uda_digit.py --gpu_id 0 --seed 2021 --dset m2u --output ckps_digits --cls_par 0.1 --ssl 0.2
36 | python digit_mixmatch.py --gpu_id 0 --seed 2021 --dset m2u --output ckps_mm --output_tar ckps_digits --cls_par 0.1 --ssl 0.2 --alpha 0.1
37 | ```
38 | 2. ##### Unsupervised Closed-set Domain Adaptation (UDA) on the Office/ Office-Home dataset
39 | - Train model on the source domain **A** (**s = 0**) [--max_epoch 50 for Office-Home]
40 | ```python
41 | cd uda/
42 | python image_source.py --gpu_id 0 --seed 2021 --trte val --da uda --output ckps/source/ --dset office --max_epoch 100 --s 0
43 | ```
44 | - Adaptation to other target domains **D and W** (hypothesis transfer)
45 | ```python
46 | python image_target.py --gpu_id 0 --seed 2021 --da uda --output ckps/target/ --dset office --s 0 --cls_par 0.3 --ssl 0.6
47 | ```
48 | - Adaptation to other target domains **D and W** (following labeling transfer) [--max_epoch 50 for Office-Home]
49 | ```python
50 | python image_mixmatch.py --gpu_id 0 --seed 2021 --da uda --dset office --max_epoch 100 --s 0 --output_tar ckps/target/ --output ckps/mixmatch/ --cls_par 0.3 --ssl 0.6 --choice ent --ps 0.0
51 | ```
52 |
53 | 3. ##### Unsupervised Closed-set Domain Adaptation (UDA) on the VISDA-C dataset
54 | - Train model on the Synthetic domain [--max_epoch 10 --lr 1e-3]
55 | ```python
56 | cd uda/
57 | python image_source.py --gpu_id 0 --seed 2021 --trte val --da uda --output ckps/source/ --dset VISDA-C --net resnet101 --lr 1e-3 --max_epoch 10 --s 0
58 | ```
59 | - Adaptation to the real domain (hypothesis transfer)
60 | ```python
61 | python image_target.py --gpu_id 0 --seed 2021 --da uda --output ckps/target/ --dset VISDA-C --s 0 --net resnet101 --cls_par 0.3 --ssl 0.6
62 | ```
63 | - Adaptation to the real domain (following labeling transfer)
64 | ```python
65 | python image_mixmatch.py --gpu_id 0 --seed 2021 --da uda --dset VISDA-C --max_epoch 10 --s 0 --output_tar ckps/target/ --output ckps/mixmatch/ --net resnet101 --cls_par 0.3 --ssl 0.6 --choice ent --ps 0.0
66 | ```
67 |
68 | 4. ##### Unsupervised Partial-set Domain Adaptation (PDA) on the Office-Home dataset
69 | - Train model on the source domain **A** (**s = 0**)
70 | ```python
71 | cd pda/
72 | python image_source.py --gpu_id 0 --seed 2021 --trte val --da pda --output ckps/source/ --dset office-home --max_epoch 50 --s 0
73 | ```
74 |
75 | - Adaptation to other target domains (hypothesis transfer)
76 | ```python
77 | python image_target.py --gpu_id 0 --seed 2021 --da pda --dset office-home --s 0 --output_src ckps/source/ --output ckps/target/ --cls_par 0.3 --ssl 0.6
78 | ```
79 |
80 | - Adaptation to the real domain (following labeling transfer)
81 | ```python
82 | python image_mixmatch.py --gpu_id 0 --seed 2021 --da pda --dset office-home --max_epoch 50 --s 0 --output_tar ckps/target/ --output ckps/mixmatch/ --cls_par 0.3 --ssl 0.6 --choice ent --ps 0.0
83 | ```
84 |
85 | 5. ##### Unsupervised Multi-source Domain Adaptation (MSDA) on the Office-Home dataset
86 | - Train model on the source domains **Ar** (**s = 0**), **Cl** (**s = 1**), **Pr** (**s = 2**), respectively
87 | ```python
88 | cd msda/
89 | python image_source.py --gpu_id 0 --seed 2021 --trte val --da uda --dset office-home --output ckps/source/ --net resnet50 --max_epoch 50 --s 0
90 | python image_source.py --gpu_id 0 --seed 2021 --trte val --da uda --dset office-home --output ckps/source/ --net resnet50 --max_epoch 50 --s 1
91 | python image_source.py --gpu_id 0 --seed 2021 --trte val --da uda --dset office-home --output ckps/source/ --net resnet50 --max_epoch 50 --s 2
92 | ```
93 |
94 | - Adaptation to the target domain (hypothesis transfer)
95 | ```python
96 | python image_target.py --gpu_id 0 --seed 2021 --cls_par 0.3 --ssl 0.6 --da uda --dset office-home --output_src ckps/source/ --output ckps/target/ --net resnet50 --s 0
97 | python image_target.py --gpu_id 0 --seed 2021 --cls_par 0.3 --ssl 0.6 --da uda --dset office-home --output_src ckps/source/ --output ckps/target/ --net resnet50 --s 1
98 | python image_target.py --gpu_id 0 --seed 2021 --cls_par 0.3 --ssl 0.6 --da uda --dset office-home --output_src ckps/source/ --output ckps/target/ --net resnet50 --s 2
99 | ```
100 |
101 | - Adaptation to the target domain (labeling transfer)
102 | ```python
103 | python image_mixmatch.py --gpu_id 0 --seed 2021 --da uda --dset office-home --max_epoch 50 --output_tar ckps/target/ --output ckps/mixmatch/ --cls_par 0.3 --ssl 0.6 --choice ent --ps 0.0 --net resnet50 --s 0
104 | python image_mixmatch.py --gpu_id 0 --seed 2021 --da uda --dset office-home --max_epoch 50 --output_tar ckps/target/ --output ckps/mixmatch/ --cls_par 0.3 --ssl 0.6 --choice ent --ps 0.0 --net resnet50 --s 1
105 | python image_mixmatch.py --gpu_id 0 --seed 2021 --da uda --dset office-home --max_epoch 50 --output_tar ckps/target/ --output ckps/mixmatch/ --cls_par 0.3 --ssl 0.6 --choice ent --ps 0.0 --net resnet50 --s 2
106 | ```
107 |
108 | - Combine domain-spetific scores together
109 | ```python
110 | python image_ms.py --gpu_id 0 --seed 2021 --cls_par 0.3 --ssl 0.6 --da uda --dset office-home --output_src ckps/source/ --output ckps/target/ --output_mm ckps/mixmatch/ --net resnet50 --t 3
111 | ```
112 |
113 | 6. ##### Semi-supervised Domain Adaptation (SSDA) on the Office-Home dataset
114 | - Train model on the source domain **Ar** (**s = 0**)
115 | ```python
116 | cd ssda/
117 | python image_source.py --gpu_id 0 --seed 2021 --output ckps/source/ --dset office-home --max_epoch 50 --s 0
118 | ```
119 |
120 | - Adaptation to the target domain **Cl** (**t = 1**) [hypothesis transfer]
121 | ```python
122 | python image_target.py --gpu_id 0 --seed 2021 --cls_par 0.1 --ssl 0.2 --output_src ckps/source --output ckps/target --dset office-home --s 0 --t 1
123 | ```
124 |
125 | - Adaptation to the target domain **Cl** (**t = 1**) [labeling transfer]
126 | ```python
127 | python image_mixmatch.py --gpu_id 0 --seed 2021 --ps 0.0 --cls_par 0.1 --ssl 0.2 --output_tar ckps/target --output ckps/mixmatch --dset office-home --max_epoch 50 --s 0 --t 1
128 | ```
129 |
130 | **Please refer *./xxda/run_xxda.sh*** for all the settings for different methods and scenarios.
131 |
132 | ### Citation
133 |
134 | If you find this code useful for your research, please cite our papers
135 | ```
136 | @article{liang2021source,
137 | title={Source Data-absent Unsupervised Domain Adaptation through Hypothesis Transfer and Labeling Transfer},
138 | author={Liang, Jian and Hu, Dapeng and Wang, Yunbo and He, Ran and Feng, Jiashi},
139 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)},
140 | year={2021},
141 | note={In Press}
142 | }
143 |
144 | @inproceedings{liang2020we,
145 | title={Do We Really Need to Access the Source Data? Source Hypothesis Transfer for Unsupervised Domain Adaptation},
146 | author={Liang, Jian and Hu, Dapeng and Feng, Jiashi},
147 | booktitle={International Conference on Machine Learning (ICML)},
148 | pages={6028--6039},
149 | year={2020}
150 | }
151 | ```
152 |
153 | ### Contact
154 |
155 | - [liangjian92@gmail.com](mailto:liangjian92@gmail.com)
156 | - [dapeng.hu@u.nus.edu](mailto:dapeng.hu@u.nus.edu)
157 | - [elefjia@nus.edu.sg](mailto:elefjia@nus.edu.sg)
158 |
--------------------------------------------------------------------------------
/code/digit/data_load/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import hashlib
4 | import gzip
5 | import errno
6 | import tarfile
7 | import zipfile
8 |
9 | import torch
10 | from torch.utils.model_zoo import tqdm
11 |
12 |
13 | def gen_bar_updater():
14 | pbar = tqdm(total=None)
15 |
16 | def bar_update(count, block_size, total_size):
17 | if pbar.total is None and total_size:
18 | pbar.total = total_size
19 | progress_bytes = count * block_size
20 | pbar.update(progress_bytes - pbar.n)
21 |
22 | return bar_update
23 |
24 |
25 | def calculate_md5(fpath, chunk_size=1024 * 1024):
26 | md5 = hashlib.md5()
27 | with open(fpath, 'rb') as f:
28 | for chunk in iter(lambda: f.read(chunk_size), b''):
29 | md5.update(chunk)
30 | return md5.hexdigest()
31 |
32 |
33 | def check_md5(fpath, md5, **kwargs):
34 | return md5 == calculate_md5(fpath, **kwargs)
35 |
36 |
37 | def check_integrity(fpath, md5=None):
38 | if not os.path.isfile(fpath):
39 | return False
40 | if md5 is None:
41 | return True
42 | return check_md5(fpath, md5)
43 |
44 |
45 | def download_url(url, root, filename=None, md5=None):
46 | """Download a file from a url and place it in root.
47 |
48 | Args:
49 | url (str): URL to download file from
50 | root (str): Directory to place downloaded file in
51 | filename (str, optional): Name to save the file under. If None, use the basename of the URL
52 | md5 (str, optional): MD5 checksum of the download. If None, do not check
53 | """
54 | import urllib
55 |
56 | root = os.path.expanduser(root)
57 | if not filename:
58 | filename = os.path.basename(url)
59 | fpath = os.path.join(root, filename)
60 |
61 | os.makedirs(root, exist_ok=True)
62 |
63 | # check if file is already present locally
64 | if check_integrity(fpath, md5):
65 | print('Using downloaded and verified file: ' + fpath)
66 | else: # download the file
67 | try:
68 | print('Downloading ' + url + ' to ' + fpath)
69 | urllib.request.urlretrieve(
70 | url, fpath,
71 | reporthook=gen_bar_updater()
72 | )
73 | except (urllib.error.URLError, IOError) as e:
74 | if url[:5] == 'https':
75 | url = url.replace('https:', 'http:')
76 | print('Failed download. Trying https -> http instead.'
77 | ' Downloading ' + url + ' to ' + fpath)
78 | urllib.request.urlretrieve(
79 | url, fpath,
80 | reporthook=gen_bar_updater()
81 | )
82 | else:
83 | raise e
84 | # check integrity of downloaded file
85 | if not check_integrity(fpath, md5):
86 | raise RuntimeError("File not found or corrupted.")
87 |
88 |
89 | def list_dir(root, prefix=False):
90 | """List all directories at a given root
91 |
92 | Args:
93 | root (str): Path to directory whose folders need to be listed
94 | prefix (bool, optional): If true, prepends the path to each result, otherwise
95 | only returns the name of the directories found
96 | """
97 | root = os.path.expanduser(root)
98 | directories = list(
99 | filter(
100 | lambda p: os.path.isdir(os.path.join(root, p)),
101 | os.listdir(root)
102 | )
103 | )
104 |
105 | if prefix is True:
106 | directories = [os.path.join(root, d) for d in directories]
107 |
108 | return directories
109 |
110 |
111 | def list_files(root, suffix, prefix=False):
112 | """List all files ending with a suffix at a given root
113 |
114 | Args:
115 | root (str): Path to directory whose folders need to be listed
116 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
117 | It uses the Python "str.endswith" method and is passed directly
118 | prefix (bool, optional): If true, prepends the path to each result, otherwise
119 | only returns the name of the files found
120 | """
121 | root = os.path.expanduser(root)
122 | files = list(
123 | filter(
124 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
125 | os.listdir(root)
126 | )
127 | )
128 |
129 | if prefix is True:
130 | files = [os.path.join(root, d) for d in files]
131 |
132 | return files
133 |
134 |
135 | def download_file_from_google_drive(file_id, root, filename=None, md5=None):
136 | """Download a Google Drive file from and place it in root.
137 |
138 | Args:
139 | file_id (str): id of file to be downloaded
140 | root (str): Directory to place downloaded file in
141 | filename (str, optional): Name to save the file under. If None, use the id of the file.
142 | md5 (str, optional): MD5 checksum of the download. If None, do not check
143 | """
144 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
145 | import requests
146 | url = "https://docs.google.com/uc?export=download"
147 |
148 | root = os.path.expanduser(root)
149 | if not filename:
150 | filename = file_id
151 | fpath = os.path.join(root, filename)
152 |
153 | os.makedirs(root, exist_ok=True)
154 |
155 | if os.path.isfile(fpath) and check_integrity(fpath, md5):
156 | print('Using downloaded and verified file: ' + fpath)
157 | else:
158 | session = requests.Session()
159 |
160 | response = session.get(url, params={'id': file_id}, stream=True)
161 | token = _get_confirm_token(response)
162 |
163 | if token:
164 | params = {'id': file_id, 'confirm': token}
165 | response = session.get(url, params=params, stream=True)
166 |
167 | _save_response_content(response, fpath)
168 |
169 |
170 | def _get_confirm_token(response):
171 | for key, value in response.cookies.items():
172 | if key.startswith('download_warning'):
173 | return value
174 |
175 | return None
176 |
177 |
178 | def _save_response_content(response, destination, chunk_size=32768):
179 | with open(destination, "wb") as f:
180 | pbar = tqdm(total=None)
181 | progress = 0
182 | for chunk in response.iter_content(chunk_size):
183 | if chunk: # filter out keep-alive new chunks
184 | f.write(chunk)
185 | progress += len(chunk)
186 | pbar.update(progress - pbar.n)
187 | pbar.close()
188 |
189 |
190 | def _is_tarxz(filename):
191 | return filename.endswith(".tar.xz")
192 |
193 |
194 | def _is_tar(filename):
195 | return filename.endswith(".tar")
196 |
197 |
198 | def _is_targz(filename):
199 | return filename.endswith(".tar.gz")
200 |
201 |
202 | def _is_tgz(filename):
203 | return filename.endswith(".tgz")
204 |
205 |
206 | def _is_gzip(filename):
207 | return filename.endswith(".gz") and not filename.endswith(".tar.gz")
208 |
209 |
210 | def _is_zip(filename):
211 | return filename.endswith(".zip")
212 |
213 |
214 | def extract_archive(from_path, to_path=None, remove_finished=False):
215 | if to_path is None:
216 | to_path = os.path.dirname(from_path)
217 |
218 | if _is_tar(from_path):
219 | with tarfile.open(from_path, 'r') as tar:
220 | tar.extractall(path=to_path)
221 | elif _is_targz(from_path) or _is_tgz(from_path):
222 | with tarfile.open(from_path, 'r:gz') as tar:
223 | tar.extractall(path=to_path)
224 | elif _is_tarxz(from_path):
225 | with tarfile.open(from_path, 'r:xz') as tar:
226 | tar.extractall(path=to_path)
227 | elif _is_gzip(from_path):
228 | to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
229 | with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
230 | out_f.write(zip_f.read())
231 | elif _is_zip(from_path):
232 | with zipfile.ZipFile(from_path, 'r') as z:
233 | z.extractall(to_path)
234 | else:
235 | raise ValueError("Extraction of {} not supported".format(from_path))
236 |
237 | if remove_finished:
238 | os.remove(from_path)
239 |
240 |
241 | def download_and_extract_archive(url, download_root, extract_root=None, filename=None,
242 | md5=None, remove_finished=False):
243 | download_root = os.path.expanduser(download_root)
244 | if extract_root is None:
245 | extract_root = download_root
246 | if not filename:
247 | filename = os.path.basename(url)
248 |
249 | download_url(url, download_root, filename, md5)
250 |
251 | archive = os.path.join(download_root, filename)
252 | print("Extracting {} to {}".format(archive, extract_root))
253 | extract_archive(archive, extract_root, remove_finished)
254 |
255 |
256 | def iterable_to_str(iterable):
257 | return "'" + "', '".join([str(item) for item in iterable]) + "'"
258 |
259 |
260 | def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None):
261 | if not isinstance(value, torch._six.string_classes):
262 | if arg is None:
263 | msg = "Expected type str, but got type {type}."
264 | else:
265 | msg = "Expected type str for argument {arg}, but got type {type}."
266 | msg = msg.format(type=type(value), arg=arg)
267 | raise ValueError(msg)
268 |
269 | if valid_values is None:
270 | return value
271 |
272 | if value not in valid_values:
273 | if custom_msg is not None:
274 | msg = custom_msg
275 | else:
276 | msg = ("Unknown value '{value}' for argument {arg}. "
277 | "Valid values are {{{valid_values}}}.")
278 | msg = msg.format(value=value, arg=arg,
279 | valid_values=iterable_to_str(valid_values))
280 | raise ValueError(msg)
281 |
282 | return value
--------------------------------------------------------------------------------
/code/digit/data_load/usps.py:
--------------------------------------------------------------------------------
1 | """Dataset setting and data loader for USPS.
2 | Modified from
3 | https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py
4 | """
5 |
6 | import gzip
7 | import os
8 | import pickle
9 | import urllib
10 | from PIL import Image
11 |
12 | import numpy as np
13 | import torch
14 | import torch.utils.data as data
15 | from torch.utils.data.sampler import WeightedRandomSampler
16 | from torchvision import datasets, transforms
17 |
18 |
19 | class USPS(data.Dataset):
20 | """USPS Dataset.
21 | Args:
22 | root (string): Root directory of dataset where dataset file exist.
23 | train (bool, optional): If True, resample from dataset randomly.
24 | download (bool, optional): If true, downloads the dataset
25 | from the internet and puts it in root directory.
26 | If dataset is already downloaded, it is not downloaded again.
27 | transform (callable, optional): A function/transform that takes in
28 | an PIL image and returns a transformed version.
29 | E.g, ``transforms.RandomCrop``
30 | """
31 |
32 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl"
33 |
34 | def __init__(self, root, train=True, transform=None, download=False):
35 | """Init USPS dataset."""
36 | # init params
37 | self.root = os.path.expanduser(root)
38 | self.filename = "usps_28x28.pkl"
39 | self.train = train
40 | # Num of Train = 7438, Num ot Test 1860
41 | self.transform = transform
42 | self.dataset_size = None
43 |
44 | # download dataset.
45 | if download:
46 | self.download()
47 | if not self._check_exists():
48 | raise RuntimeError("Dataset not found." +
49 | " You can use download=True to download it")
50 |
51 | self.train_data, self.train_labels = self.load_samples()
52 | if self.train:
53 | total_num_samples = self.train_labels.shape[0]
54 | indices = np.arange(total_num_samples)
55 | self.train_data = self.train_data[indices[0:self.dataset_size], ::]
56 | self.train_labels = self.train_labels[indices[0:self.dataset_size]]
57 | self.train_data *= 255.0
58 | self.train_data = np.squeeze(self.train_data).astype(np.uint8)
59 |
60 | def __getitem__(self, index):
61 | """Get images and target for data loader.
62 | Args:
63 | index (int): Index
64 | Returns:
65 | tuple: (image, target) where target is index of the target class.
66 | """
67 | img, label = self.train_data[index], self.train_labels[index]
68 | img = Image.fromarray(img, mode='L')
69 | img = img.copy()
70 | if self.transform is not None:
71 | img = self.transform(img)
72 | return img, label.astype("int64")
73 |
74 | def __len__(self):
75 | """Return size of dataset."""
76 | return len(self.train_data)
77 |
78 | def _check_exists(self):
79 | """Check if dataset is download and in right place."""
80 | return os.path.exists(os.path.join(self.root, self.filename))
81 |
82 | def download(self):
83 | """Download dataset."""
84 | filename = os.path.join(self.root, self.filename)
85 | dirname = os.path.dirname(filename)
86 | if not os.path.isdir(dirname):
87 | os.makedirs(dirname)
88 | if os.path.isfile(filename):
89 | return
90 | print("Download %s to %s" % (self.url, os.path.abspath(filename)))
91 | urllib.request.urlretrieve(self.url, filename)
92 | print("[DONE]")
93 | return
94 |
95 | def load_samples(self):
96 | """Load sample images from dataset."""
97 | filename = os.path.join(self.root, self.filename)
98 | f = gzip.open(filename, "rb")
99 | data_set = pickle.load(f, encoding="bytes")
100 | f.close()
101 | if self.train:
102 | images = data_set[0][0]
103 | labels = data_set[0][1]
104 | self.dataset_size = labels.shape[0]
105 | else:
106 | images = data_set[1][0]
107 | labels = data_set[1][1]
108 | self.dataset_size = labels.shape[0]
109 | return images, labels
110 |
111 |
112 | class USPS_idx(data.Dataset):
113 | """USPS Dataset.
114 | Args:
115 | root (string): Root directory of dataset where dataset file exist.
116 | train (bool, optional): If True, resample from dataset randomly.
117 | download (bool, optional): If true, downloads the dataset
118 | from the internet and puts it in root directory.
119 | If dataset is already downloaded, it is not downloaded again.
120 | transform (callable, optional): A function/transform that takes in
121 | an PIL image and returns a transformed version.
122 | E.g, ``transforms.RandomCrop``
123 | """
124 |
125 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl"
126 |
127 | def __init__(self, root, train=True, transform=None, download=False):
128 | """Init USPS dataset."""
129 | # init params
130 | self.root = os.path.expanduser(root)
131 | self.filename = "usps_28x28.pkl"
132 | self.train = train
133 | # Num of Train = 7438, Num ot Test 1860
134 | self.transform = transform
135 | self.dataset_size = None
136 |
137 | # download dataset.
138 | if download:
139 | self.download()
140 | if not self._check_exists():
141 | raise RuntimeError("Dataset not found." +
142 | " You can use download=True to download it")
143 |
144 | self.train_data, self.train_labels = self.load_samples()
145 | if self.train:
146 | total_num_samples = self.train_labels.shape[0]
147 | indices = np.arange(total_num_samples)
148 | self.train_data = self.train_data[indices[0:self.dataset_size], ::]
149 | self.train_labels = self.train_labels[indices[0:self.dataset_size]]
150 | self.train_data *= 255.0
151 | self.train_data = np.squeeze(self.train_data).astype(np.uint8)
152 |
153 | def __getitem__(self, index):
154 | """Get images and target for data loader.
155 | Args:
156 | index (int): Index
157 | Returns:
158 | tuple: (image, target) where target is index of the target class.
159 | """
160 | img, label = self.train_data[index], self.train_labels[index]
161 | img = Image.fromarray(img, mode='L')
162 | img = img.copy()
163 | if self.transform is not None:
164 | img = self.transform(img)
165 | return img, label.astype("int64"), index
166 |
167 | def __len__(self):
168 | """Return size of dataset."""
169 | return len(self.train_data)
170 |
171 | def _check_exists(self):
172 | """Check if dataset is download and in right place."""
173 | return os.path.exists(os.path.join(self.root, self.filename))
174 |
175 | def download(self):
176 | """Download dataset."""
177 | filename = os.path.join(self.root, self.filename)
178 | dirname = os.path.dirname(filename)
179 | if not os.path.isdir(dirname):
180 | os.makedirs(dirname)
181 | if os.path.isfile(filename):
182 | return
183 | print("Download %s to %s" % (self.url, os.path.abspath(filename)))
184 | urllib.request.urlretrieve(self.url, filename)
185 | print("[DONE]")
186 | return
187 |
188 | def load_samples(self):
189 | """Load sample images from dataset."""
190 | filename = os.path.join(self.root, self.filename)
191 | f = gzip.open(filename, "rb")
192 | data_set = pickle.load(f, encoding="bytes")
193 | f.close()
194 | if self.train:
195 | images = data_set[0][0]
196 | labels = data_set[0][1]
197 | self.dataset_size = labels.shape[0]
198 | else:
199 | images = data_set[1][0]
200 | labels = data_set[1][1]
201 | self.dataset_size = labels.shape[0]
202 | return images, labels
203 |
204 |
205 | class USPS_twice(USPS):
206 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl"
207 |
208 | def __init__(self, root, train=True, transform=None, download=False):
209 | """Init USPS dataset."""
210 | # init params
211 | self.root = os.path.expanduser(root)
212 | self.filename = "usps_28x28.pkl"
213 | self.train = train
214 | # Num of Train = 7438, Num ot Test 1860
215 | self.transform = transform
216 | self.dataset_size = None
217 |
218 | # download dataset.
219 | if download:
220 | self.download()
221 | if not self._check_exists():
222 | raise RuntimeError("Dataset not found." +
223 | " You can use download=True to download it")
224 |
225 | self.train_data, self.train_labels = self.load_samples()
226 | if self.train:
227 | total_num_samples = self.train_labels.shape[0]
228 | indices = np.arange(total_num_samples)
229 | self.train_data = self.train_data[indices[0:self.dataset_size], ::]
230 | self.train_labels = self.train_labels[indices[0:self.dataset_size]]
231 | self.train_data *= 255.0
232 | self.train_data = np.squeeze(self.train_data).astype(np.uint8)
233 |
234 | def __getitem__(self, index):
235 | """Get images and target for data loader.
236 | Args:
237 | index (int): Index
238 | Returns:
239 | tuple: (image, target) where target is index of the target class.
240 | """
241 | img, label = self.train_data[index], self.train_labels[index]
242 | img = Image.fromarray(img, mode='L')
243 | img = img.copy()
244 | if self.transform is not None:
245 | img = [self.transform(img), self.transform(img)]
246 | return img, label.astype("int64"), index
247 |
248 | def __len__(self):
249 | """Return size of dataset."""
250 | return len(self.train_data)
251 |
--------------------------------------------------------------------------------
/code/msda/image_ms.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import torchvision
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torchvision import transforms
10 | import network, loss
11 | from torch.utils.data import DataLoader
12 | from data_list import ImageList, ImageList_idx
13 | import random, pdb, math, copy
14 | from tqdm import tqdm
15 | from loss import CrossEntropyLabelSmooth
16 | from scipy.spatial.distance import cdist
17 | from sklearn.metrics import confusion_matrix
18 |
19 | def image_train(resize_size=256, crop_size=224, alexnet=False):
20 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
21 | std=[0.229, 0.224, 0.225])
22 | return transforms.Compose([
23 | transforms.Resize((resize_size, resize_size)),
24 | transforms.RandomCrop(crop_size),
25 | transforms.RandomHorizontalFlip(),
26 | transforms.ToTensor(),
27 | normalize
28 | ])
29 |
30 | def image_test(resize_size=256, crop_size=224, alexnet=False):
31 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
32 | std=[0.229, 0.224, 0.225])
33 | return transforms.Compose([
34 | transforms.Resize((resize_size, resize_size)),
35 | transforms.CenterCrop(crop_size),
36 | transforms.ToTensor(),
37 | normalize
38 | ])
39 |
40 | def data_load(args):
41 | ## prepare data
42 | dsets = {}
43 | dset_loaders = {}
44 | train_bs = args.batch_size
45 | txt_test = open(args.test_dset_path).readlines()
46 |
47 | dsets["test"] = ImageList_idx(txt_test, transform=image_test())
48 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3,
49 | shuffle=False, num_workers=args.worker, drop_last=False)
50 |
51 | return dset_loaders
52 |
53 | def cal_acc(loader, netF, netB, netC):
54 | start_test = True
55 | with torch.no_grad():
56 | iter_test = iter(loader)
57 | for i in range(len(loader)):
58 | data = iter_test.next()
59 | inputs = data[0]
60 | labels = data[1]
61 | inputs = inputs.cuda()
62 | outputs = netC(netB(netF(inputs)))
63 | if start_test:
64 | all_output = outputs.float().cpu()
65 | all_label = labels.float()
66 | start_test = False
67 | else:
68 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
69 | all_label = torch.cat((all_label, labels.float()), 0)
70 | _, predict = torch.max(all_output, 1)
71 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
72 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item()
73 |
74 | return accuracy, all_label, nn.Softmax(dim=1)(all_output)
75 |
76 | def print_args(args):
77 | s = "==========================================\n"
78 | for arg, content in args.__dict__.items():
79 | s += "{}:{}\n".format(arg, content)
80 | return s
81 |
82 | def test_target_srconly(args):
83 | dset_loaders = data_load(args)
84 | ## set base network
85 | if args.net[0:3] == 'res':
86 | netF = network.ResBase(res_name=args.net).cuda()
87 | elif args.net[0:3] == 'vgg':
88 | netF = network.VGGBase(vgg_name=args.net).cuda()
89 |
90 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
91 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
92 |
93 | args.modelpath = args.output_dir_src + '/source_F.pt'
94 | netF.load_state_dict(torch.load(args.modelpath))
95 | args.modelpath = args.output_dir_src + '/source_B.pt'
96 | netB.load_state_dict(torch.load(args.modelpath))
97 | args.modelpath = args.output_dir_src + '/source_C.pt'
98 | netC.load_state_dict(torch.load(args.modelpath))
99 | netF.eval()
100 | netB.eval()
101 | netC.eval()
102 |
103 | acc, y, py = cal_acc(dset_loaders['test'], netF, netB, netC)
104 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc*100)
105 | args.out_file.write(log_str)
106 | args.out_file.flush()
107 | print(log_str)
108 |
109 | return y, py
110 |
111 | def test_target(args):
112 | dset_loaders = data_load(args)
113 | ## set base network
114 | if args.net[0:3] == 'res':
115 | netF = network.ResBase(res_name=args.net).cuda()
116 | elif args.net[0:3] == 'vgg':
117 | netF = network.VGGBase(vgg_name=args.net).cuda()
118 |
119 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
120 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
121 |
122 | args.modelpath = args.output_dir_ori + "/target_F_" + args.savename + ".pt"
123 | netF.load_state_dict(torch.load(args.modelpath))
124 | args.modelpath = args.output_dir_ori + "/target_B_" + args.savename + ".pt"
125 | netB.load_state_dict(torch.load(args.modelpath))
126 | args.modelpath = args.output_dir_ori + "/target_C_" + args.savename + ".pt"
127 | netC.load_state_dict(torch.load(args.modelpath))
128 | netF.eval()
129 | netB.eval()
130 | netC.eval()
131 |
132 | acc, y, py = cal_acc(dset_loaders['test'], netF, netB, netC)
133 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc*100)
134 | args.out_file.write(log_str)
135 | args.out_file.flush()
136 | print(log_str)
137 |
138 | return y, py
139 |
140 | if __name__ == "__main__":
141 | parser = argparse.ArgumentParser(description='SHOT++')
142 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
143 | parser.add_argument('--s', type=int, default=0, help="source")
144 | parser.add_argument('--t', type=int, default=1, help="target")
145 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations")
146 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
147 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
148 | parser.add_argument('--dset', type=str, default='office-caltech', choices=['pacs', 'office-home', 'office-caltech'])
149 | parser.add_argument('--net', type=str, default='resnet101', help="alexnet, vgg16, resnet50")
150 | parser.add_argument('--seed', type=int, default=2020, help="random seed")
151 |
152 | parser.add_argument('--gent', type=bool, default=True)
153 | parser.add_argument('--ent', type=bool, default=True)
154 | parser.add_argument('--threshold', type=int, default=0)
155 | parser.add_argument('--cls_par', type=float, default=0.3)
156 |
157 | parser.add_argument('--bottleneck', type=int, default=256)
158 | parser.add_argument('--epsilon', type=float, default=1e-5)
159 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
160 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
161 | parser.add_argument('--output', type=str, default='san')
162 | parser.add_argument('--output_src', type=str, default='san')
163 | parser.add_argument('--output_tar', type=str, default='san')
164 | parser.add_argument('--output_mm', type=str, default='san')
165 | parser.add_argument('--da', type=str, default='uda', choices=['uda'])
166 | parser.add_argument('--ssl', type=float, default=0.0)
167 | parser.add_argument('--savename', type=str, default='san')
168 | args = parser.parse_args()
169 |
170 | if args.dset == 'office-home':
171 | names = ['Art', 'Clipart', 'Product', 'RealWorld']
172 | args.class_num = 65
173 | elif args.dset == 'office-caltech':
174 | names = ['amazon', 'caltech', 'dslr', 'webcam']
175 | args.class_num = 10
176 | elif args.dset == 'pacs':
177 | names = ['art_painting', 'cartoon', 'photo', 'sketch']
178 | args.class_num = 7
179 |
180 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
181 | SEED = args.seed
182 | torch.manual_seed(SEED)
183 | torch.cuda.manual_seed(SEED)
184 | np.random.seed(SEED)
185 | random.seed(SEED)
186 | # torch.backends.cudnn.deterministic = True
187 |
188 | args.output_dir = osp.join(args.output, args.da, args.dset, names[args.t][0].upper())
189 | if not osp.exists(args.output_dir):
190 | os.system('mkdir -p ' + args.output_dir)
191 | if not osp.exists(args.output_dir):
192 | os.mkdir(args.output_dir)
193 |
194 | args.savename = 'par_' + str(args.cls_par)
195 | if args.ssl > 0:
196 | args.savename += ('_ssl_' + str(args.ssl))
197 |
198 | args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w')
199 | args.out_file.write(print_args(args)+'\n')
200 | args.out_file.flush()
201 |
202 | score_srconly = 0
203 | score_srconly_mm = 0
204 | score = 0
205 | score_mm = 0
206 |
207 | for i in range(len(names)):
208 | if i == args.t:
209 | continue
210 | args.s = i
211 |
212 | folder = '../data/'
213 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
214 | args.name = names[args.s][0].upper()+names[args.t][0].upper()
215 |
216 | if args.dset == 'pacs':
217 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_test_kfold.txt'
218 |
219 | args.output_dir_src = osp.join(args.output_src, args.da, args.dset, names[args.s][0].upper())
220 | label, output_srconly = test_target_srconly(args)
221 | score_srconly += output_srconly
222 |
223 | args.output_dir_ori = osp.join(args.output_tar, args.da, args.dset, names[args.s][0].upper()+names[args.t][0].upper())
224 | _, output = test_target(args)
225 | score += output
226 |
227 | args.output_dir = osp.join(args.output_mm, args.da, args.dset, names[args.s][0].upper()+names[args.t][0].upper())
228 | mm = np.load(args.output_dir + '/ps_0.0_' + args.savename + '.npz')
229 | score_mm += torch.from_numpy(mm['score'])
230 |
231 | args.output_dir = osp.join(args.output_mm, args.da, args.dset, names[args.s][0].upper()+names[args.t][0].upper())
232 | mm = np.load(args.output_dir + '/ps_0.0_srconly.npz')
233 | score_srconly_mm += torch.from_numpy(mm['score'])
234 |
235 | _, predict = torch.max(score_srconly, 1)
236 | acc = torch.sum(torch.squeeze(predict).float() == label).item() / float(label.size()[0])
237 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(names[args.t][0].upper(), acc*100)
238 | args.out_file.write(log_str)
239 | args.out_file.flush()
240 | print(log_str)
241 |
242 | _, predict = torch.max(score_srconly_mm, 1)
243 | acc = torch.sum(torch.squeeze(predict).float() == label).item() / float(label.size()[0])
244 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(names[args.t][0].upper(), acc*100)
245 | args.out_file.write(log_str)
246 | args.out_file.flush()
247 | print(log_str)
248 |
249 | _, predict = torch.max(score, 1)
250 | acc = torch.sum(torch.squeeze(predict).float() == label).item() / float(label.size()[0])
251 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(names[args.t][0].upper(), acc*100)
252 | args.out_file.write(log_str)
253 | args.out_file.flush()
254 | print(log_str)
255 |
256 | _, predict = torch.max(score_mm, 1)
257 | acc = torch.sum(torch.squeeze(predict).float() == label).item() / float(label.size()[0])
258 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(names[args.t][0].upper(), acc*100)
259 | args.out_file.write(log_str)
260 | args.out_file.flush()
261 | print(log_str)
262 |
263 | args.out_file.close()
--------------------------------------------------------------------------------
/code/ssda/image_source.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import torchvision
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torchvision import transforms
10 | import network
11 | import loss
12 | from torch.utils.data import DataLoader
13 | from data_list import ImageList
14 | import random, pdb, math, copy
15 | from tqdm import tqdm
16 | from loss import CrossEntropyLabelSmooth
17 | from scipy.spatial.distance import cdist
18 | from sklearn.metrics import confusion_matrix
19 | from sklearn.cluster import KMeans
20 |
21 | def op_copy(optimizer):
22 | for param_group in optimizer.param_groups:
23 | param_group['lr0'] = param_group['lr']
24 | return optimizer
25 |
26 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
27 | decay = (1 + gamma * iter_num / max_iter) ** (-power)
28 | for param_group in optimizer.param_groups:
29 | param_group['lr'] = param_group['lr0'] * decay
30 | param_group['weight_decay'] = 1e-3
31 | param_group['momentum'] = 0.9
32 | param_group['nesterov'] = True
33 | return optimizer
34 |
35 | def image_train(resize_size=256, crop_size=224, alexnet=False):
36 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
37 | std=[0.229, 0.224, 0.225])
38 | return transforms.Compose([
39 | transforms.Resize((resize_size, resize_size)),
40 | transforms.RandomCrop(crop_size),
41 | transforms.RandomHorizontalFlip(),
42 | transforms.ToTensor(),
43 | normalize
44 | ])
45 |
46 | def image_test(resize_size=256, crop_size=224, alexnet=False):
47 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
48 | std=[0.229, 0.224, 0.225])
49 | return transforms.Compose([
50 | transforms.Resize((resize_size, resize_size)),
51 | transforms.CenterCrop(crop_size),
52 | transforms.ToTensor(),
53 | normalize
54 | ])
55 |
56 | def data_load(args):
57 | ## prepare data
58 | dsets = {}
59 | dset_loaders = {}
60 | train_bs = args.batch_size
61 | txt_src = open(args.s_dset_path).readlines()
62 | txt_test = open(args.test_dset_path).readlines()
63 |
64 | dsize = len(txt_src)
65 | tr_size = int(0.9*dsize)
66 | tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])
67 |
68 | dsets["source_tr"] = ImageList(tr_txt, transform=image_train())
69 | dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
70 | dsets["source_te"] = ImageList(te_txt, transform=image_test())
71 | dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False)
72 | dsets["test"] = ImageList(txt_test, transform=image_test())
73 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=True, num_workers=args.worker, drop_last=False)
74 |
75 | return dset_loaders
76 |
77 | def cal_acc(loader, netF, netB, netC, flag=False):
78 | start_test = True
79 | with torch.no_grad():
80 | iter_test = iter(loader)
81 | for i in range(len(loader)):
82 | data = iter_test.next()
83 | inputs = data[0]
84 | labels = data[1]
85 | inputs = inputs.cuda()
86 | outputs = netC(netB(netF(inputs)))
87 | if start_test:
88 | all_output = outputs.float().cpu()
89 | all_label = labels.float()
90 | start_test = False
91 | else:
92 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
93 | all_label = torch.cat((all_label, labels.float()), 0)
94 |
95 | all_output = nn.Softmax(dim=1)(all_output)
96 | _, predict = torch.max(all_output, 1)
97 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
98 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item()
99 |
100 | return accuracy*100, mean_ent
101 |
102 | def train_source(args):
103 | dset_loaders = data_load(args)
104 | ## set base network
105 | if args.net[0:3] == 'res':
106 | netF = network.ResBase(res_name=args.net).cuda()
107 | elif args.net[0:3] == 'vgg':
108 | netF = network.VGGBase(vgg_name=args.net).cuda()
109 |
110 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
111 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
112 |
113 | param_group = []
114 | learning_rate = args.lr
115 | for k, v in netF.named_parameters():
116 | param_group += [{'params': v, 'lr': learning_rate*0.1}]
117 | for k, v in netB.named_parameters():
118 | param_group += [{'params': v, 'lr': learning_rate}]
119 | for k, v in netC.named_parameters():
120 | param_group += [{'params': v, 'lr': learning_rate}]
121 | optimizer = optim.SGD(param_group)
122 | optimizer = op_copy(optimizer)
123 |
124 | acc_init = 0
125 | max_iter = args.max_epoch * len(dset_loaders["source_tr"])
126 | interval_iter = max_iter // 10
127 | iter_num = 0
128 |
129 | netF.train()
130 | netB.train()
131 | netC.train()
132 |
133 | while iter_num < max_iter:
134 | try:
135 | inputs_source, labels_source = iter_source.next()
136 | except:
137 | iter_source = iter(dset_loaders["source_tr"])
138 | inputs_source, labels_source = iter_source.next()
139 |
140 | if inputs_source.size(0) == 1:
141 | continue
142 |
143 | iter_num += 1
144 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
145 |
146 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda()
147 | outputs_source = netC(netB(netF(inputs_source)))
148 | classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source)
149 |
150 | optimizer.zero_grad()
151 | classifier_loss.backward()
152 | optimizer.step()
153 |
154 | if iter_num % interval_iter == 0 or iter_num == max_iter:
155 | netF.eval()
156 | netB.eval()
157 | netC.eval()
158 |
159 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC, False)
160 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te)
161 |
162 | args.out_file.write(log_str + '\n')
163 | args.out_file.flush()
164 | print(log_str+'\n')
165 |
166 | if acc_s_te >= acc_init:
167 | acc_init = acc_s_te
168 | best_netF = netF.state_dict()
169 | best_netB = netB.state_dict()
170 | best_netC = netC.state_dict()
171 |
172 | netF.train()
173 | netB.train()
174 | netC.train()
175 |
176 | torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt"))
177 | torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt"))
178 | torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt"))
179 |
180 | return netF, netB, netC
181 |
182 | def test_target(args):
183 | dset_loaders = data_load(args)
184 | ## set base network
185 | if args.net[0:3] == 'res':
186 | netF = network.ResBase(res_name=args.net).cuda()
187 | elif args.net[0:3] == 'vgg':
188 | netF = network.VGGBase(vgg_name=args.net).cuda()
189 |
190 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
191 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
192 |
193 | args.modelpath = args.output_dir_src + '/source_F.pt'
194 | netF.load_state_dict(torch.load(args.modelpath))
195 | args.modelpath = args.output_dir_src + '/source_B.pt'
196 | netB.load_state_dict(torch.load(args.modelpath))
197 | args.modelpath = args.output_dir_src + '/source_C.pt'
198 | netC.load_state_dict(torch.load(args.modelpath))
199 | netF.eval()
200 | netB.eval()
201 | netC.eval()
202 |
203 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False)
204 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc)
205 |
206 | args.out_file.write(log_str)
207 | args.out_file.flush()
208 | print(log_str)
209 |
210 | def print_args(args):
211 | s = "==========================================\n"
212 | for arg, content in args.__dict__.items():
213 | s += "{}:{}\n".format(arg, content)
214 | return s
215 |
216 | if __name__ == "__main__":
217 | parser = argparse.ArgumentParser(description='SHOT++')
218 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
219 | parser.add_argument('--s', type=int, default=0, help="source")
220 | parser.add_argument('--t', type=int, default=1, help="target")
221 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations")
222 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
223 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
224 | parser.add_argument('--dset', type=str, default='office-home', choices=['office-home'])
225 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
226 | parser.add_argument('--net', type=str, default='vgg16', help="resnet34, vgg16, resnet50, resnet101")
227 | parser.add_argument('--seed', type=int, default=2020, help="random seed")
228 | parser.add_argument('--bottleneck', type=int, default=256)
229 | parser.add_argument('--epsilon', type=float, default=1e-5)
230 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
231 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
232 | parser.add_argument('--smooth', type=float, default=0.1)
233 | parser.add_argument('--output', type=str, default='san')
234 | parser.add_argument('--da', type=str, default='ssda')
235 | parser.add_argument('--shot', type=int, default=1, choices=[1, 3])
236 |
237 | args = parser.parse_args()
238 |
239 | if args.dset == 'office-home':
240 | names = ['Art', 'Clipart', 'Product', 'Real']
241 | args.class_num = 65
242 |
243 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
244 | SEED = args.seed
245 | torch.manual_seed(SEED)
246 | torch.cuda.manual_seed(SEED)
247 | np.random.seed(SEED)
248 | random.seed(SEED)
249 | # torch.backends.cudnn.deterministic = True
250 |
251 | folder = '../data/ssda/'
252 | args.s_dset_path = folder + args.dset + '/labeled_source_images_' + names[args.s] + '.txt'
253 | args.test_dset_path = folder + args.dset + '/unlabeled_target_images_' + names[args.t] + '_' + str(args.shot) + '.txt'
254 |
255 | args.output_dir_src = osp.join(args.output, args.da, args.dset, names[args.s][0].upper())
256 | args.name_src = names[args.s][0].upper()
257 | if not osp.exists(args.output_dir_src):
258 | os.system('mkdir -p ' + args.output_dir_src)
259 | if not osp.exists(args.output_dir_src):
260 | os.mkdir(args.output_dir_src)
261 |
262 | args.out_file = open(osp.join(args.output_dir_src, 'log.txt'), 'w')
263 | args.out_file.write(print_args(args)+'\n')
264 | args.out_file.flush()
265 | train_source(args)
266 |
267 | args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w')
268 | for i in range(len(names)):
269 | if i == args.s:
270 | continue
271 | args.t = i
272 | args.name = names[args.s][0].upper() + names[args.t][0].upper()
273 |
274 | args.s_dset_path = folder + args.dset + '/labeled_source_images_' + names[args.s] + '.txt'
275 | args.test_dset_path = folder + args.dset + '/unlabeled_target_images_' + names[args.t] + '_' + str(args.shot) + '.txt'
276 | test_target(args)
277 | args.out_file.close()
--------------------------------------------------------------------------------
/code/data/pacs/photo_crossval_kfold.txt:
--------------------------------------------------------------------------------
1 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0001.jpg 1
2 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0002.jpg 1
3 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0003.jpg 1
4 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0004.jpg 1
5 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0005.jpg 1
6 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0006.jpg 1
7 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0007.jpg 1
8 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0009.jpg 1
9 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0010.jpg 1
10 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0011.jpg 1
11 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0012.jpg 1
12 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0013.jpg 1
13 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0014.jpg 1
14 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0015.jpg 1
15 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0016.jpg 1
16 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0017.jpg 1
17 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0018.jpg 1
18 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0020.jpg 1
19 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/dog/056_0021.jpg 1
20 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0001.jpg 2
21 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0002.jpg 2
22 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0003.jpg 2
23 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0004.jpg 2
24 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0005.jpg 2
25 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0006.jpg 2
26 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0007.jpg 2
27 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0008.jpg 2
28 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0009.jpg 2
29 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0010.jpg 2
30 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0011.jpg 2
31 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0012.jpg 2
32 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0013.jpg 2
33 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0014.jpg 2
34 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0015.jpg 2
35 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0016.jpg 2
36 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0017.jpg 2
37 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0018.jpg 2
38 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0019.jpg 2
39 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0020.jpg 2
40 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/elephant/064_0021.jpg 2
41 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0001.jpg 3
42 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0002.jpg 3
43 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0003.jpg 3
44 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0004.jpg 3
45 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0005.jpg 3
46 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0006.jpg 3
47 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0007.jpg 3
48 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0008.jpg 3
49 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0009.jpg 3
50 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0010.jpg 3
51 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0011.jpg 3
52 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0012.jpg 3
53 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0013.jpg 3
54 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0014.jpg 3
55 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0015.jpg 3
56 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0016.jpg 3
57 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0017.jpg 3
58 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0018.jpg 3
59 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/giraffe/084_0019.jpg 3
60 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0001.jpg 4
61 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0002.jpg 4
62 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0003.jpg 4
63 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0004.jpg 4
64 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0005.jpg 4
65 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0006.jpg 4
66 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0007.jpg 4
67 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0008.jpg 4
68 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0009.jpg 4
69 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0010.jpg 4
70 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0012.jpg 4
71 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0013.jpg 4
72 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0016.jpg 4
73 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0018.jpg 4
74 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0019.jpg 4
75 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0020.jpg 4
76 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0021.jpg 4
77 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0022.jpg 4
78 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/guitar/063_0023.jpg 4
79 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0002.jpg 5
80 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0003.jpg 5
81 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0007.jpg 5
82 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0008.jpg 5
83 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0009.jpg 5
84 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0010.jpg 5
85 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0012.jpg 5
86 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0013.jpg 5
87 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0022.jpg 5
88 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0025.jpg 5
89 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0028.jpg 5
90 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0029.jpg 5
91 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0030.jpg 5
92 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0033.jpg 5
93 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0037.jpg 5
94 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0038.jpg 5
95 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0041.jpg 5
96 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0042.jpg 5
97 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0047.jpg 5
98 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/horse/105_0048.jpg 5
99 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_010.jpg 6
100 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_011.jpg 6
101 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_012.jpg 6
102 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_013.jpg 6
103 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_014.jpg 6
104 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_015.jpg 6
105 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_016.jpg 6
106 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_017.jpg 6
107 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_018.jpg 6
108 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_021.jpg 6
109 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_019.jpg 6
110 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_022.jpg 6
111 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_020.jpg 6
112 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_023.jpg 6
113 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_024.jpg 6
114 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_026.jpg 6
115 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_025.jpg 6
116 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_027.jpg 6
117 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_028.jpg 6
118 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_029.jpg 6
119 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_031.jpg 6
120 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_239.jpg 6
121 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_240.jpg 6
122 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_241.jpg 6
123 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_242.jpg 6
124 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_248.jpg 6
125 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_246.jpg 6
126 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_247.jpg 6
127 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/house/pic_244.jpg 6
128 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0001.jpg 0
129 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0002.jpg 0
130 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0003.jpg 0
131 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0004.jpg 0
132 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0005.jpg 0
133 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0006.jpg 0
134 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0007.jpg 0
135 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0008.jpg 0
136 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0009.jpg 0
137 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0010.jpg 0
138 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0011.jpg 0
139 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0012.jpg 0
140 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0013.jpg 0
141 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0014.jpg 0
142 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0015.jpg 0
143 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0016.jpg 0
144 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0017.jpg 0
145 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0018.jpg 0
146 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0019.jpg 0
147 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0020.jpg 0
148 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0021.jpg 0
149 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0022.jpg 0
150 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0023.jpg 0
151 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0024.jpg 0
152 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0025.jpg 0
153 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0026.jpg 0
154 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0027.jpg 0
155 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0028.jpg 0
156 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0029.jpg 0
157 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0030.jpg 0
158 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0031.jpg 0
159 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0032.jpg 0
160 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0033.jpg 0
161 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0034.jpg 0
162 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0035.jpg 0
163 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0036.jpg 0
164 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0037.jpg 0
165 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0038.jpg 0
166 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0039.jpg 0
167 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0040.jpg 0
168 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0041.jpg 0
169 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0042.jpg 0
170 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0043.jpg 0
171 | /Checkpoint/liangjian/da_dataset/pacs/kfold/photo/person/253_0044.jpg 0
172 |
--------------------------------------------------------------------------------
/code/msda/image_source.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import torchvision
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torchvision import transforms
10 | import network
11 | import loss
12 | from torch.utils.data import DataLoader
13 | from data_list import ImageList
14 | import random, pdb, math, copy
15 | from tqdm import tqdm
16 | from loss import CrossEntropyLabelSmooth
17 | from scipy.spatial.distance import cdist
18 | from sklearn.metrics import confusion_matrix
19 | from sklearn.cluster import KMeans
20 |
21 | def op_copy(optimizer):
22 | for param_group in optimizer.param_groups:
23 | param_group['lr0'] = param_group['lr']
24 | return optimizer
25 |
26 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
27 | decay = (1 + gamma * iter_num / max_iter) ** (-power)
28 | for param_group in optimizer.param_groups:
29 | param_group['lr'] = param_group['lr0'] * decay
30 | param_group['weight_decay'] = 1e-3
31 | param_group['momentum'] = 0.9
32 | param_group['nesterov'] = True
33 | return optimizer
34 |
35 | def image_train(resize_size=256, crop_size=224, alexnet=False):
36 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
37 | std=[0.229, 0.224, 0.225])
38 | return transforms.Compose([
39 | transforms.Resize((resize_size, resize_size)),
40 | transforms.RandomCrop(crop_size),
41 | transforms.RandomHorizontalFlip(),
42 | transforms.ToTensor(),
43 | normalize
44 | ])
45 |
46 | def image_test(resize_size=256, crop_size=224, alexnet=False):
47 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
48 | std=[0.229, 0.224, 0.225])
49 | return transforms.Compose([
50 | transforms.Resize((resize_size, resize_size)),
51 | transforms.CenterCrop(crop_size),
52 | transforms.ToTensor(),
53 | normalize
54 | ])
55 |
56 | def data_load(args):
57 | ## prepare data
58 | dsets = {}
59 | dset_loaders = {}
60 | train_bs = args.batch_size
61 | txt_src = open(args.s_dset_path).readlines()
62 | txt_test = open(args.test_dset_path).readlines()
63 |
64 | dsize = len(txt_src)
65 | tr_size = int(0.9*dsize)
66 | tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])
67 |
68 | dsets["source_tr"] = ImageList(tr_txt, transform=image_train())
69 | dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True,
70 | num_workers=args.worker, drop_last=False)
71 | dsets["source_te"] = ImageList(te_txt, transform=image_test())
72 | dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs*3, shuffle=False,
73 | num_workers=args.worker, drop_last=False)
74 | dsets["test"] = ImageList(txt_test, transform=image_test())
75 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False,
76 | num_workers=args.worker, drop_last=False)
77 |
78 | return dset_loaders
79 |
80 | def cal_acc(loader, netF, netB, netC):
81 | start_test = True
82 | with torch.no_grad():
83 | iter_test = iter(loader)
84 | for i in range(len(loader)):
85 | data = iter_test.next()
86 | inputs = data[0]
87 | labels = data[1]
88 | inputs = inputs.cuda()
89 | outputs = netC(netB(netF(inputs)))
90 | if start_test:
91 | all_output = outputs.float().cpu()
92 | all_label = labels.float()
93 | start_test = False
94 | else:
95 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
96 | all_label = torch.cat((all_label, labels.float()), 0)
97 |
98 | all_output = nn.Softmax(dim=1)(all_output)
99 | _, predict = torch.max(all_output, 1)
100 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
101 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item()
102 |
103 | return accuracy*100, mean_ent
104 |
105 | def train_source(args):
106 | dset_loaders = data_load(args)
107 | ## set base network
108 | if args.net[0:3] == 'res':
109 | netF = network.ResBase(res_name=args.net).cuda()
110 | elif args.net[0:3] == 'vgg':
111 | netF = network.VGGBase(vgg_name=args.net).cuda()
112 |
113 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
114 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
115 |
116 | param_group = []
117 | learning_rate = args.lr
118 | for k, v in netF.named_parameters():
119 | param_group += [{'params': v, 'lr': learning_rate*0.1}]
120 | for k, v in netB.named_parameters():
121 | param_group += [{'params': v, 'lr': learning_rate}]
122 | for k, v in netC.named_parameters():
123 | param_group += [{'params': v, 'lr': learning_rate}]
124 | optimizer = optim.SGD(param_group)
125 | optimizer = op_copy(optimizer)
126 |
127 | acc_init = 0
128 | max_iter = args.max_epoch * len(dset_loaders["source_tr"])
129 | interval_iter = max_iter // 10
130 | iter_num = 0
131 |
132 | netF.train()
133 | netB.train()
134 | netC.train()
135 |
136 | while iter_num < max_iter:
137 | try:
138 | inputs_source, labels_source = iter_source.next()
139 | except:
140 | iter_source = iter(dset_loaders["source_tr"])
141 | inputs_source, labels_source = iter_source.next()
142 |
143 | if inputs_source.size(0) == 1:
144 | continue
145 |
146 | iter_num += 1
147 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
148 |
149 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda()
150 | outputs_source = netC(netB(netF(inputs_source)))
151 | classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source)
152 |
153 | optimizer.zero_grad()
154 | classifier_loss.backward()
155 | optimizer.step()
156 |
157 | if iter_num % interval_iter == 0 or iter_num == max_iter:
158 | netF.eval()
159 | netB.eval()
160 | netC.eval()
161 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC)
162 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te)
163 | args.out_file.write(log_str + '\n')
164 | args.out_file.flush()
165 | print(log_str+'\n')
166 |
167 | if acc_s_te >= acc_init:
168 | acc_init = acc_s_te
169 | best_netF = netF.state_dict()
170 | best_netB = netB.state_dict()
171 | best_netC = netC.state_dict()
172 |
173 | netF.train()
174 | netB.train()
175 | netC.train()
176 |
177 | torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt"))
178 | torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt"))
179 | torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt"))
180 |
181 | return netF, netB, netC
182 |
183 | def test_target(args):
184 | dset_loaders = data_load(args)
185 | ## set base network
186 | if args.net[0:3] == 'res':
187 | netF = network.ResBase(res_name=args.net).cuda()
188 | elif args.net[0:3] == 'vgg':
189 | netF = network.VGGBase(vgg_name=args.net).cuda()
190 |
191 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
192 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
193 |
194 | args.modelpath = args.output_dir_src + '/source_F.pt'
195 | netF.load_state_dict(torch.load(args.modelpath))
196 | args.modelpath = args.output_dir_src + '/source_B.pt'
197 | netB.load_state_dict(torch.load(args.modelpath))
198 | args.modelpath = args.output_dir_src + '/source_C.pt'
199 | netC.load_state_dict(torch.load(args.modelpath))
200 | netF.eval()
201 | netB.eval()
202 | netC.eval()
203 |
204 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC)
205 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc)
206 |
207 | args.out_file.write(log_str)
208 | args.out_file.flush()
209 | print(log_str)
210 |
211 | def print_args(args):
212 | s = "==========================================\n"
213 | for arg, content in args.__dict__.items():
214 | s += "{}:{}\n".format(arg, content)
215 | return s
216 |
217 | if __name__ == "__main__":
218 | parser = argparse.ArgumentParser(description='SHOT')
219 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
220 | parser.add_argument('--s', type=int, default=0, help="source")
221 | parser.add_argument('--t', type=int, default=1, help="target")
222 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations")
223 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
224 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
225 | parser.add_argument('--dset', type=str, default='office-home', choices=['pacs', 'office-home', 'office-caltech'])
226 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
227 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet18, resnet34, resnet50, resnet101")
228 | parser.add_argument('--seed', type=int, default=2020, help="random seed")
229 | parser.add_argument('--bottleneck', type=int, default=256)
230 | parser.add_argument('--epsilon', type=float, default=1e-5)
231 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
232 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
233 | parser.add_argument('--smooth', type=float, default=0.1)
234 | parser.add_argument('--output', type=str, default='san')
235 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda'])
236 | args = parser.parse_args()
237 |
238 | if args.dset == 'office-home':
239 | names = ['Art', 'Clipart', 'Product', 'RealWorld']
240 | args.class_num = 65
241 | elif args.dset == 'office-caltech':
242 | names = ['amazon', 'caltech', 'dslr', 'webcam']
243 | args.class_num = 10
244 | elif args.dset == 'pacs':
245 | names = ['art_painting', 'cartoon', 'photo', 'sketch']
246 | args.class_num = 7
247 |
248 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
249 | SEED = args.seed
250 | torch.manual_seed(SEED)
251 | torch.cuda.manual_seed(SEED)
252 | np.random.seed(SEED)
253 | random.seed(SEED)
254 | # torch.backends.cudnn.deterministic = True
255 |
256 | folder = '/Checkpoint/liangjian/tran/data/'
257 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt'
258 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
259 |
260 | if args.dset == 'pacs':
261 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_test_kfold.txt'
262 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_test_kfold.txt'
263 |
264 | args.output_dir_src = osp.join(args.output, args.da, args.dset, names[args.s][0].upper())
265 | args.name_src = names[args.s][0].upper()
266 | if not osp.exists(args.output_dir_src):
267 | os.system('mkdir -p ' + args.output_dir_src)
268 | if not osp.exists(args.output_dir_src):
269 | os.mkdir(args.output_dir_src)
270 |
271 | args.out_file = open(osp.join(args.output_dir_src, 'log.txt'), 'w')
272 | args.out_file.write(print_args(args)+'\n')
273 | args.out_file.flush()
274 | train_source(args)
275 |
276 | args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w')
277 | for i in range(len(names)):
278 | if i == args.s:
279 | continue
280 | args.t = i
281 | args.name = names[args.s][0].upper() + names[args.t][0].upper()
282 |
283 | folder = '../data/'
284 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt'
285 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
286 |
287 | if args.dset == 'pacs':
288 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_test_kfold.txt'
289 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_test_kfold.txt'
290 |
291 | test_target(args)
292 |
293 | args.out_file.close()
--------------------------------------------------------------------------------
/code/digit/data_load/svhn.py:
--------------------------------------------------------------------------------
1 | from .vision import VisionDataset
2 | from PIL import Image
3 | import os
4 | import os.path
5 | import numpy as np
6 | from .utils import download_url, check_integrity, verify_str_arg
7 |
8 |
9 | class SVHN(VisionDataset):
10 | """`SVHN `_ Dataset.
11 | Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
12 | we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
13 | expect the class labels to be in the range `[0, C-1]`
14 |
15 | .. warning::
16 |
17 | This class needs `scipy `_ to load data from `.mat` format.
18 |
19 | Args:
20 | root (string): Root directory of dataset where directory
21 | ``SVHN`` exists.
22 | split (string): One of {'train', 'test', 'extra'}.
23 | Accordingly dataset is selected. 'extra' is Extra training set.
24 | transform (callable, optional): A function/transform that takes in an PIL image
25 | and returns a transformed version. E.g, ``transforms.RandomCrop``
26 | target_transform (callable, optional): A function/transform that takes in the
27 | target and transforms it.
28 | download (bool, optional): If true, downloads the dataset from the internet and
29 | puts it in root directory. If dataset is already downloaded, it is not
30 | downloaded again.
31 |
32 | """
33 |
34 | split_list = {
35 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
36 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"],
37 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
38 | "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"],
39 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
40 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
41 |
42 | def __init__(self, root, split='train', transform=None, target_transform=None,
43 | download=False):
44 | super(SVHN, self).__init__(root, transform=transform,
45 | target_transform=target_transform)
46 | self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
47 | self.url = self.split_list[split][0]
48 | self.filename = self.split_list[split][1]
49 | self.file_md5 = self.split_list[split][2]
50 |
51 | if download:
52 | self.download()
53 |
54 | if not self._check_integrity():
55 | raise RuntimeError('Dataset not found or corrupted.' +
56 | ' You can use download=True to download it')
57 |
58 | # import here rather than at top of file because this is
59 | # an optional dependency for torchvision
60 | import scipy.io as sio
61 |
62 | # reading(loading) mat file as array
63 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
64 |
65 | self.data = loaded_mat['X']
66 | # loading from the .mat file gives an np array of type np.uint8
67 | # converting to np.int64, so that we have a LongTensor after
68 | # the conversion from the numpy array
69 | # the squeeze is needed to obtain a 1D tensor
70 | self.labels = loaded_mat['y'].astype(np.int64).squeeze()
71 |
72 | # the svhn dataset assigns the class label "10" to the digit 0
73 | # this makes it inconsistent with several loss functions
74 | # which expect the class labels to be in the range [0, C-1]
75 | np.place(self.labels, self.labels == 10, 0)
76 | self.data = np.transpose(self.data, (3, 2, 0, 1))
77 |
78 | def __getitem__(self, index):
79 | """
80 | Args:
81 | index (int): Index
82 |
83 | Returns:
84 | tuple: (image, target) where target is index of the target class.
85 | """
86 | img, target = self.data[index], int(self.labels[index])
87 |
88 | # doing this so that it is consistent with all other datasets
89 | # to return a PIL Image
90 | img = Image.fromarray(np.transpose(img, (1, 2, 0)))
91 |
92 | if self.transform is not None:
93 | img = self.transform(img)
94 |
95 | if self.target_transform is not None:
96 | target = self.target_transform(target)
97 |
98 | return img, target
99 |
100 | def __len__(self):
101 | return len(self.data)
102 |
103 | def _check_integrity(self):
104 | root = self.root
105 | md5 = self.split_list[self.split][2]
106 | fpath = os.path.join(root, self.filename)
107 | return check_integrity(fpath, md5)
108 |
109 | def download(self):
110 | md5 = self.split_list[self.split][2]
111 | download_url(self.url, self.root, self.filename, md5)
112 |
113 | def extra_repr(self):
114 | return "Split: {split}".format(**self.__dict__)
115 |
116 | class SVHN_idx(VisionDataset):
117 | """`SVHN `_ Dataset.
118 | Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
119 | we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
120 | expect the class labels to be in the range `[0, C-1]`
121 |
122 | .. warning::
123 |
124 | This class needs `scipy `_ to load data from `.mat` format.
125 |
126 | Args:
127 | root (string): Root directory of dataset where directory
128 | ``SVHN`` exists.
129 | split (string): One of {'train', 'test', 'extra'}.
130 | Accordingly dataset is selected. 'extra' is Extra training set.
131 | transform (callable, optional): A function/transform that takes in an PIL image
132 | and returns a transformed version. E.g, ``transforms.RandomCrop``
133 | target_transform (callable, optional): A function/transform that takes in the
134 | target and transforms it.
135 | download (bool, optional): If true, downloads the dataset from the internet and
136 | puts it in root directory. If dataset is already downloaded, it is not
137 | downloaded again.
138 |
139 | """
140 |
141 | split_list = {
142 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
143 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"],
144 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
145 | "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"],
146 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
147 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
148 |
149 | def __init__(self, root, split='train', transform=None, target_transform=None,
150 | download=False):
151 | super(SVHN_idx, self).__init__(root, transform=transform,
152 | target_transform=target_transform)
153 | self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
154 | self.url = self.split_list[split][0]
155 | self.filename = self.split_list[split][1]
156 | self.file_md5 = self.split_list[split][2]
157 |
158 | if download:
159 | self.download()
160 |
161 | if not self._check_integrity():
162 | raise RuntimeError('Dataset not found or corrupted.' +
163 | ' You can use download=True to download it')
164 |
165 | # import here rather than at top of file because this is
166 | # an optional dependency for torchvision
167 | import scipy.io as sio
168 |
169 | # reading(loading) mat file as array
170 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
171 |
172 | self.data = loaded_mat['X']
173 | # loading from the .mat file gives an np array of type np.uint8
174 | # converting to np.int64, so that we have a LongTensor after
175 | # the conversion from the numpy array
176 | # the squeeze is needed to obtain a 1D tensor
177 | self.labels = loaded_mat['y'].astype(np.int64).squeeze()
178 |
179 | # the svhn dataset assigns the class label "10" to the digit 0
180 | # this makes it inconsistent with several loss functions
181 | # which expect the class labels to be in the range [0, C-1]
182 | np.place(self.labels, self.labels == 10, 0)
183 | self.data = np.transpose(self.data, (3, 2, 0, 1))
184 |
185 | def __getitem__(self, index):
186 | """
187 | Args:
188 | index (int): Index
189 |
190 | Returns:
191 | tuple: (image, target) where target is index of the target class.
192 | """
193 | img, target = self.data[index], int(self.labels[index])
194 |
195 | # doing this so that it is consistent with all other datasets
196 | # to return a PIL Image
197 | img = Image.fromarray(np.transpose(img, (1, 2, 0)))
198 |
199 | if self.transform is not None:
200 | img = self.transform(img)
201 |
202 | if self.target_transform is not None:
203 | target = self.target_transform(target)
204 |
205 | return img, target, index
206 |
207 | def __len__(self):
208 | return len(self.data)
209 |
210 | def _check_integrity(self):
211 | root = self.root
212 | md5 = self.split_list[self.split][2]
213 | fpath = os.path.join(root, self.filename)
214 | return check_integrity(fpath, md5)
215 |
216 | def download(self):
217 | md5 = self.split_list[self.split][2]
218 | download_url(self.url, self.root, self.filename, md5)
219 |
220 | def extra_repr(self):
221 | return "Split: {split}".format(**self.__dict__)
222 |
223 |
224 | class SVHN_twice(SVHN):
225 | split_list = {
226 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
227 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"],
228 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
229 | "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"],
230 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
231 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
232 |
233 | def __init__(self, root, split='train', transform=None, target_transform=None,
234 | download=False):
235 | super(SVHN_idx, self).__init__(root, transform=transform,
236 | target_transform=target_transform)
237 | self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
238 | self.url = self.split_list[split][0]
239 | self.filename = self.split_list[split][1]
240 | self.file_md5 = self.split_list[split][2]
241 |
242 | if download:
243 | self.download()
244 |
245 | if not self._check_integrity():
246 | raise RuntimeError('Dataset not found or corrupted.' +
247 | ' You can use download=True to download it')
248 |
249 | # import here rather than at top of file because this is
250 | # an optional dependency for torchvision
251 | import scipy.io as sio
252 |
253 | # reading(loading) mat file as array
254 | loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
255 |
256 | self.data = loaded_mat['X']
257 | # loading from the .mat file gives an np array of type np.uint8
258 | # converting to np.int64, so that we have a LongTensor after
259 | # the conversion from the numpy array
260 | # the squeeze is needed to obtain a 1D tensor
261 | self.labels = loaded_mat['y'].astype(np.int64).squeeze()
262 |
263 | # the svhn dataset assigns the class label "10" to the digit 0
264 | # this makes it inconsistent with several loss functions
265 | # which expect the class labels to be in the range [0, C-1]
266 | np.place(self.labels, self.labels == 10, 0)
267 | self.data = np.transpose(self.data, (3, 2, 0, 1))
268 |
269 | def __getitem__(self, index):
270 | """
271 | Args:
272 | index (int): Index
273 |
274 | Returns:
275 | tuple: (image, target) where target is index of the target class.
276 | """
277 | img, target = self.data[index], int(self.labels[index])
278 |
279 | # doing this so that it is consistent with all other datasets
280 | # to return a PIL Image
281 | img = Image.fromarray(np.transpose(img, (1, 2, 0)))
282 |
283 | if self.transform is not None:
284 | img = [self.transform(img), self.transform(img)]
285 |
286 | if self.target_transform is not None:
287 | target = self.target_transform(target)
288 |
289 | return img, target, index
290 |
291 | def __len__(self):
292 | return len(self.data)
--------------------------------------------------------------------------------
/code/data/office-caltech/dslr_list.txt:
--------------------------------------------------------------------------------
1 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0001.jpg 0
2 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0002.jpg 0
3 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0003.jpg 0
4 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0004.jpg 0
5 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0005.jpg 0
6 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0006.jpg 0
7 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0007.jpg 0
8 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0008.jpg 0
9 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0009.jpg 0
10 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0010.jpg 0
11 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0011.jpg 0
12 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/back_pack/frame_0012.jpg 0
13 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0001.jpg 1
14 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0002.jpg 1
15 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0003.jpg 1
16 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0004.jpg 1
17 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0005.jpg 1
18 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0006.jpg 1
19 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0007.jpg 1
20 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0008.jpg 1
21 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0009.jpg 1
22 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0010.jpg 1
23 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0011.jpg 1
24 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0012.jpg 1
25 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0013.jpg 1
26 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0014.jpg 1
27 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0015.jpg 1
28 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0016.jpg 1
29 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0017.jpg 1
30 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0018.jpg 1
31 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0019.jpg 1
32 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0020.jpg 1
33 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/bike/frame_0021.jpg 1
34 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0001.jpg 2
35 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0002.jpg 2
36 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0003.jpg 2
37 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0004.jpg 2
38 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0005.jpg 2
39 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0006.jpg 2
40 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0007.jpg 2
41 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0008.jpg 2
42 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0009.jpg 2
43 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0010.jpg 2
44 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0011.jpg 2
45 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/calculator/frame_0012.jpg 2
46 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0001.jpg 3
47 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0002.jpg 3
48 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0003.jpg 3
49 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0004.jpg 3
50 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0005.jpg 3
51 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0006.jpg 3
52 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0007.jpg 3
53 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0008.jpg 3
54 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0009.jpg 3
55 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0010.jpg 3
56 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0011.jpg 3
57 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0012.jpg 3
58 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/headphones/frame_0013.jpg 3
59 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0001.jpg 4
60 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0002.jpg 4
61 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0003.jpg 4
62 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0004.jpg 4
63 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0005.jpg 4
64 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0006.jpg 4
65 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0007.jpg 4
66 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0008.jpg 4
67 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0009.jpg 4
68 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/keyboard/frame_0010.jpg 4
69 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0001.jpg 5
70 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0002.jpg 5
71 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0003.jpg 5
72 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0004.jpg 5
73 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0005.jpg 5
74 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0006.jpg 5
75 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0007.jpg 5
76 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0008.jpg 5
77 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0009.jpg 5
78 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0010.jpg 5
79 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0011.jpg 5
80 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0012.jpg 5
81 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0013.jpg 5
82 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0014.jpg 5
83 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0015.jpg 5
84 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0016.jpg 5
85 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0017.jpg 5
86 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0018.jpg 5
87 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0019.jpg 5
88 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0020.jpg 5
89 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0021.jpg 5
90 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0022.jpg 5
91 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0023.jpg 5
92 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/laptop_computer/frame_0024.jpg 5
93 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0001.jpg 6
94 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0002.jpg 6
95 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0003.jpg 6
96 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0004.jpg 6
97 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0005.jpg 6
98 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0006.jpg 6
99 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0007.jpg 6
100 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0008.jpg 6
101 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0009.jpg 6
102 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0010.jpg 6
103 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0011.jpg 6
104 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0012.jpg 6
105 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0013.jpg 6
106 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0014.jpg 6
107 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0015.jpg 6
108 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0016.jpg 6
109 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0017.jpg 6
110 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0018.jpg 6
111 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0019.jpg 6
112 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0020.jpg 6
113 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0021.jpg 6
114 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/monitor/frame_0022.jpg 6
115 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0001.jpg 7
116 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0002.jpg 7
117 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0003.jpg 7
118 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0004.jpg 7
119 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0005.jpg 7
120 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0006.jpg 7
121 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0007.jpg 7
122 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0008.jpg 7
123 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0009.jpg 7
124 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0010.jpg 7
125 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0011.jpg 7
126 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mouse/frame_0012.jpg 7
127 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mug/frame_0001.jpg 8
128 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mug/frame_0002.jpg 8
129 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mug/frame_0003.jpg 8
130 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mug/frame_0004.jpg 8
131 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mug/frame_0005.jpg 8
132 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mug/frame_0006.jpg 8
133 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mug/frame_0007.jpg 8
134 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/mug/frame_0008.jpg 8
135 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0001.jpg 9
136 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0002.jpg 9
137 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0003.jpg 9
138 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0004.jpg 9
139 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0005.jpg 9
140 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0006.jpg 9
141 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0007.jpg 9
142 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0008.jpg 9
143 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0009.jpg 9
144 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0010.jpg 9
145 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0011.jpg 9
146 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0012.jpg 9
147 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0013.jpg 9
148 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0014.jpg 9
149 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0015.jpg 9
150 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0016.jpg 9
151 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0017.jpg 9
152 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0018.jpg 9
153 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0019.jpg 9
154 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0020.jpg 9
155 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0021.jpg 9
156 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0022.jpg 9
157 | /Checkpoint/liangjian/da_dataset/office31/dslr/images/projector/frame_0023.jpg 9
158 |
--------------------------------------------------------------------------------
/code/uda/image_source.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import torchvision
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torchvision import transforms
10 | import network
11 | import loss
12 | from torch.utils.data import DataLoader
13 | from data_list import ImageList
14 | import random, pdb, math, copy
15 | from tqdm import tqdm
16 | from loss import CrossEntropyLabelSmooth
17 | from scipy.spatial.distance import cdist
18 | from sklearn.metrics import confusion_matrix
19 | from sklearn.cluster import KMeans
20 |
21 | def op_copy(optimizer):
22 | for param_group in optimizer.param_groups:
23 | param_group['lr0'] = param_group['lr']
24 | return optimizer
25 |
26 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75):
27 | decay = (1 + gamma * iter_num / max_iter) ** (-power)
28 | for param_group in optimizer.param_groups:
29 | param_group['lr'] = param_group['lr0'] * decay
30 | param_group['weight_decay'] = 1e-3
31 | param_group['momentum'] = 0.9
32 | param_group['nesterov'] = True
33 | return optimizer
34 |
35 | def image_train(resize_size=256, crop_size=224, alexnet=False):
36 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
37 | std=[0.229, 0.224, 0.225])
38 | return transforms.Compose([
39 | transforms.Resize((resize_size, resize_size)),
40 | transforms.RandomCrop(crop_size),
41 | transforms.RandomHorizontalFlip(),
42 | transforms.ToTensor(),
43 | normalize
44 | ])
45 |
46 | def image_test(resize_size=256, crop_size=224, alexnet=False):
47 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
48 | std=[0.229, 0.224, 0.225])
49 | return transforms.Compose([
50 | transforms.Resize((resize_size, resize_size)),
51 | transforms.CenterCrop(crop_size),
52 | transforms.ToTensor(),
53 | normalize
54 | ])
55 |
56 | def data_load(args):
57 | ## prepare data
58 | dsets = {}
59 | dset_loaders = {}
60 | train_bs = args.batch_size
61 | txt_src = open(args.s_dset_path).readlines()
62 | txt_test = open(args.test_dset_path).readlines()
63 |
64 | dsize = len(txt_src)
65 | tr_size = int(0.9*dsize)
66 | # print(dsize, tr_size, dsize - tr_size)
67 | tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])
68 |
69 | dsets["source_tr"] = ImageList(tr_txt, transform=image_train())
70 | dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True,
71 | num_workers=args.worker, drop_last=False)
72 | dsets["source_te"] = ImageList(te_txt, transform=image_test())
73 | dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs*3, shuffle=False,
74 | num_workers=args.worker, drop_last=False)
75 | dsets["test"] = ImageList(txt_test, transform=image_test())
76 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False,
77 | num_workers=args.worker, drop_last=False)
78 |
79 | return dset_loaders
80 |
81 | def cal_acc(loader, netF, netB, netC, flag=False):
82 | start_test = True
83 | with torch.no_grad():
84 | iter_test = iter(loader)
85 | for i in range(len(loader)):
86 | data = iter_test.next()
87 | inputs = data[0]
88 | labels = data[1]
89 | inputs = inputs.cuda()
90 | outputs = netC(netB(netF(inputs)))
91 | if start_test:
92 | all_output = outputs.float().cpu()
93 | all_label = labels.float()
94 | start_test = False
95 | else:
96 | all_output = torch.cat((all_output, outputs.float().cpu()), 0)
97 | all_label = torch.cat((all_label, labels.float()), 0)
98 |
99 | all_output = nn.Softmax(dim=1)(all_output)
100 | _, predict = torch.max(all_output, 1)
101 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
102 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item()
103 |
104 | if flag:
105 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
106 | matrix = matrix[np.unique(all_label).astype(int),:]
107 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100
108 | aacc = acc.mean()
109 | aa = [str(np.round(i, 2)) for i in acc]
110 | acc = ' '.join(aa)
111 | return aacc, acc
112 | else:
113 | return accuracy*100, mean_ent
114 |
115 | def train_source(args):
116 | dset_loaders = data_load(args)
117 | ## set base network
118 | if args.net[0:3] == 'res':
119 | netF = network.ResBase(res_name=args.net).cuda()
120 | elif args.net[0:3] == 'vgg':
121 | netF = network.VGGBase(vgg_name=args.net).cuda()
122 |
123 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
124 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
125 |
126 | param_group = []
127 | learning_rate = args.lr
128 | for k, v in netF.named_parameters():
129 | param_group += [{'params': v, 'lr': learning_rate*0.1}]
130 | for k, v in netB.named_parameters():
131 | param_group += [{'params': v, 'lr': learning_rate}]
132 | for k, v in netC.named_parameters():
133 | param_group += [{'params': v, 'lr': learning_rate}]
134 | optimizer = optim.SGD(param_group)
135 | optimizer = op_copy(optimizer)
136 |
137 | acc_init = 0
138 | max_iter = args.max_epoch * len(dset_loaders["source_tr"])
139 | interval_iter = max_iter // 10
140 | iter_num = 0
141 |
142 | netF.train()
143 | netB.train()
144 | netC.train()
145 |
146 | while iter_num < max_iter:
147 | try:
148 | inputs_source, labels_source = iter_source.next()
149 | except:
150 | iter_source = iter(dset_loaders["source_tr"])
151 | inputs_source, labels_source = iter_source.next()
152 |
153 | if inputs_source.size(0) == 1:
154 | continue
155 |
156 | iter_num += 1
157 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
158 |
159 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda()
160 | outputs_source = netC(netB(netF(inputs_source)))
161 | classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source)
162 |
163 | optimizer.zero_grad()
164 | classifier_loss.backward()
165 | optimizer.step()
166 |
167 | if iter_num % interval_iter == 0 or iter_num == max_iter:
168 | netF.eval()
169 | netB.eval()
170 | netC.eval()
171 | if args.dset=='VISDA-C':
172 | acc_s_te, acc_list = cal_acc(dset_loaders['source_te'], netF, netB, netC, True)
173 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te) + '\n' + acc_list
174 | else:
175 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC, False)
176 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te)
177 | args.out_file.write(log_str + '\n')
178 | args.out_file.flush()
179 | print(log_str+'\n')
180 |
181 | if acc_s_te >= acc_init:
182 | acc_init = acc_s_te
183 | best_netF = netF.state_dict()
184 | best_netB = netB.state_dict()
185 | best_netC = netC.state_dict()
186 |
187 | netF.train()
188 | netB.train()
189 | netC.train()
190 |
191 | torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt"))
192 | torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt"))
193 | torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt"))
194 |
195 | return netF, netB, netC
196 |
197 | def test_target(args):
198 | dset_loaders = data_load(args)
199 | ## set base network
200 | if args.net[0:3] == 'res':
201 | netF = network.ResBase(res_name=args.net).cuda()
202 | elif args.net[0:3] == 'vgg':
203 | netF = network.VGGBase(vgg_name=args.net).cuda()
204 |
205 | netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
206 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
207 |
208 | args.modelpath = args.output_dir_src + '/source_F.pt'
209 | netF.load_state_dict(torch.load(args.modelpath))
210 | args.modelpath = args.output_dir_src + '/source_B.pt'
211 | netB.load_state_dict(torch.load(args.modelpath))
212 | args.modelpath = args.output_dir_src + '/source_C.pt'
213 | netC.load_state_dict(torch.load(args.modelpath))
214 | netF.eval()
215 | netB.eval()
216 | netC.eval()
217 |
218 | if args.dset=='VISDA-C':
219 | acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True)
220 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc) + '\n' + acc_list
221 | else:
222 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False)
223 | log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc)
224 |
225 | args.out_file.write(log_str)
226 | args.out_file.flush()
227 | print(log_str)
228 |
229 | def print_args(args):
230 | s = "==========================================\n"
231 | for arg, content in args.__dict__.items():
232 | s += "{}:{}\n".format(arg, content)
233 | return s
234 |
235 | if __name__ == "__main__":
236 | parser = argparse.ArgumentParser(description='SHOT++')
237 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
238 | parser.add_argument('--s', type=int, default=0, help="source")
239 | parser.add_argument('--t', type=int, default=1, help="target")
240 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations")
241 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size")
242 | parser.add_argument('--worker', type=int, default=4, help="number of workers")
243 | parser.add_argument('--dset', type=str, default='office-home', choices=['VISDA-C', 'office', 'office-home'])
244 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
245 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet18, resnet34, resnet50, resnet101")
246 | parser.add_argument('--seed', type=int, default=2020, help="random seed")
247 | parser.add_argument('--bottleneck', type=int, default=256)
248 | parser.add_argument('--epsilon', type=float, default=1e-5)
249 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
250 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
251 | parser.add_argument('--smooth', type=float, default=0.1)
252 | parser.add_argument('--output', type=str, default='san')
253 | parser.add_argument('--da', type=str, default='uda', choices=['uda'])
254 | args = parser.parse_args()
255 |
256 | if args.dset == 'office-home':
257 | names = ['Art', 'Clipart', 'Product', 'RealWorld']
258 | args.class_num = 65
259 | elif args.dset == 'office':
260 | names = ['amazon', 'dslr', 'webcam']
261 | args.class_num = 31
262 | elif args.dset == 'VISDA-C':
263 | names = ['train', 'validation']
264 | args.class_num = 12
265 | args.lr = 1e-3
266 |
267 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
268 | SEED = args.seed
269 | torch.manual_seed(SEED)
270 | torch.cuda.manual_seed(SEED)
271 | np.random.seed(SEED)
272 | random.seed(SEED)
273 | # torch.backends.cudnn.deterministic = True
274 |
275 | folder = '../data/'
276 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt'
277 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
278 |
279 | args.output_dir_src = osp.join(args.output, args.da, args.dset, names[args.s][0].upper())
280 | args.name_src = names[args.s][0].upper()
281 | if not osp.exists(args.output_dir_src):
282 | os.system('mkdir -p ' + args.output_dir_src)
283 | if not osp.exists(args.output_dir_src):
284 | os.mkdir(args.output_dir_src)
285 |
286 | args.out_file = open(osp.join(args.output_dir_src, 'log.txt'), 'w')
287 | args.out_file.write(print_args(args)+'\n')
288 | args.out_file.flush()
289 | train_source(args)
290 |
291 | args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w')
292 | for i in range(len(names)):
293 | if i == args.s:
294 | continue
295 | args.t = i
296 | args.name = names[args.s][0].upper() + names[args.t][0].upper()
297 |
298 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt'
299 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt'
300 |
301 | test_target(args)
302 |
303 | args.out_file.close()
--------------------------------------------------------------------------------