├── .gitignore
├── .idea
├── SymNets.iml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── LICENSE
├── README.md
└── symnets
├── data
├── __init__.py
├── folder_new.py
└── prepare_data.py
├── main.py
├── models
├── DomainClassifierSource.py
├── DomainClassifierTarget.py
├── EntropyMinimizationPrinciple.py
├── __init__.py
└── resnet.py
├── opts.py
├── run.sh
└── trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/.idea/SymNets.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 | 1558150423042
73 |
74 |
75 | 1558150423042
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Yabin Zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SymNets
2 | Official PyTroch implementation for ["Domain-Symnetric Networks for Adversarial Domain Adaptation (CVPR 2019)"](http://openaccess.thecvf.com/content_CVPR_2019/papers/Zhang_Domain-Symmetric_Networks_for_Adversarial_Domain_Adaptation_CVPR_2019_paper.pdf).
3 |
4 | # News!
5 | An extension of this work is recently accepted by **TPAMI 2020**, including
6 | * A new theoretical framework closely supports/motivates a series of algorithms, including SymNets and MCD.
7 | * A algorithm improvement and an unified framework for adversarial UDA.
8 | * Excellent results of SymNets on tasks of partial and open set UDA.
9 |
10 | TPAMI Paper ["Unsupervised Multi-Class Domain Adaptation: Theory, Algorithms, and Practice"](https://arxiv.org/pdf/2002.08681.pdf). and
11 | [Codes](https://github.com/YBZh/MultiClassDA) .
12 | ### Prerequisites
13 | Linux
14 |
15 | NVIDIA GPU + CUDA (may CuDNN) and corresponding PyTorch framework (version 0.5.0)
16 |
17 | Python 3.6
18 |
19 | ### Training and Evaluation
20 | Please refer to 'run.sh'
21 |
22 | ## Citation
23 |
24 | @inproceedings{zhang2019domain,
25 | title={Domain-symmetric networks for adversarial domain adaptation},
26 | author={Zhang, Yabin and Tang, Hui and Jia, Kui and Tan, Mingkui},
27 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
28 | pages={5031--5040},
29 | year={2019}
30 | }
31 | @article{zhang2020unsupervised,
32 | title={Unsupervised multi-class domain adaptation: Theory, algorithms, and practice},
33 | author={Zhang, Yabin and Deng, Bin and Tang, Hui and Zhang, Lei and Jia, Kui},
34 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
35 | year={2020},
36 | publisher={IEEE}
37 | }
38 |
39 | ## Contact
40 | If you have any problem about our code, feel free to contact
41 | - zhang.yabin@mail.scut.edu.cn
42 |
43 | or describe your problem in Issues.
44 |
--------------------------------------------------------------------------------
/symnets/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/SymNets/45e023762c80cfe9a9b625e2e01c9e989150a4f8/symnets/data/__init__.py
--------------------------------------------------------------------------------
/symnets/data/folder_new.py:
--------------------------------------------------------------------------------
1 | ### Modify the ImageFolder function to get the image path in the data loader
2 | import torch.utils.data as data
3 |
4 | from PIL import Image
5 | import os
6 | import os.path
7 | from PIL import ImageFile
8 |
9 | ImageFile.LOAD_TRUNCATED_IMAGES = True ## used to handle some error when loading the special images.
10 |
11 | print('the data loader file has been modified')
12 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
13 |
14 |
15 | def is_image_file(filename):
16 | """Checks if a file is an image.
17 | Args:
18 | filename (string): path to a file
19 | Returns:
20 | bool: True if the filename ends with a known image extension
21 | """
22 | filename_lower = filename.lower()
23 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)
24 |
25 |
26 | def find_classes(dir):
27 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
28 | classes.sort()
29 | class_to_idx = {classes[i]: i for i in range(len(classes))}
30 | return classes, class_to_idx
31 |
32 |
33 | def make_dataset(dir, class_to_idx):
34 | images = []
35 | dir = os.path.expanduser(dir)
36 | for target in sorted(os.listdir(dir)):
37 | d = os.path.join(dir, target)
38 | if not os.path.isdir(d):
39 | continue
40 |
41 | for root, _, fnames in sorted(os.walk(d)):
42 | for fname in sorted(fnames):
43 | if is_image_file(fname):
44 | path = os.path.join(root, fname)
45 | item = (path, class_to_idx[target])
46 | images.append(item)
47 |
48 | return images
49 |
50 |
51 | def pil_loader(path):
52 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
53 | with open(path, 'rb') as f:
54 | img = Image.open(f)
55 | return img.convert('RGB')
56 |
57 |
58 | def accimage_loader(path):
59 | import accimage
60 | try:
61 | return accimage.Image(path)
62 | except IOError:
63 | # Potentially a decoding problem, fall back to PIL.Image
64 | return pil_loader(path)
65 |
66 |
67 | def default_loader(path):
68 | from torchvision import get_image_backend
69 | if get_image_backend() == 'accimage':
70 | return accimage_loader(path)
71 | else:
72 | return pil_loader(path)
73 |
74 |
75 | class ImageFolder_new(data.Dataset):
76 | """A generic data loader where the images are arranged in this way: ::
77 | root/dog/xxx.png
78 | root/dog/xxy.png
79 | root/dog/xxz.png
80 | root/cat/123.png
81 | root/cat/nsdf3.png
82 | root/cat/asd932_.png
83 | Args:
84 | root (string): Root directory path.
85 | transform (callable, optional): A function/transform that takes in an PIL image
86 | and returns a transformed version. E.g, ``transforms.RandomCrop``
87 | target_transform (callable, optional): A function/transform that takes in the
88 | target and transforms it.
89 | loader (callable, optional): A function to load an image given its path.
90 | Attributes:
91 | classes (list): List of the class names.
92 | class_to_idx (dict): Dict with items (class_name, class_index).
93 | imgs (list): List of (image path, class_index) tuples
94 | """
95 |
96 | def __init__(self, root, transform=None, target_transform=None,
97 | loader=default_loader):
98 | classes, class_to_idx = find_classes(root)
99 | imgs = make_dataset(root, class_to_idx)
100 | if len(imgs) == 0:
101 | raise (RuntimeError("Found 0 images in subfolders of: " + root + "\n"
102 | "Supported image extensions are: " + ",".join(
103 | IMG_EXTENSIONS)))
104 |
105 | self.root = root
106 | self.imgs = imgs
107 | self.classes = classes
108 | self.class_to_idx = class_to_idx
109 | self.transform = transform
110 | self.target_transform = target_transform
111 | self.loader = loader
112 |
113 | def __getitem__(self, index):
114 | """
115 | Args:
116 | index (int): Index
117 | Returns:
118 | tuple: (image, target) where target is class_index of the target class.
119 | """
120 | path, target = self.imgs[index]
121 | img = self.loader(path)
122 | if self.transform is not None:
123 | img = self.transform(img)
124 | if self.target_transform is not None:
125 | target = self.target_transform(target)
126 |
127 | return img, target, path
128 |
129 | def __len__(self):
130 | return len(self.imgs)
131 |
132 | def __repr__(self):
133 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
134 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
135 | fmt_str += ' Root Location: {}\n'.format(self.root)
136 | tmp = ' Transforms (if any): '
137 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
138 | tmp = ' Target Transforms (if any): '
139 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
140 | return fmt_str
141 |
--------------------------------------------------------------------------------
/symnets/data/prepare_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.transforms as transforms
4 | import torchvision.datasets as datasets
5 | from data.folder_new import ImageFolder_new
6 |
7 |
8 | def generate_dataloader(args):
9 | # Data loading code
10 | traindir_source = os.path.join(args.data_path_source, args.src)
11 | traindir_target = os.path.join(args.data_path_source_t, args.src_t)
12 | valdir = os.path.join(args.data_path_target, args.tar)
13 | if not os.path.isdir(traindir_source):
14 | # split_train_test_images(args.data_path)
15 | raise ValueError('Null path of source train data!!!')
16 |
17 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
18 | std=[0.229, 0.224, 0.225])
19 | source_train_dataset = datasets.ImageFolder(
20 | traindir_source,
21 | transforms.Compose([
22 | transforms.Resize(256),
23 | transforms.RandomCrop(224),
24 | transforms.RandomHorizontalFlip(),
25 | transforms.ToTensor(),
26 | normalize,
27 | ])
28 | )
29 | source_train_loader = torch.utils.data.DataLoader(
30 | source_train_dataset, batch_size=args.batch_size_s, shuffle=True,
31 | drop_last=True, num_workers=args.workers, pin_memory=True, sampler=None
32 | )
33 |
34 | source_val_dataset = ImageFolder_new(
35 | traindir_source,
36 | transforms.Compose([
37 | transforms.Resize(256),
38 | transforms.CenterCrop(224),
39 | transforms.ToTensor(),
40 | normalize,
41 | ])
42 | )
43 | source_val_loader = torch.utils.data.DataLoader(
44 | source_val_dataset, batch_size=args.batch_size_s, shuffle=False,
45 | num_workers=args.workers, pin_memory=True, sampler=None
46 | )
47 |
48 | target_train_dataset = datasets.ImageFolder(
49 | traindir_target,
50 | transforms.Compose([
51 | transforms.Resize(256),
52 | transforms.RandomCrop(224),
53 | transforms.RandomHorizontalFlip(),
54 | transforms.ToTensor(),
55 | normalize,
56 | ])
57 | )
58 | target_train_loader = torch.utils.data.DataLoader(
59 | target_train_dataset, batch_size=args.batch_size_t, shuffle=True,
60 | drop_last=True, num_workers=args.workers, pin_memory=True, sampler=None
61 | )
62 | target_val_loader = torch.utils.data.DataLoader(
63 | ImageFolder_new(valdir, transforms.Compose([
64 | transforms.Resize(256),
65 | transforms.CenterCrop(224),
66 | transforms.ToTensor(),
67 | normalize,
68 | ])),
69 | batch_size=args.batch_size_t, shuffle=False,
70 | num_workers=args.workers, pin_memory=True
71 | )
72 | return source_train_loader, source_val_loader, target_train_loader, target_val_loader
73 |
74 |
--------------------------------------------------------------------------------
/symnets/main.py:
--------------------------------------------------------------------------------
1 | ##############################################################################
2 | # The simplified official code for the CVPR19 paper: Domain-Symnetric Networks for Adversarial Domain Adaptation
3 | ##############################################################################
4 | import json
5 | import os
6 | import shutil
7 | import time
8 |
9 | import torch.backends.cudnn as cudnn
10 | import torch.nn as nn
11 | import torch.optim
12 | from data.prepare_data import generate_dataloader # Prepare the data and dataloader
13 | from models.resnet import resnet # The model construction
14 | from opts import opts # The options for the project
15 | from trainer import train # For the training process
16 | from trainer import validate # For the validate (test) process
17 | from trainer import adjust_learning_rate
18 | from models.DomainClassifierTarget import DClassifierForTarget
19 | from models.DomainClassifierSource import DClassifierForSource
20 | from models.EntropyMinimizationPrinciple import EMLossForTarget
21 | import ipdb
22 |
23 | best_prec1 = 0
24 |
25 | def main():
26 | global args, best_prec1
27 | current_epoch = 0
28 | epoch_count_dataset = 'source' ##
29 | args = opts()
30 | if args.arch == 'alexnet':
31 | raise ValueError('the request arch is not prepared', args.arch)
32 | # model = alexnet(args)
33 | # for param in model.named_parameters():
34 | # if param[0].find('features1') != -1:
35 | # param[1].require_grad = False
36 | elif args.arch.find('resnet') != -1:
37 | model = resnet(args)
38 | else:
39 | raise ValueError('Unavailable model architecture!!!')
40 | # define-multi GPU
41 | model = torch.nn.DataParallel(model).cuda()
42 | print(model)
43 | criterion_classifier_target = DClassifierForTarget(nClass=args.num_classes).cuda()
44 | criterion_classifier_source = DClassifierForSource(nClass=args.num_classes).cuda()
45 | criterion_em_target = EMLossForTarget(nClass=args.num_classes).cuda()
46 | criterion = nn.CrossEntropyLoss().cuda()
47 | # To apply different learning rate to different layer
48 | if args.arch == 'alexnet':
49 | optimizer = torch.optim.SGD([
50 | # {'params': model.module.features1.parameters(), 'name': 'pre-trained'},
51 | {'params': model.module.features2.parameters(), 'name': 'pre-trained'},
52 | {'params': model.module.classifier.parameters(), 'name': 'pre-trained'},
53 | {'params': model.module.fc.parameters(), 'name': 'new-added'}
54 | ],
55 | lr=args.lr,
56 | momentum=args.momentum,
57 | weight_decay=args.weight_decay,
58 | nesterov=True)
59 | elif args.arch.find('resnet') != -1:
60 | optimizer = torch.optim.SGD([
61 | {'params': model.module.conv1.parameters(), 'name': 'pre-trained'},
62 | {'params': model.module.bn1.parameters(), 'name': 'pre-trained'},
63 | {'params': model.module.layer1.parameters(), 'name': 'pre-trained'},
64 | {'params': model.module.layer2.parameters(), 'name': 'pre-trained'},
65 | {'params': model.module.layer3.parameters(), 'name': 'pre-trained'},
66 | {'params': model.module.layer4.parameters(), 'name': 'pre-trained'},
67 | #{'params': model.module.fc.parameters(), 'name': 'pre-trained'}
68 | {'params': model.module.fc.parameters(), 'name': 'new-added'}
69 | ],
70 | lr=args.lr,
71 | momentum=args.momentum,
72 | weight_decay=args.weight_decay,
73 | nesterov=True)
74 | else:
75 | raise ValueError('Unavailable model architecture!!!')
76 |
77 | # optionally resume from a checkpoint
78 | if args.resume:
79 | if os.path.isfile(args.resume):
80 | print("==> loading checkpoints '{}'".format(args.resume))
81 | checkpoint = torch.load(args.resume)
82 | args.start_epoch = checkpoint['epoch']
83 | best_prec1 = checkpoint['best_prec1']
84 | model.load_state_dict(checkpoint['state_dict'])
85 | optimizer.load_state_dict(checkpoint['optimizer'])
86 | print("==> loaded checkpoint '{}'(epoch {})"
87 | .format(args.resume, checkpoint['epoch']))
88 | else:
89 | raise ValueError('The file to be resumed from is not exited', args.resume)
90 | if not os.path.isdir(args.log):
91 | os.makedirs(args.log)
92 | log = open(os.path.join(args.log, 'log.txt'), 'a')
93 | state = {k: v for k, v in args._get_kwargs()}
94 | log.write(json.dumps(state) + '\n')
95 | log.close()
96 |
97 | cudnn.benchmark = True
98 | # process the data and prepare the dataloaders.
99 | source_train_loader, source_val_loader, target_train_loader, val_loader = generate_dataloader(args)
100 | #test only
101 | if args.test_only:
102 | validate(val_loader, model, criterion, -1, args)
103 | return
104 | # start time
105 | log = open(os.path.join(args.log, 'log.txt'), 'a')
106 | log.write('\n-------------------------------------------\n')
107 | log.write(time.asctime(time.localtime(time.time())))
108 | log.write('\n-------------------------------------------')
109 | log.close()
110 | source_train_loader_batch = enumerate(source_train_loader)
111 | target_train_loader_batch = enumerate(target_train_loader)
112 | batch_number_s = len(source_train_loader)
113 | batch_number_t = len(target_train_loader)
114 | if batch_number_s < batch_number_t:
115 | epoch_count_dataset = 'target'
116 | while (current_epoch < args.epochs):
117 | # train for one iteration
118 | adjust_learning_rate(optimizer, current_epoch, args)
119 | source_train_loader_batch, target_train_loader_batch, current_epoch, new_epoch_flag = train(source_train_loader, source_train_loader_batch, target_train_loader, target_train_loader_batch, model, criterion_classifier_source, criterion_classifier_target, criterion_em_target, criterion, optimizer, current_epoch, epoch_count_dataset, args)
120 | # evaluate on the val data
121 | if new_epoch_flag:
122 | if (current_epoch + 1) % args.test_freq == 0 or current_epoch == 0:
123 | prec1 = validate(val_loader, model, criterion, current_epoch, args)
124 | # record the best prec1 and save checkpoint
125 | is_best = prec1 > best_prec1
126 | best_prec1 = max(prec1, best_prec1)
127 | if is_best:
128 | log = open(os.path.join(args.log, 'log.txt'), 'a')
129 | log.write(' Best acc: %3f' % (best_prec1))
130 | log.close()
131 | save_checkpoint({
132 | 'epoch': current_epoch + 1,
133 | 'arch': args.arch,
134 | 'state_dict': model.state_dict(),
135 | 'best_prec1': best_prec1,
136 | 'optimizer' : optimizer.state_dict(),
137 | }, is_best, args)
138 |
139 | # end time
140 | log = open(os.path.join(args.log, 'log.txt'), 'a')
141 | log.write('\n-------------------------------------------\n')
142 | log.write(time.asctime(time.localtime(time.time())))
143 | log.write('\n-------------------------------------------\n')
144 | log.close()
145 |
146 |
147 | def save_checkpoint(state, is_best, args):
148 | filename = 'checkpoint.pth.tar'
149 | dir_save_file = os.path.join(args.log, filename)
150 | torch.save(state, dir_save_file)
151 | if is_best:
152 | shutil.copyfile(dir_save_file, os.path.join(args.log, 'model_best.pth.tar'))
153 |
154 |
155 | if __name__ == '__main__':
156 | main()
157 |
158 |
159 |
160 |
161 |
162 |
--------------------------------------------------------------------------------
/symnets/models/DomainClassifierSource.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 |
7 | def _assert_no_grad(variable):
8 | assert not variable.requires_grad, \
9 | "nn criterions don't compute the gradient w.r.t. targets - please " \
10 | "mark these variables as volatile or not requiring gradients"
11 |
12 |
13 | class _Loss(nn.Module):
14 | def __init__(self, size_average=True):
15 | super(_Loss, self).__init__()
16 | self.size_average = size_average
17 |
18 |
19 | class _WeightedLoss(_Loss):
20 | def __init__(self, weight=None, size_average=True):
21 | super(_WeightedLoss, self).__init__(size_average)
22 | self.register_buffer('weight', weight)
23 |
24 |
25 | class DClassifierForSource(_WeightedLoss):
26 |
27 | def __init__(self, weight=None, size_average=True, ignore_index=-100, reduce=True, nClass=10):
28 | super(DClassifierForSource, self).__init__(weight, size_average)
29 | self.nClass = nClass
30 |
31 | def forward(self, input):
32 | # _assert_no_grad(target)
33 | batch_size = input.size(0)
34 | prob = F.softmax(input, dim=1)
35 | if (prob.data[:, :self.nClass].sum(1) == 0).sum() != 0: ########### in case of log(0)
36 | soft_weight = torch.FloatTensor(batch_size).fill_(0)
37 | soft_weight[prob[:, :self.nClass].sum(1).data.cpu() == 0] = 1e-6
38 | soft_weight_var = Variable(soft_weight).cuda()
39 | loss = -((prob[:, :self.nClass].sum(1) + soft_weight_var).log().mean())
40 | else:
41 | loss = -(prob[:, :self.nClass].sum(1).log().mean())
42 | return loss
--------------------------------------------------------------------------------
/symnets/models/DomainClassifierTarget.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 |
7 | def _assert_no_grad(variable):
8 | assert not variable.requires_grad, \
9 | "nn criterions don't compute the gradient w.r.t. targets - please " \
10 | "mark these variables as volatile or not requiring gradients"
11 |
12 |
13 | class _Loss(nn.Module):
14 | def __init__(self, size_average=True):
15 | super(_Loss, self).__init__()
16 | self.size_average = size_average
17 |
18 |
19 | class _WeightedLoss(_Loss):
20 | def __init__(self, weight=None, size_average=True):
21 | super(_WeightedLoss, self).__init__(size_average)
22 | self.register_buffer('weight', weight)
23 |
24 |
25 | class DClassifierForTarget(_WeightedLoss):
26 |
27 | def __init__(self, weight=None, size_average=True, ignore_index=-100, reduce=True, nClass = 10):
28 | super(DClassifierForTarget, self).__init__(weight, size_average)
29 | self.nClass = nClass
30 |
31 | def forward(self, input):
32 | # _assert_no_grad(target)
33 | batch_size = input.size(0)
34 |
35 | prob = F.softmax(input, dim=1)
36 | if (prob.data[:, self.nClass:].sum(1) == 0).sum() != 0: ########### in case of log(0)
37 | soft_weight = torch.FloatTensor(batch_size).fill_(0)
38 | soft_weight[prob[:, self.nClass:].sum(1).data.cpu() == 0] = 1e-6
39 | soft_weight_var = Variable(soft_weight).cuda()
40 | loss = -((prob[:, self.nClass:].sum(1) + soft_weight_var).log().mean())
41 | else:
42 | loss = -(prob[:,self.nClass:].sum(1).log().mean())
43 | return loss
--------------------------------------------------------------------------------
/symnets/models/EntropyMinimizationPrinciple.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 |
7 | def _assert_no_grad(variable):
8 | assert not variable.requires_grad, \
9 | "nn criterions don't compute the gradient w.r.t. targets - please " \
10 | "mark these variables as volatile or not requiring gradients"
11 |
12 |
13 | class _Loss(nn.Module):
14 | def __init__(self, size_average=True):
15 | super(_Loss, self).__init__()
16 | self.size_average = size_average
17 |
18 |
19 | class _WeightedLoss(_Loss):
20 | def __init__(self, weight=None, size_average=True):
21 | super(_WeightedLoss, self).__init__(size_average)
22 | self.register_buffer('weight', weight)
23 |
24 |
25 | class EMLossForTarget(_WeightedLoss):
26 |
27 | def __init__(self, weight=None, size_average=True, ignore_index=-100, reduce=True, nClass = 10):
28 | super(EMLossForTarget, self).__init__(weight, size_average)
29 | self.nClass = nClass
30 |
31 | def forward(self, input):
32 | batch_size = input.size(0)
33 | prob = F.softmax(input, dim=1)
34 | prob_source = prob[:, :self.nClass]
35 | prob_target = prob[:, self.nClass:]
36 | prob_sum = prob_target + prob_source
37 | if (prob_sum.data.cpu() == 0).sum() != 0:
38 | weight_sum = torch.FloatTensor(batch_size, self.nClass).fill_(0)
39 | weight_sum[prob_sum.data.cpu() == 0] = 1e-6
40 | weight_sum = Variable(weight_sum).cuda()
41 | loss_sum = -(prob_sum + weight_sum).log().mul(prob_sum).sum(1).mean()
42 | else:
43 | loss_sum = -prob_sum.log().mul(prob_sum).sum(1).mean()
44 |
45 | return loss_sum
--------------------------------------------------------------------------------
/symnets/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/SymNets/45e023762c80cfe9a9b625e2e01c9e989150a4f8/symnets/models/__init__.py
--------------------------------------------------------------------------------
/symnets/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 | import torch
5 | import ipdb
6 |
7 |
8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
9 | 'resnet152']
10 |
11 |
12 | model_urls = {
13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
18 | }
19 |
20 |
21 | def conv3x3(in_planes, out_planes, stride=1):
22 | "3x3 convolution with padding"
23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
24 | padding=1, bias=False)
25 |
26 |
27 | class BasicBlock(nn.Module):
28 | expansion = 1
29 |
30 | def __init__(self, inplanes, planes, stride=1, downsample=None):
31 | super(BasicBlock, self).__init__()
32 | self.conv1 = conv3x3(inplanes, planes, stride)
33 | self.bn1 = nn.BatchNorm2d(planes)
34 | self.relu = nn.ReLU(inplace=True)
35 | self.conv2 = conv3x3(planes, planes)
36 | self.bn2 = nn.BatchNorm2d(planes)
37 | self.downsample = downsample
38 | self.stride = stride
39 |
40 | def forward(self, x):
41 | residual = x
42 |
43 | out = self.conv1(x)
44 | out = self.bn1(out)
45 | out = self.relu(out)
46 |
47 | out = self.conv2(out)
48 | out = self.bn2(out)
49 |
50 | if self.downsample is not None:
51 | residual = self.downsample(x)
52 |
53 | out += residual
54 | out = self.relu(out)
55 |
56 | return out
57 |
58 |
59 | class Bottleneck(nn.Module):
60 | expansion = 4
61 |
62 | def __init__(self, inplanes, planes, stride=1, downsample=None):
63 | super(Bottleneck, self).__init__()
64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
65 | self.bn1 = nn.BatchNorm2d(planes)
66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
67 | padding=1, bias=False)
68 | self.bn2 = nn.BatchNorm2d(planes)
69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
70 | self.bn3 = nn.BatchNorm2d(planes * 4)
71 | self.relu = nn.ReLU(inplace=True)
72 | self.downsample = downsample
73 | self.stride = stride
74 |
75 | def forward(self, x):
76 | residual = x
77 |
78 | out = self.conv1(x)
79 | out = self.bn1(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv2(out)
83 | out = self.bn2(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv3(out)
87 | out = self.bn3(out)
88 |
89 | if self.downsample is not None:
90 | residual = self.downsample(x)
91 |
92 | out += residual
93 | out = self.relu(out)
94 |
95 | return out
96 |
97 |
98 | class ResNet(nn.Module):
99 |
100 | def __init__(self, block, layers, num_classes=1000):
101 | self.inplanes = 64
102 | super(ResNet, self).__init__()
103 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
104 | bias=False)
105 | self.bn1 = nn.BatchNorm2d(64)
106 | self.relu = nn.ReLU(inplace=True)
107 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
108 | self.layer1 = self._make_layer(block, 64, layers[0])
109 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
110 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
111 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
112 | self.avgpool = nn.AvgPool2d(7)
113 | self.fc = nn.Linear(512 * block.expansion, num_classes)
114 |
115 | for m in self.modules():
116 | if isinstance(m, nn.Conv2d):
117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
118 | m.weight.data.normal_(0, math.sqrt(2. / n))
119 | elif isinstance(m, nn.BatchNorm2d):
120 | m.weight.data.fill_(1)
121 | m.bias.data.zero_()
122 |
123 | def _make_layer(self, block, planes, blocks, stride=1):
124 | downsample = None
125 | if stride != 1 or self.inplanes != planes * block.expansion:
126 | downsample = nn.Sequential(
127 | nn.Conv2d(self.inplanes, planes * block.expansion,
128 | kernel_size=1, stride=stride, bias=False),
129 | nn.BatchNorm2d(planes * block.expansion),
130 | )
131 |
132 | layers = []
133 | layers.append(block(self.inplanes, planes, stride, downsample))
134 | self.inplanes = planes * block.expansion
135 | for i in range(1, blocks):
136 | layers.append(block(self.inplanes, planes))
137 |
138 | return nn.Sequential(*layers)
139 |
140 | def forward(self, x):
141 | x = self.conv1(x)
142 | x = self.bn1(x)
143 | x = self.relu(x)
144 | x = self.maxpool(x)
145 |
146 | x = self.layer1(x)
147 | x = self.layer2(x)
148 | x = self.layer3(x)
149 | x = self.layer4(x)
150 |
151 | x = self.avgpool(x)
152 | x = x.view(x.size(0), -1)
153 | x = self.fc(x)
154 |
155 | return x
156 |
157 |
158 | def resnet18(args, **kwargs):
159 | """Constructs a ResNet-18 model.
160 | Args:
161 | pretrained (bool): If True, returns a model pre-trained on ImageNet
162 | """
163 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
164 | if args.pretrained:
165 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
166 | # modify the structure of the model.
167 | num_of_feature_map = model.fc.in_features
168 | model.fc = nn.Linear(num_of_feature_map, args.num_classes)
169 | model.fc.weight.data.normal_(0.0, 0.02)
170 | model.fc.bias.data.normal_(0)
171 | return model
172 |
173 |
174 | def resnet34(args, **kwargs):
175 | """Constructs a ResNet-34 model.
176 | Args:
177 | pretrained (bool): If True, returns a model pre-trained on ImageNet
178 | """
179 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
180 | if args.pretrained:
181 | print('Load ImageNet pre-trained resnet model')
182 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
183 | # modify the structure of the model.
184 | num_of_feature_map = model.fc.in_features
185 | model.fc = nn.Linear(num_of_feature_map, args.num_classes)
186 |
187 | return model
188 |
189 |
190 | def resnet50(args, **kwargs):
191 | """Constructs a ResNet-50 model.
192 | Args:
193 | pretrained (bool): If True, returns a model pre-trained on ImageNet
194 | """
195 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
196 | if args.pretrained:
197 | if args.pretrained_checkpoint: ################### use self-pretrained model
198 | # modify the structure of the model.
199 | num_of_feature_map = model.fc.in_features
200 | model.fc = nn.Linear(num_of_feature_map, args.num_classes * 2)
201 | init_dict = model.state_dict()
202 | pretrained_dict_temp = torch.load(args.pretrained_checkpoint)['state_dict']
203 | pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict_temp.items()}
204 | temp = init_dict['fc.weight'].clone()
205 | temp[:args.num_classes, :] = pretrained_dict['fc.weight'].clone()
206 | pretrained_dict['fc.weight'] = temp.clone()
207 | temp = init_dict['fc.bias'].clone()
208 | temp[:args.num_classes] = pretrained_dict['fc.bias'].clone()
209 | pretrained_dict['fc.bias'] = temp.clone()
210 | model.load_state_dict(pretrained_dict)
211 | else:
212 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) ########## use imagenet pretrained model
213 | # modify the structure of the model.
214 | num_of_feature_map = model.fc.in_features
215 | model.fc = nn.Linear(num_of_feature_map, args.num_classes * 2)
216 | return model
217 |
218 |
219 | def resnet101(args, **kwargs):
220 | """Constructs a ResNet-101 model.
221 | Args:
222 | pretrained (bool): If True, returns a model pre-trained on ImageNet
223 | """
224 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
225 | if args.pretrained:
226 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
227 | # modify the structure of the model.
228 | num_of_feature_map = model.fc.in_features
229 | model.fc = nn.Linear(num_of_feature_map, args.num_classes)
230 | model.fc.weight.data.normal_(0.0, 0.02)
231 | model.fc.bias.data.normal_(0)
232 | return model
233 |
234 |
235 | def resnet152(args, **kwargs):
236 | """Constructs a ResNet-152 model.
237 | Args:
238 | pretrained (bool): If True, returns a model pre-trained on ImageNet
239 | """
240 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
241 | if args.pretrained:
242 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
243 | # modify the structure of the model.
244 | num_of_feature_map = model.fc.in_features
245 | model.fc = nn.Linear(num_of_feature_map, args.num_classes)
246 |
247 | return model
248 |
249 |
250 | def resnet(args, **kwargs): ################ Only support ResNet-50 in this simple code.
251 | print("==> creating model '{}' ".format(args.arch))
252 | if args.arch == 'resnet18':
253 | return resnet18(args)
254 | elif args.arch == 'resnet34':
255 | return resnet34(args)
256 | elif args.arch == 'resnet50':
257 | return resnet50(args)
258 | elif args.arch == 'resnet101':
259 | return resnet101(args)
260 | elif args.arch == 'resnet152':
261 | return resnet152(args)
262 | else:
263 | raise ValueError('Unrecognized model architecture', args.arch)
--------------------------------------------------------------------------------
/symnets/opts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def opts():
5 | parser = argparse.ArgumentParser(description='Train alexnet on the cub200 dataset',
6 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
7 | parser.add_argument('--data_path_source', type=str, default='',
8 | help='Root of train data set of the source domain')
9 | parser.add_argument('--data_path_source_t', type=str, default='',
10 | help='Root of train data set of the target domain')
11 | parser.add_argument('--data_path_target', type=str, default='',
12 | help='Root of the test data set of the target domain')
13 | parser.add_argument('--src', type=str, default='amazon',
14 | help='choose between amazon | dslr | webcam')
15 | parser.add_argument('--src_t', type=str, default='webcam',
16 | help='choose between amazon | dslr | webcam')
17 | parser.add_argument('--tar', type=str, default='webcam',
18 | help='choose between amazon | dslr | webcam')
19 | parser.add_argument('--num_classes', type=int, default=31,
20 | help='number of classes of data used to fine-tune the pre-trained model')
21 | # Optimization options
22 | parser.add_argument('--epochs', '-e', type=int, default=200, help='Number of epochs to train')
23 | parser.add_argument('--batch_size_s', '-b-s', type=int, default=128, help='Batch size of the source data.')
24 | parser.add_argument('--batch_size_t', '-b-t', type=int, default=128, help='Batch size of the target data.')
25 | parser.add_argument('--lr', '--learning_rate', type=float, default=0.01, help='The Learning Rate.')
26 | parser.add_argument('--momentum', '-m', type=float, default=0.9, help='Momentum.')
27 | parser.add_argument('--weight_decay', '-wd', type=float, default=0.0001, help='Weight decay (L2 penalty).')
28 | parser.add_argument('--schedule', type=int, nargs='+', default=[79, 119],
29 | help='Decrease learning rate at these epochs[used in step decay].')
30 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
31 | # checkpoints
32 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
33 | help='manual epoch number (useful on restarts)')
34 | parser.add_argument('--resume', type=str, default='', help='Checkpoints path to resume(default none)')
35 | parser.add_argument('--pretrained_checkpoint', type=str, default='', help='Self-Pretrained checkpoint to resume (default none)')
36 | parser.add_argument('--test_only', '-t', action='store_true', help='Test only flag')
37 | # Architecture
38 | parser.add_argument('--arch', type=str, default='resnet50', help='Model name')
39 | parser.add_argument('--flag', type=str, default='original', help='flag for different settings')
40 | parser.add_argument('--pretrained', action='store_true', help='whether using pretrained model')
41 | # i/o
42 | parser.add_argument('--log', type=str, default='./checkpoints', help='Log folder')
43 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
44 | help='number of data loading workers (default: 4)')
45 | parser.add_argument('--test_freq', default=10, type=int,
46 | help='test frequency (default: 1)')
47 | parser.add_argument('--print_freq', '-p', default=10, type=int,
48 | metavar='N', help='print frequency (default: 10)')
49 | parser.add_argument('--score_frep', default=300, type=int,
50 | metavar='N', help='print frequency (default: 300, not download score)')
51 | args = parser.parse_args()
52 |
53 | args.log = args.log + '_' + args.src + '2' + args.tar + '_' + args.arch + '_' + args.flag
54 | return args
55 |
--------------------------------------------------------------------------------
/symnets/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python main.py --data_path_source /data/domain_adaptation/Office31/ --src amazon --epochs 200 --num_classes 31 --print_freq 1 --test_freq 1 \
3 | --data_path_source_t /data/domain_adaptation/Office31/ --src_t webcam --lr 0.01 --gamma 0.1 --weight_decay 1e-4 --workers 4 \
4 | --data_path_target /data/domain_adaptation/Office31/ --tar webcam --pretrained --flag symnet --log office31
5 |
6 |
7 |
--------------------------------------------------------------------------------
/symnets/trainer.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import os
4 | import math
5 | import ipdb
6 | import torch.nn.functional as F
7 |
8 | def train(source_train_loader, source_train_loader_batch, target_train_loader, target_train_loader_batch, model, criterion_classifier_source, criterion_classifier_target, criterion_em_target, criterion, optimizer, epoch, epoch_count_dataset, args):
9 | batch_time = AverageMeter()
10 | data_time = AverageMeter()
11 | losses_classifier = AverageMeter()
12 | losses_G = AverageMeter()
13 | top1_source = AverageMeter()
14 | top1_target = AverageMeter()
15 | model.train()
16 | new_epoch_flag = False
17 | end = time.time()
18 | try:
19 | (input_source, target_source) = source_train_loader_batch.__next__()[1]
20 | except StopIteration:
21 | if epoch_count_dataset == 'source':
22 | epoch = epoch + 1
23 | new_epoch_flag = True
24 | source_train_loader_batch = enumerate(source_train_loader)
25 | (input_source, target_source) = source_train_loader_batch.__next__()[1]
26 |
27 | try:
28 | (input_target, _) = target_train_loader_batch.__next__()[1]
29 | except StopIteration:
30 | if epoch_count_dataset == 'target':
31 | epoch = epoch + 1
32 | new_epoch_flag = True
33 | target_train_loader_batch = enumerate(target_train_loader)
34 | (input_target, _) = target_train_loader_batch.__next__()[1]
35 | data_time.update(time.time() - end)
36 |
37 | target_source_temp = target_source + args.num_classes
38 | target_source_temp = target_source_temp.cuda(async=True)
39 | target_source_temp_var = torch.autograd.Variable(target_source_temp) #### labels for target classifier
40 |
41 | target_source = target_source.cuda(async=True)
42 | input_source_var = torch.autograd.Variable(input_source)
43 | target_source_var = torch.autograd.Variable(target_source) ######## labels for source classifier.
44 | ############################################ for source samples
45 | output_source = model(input_source_var)
46 | loss_task_s_Cs = criterion(output_source[:,:args.num_classes], target_source_var)
47 | loss_task_s_Ct = criterion(output_source[:,args.num_classes:], target_source_var)
48 |
49 | loss_domain_st_Cst_part1 = criterion_classifier_source(output_source)
50 | loss_category_st_G = 0.5 * criterion(output_source, target_source_var) + 0.5 * criterion(output_source, target_source_temp_var)
51 |
52 |
53 | input_target_var = torch.autograd.Variable(input_target)
54 | output_target = model(input_target_var)
55 | loss_domain_st_Cst_part2 = criterion_classifier_target(output_target)
56 | loss_domain_st_G = 0.5 * criterion_classifier_target(output_target) + 0.5 * criterion_classifier_source(output_target)
57 | loss_target_em = criterion_em_target(output_target)
58 |
59 | lam = 2 / (1 + math.exp(-1 * 10 * epoch / args.epochs)) - 1
60 | if args.flag == 'no_em':
61 | loss_classifier = loss_task_s_Cs + loss_task_s_Ct + loss_domain_st_Cst_part1 + loss_domain_st_Cst_part2 ### used to classifier
62 | loss_G = loss_category_st_G + lam * loss_domain_st_G ### used to feature extractor
63 |
64 | elif args.flag == 'symnet': #
65 | loss_classifier = loss_task_s_Cs + loss_task_s_Ct + loss_domain_st_Cst_part1 + loss_domain_st_Cst_part2 ### used to classifier
66 | loss_G = loss_category_st_G + lam * (loss_domain_st_G + loss_target_em) ### used to feature extractor
67 |
68 | else:
69 | raise ValueError('unrecognized flag:', args.flag)
70 |
71 | # mesure accuracy and record loss
72 | prec1_source, _ = accuracy(output_source.data[:, :args.num_classes], target_source, topk=(1,5))
73 | prec1_target, _ = accuracy(output_source.data[:, args.num_classes:], target_source, topk=(1,5))
74 | losses_classifier.update(loss_classifier.data[0], input_source.size(0))
75 | losses_G.update(loss_G.data[0], input_source.size(0))
76 | top1_source.update(prec1_source[0], input_source.size(0))
77 | top1_target.update(prec1_target[0], input_source.size(0))
78 |
79 | #compute gradient and do SGD step
80 | optimizer.zero_grad()
81 | loss_classifier.backward(retain_graph=True)
82 | temp_grad = []
83 | for param in model.parameters():
84 | temp_grad.append(param.grad.data.clone())
85 | grad_for_classifier = temp_grad
86 |
87 | optimizer.zero_grad()
88 | loss_G.backward()
89 | temp_grad = []
90 | for param in model.parameters():
91 | temp_grad.append(param.grad.data.clone())
92 | grad_for_featureExtractor = temp_grad
93 |
94 | count = 0
95 | for param in model.parameters():
96 | temp_grad = param.grad.data.clone()
97 | temp_grad.zero_()
98 | if count < 159: ########### the feautre extractor of the ResNet-50
99 | temp_grad = temp_grad + grad_for_featureExtractor[count]
100 | else:
101 | temp_grad = temp_grad + grad_for_classifier[count]
102 | temp_grad = temp_grad
103 | param.grad.data = temp_grad
104 | count = count + 1
105 | optimizer.step()
106 | batch_time.update(time.time() - end)
107 | end = time.time()
108 | if (epoch + 1) % args.print_freq == 0 or epoch == 0:
109 | print('Train: [{0}/{1}]\t'
110 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
111 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
112 | 'Loss@C {loss_c.val:.4f} ({loss_c.avg:.4f})\t'
113 | 'Loss@G {loss_g.val:.4f} ({loss_g.avg:.4f})\t'
114 | 'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
115 | 'top1T {top1T.val:.3f} ({top1T.avg:.3f})'.format(
116 | epoch, args.epochs, batch_time=batch_time,
117 | data_time=data_time, loss_c=losses_classifier, loss_g=losses_G, top1S=top1_source, top1T=top1_target))
118 | if new_epoch_flag:
119 | log = open(os.path.join(args.log, 'log.txt'), 'a')
120 | log.write("\n")
121 | log.write("Train:epoch: %d, loss@min: %4f, loss@max: %4f, Top1S acc: %3f, Top1T acc: %3f" % (epoch, losses_classifier.avg, losses_G.avg, top1_source.avg, top1_target.avg))
122 | log.close()
123 |
124 | return source_train_loader_batch, target_train_loader_batch, epoch, new_epoch_flag
125 |
126 |
127 | def validate(val_loader, model, criterion, epoch, args):
128 | batch_time = AverageMeter()
129 | losses_source = AverageMeter()
130 | losses_target = AverageMeter()
131 | top1_source = AverageMeter()
132 | top1_target = AverageMeter()
133 | # switch to evaluate mode
134 | model.eval()
135 |
136 | end = time.time()
137 | for i, (input, target,_) in enumerate(val_loader):
138 | target = target.cuda(async=True)
139 | input_var = torch.autograd.Variable(input) #, volatile=True)
140 | target_var = torch.autograd.Variable(target) #, volatile=True)
141 | # compute output
142 | with torch.no_grad():
143 | output = model(input_var)
144 | loss_source = criterion(output[:, :args.num_classes], target_var)
145 | loss_target = criterion(output[:, args.num_classes:], target_var)
146 | # measure accuracy and record loss
147 | prec1_source, _ = accuracy(output.data[:, :args.num_classes], target, topk=(1, 5))
148 | prec1_target, _ = accuracy(output.data[:, args.num_classes:], target, topk=(1, 5))
149 |
150 | losses_source.update(loss_source.data[0], input.size(0))
151 | losses_target.update(loss_target.data[0], input.size(0))
152 |
153 | top1_source.update(prec1_source[0], input.size(0))
154 | top1_target.update(prec1_target[0], input.size(0))
155 | # measure elapsed time
156 | batch_time.update(time.time() - end)
157 | end = time.time()
158 | if i % args.print_freq == 0:
159 | print('Test: [{0}][{1}/{2}]\t'
160 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
161 | 'LS {lossS.val:.4f} ({lossS.avg:.4f})\t'
162 | 'LT {lossT.val:.4f} ({lossT.avg:.4f})\t'
163 | 'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
164 | 'top1T {top1T.val:.3f} ({top1T.avg:.3f})'.format(
165 | epoch, i, len(val_loader), batch_time=batch_time, lossS=losses_source, lossT=losses_target,
166 | top1S=top1_source, top1T=top1_target))
167 |
168 | print(' * Top1@S {top1S.avg:.3f} Top1@T {top1T.avg:.3f}'
169 | .format(top1S=top1_source, top1T=top1_target))
170 | log = open(os.path.join(args.log, 'log.txt'), 'a')
171 | log.write("\n")
172 | log.write(" Test:epoch: %d, LS: %4f, LT: %4f, Top1S: %3f, Top1T: %3f" %\
173 | (epoch, losses_source.avg, losses_target.avg, top1_source.avg, top1_target.avg))
174 | log.close()
175 | return max(top1_source.avg, top1_target.avg)
176 |
177 |
178 | class AverageMeter(object):
179 | """Computes and stores the average and current value"""
180 | def __init__(self):
181 | self.reset()
182 |
183 | def reset(self):
184 | self.val = 0
185 | self.avg = 0
186 | self.sum = 0
187 | self.count = 0
188 |
189 | def update(self, val, n=1):
190 | self.val = val
191 | self.sum += val * n
192 | self.count += n
193 | self.avg = self.sum / self.count
194 |
195 |
196 | def adjust_learning_rate(optimizer, epoch, args):
197 | """Adjust the learning rate according the epoch"""
198 | ## annealing strategy 1
199 | # epoch_total = int(args.epochs / args.test_freq)
200 | # epoch = int((epoch + 1) / args.test_freq)
201 | lr = args.lr / pow((1 + 10 * epoch / args.epochs), 0.75)
202 | lr_pretrain = args.lr * 0.1 / pow((1 + 10 * epoch / args.epochs), 0.75) # 0.001 / pow((1 + 10 * epoch / epoch_total), 0.75)
203 | ## annealing strategy 2
204 | # exp = epoch > args.schedule[1] and 2 or epoch > args.schedule[0] and 1 or 0
205 | # lr = args.lr * (args.gamma ** exp)
206 | # lr_pretrain = lr * 0.1 #1e-3
207 | for param_group in optimizer.param_groups:
208 | if param_group['name'] == 'pre-trained':
209 | param_group['lr'] = lr_pretrain
210 | else:
211 | param_group['lr'] = lr
212 |
213 |
214 |
215 |
216 | def accuracy(output, target, topk=(1,)):
217 | """Computes the precision@k for the specified values of k"""
218 | maxk = max(topk)
219 | batch_size = target.size(0)
220 | _, pred = output.topk(maxk, 1, True, True)
221 | pred = pred.t()
222 | correct = pred.eq(target.view(1, -1).expand_as(pred))
223 |
224 | res = []
225 | for k in topk:
226 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
227 | res.append(correct_k.mul_(100.0 / batch_size))
228 | return res
229 |
--------------------------------------------------------------------------------