├── .DS_Store ├── .gitignore ├── README.md ├── data └── pacs.py ├── dataset └── pacs ├── figs └── sagnet.png ├── modules ├── loss.py ├── randomizations.py ├── sag_resnet.py └── utils.py └── train.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyseob/SagNet/f87161c26afc3b20b42ff28a9306e927239abc15/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | script/ 2 | checkpoint/ 3 | __pycache__/ 4 | *.pyc 5 | .DS_Store 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![sagnet](figs/sagnet.png) 2 | 3 | # Style-Agnostic Networks (SagNets) 4 | By Hyeonseob Nam, HyunJae Lee, Jongchan Park, Wonjun Yoon, and Donggeun Yoo. 5 | 6 | Lunit, Inc. 7 | 8 | ### Introduction 9 | This repository contains a pytorch implementation of Style-Agnostic Networks (SagNets) for Domain Generalization. 10 | It is also an extension of our method which won the first place in Semi-Supervised Domain Adaptation of [Visual Domain Adaptation (VisDA)-2019 Challenge](https://ai.bu.edu/visda-2019/). 11 | Details are described in [Reducing Domain Gap by Reducing Style Bias](https://openaccess.thecvf.com/content/CVPR2021/papers/Nam_Reducing_Domain_Gap_by_Reducing_Style_Bias_CVPR_2021_paper.pdf), **CVPR 2021 (Oral)**. 12 | 13 | ### Citation 14 | If you use this code in your research, please cite: 15 | 16 | ``` 17 | @inproceedings{nam2021reducing, 18 | title={Reducing Domain Gap by Reducing Style Bias}, 19 | author={Nam, Hyeonseob and Lee, HyunJae and Park, Jongchan and Yoon, Wonjun and Yoo, Donggeun}, 20 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 21 | year={2021} 22 | } 23 | ``` 24 | 25 | ### Prerequisites 26 | - [PyTorch 1.0.0+](https://pytorch.org/) 27 | - Python 3.6+ 28 | - Cuda 8.0+ 29 | 30 | ### Setup 31 | Download [PACS](http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017) dataset into ```./dataset/pacs``` 32 | ``` 33 | images -> ./dataset/pacs/images/kfold/art_painting/dog/pic_001.jpg, ... 34 | splits -> ./dataset/pacs/splits/art_painting_train_kfold.txt, ... 35 | ``` 36 | 37 | ### Usage 38 | #### Multi-Source Domain Generalization 39 | ``` 40 | python train.py --sources Rest --targets [domain] --method sagnet --sagnet --batch-size 32 -g [gpus] 41 | ``` 42 | #### Single-Source Domain Generalization 43 | ``` 44 | python train.py --sources [domain] --targets Rest --method sagnet --sagnet --batch-size 96 -g [gpus] 45 | ``` 46 | Results are saved into ```./checkpoint``` 47 | -------------------------------------------------------------------------------- /data/pacs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from torchvision.datasets.folder import default_loader 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class PACS(Dataset): 9 | 10 | def __init__(self, root, split_dir, domain, split, transform=None): 11 | self.root = os.path.expanduser(root) 12 | self.split_dir = os.path.expanduser(split_dir) 13 | self.domain = domain 14 | self.split = split 15 | self.transform = transform 16 | self.loader = default_loader 17 | 18 | self.preprocess() 19 | 20 | def __getitem__(self, index): 21 | image_path = os.path.join(self.root, self.samples[index][0]) 22 | label = self.samples[index][1] 23 | 24 | image = self.loader(image_path) 25 | if self.transform is not None: 26 | image = self.transform(image) 27 | return image, label 28 | 29 | def __len__(self): 30 | return len(self.samples) 31 | 32 | def preprocess(self): 33 | split_path = os.path.join(self.split_dir, '{}_{}_kfold.txt'.format(self.domain, self.split)) 34 | self.samples = np.genfromtxt(split_path, dtype=str).tolist() 35 | self.samples = [(img, int(lbl) - 1) for img, lbl in self.samples] 36 | 37 | print('domain: {:14s} split: {:10s} n_images: {:<6d}' 38 | .format(self.domain, self.split, len(self.samples))) 39 | -------------------------------------------------------------------------------- /dataset/pacs: -------------------------------------------------------------------------------- 1 | /data/dataset/PACS -------------------------------------------------------------------------------- /figs/sagnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyseob/SagNet/f87161c26afc3b20b42ff28a9306e927239abc15/figs/sagnet.png -------------------------------------------------------------------------------- /modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AdvLoss(torch.nn.Module): 5 | def __init__(self, eps=1e-5): 6 | super().__init__() 7 | self.eps = eps 8 | 9 | def forward(self, inputs): 10 | inputs = inputs.softmax(dim=1) 11 | loss = - torch.log(inputs + self.eps).mean(dim=1) 12 | return loss.mean() 13 | -------------------------------------------------------------------------------- /modules/randomizations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class StyleRandomization(nn.Module): 6 | def __init__(self, eps=1e-5): 7 | super().__init__() 8 | self.eps = eps 9 | 10 | def forward(self, x): 11 | N, C, H, W = x.size() 12 | 13 | if self.training: 14 | x = x.view(N, C, -1) 15 | mean = x.mean(-1, keepdim=True) 16 | var = x.var(-1, keepdim=True) 17 | 18 | x = (x - mean) / (var + self.eps).sqrt() 19 | 20 | idx_swap = torch.randperm(N) 21 | alpha = torch.rand(N, 1, 1) 22 | if x.is_cuda: 23 | alpha = alpha.cuda() 24 | mean = alpha * mean + (1 - alpha) * mean[idx_swap] 25 | var = alpha * var + (1 - alpha) * var[idx_swap] 26 | 27 | x = x * (var + self.eps).sqrt() + mean 28 | x = x.view(N, C, H, W) 29 | 30 | return x 31 | 32 | 33 | class ContentRandomization(nn.Module): 34 | def __init__(self, eps=1e-5): 35 | super().__init__() 36 | self.eps = eps 37 | 38 | def forward(self, x): 39 | N, C, H, W = x.size() 40 | 41 | if self.training: 42 | x = x.view(N, C, -1) 43 | mean = x.mean(-1, keepdim=True) 44 | var = x.var(-1, keepdim=True) 45 | 46 | x = (x - mean) / (var + self.eps).sqrt() 47 | 48 | idx_swap = torch.randperm(N) 49 | x = x[idx_swap].detach() 50 | 51 | x = x * (var + self.eps).sqrt() + mean 52 | x = x.view(N, C, H, W) 53 | 54 | return x 55 | -------------------------------------------------------------------------------- /modules/sag_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | from .randomizations import StyleRandomization, ContentRandomization 6 | 7 | model_urls = { 8 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 9 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 10 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 11 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 12 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 13 | } 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | "3x3 convolution with padding" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 19 | 20 | 21 | class BasicBlock(nn.Module): 22 | expansion = 1 23 | 24 | def __init__(self, inplanes, planes, stride=1, downsample=None): 25 | super().__init__() 26 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False) 27 | self.bn1 = nn.BatchNorm2d(planes) 28 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | 44 | if self.downsample is not None: 45 | residual = self.downsample(x) 46 | 47 | out += residual 48 | out = self.relu(out) 49 | 50 | return out 51 | 52 | 53 | class Bottleneck(nn.Module): 54 | expansion = 4 55 | 56 | def __init__(self, inplanes, planes, stride=1, downsample=None): 57 | super().__init__() 58 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(planes) 60 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 61 | padding=1, bias=False) 62 | self.bn2 = nn.BatchNorm2d(planes) 63 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 64 | self.bn3 = nn.BatchNorm2d(planes * 4) 65 | self.relu = nn.ReLU(inplace=True) 66 | self.downsample = downsample 67 | self.stride = stride 68 | 69 | def forward(self, x): 70 | residual = x 71 | 72 | out = self.conv1(x) 73 | out = self.bn1(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv2(out) 77 | out = self.bn2(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv3(out) 81 | out = self.bn3(out) 82 | 83 | if self.downsample is not None: 84 | residual = self.downsample(x) 85 | 86 | out += residual 87 | out = self.relu(out) 88 | 89 | return out 90 | 91 | 92 | class ResNet(nn.Module): 93 | 94 | def __init__(self, block, layers, num_classes=1000, drop=0, sagnet=True, style_stage=3): 95 | super().__init__() 96 | 97 | self.drop = drop 98 | self.sagnet = sagnet 99 | self.style_stage = style_stage 100 | 101 | self.inplanes = 64 102 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | 107 | self.layer1 = self._make_layer(block, 64, layers[0]) 108 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 109 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 110 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 111 | 112 | self.avgpool = nn.AdaptiveAvgPool2d(1) 113 | self.dropout = nn.Dropout(self.drop) 114 | self.fc = nn.Linear(512 * block.expansion, num_classes) 115 | 116 | if self.sagnet: 117 | # randomizations 118 | self.style_randomization = StyleRandomization() 119 | self.content_randomization = ContentRandomization() 120 | 121 | # style-biased network 122 | style_layers = [] 123 | if style_stage == 1: 124 | self.inplanes = 64 125 | style_layers += [self._make_layer(block, 64, layers[0])] 126 | if style_stage <= 2: 127 | self.inplanes = 64 * block.expansion 128 | style_layers += [self._make_layer(block, 128, layers[1], stride=2)] 129 | if style_stage <= 3: 130 | self.inplanes = 128 * block.expansion 131 | style_layers += [self._make_layer(block, 256, layers[2], stride=2)] 132 | if style_stage <= 4: 133 | self.inplanes = 256 * block.expansion 134 | style_layers += [self._make_layer(block, 512, layers[3], stride=2)] 135 | self.style_net = nn.Sequential(*style_layers) 136 | 137 | self.style_avgpool = nn.AdaptiveAvgPool2d(1) 138 | self.style_dropout = nn.Dropout(self.drop) 139 | self.style_fc = nn.Linear(512 * block.expansion, num_classes) 140 | 141 | # init weights 142 | for m in self.modules(): 143 | if isinstance(m, nn.Conv2d): 144 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 145 | if isinstance(m, nn.BatchNorm2d): 146 | nn.init.constant_(m.weight, 1) 147 | nn.init.constant_(m.bias, 0) 148 | 149 | def adv_params(self): 150 | params = [] 151 | layers = [self.bn1, self.layer1, self.layer2, self.layer3, self.layer4] 152 | for layer in layers[:self.style_stage]: 153 | for m in layer.modules(): 154 | if isinstance(m, nn.BatchNorm2d): 155 | params += [p for p in m.parameters()] 156 | return params 157 | 158 | def style_params(self): 159 | params = [] 160 | for m in [self.style_net, self.style_fc]: 161 | params += [p for p in m.parameters()] 162 | return params 163 | 164 | def _make_layer(self, block, planes, blocks, stride=1): 165 | downsample = None 166 | if stride != 1 or self.inplanes != planes * block.expansion: 167 | if stride != 1: 168 | downsample = nn.Sequential( 169 | nn.Conv2d(self.inplanes, planes * block.expansion, stride=stride, 170 | kernel_size=1, bias=False), 171 | nn.BatchNorm2d(planes * block.expansion), 172 | ) 173 | else: 174 | downsample = nn.Sequential( 175 | nn.Conv2d(self.inplanes, planes * block.expansion, 176 | kernel_size=1, bias=False), 177 | nn.BatchNorm2d(planes * block.expansion), 178 | ) 179 | 180 | layers = [] 181 | layers.append(block(self.inplanes, planes, stride, downsample)) 182 | self.inplanes = planes * block.expansion 183 | for i in range(1, blocks): 184 | layers.append(block(self.inplanes, planes)) 185 | return nn.Sequential(*layers) 186 | 187 | def forward(self, x): 188 | x = self.conv1(x) 189 | x = self.bn1(x) 190 | x = self.relu(x) 191 | x = self.maxpool(x) 192 | 193 | for i, layer in enumerate([self.layer1, self.layer2, self.layer3, self.layer4]): 194 | if self.sagnet and i + 1 == self.style_stage: 195 | # randomization 196 | x_style = self.content_randomization(x) 197 | x = self.style_randomization(x) 198 | x = layer(x) 199 | 200 | # content output 201 | feat = self.avgpool(x) 202 | feat = feat.view(x.size(0), -1) 203 | feat = self.dropout(feat) 204 | y = self.fc(feat) 205 | 206 | if self.sagnet: 207 | # style output 208 | x_style = self.style_net(x_style) 209 | feat = self.style_avgpool(x_style) 210 | feat = feat.view(feat.size(0), -1) 211 | feat = self.style_dropout(feat) 212 | y_style = self.style_fc(feat) 213 | else: 214 | y_style = None 215 | 216 | return y, y_style 217 | 218 | 219 | def sag_resnet(depth, pretrained=False, **kwargs): 220 | if depth == 18: 221 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 222 | elif depth == 50: 223 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 224 | elif depth == 101: 225 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 226 | elif depth == 152: 227 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 228 | 229 | if pretrained: 230 | model_url = model_urls['resnet' + str(depth)] 231 | print('load a pretrained model from {}'.format(model_url)) 232 | 233 | states = model_zoo.load_url(model_url) 234 | states.pop('fc.weight') 235 | states.pop('fc.bias') 236 | model.load_state_dict(states, strict=False) 237 | 238 | if model.sagnet: 239 | states_style = {} 240 | for i in range(model.style_stage, 5): 241 | for k, v in states.items(): 242 | if k.startswith('layer' + str(i)): 243 | states_style[str(i - model.style_stage) + k[6:]] = v 244 | model.style_net.load_state_dict(states_style) 245 | 246 | return model 247 | -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import numpy as np 5 | 6 | import torch 7 | 8 | 9 | class Logger(object): 10 | def __init__(self, log_path): 11 | self.terminal = sys.stdout 12 | self.log_path = log_path 13 | f = open(self.log_path, "w") 14 | f.close() 15 | 16 | def write(self, message): 17 | self.terminal.write(message) 18 | with open(self.log_path, "a") as f: 19 | f.write(message) 20 | 21 | def flush(self): 22 | pass 23 | 24 | 25 | class AverageMeter(object): 26 | def __init__(self, limit=100): 27 | self.items = [] 28 | self.limit = limit 29 | self.avg = 0 30 | 31 | def __repr__(self): 32 | return '{:.4f}'.format(self.avg) 33 | 34 | def toJSON(self): 35 | return json.dumps(self.avg) 36 | 37 | def update(self, val): 38 | self.items.append(val) 39 | if len(self.items) > self.limit: 40 | self.items = self.items[1:] 41 | self.avg = sum(self.items) / len(self.items) 42 | 43 | 44 | def compute_accuracy(pred, label): 45 | if not isinstance(pred, np.ndarray): 46 | pred = pred.data.cpu().numpy() 47 | if not isinstance(label, np.ndarray): 48 | label = label.data.cpu().numpy() 49 | pred = pred.argmax(axis=1) 50 | correct = (pred == label).sum() 51 | acc = correct / len(label) 52 | return acc 53 | 54 | 55 | def save_result(result, save_dir): 56 | def _dumper(obj): 57 | try: 58 | return obj.toJSON() 59 | except: 60 | return obj.__dict__ 61 | with open(os.path.join(save_dir, 'result.json'), 'w') as fp: 62 | json.dump(result, fp, default=_dumper, indent=4) 63 | 64 | 65 | def save_model(model, save_dir, postfix): 66 | model_path = os.path.join(save_dir, 'checkpoint_{}.pth'.format(postfix)) 67 | model.cpu() 68 | print('save model to {}'.format(model_path)) 69 | torch.save(model.state_dict(), model_path) 70 | model.cuda() 71 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import time 5 | import copy 6 | import numpy as np 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.optim as optim 11 | from torchvision import transforms 12 | 13 | from modules.sag_resnet import sag_resnet 14 | from modules.loss import * 15 | from modules.utils import * 16 | 17 | 18 | # Training settings 19 | parser = argparse.ArgumentParser(description='PyTorch SagNet') 20 | 21 | # dataset 22 | parser.add_argument('--dataset-dir', type=str, default='dataset', 23 | help='home directory to dataset') 24 | parser.add_argument('--dataset', type=str, default='pacs', 25 | help='dataset name') 26 | parser.add_argument('--sources', type=str, nargs='*', 27 | help='domains for train') 28 | parser.add_argument('--targets', type=str, nargs='*', 29 | help='domains for test') 30 | 31 | # save dir 32 | parser.add_argument('--save-dir', type=str, default='checkpoint', 33 | help='home directory to save model') 34 | parser.add_argument('--method', type=str, default='sagnet', 35 | help='method name') 36 | 37 | # data loader 38 | parser.add_argument('--workers', type=int, default=4, 39 | help='number of workers') 40 | parser.add_argument('--batch-size', type=int, default=32, 41 | help='batch size for each source domain') 42 | parser.add_argument('--input-size', type=int, default=256, 43 | help='input image size') 44 | parser.add_argument('--crop-size', type=int, default=224, 45 | help='crop image size') 46 | parser.add_argument('--colorjitter', type=float, default=0.4, 47 | help='color jittering') 48 | 49 | # model 50 | parser.add_argument('--arch', type=str, default='sag_resnet', 51 | help='network archiecture') 52 | parser.add_argument('--depth', type=str, default='18', 53 | help='depth of network') 54 | parser.add_argument('--drop', type=float, default=0.5, 55 | help='dropout ratio') 56 | 57 | # sagnet 58 | parser.add_argument('--sagnet', action='store_true', default=False, 59 | help='use sagnet') 60 | parser.add_argument('--style-stage', type=int, default=3, 61 | help='stage to extract style features {1, 2, 3, 4}') 62 | parser.add_argument('--w-adv', type=float, default=0.1, 63 | help='weight for adversarial loss') 64 | 65 | # training policy 66 | parser.add_argument('--from-sketch', action='store_true', default=False, 67 | help='training from scratch') 68 | parser.add_argument('--lr', type=float, default=0.004, 69 | help='initial learning rate') 70 | parser.add_argument('--weight-decay', type=float, default=1e-4, 71 | help='weight decay') 72 | parser.add_argument('--iterations', type=int, default=2000, 73 | help='number of training iterations') 74 | parser.add_argument('--scheduler', type=str, default='cosine', 75 | help='learning rate scheduler {step, cosine}') 76 | parser.add_argument('--milestones', type=int, nargs='+', default=[1000, 1500], 77 | help='milestones to decay learning rate (for step scheduler)') 78 | parser.add_argument('--gamma', type=float, default=0.1, 79 | help='gamma to decay learning rate') 80 | parser.add_argument('--momentum', type=float, default=0.9, 81 | help='SGD momentum') 82 | parser.add_argument('--clip-adv', type=float, default=0.1, 83 | help='grad clipping for adversarial loss') 84 | 85 | # etc 86 | parser.add_argument('--seed', type=int, default=-1, 87 | help='random seed') 88 | parser.add_argument('--log-interval', type=int, default=10, 89 | help='iterations for logging training status') 90 | parser.add_argument('--log-test-interval', type=int, default=10, 91 | help='iterations for logging test status') 92 | parser.add_argument('--test-interval', type=int, default=100, 93 | help='iterations for test') 94 | parser.add_argument('-g', '--gpu-id', type=str, default='0', 95 | help='gpu id') 96 | 97 | 98 | def main(args): 99 | global status 100 | 101 | # Set gpus 102 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 103 | 104 | # Set domains 105 | if args.dataset == 'pacs': 106 | all_domains = ['art_painting', 'cartoon', 'sketch', 'photo'] 107 | 108 | if args.sources[0] == 'Rest': 109 | args.sources = [d for d in all_domains if d not in args.targets] 110 | if args.targets[0] == 'Rest': 111 | args.targets = [d for d in all_domains if d not in args.sources] 112 | 113 | # Set save dir 114 | save_dir = os.path.join(args.save_dir, args.dataset, args.method, ','.join(args.sources)) 115 | print('Save directory: {}'.format(save_dir)) 116 | os.makedirs(save_dir, exist_ok=True) 117 | 118 | # Set Logger 119 | log_path = os.path.join(save_dir, 'log.txt') 120 | sys.stdout = Logger(log_path) 121 | 122 | # Print arguments 123 | print('\nArguments') 124 | for arg in vars(args): 125 | print(' - {}: {}'.format(arg, getattr(args, arg))) 126 | 127 | # Init seed 128 | if args.seed >= 0: 129 | torch.manual_seed(args.seed) 130 | torch.cuda.manual_seed(args.seed) 131 | 132 | # Initialzie loader 133 | print('\nInitialize loaders...') 134 | init_loader() 135 | 136 | # Initialize model 137 | print('\nInitialize model...') 138 | init_model() 139 | 140 | # Initialize optimizer 141 | print('\nInitialize optimizers...') 142 | init_optimizer() 143 | 144 | # Initialize status 145 | src_keys = ['t_data', 't_net', 'l_c', 'l_s', 'l_adv', 'acc'] 146 | status = OrderedDict([ 147 | ('iteration', 0), 148 | ('lr', 0), 149 | ('src', OrderedDict([(k, AverageMeter()) for k in src_keys])), 150 | ('val_acc', OrderedDict([(domain, 0) for domain in args.sources])), 151 | ('mean_val_acc', 0), 152 | ('test_acc', OrderedDict([(domain, 0) for domain in args.targets])), 153 | ('mean_test_acc', 0), 154 | ]) 155 | 156 | # Main loop 157 | print('\nStart training...') 158 | results = [] 159 | for step in range(args.iterations): 160 | train(step) 161 | 162 | if (step + 1) % args.test_interval == 0: 163 | save_model(model, save_dir, 'latest') 164 | 165 | for i, domain in enumerate(args.sources): 166 | print('Validation: {}'.format(domain)) 167 | status['val_acc'][domain] = test(loader_vals[i]) 168 | for i, domain in enumerate(args.targets): 169 | print('Test: {}'.format(domain)) 170 | status['test_acc'][domain] = test(loader_tgts[i]) 171 | 172 | status['mean_val_acc'] = sum(status['val_acc'].values()) / len(status['val_acc']) 173 | status['mean_test_acc'] = sum(status['test_acc'].values()) / len(status['test_acc']) 174 | 175 | print('Val accuracy: {:.5f} ({})'.format(status['mean_val_acc'], 176 | ', '.join(['{}: {:.5f}'.format(k, v) for k, v in status['val_acc'].items()]))) 177 | print('Test accuracy: {:.5f} ({})'.format(status['mean_test_acc'], 178 | ', '.join(['{}: {:.5f}'.format(k, v) for k, v in status['test_acc'].items()]))) 179 | 180 | results.append(copy.deepcopy(status)) 181 | save_result(results, save_dir) 182 | 183 | 184 | def init_loader(): 185 | global loader_srcs, loader_vals, loader_tgts 186 | global num_classes 187 | 188 | # Set transforms 189 | stats = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 190 | 191 | trans_list = [] 192 | trans_list.append(transforms.RandomResizedCrop(args.crop_size, scale=(0.5, 1))) 193 | if args.colorjitter: 194 | trans_list.append(transforms.ColorJitter(*[args.colorjitter] * 4)) 195 | trans_list.append(transforms.RandomHorizontalFlip()) 196 | trans_list.append(transforms.ToTensor()) 197 | trans_list.append(transforms.Normalize(*stats)) 198 | 199 | train_transform = transforms.Compose(trans_list) 200 | test_transform = transforms.Compose([ 201 | transforms.Resize(args.input_size), 202 | transforms.CenterCrop(args.crop_size), 203 | transforms.ToTensor(), 204 | transforms.Normalize(*stats)]) 205 | 206 | # Set datasets 207 | if args.dataset == 'pacs': 208 | from data.pacs import PACS 209 | image_dir = os.path.join(args.dataset_dir, args.dataset, 'images', 'kfold') 210 | split_dir = os.path.join(args.dataset_dir, args.dataset, 'splits') 211 | 212 | print('--- Training ---') 213 | dataset_srcs = [PACS(image_dir, 214 | split_dir, 215 | domain=domain, 216 | split='train', 217 | transform=train_transform) 218 | for domain in args.sources] 219 | print('--- Validation ---') 220 | dataset_vals = [PACS(image_dir, 221 | split_dir, 222 | domain=domain, 223 | split='crossval', 224 | transform=test_transform) 225 | for domain in args.sources] 226 | print('--- Test ---') 227 | dataset_tgts = [PACS(image_dir, 228 | split_dir, 229 | domain=domain, 230 | split='test', 231 | transform=test_transform) 232 | for domain in args.targets] 233 | num_classes = 7 234 | else: 235 | raise NotImplementedError('Unknown dataset: {}'.format(args.dataset)) 236 | 237 | # Set loaders 238 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 239 | loader_srcs = [torch.utils.data.DataLoader( 240 | dataset, 241 | batch_size=args.batch_size, 242 | shuffle=True, 243 | drop_last=True, 244 | **kwargs) 245 | for dataset in dataset_srcs] 246 | loader_vals = [torch.utils.data.DataLoader( 247 | dataset, 248 | batch_size=int(args.batch_size * 4), 249 | shuffle=False, 250 | drop_last=False, 251 | **kwargs) 252 | for dataset in dataset_vals] 253 | loader_tgts = [torch.utils.data.DataLoader( 254 | dataset_tgt, 255 | batch_size=int(args.batch_size * 4), 256 | shuffle=False, 257 | drop_last=False, 258 | **kwargs) 259 | for dataset_tgt in dataset_tgts] 260 | 261 | 262 | def init_model(): 263 | global model 264 | model = sag_resnet(depth=int(args.depth), 265 | pretrained=not args.from_sketch, 266 | num_classes=num_classes, 267 | drop=args.drop, 268 | sagnet=args.sagnet, 269 | style_stage=args.style_stage) 270 | 271 | print(model) 272 | model = torch.nn.DataParallel(model).cuda() 273 | 274 | 275 | def init_optimizer(): 276 | global optimizer, optimizer_style, optimizer_adv 277 | global scheduler, scheduler_style, scheduler_adv 278 | global criterion, criterion_style, criterion_adv 279 | 280 | # Set hyperparams 281 | optim_hyperparams = {'lr': args.lr, 282 | 'weight_decay': args.weight_decay, 283 | 'momentum': args.momentum} 284 | if args.scheduler == 'step': 285 | Scheduler = optim.lr_scheduler.MultiStepLR 286 | sch_hyperparams = {'milestones': args.milestones, 287 | 'gamma': args.gamma} 288 | elif args.scheduler == 'cosine': 289 | Scheduler = optim.lr_scheduler.CosineAnnealingLR 290 | sch_hyperparams = {'T_max': args.iterations} 291 | 292 | # Main learning 293 | params = model.module.parameters() 294 | optimizer = optim.SGD(params, **optim_hyperparams) 295 | scheduler = Scheduler(optimizer, **sch_hyperparams) 296 | criterion = torch.nn.CrossEntropyLoss() 297 | 298 | if args.sagnet: 299 | # Style learning 300 | params_style = model.module.style_params() 301 | optimizer_style = optim.SGD(params_style, **optim_hyperparams) 302 | scheduler_style = Scheduler(optimizer_style, **sch_hyperparams) 303 | criterion_style = torch.nn.CrossEntropyLoss() 304 | 305 | # Adversarial learning 306 | params_adv = model.module.adv_params() 307 | optimizer_adv = optim.SGD(params_adv, **optim_hyperparams) 308 | scheduler_adv = Scheduler(optimizer_adv, **sch_hyperparams) 309 | criterion_adv = AdvLoss() 310 | 311 | 312 | def train(step): 313 | global dataiter_srcs 314 | 315 | ## Initialize iteration 316 | model.train() 317 | 318 | scheduler.step() 319 | if args.sagnet: 320 | scheduler_style.step() 321 | scheduler_adv.step() 322 | 323 | ## Load data 324 | tic = time.time() 325 | 326 | n_srcs = len(args.sources) 327 | if step == 0: 328 | dataiter_srcs = [None] * n_srcs 329 | data = [None] * n_srcs 330 | label = [None] * n_srcs 331 | for i in range(n_srcs): 332 | if step % len(loader_srcs[i]) == 0: 333 | dataiter_srcs[i] = iter(loader_srcs[i]) 334 | data[i], label[i] = next(dataiter_srcs[i]) 335 | 336 | data = torch.cat(data) 337 | label = torch.cat(label) 338 | rand_idx = torch.randperm(len(data)) 339 | data = data[rand_idx] 340 | label = label[rand_idx].cuda() 341 | 342 | time_data = time.time() - tic 343 | 344 | ## Process batch 345 | tic = time.time() 346 | 347 | # forward 348 | y, y_style = model(data) 349 | 350 | if args.sagnet: 351 | # learn style 352 | loss_style = criterion(y_style, label) 353 | optimizer_style.zero_grad() 354 | loss_style.backward(retain_graph=True) 355 | optimizer_style.step() 356 | 357 | # learn style_adv 358 | loss_adv = args.w_adv * criterion_adv(y_style) 359 | optimizer_adv.zero_grad() 360 | loss_adv.backward(retain_graph=True) 361 | if args.clip_adv is not None: 362 | torch.nn.utils.clip_grad_norm_(model.module.adv_params(), args.clip_adv) 363 | optimizer_adv.step() 364 | 365 | # learn content 366 | loss = criterion(y, label) 367 | optimizer.zero_grad() 368 | loss.backward() 369 | optimizer.step() 370 | 371 | time_net = time.time() - tic 372 | 373 | ## Update status 374 | status['iteration'] = step + 1 375 | status['lr'] = optimizer.param_groups[0]['lr'] 376 | status['src']['t_data'].update(time_data) 377 | status['src']['t_net'].update(time_net) 378 | status['src']['l_c'].update(loss.item()) 379 | if args.sagnet: 380 | status['src']['l_s'].update(loss_style.item()) 381 | status['src']['l_adv'].update(loss_adv.item()) 382 | status['src']['acc'].update(compute_accuracy(y, label)) 383 | 384 | ## Log result 385 | if step % args.log_interval == 0: 386 | print('[{}/{} ({:.0f}%)] lr {:.5f}, {}'.format( 387 | step, args.iterations, 100. * step / args.iterations, status['lr'], 388 | ', '.join(['{} {}'.format(k, v) for k, v in status['src'].items()]))) 389 | 390 | 391 | def test(loader_tgt): 392 | model.eval() 393 | preds, labels = [], [] 394 | for batch_idx, (data, label) in enumerate(loader_tgt): 395 | # forward 396 | with torch.no_grad(): 397 | y, _ = model(data) 398 | 399 | # result 400 | preds += [y.data.cpu().numpy()] 401 | labels += [label.data.cpu().numpy()] 402 | 403 | # log 404 | if args.log_test_interval != -1 and batch_idx % args.log_test_interval == 0: 405 | print('[{}/{} ({:.0f}%)]'.format( 406 | batch_idx, len(loader_tgt), 100. * batch_idx / len(loader_tgt))) 407 | 408 | # Aggregate result 409 | preds = np.concatenate(preds, axis=0) 410 | labels = np.concatenate(labels, axis=0) 411 | acc = compute_accuracy(preds, labels) 412 | return acc 413 | 414 | 415 | if __name__ == '__main__': 416 | args = parser.parse_args() 417 | main(args) 418 | --------------------------------------------------------------------------------