├── requirements.txt ├── loss ├── EntropyLoss.py └── MaximumSquareLoss.py ├── util ├── scheduler.py └── util.py ├── script ├── deepall.sh ├── original.sh └── general.sh ├── train ├── eval.py ├── deepall.py └── general.py ├── model ├── Discriminator.py ├── alexnet.py ├── resnet.py └── caffenet.py ├── dataloader ├── dataloader.py └── Dataset.py ├── README.md ├── clustering ├── clustering.py └── domain_split.py └── main └── main.py /requirements.txt: -------------------------------------------------------------------------------- 1 | sklearn 2 | matplotlib -------------------------------------------------------------------------------- /loss/EntropyLoss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | 4 | class HLoss(nn.Module): 5 | def __init__(self): 6 | super(HLoss, self).__init__() 7 | 8 | def forward(self, x): 9 | b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1) 10 | b = -1.0 * b.sum(dim=1).mean() 11 | return b -------------------------------------------------------------------------------- /loss/MaximumSquareLoss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | import torch 4 | 5 | class MaximumSquareLoss(nn.Module): 6 | def __init__(self): 7 | super(MaximumSquareLoss, self).__init__() 8 | def forward(self, x): 9 | p = F.softmax(x, dim=1) 10 | b = (torch.mul(p, p)) 11 | b = -1.0 * b.sum(dim=1).mean() / 2 12 | return b 13 | -------------------------------------------------------------------------------- /util/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | class inv_lr_scheduler(_LRScheduler): 4 | def __init__(self, optimizer, alpha, beta, total_epoch, last_epoch=-1): 5 | self.alpha = alpha 6 | self.beta = beta 7 | self.total_epoch = total_epoch 8 | super(inv_lr_scheduler, self).__init__(optimizer, last_epoch) 9 | 10 | def get_lr(self): 11 | return [base_lr * ((1 + self.alpha * self.last_epoch / self.total_epoch) ** (-self.beta)) for base_lr in self.base_lrs] 12 | -------------------------------------------------------------------------------- /script/deepall.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | domain=("photo" "art" "cartoon" "sketch") 4 | 5 | times=5 6 | for i in `seq 1 $times` 7 | do 8 | max=$((${#domain[@]}-1)) 9 | for j in `seq 0 $max` 10 | do 11 | dir_name="PACS/default/${domain[j]}${i}" 12 | echo $dir_name 13 | python ../main/main.py \ 14 | --data-root='/data/unagi0/matsuura/PACS/raw_images/kfold/' \ 15 | --save-root='/data/unagi0/matsuura/result/dg_mmld/' \ 16 | --result-dir=$dir_name \ 17 | --train='deepall' \ 18 | --data='PACS' \ 19 | --model='caffenet' \ 20 | --exp-num=$j \ 21 | --gpu=0 \ 22 | --num-epoch=30 \ 23 | --scheduler='step' \ 24 | --lr=1e-3 \ 25 | --lr-step=24 \ 26 | --lr-decay-gamma=0.1 \ 27 | --nesterov \ 28 | --fc-weight=10.0 \ 29 | --color-jitter \ 30 | --min-scale=0.8 31 | done 32 | done -------------------------------------------------------------------------------- /script/original.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | domain=("photo" "art" "cartoon" "sketch") 4 | 5 | times=5 6 | for i in `seq 1 $times` 7 | do 8 | max=$((${#domain[@]}-1)) 9 | for j in `seq 0 $max` 10 | do 11 | dir_name="PACS/default/${domain[j]}${i}" 12 | echo $dir_name 13 | python ../main/main.py \ 14 | --data-root='/data/unagi0/matsuura/PACS/raw_images/kfold/' \ 15 | --save-root='/data/unagi0/matsuura/result/dg_mmld/' \ 16 | --result-dir=$dir_name \ 17 | --train='general' \ 18 | --data='PACS' \ 19 | --model='caffenet' \ 20 | --entropy='default' \ 21 | --exp-num=$j \ 22 | --gpu=0 \ 23 | --num-epoch=30 \ 24 | --scheduler='step' \ 25 | --lr=1e-3 \ 26 | --lr-step=24 \ 27 | --lr-decay-gamma=0.1 \ 28 | --nesterov \ 29 | --fc-weight=10.0 \ 30 | --disc-weight=10.0 \ 31 | --entropy-weight=1.0 \ 32 | --grl-weight=1.0 \ 33 | --loss-disc-weight \ 34 | --color-jitter \ 35 | --min-scale=0.8 36 | done 37 | done -------------------------------------------------------------------------------- /script/general.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | domain=("photo" "art" "cartoon" "sketch") 4 | 5 | times=5 6 | for i in `seq 1 $times` 7 | do 8 | max=$((${#domain[@]}-1)) 9 | for j in `seq 0 $max` 10 | do 11 | dir_name="PACS/default/${domain[j]}${i}" 12 | echo $dir_name 13 | python ../main/main.py \ 14 | --data-root='/data/unagi0/matsuura/PACS/raw_images/kfold/' \ 15 | --save-root='/data/unagi0/matsuura/result/dg_mmld/' \ 16 | --result-dir=$dir_name \ 17 | --train='general' \ 18 | --data='PACS' \ 19 | --model='caffenet' \ 20 | --clustering \ 21 | --clustering-method='Kmeans' \ 22 | --num-clustering=3 \ 23 | --clustering-step=1 \ 24 | --entropy='default' \ 25 | --exp-num=$j \ 26 | --gpu=0 \ 27 | --num-epoch=30 \ 28 | --scheduler='step' \ 29 | --lr=1e-3 \ 30 | --lr-step=24 \ 31 | --lr-decay-gamma=0.1 \ 32 | --nesterov \ 33 | --fc-weight=10.0 \ 34 | --disc-weight=10.0 \ 35 | --entropy-weight=1.0 \ 36 | --grl-weight=1.0 \ 37 | --loss-disc-weight \ 38 | --color-jitter \ 39 | --min-scale=0.8 \ 40 | --instance-stat 41 | done 42 | done -------------------------------------------------------------------------------- /train/eval.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | def eval_model(model, eval_data, device, epoch, filename): 5 | criterion = nn.CrossEntropyLoss() 6 | model.eval() # Set model to training mode 7 | running_loss = 0.0 8 | running_corrects = 0 9 | # Iterate over data. 10 | data_num = 0 11 | for inputs, labels in eval_data: 12 | with torch.no_grad(): 13 | inputs = inputs.to(device) 14 | labels = labels.to(device) 15 | # forward 16 | outputs = model(inputs) 17 | if isinstance(outputs, tuple): 18 | outputs = outputs[0] 19 | loss = criterion(outputs, labels) 20 | _, preds = torch.max(outputs, 1) 21 | # statistics 22 | running_loss += loss.item() * inputs.size(0) 23 | running_corrects += torch.sum(preds == labels.data).item() 24 | data_num += inputs.size(0) 25 | epoch_loss = running_loss / len(eval_data.dataset) 26 | epoch_acc = running_corrects / len(eval_data.dataset) 27 | log = 'Eval: Epoch: {} Loss: {:.4f} Acc: {:.4f}'.format(epoch, epoch_loss, epoch_acc) 28 | print(log) 29 | with open(filename, 'a') as f: 30 | f.write(log + '\n') 31 | return epoch_acc -------------------------------------------------------------------------------- /train/deepall.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from util.util import split_domain 3 | import torch 4 | from numpy.random import * 5 | import numpy as np 6 | 7 | def train(model, train_data, optimizers, device, epoch, num_epoch, filename, entropy=None, disc_weight=None, entropy_weight=None, grl_weight=None): 8 | criterion = nn.CrossEntropyLoss() 9 | 10 | model.train() # Set model to training mode 11 | running_loss = 0.0 12 | running_corrects = 0 13 | # Iterate over data. 14 | for inputs, labels in train_data: 15 | inputs = inputs.to(device) 16 | labels = labels.to(device) 17 | # zero the parameter gradients 18 | for optimizer in optimizers: 19 | optimizer.zero_grad() 20 | # forward 21 | outputs = model(inputs) 22 | loss = criterion(outputs, labels) 23 | _, preds = torch.max(outputs, 1) 24 | loss.backward() 25 | for optimizer in optimizers: 26 | optimizer.step() 27 | # statistics 28 | running_loss += loss.item() * inputs.size(0) 29 | running_corrects += torch.sum(preds == labels.data) 30 | epoch_loss = running_loss / len(train_data.dataset) 31 | epoch_acc = running_corrects.double() / len(train_data.dataset) 32 | 33 | log = 'Train: Epoch: {} Loss: {:.4f} Acc: {:.4f}'.format(epoch, epoch_loss, epoch_acc) 34 | print(log) 35 | with open(filename, 'a') as f: 36 | f.write(log + '\n') 37 | return model, optimizers 38 | -------------------------------------------------------------------------------- /model/Discriminator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | 6 | class GradReverse(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, x, lambd, reverse=True): 9 | ctx.lambd = lambd 10 | ctx.reverse=reverse 11 | return x.view_as(x) 12 | 13 | @staticmethod 14 | def backward(ctx, grad_output): 15 | if ctx.reverse: 16 | return (grad_output * -ctx.lambd), None, None 17 | else: 18 | return (grad_output * ctx.lambd), None, None 19 | 20 | def grad_reverse(x, lambd=1.0, reverse=True): 21 | return GradReverse.apply(x, lambd, reverse) 22 | 23 | class Discriminator(nn.Module): 24 | def __init__(self, dims, grl=True, reverse=True): 25 | if len(dims) != 4: 26 | raise ValueError("Discriminator input dims should be three dim!") 27 | super(Discriminator, self).__init__() 28 | self.grl = grl 29 | self.reverse = reverse 30 | self.model = nn.Sequential( 31 | nn.Linear(dims[0], dims[1]), 32 | nn.ReLU(), 33 | nn.Dropout(0.5), 34 | nn.Linear(dims[1], dims[2]), 35 | nn.ReLU(), 36 | nn.Dropout(0.5), 37 | nn.Linear(dims[2], dims[3]), 38 | ) 39 | self.lambd = 0.0 40 | 41 | def set_lambd(self, lambd): 42 | self.lambd = lambd 43 | 44 | def forward(self, x): 45 | if self.grl: 46 | x = grad_reverse(x, self.lambd, self.reverse) 47 | x = self.model(x) 48 | return x -------------------------------------------------------------------------------- /dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, random_split 2 | import numpy as np 3 | from copy import deepcopy 4 | from dataloader.Dataset import DG_Dataset 5 | 6 | def random_split_dataloader(data, data_root, source_domain, target_domain, batch_size, 7 | get_domain_label=False, get_cluster=False, num_workers=4, color_jitter=True, min_scale=0.8): 8 | if data=='VLCS': 9 | split_rate = 0.7 10 | else: 11 | split_rate = 0.9 12 | source = DG_Dataset(root_dir=data_root, domain=source_domain, split='val', 13 | get_domain_label=False, get_cluster=False, color_jitter=color_jitter, min_scale=min_scale) 14 | source_train, source_val = random_split(source, [int(len(source)*split_rate), len(source)-int(len(source)*split_rate)]) 15 | source_train = deepcopy(source_train) 16 | source_train.dataset.split='train' 17 | source_train.dataset.set_transform('train') 18 | source_train.dataset.get_domain_label=get_domain_label 19 | source_train.dataset.get_cluster=get_cluster 20 | 21 | target_test = DG_Dataset(root_dir=data_root, domain=target_domain, split='test', 22 | get_domain_label=False, get_cluster=False) 23 | 24 | print('Train: {}, Val: {}, Test: {}'.format(len(source_train), len(source_val), len(target_test))) 25 | 26 | source_train = DataLoader(source_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) 27 | source_val = DataLoader(source_val, batch_size=batch_size, shuffle=False, num_workers=num_workers) 28 | target_test = DataLoader(target_test, batch_size=batch_size, shuffle=False, num_workers=num_workers) 29 | return source_train, source_val, target_test -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Domain Generalization Using a Mixture of Multiple Latent Domains 2 | ![model](https://user-images.githubusercontent.com/22876486/68654944-64933100-0572-11ea-8cd0-2ff148ca1843.png) 3 | This is the pytorch implementation of the AAAI 2020 poster paper "Domain Generalization Using a Mixture of Multiple Latent Domains". 4 | 5 | ## Requirements 6 | - A Python install version 3.6 7 | - A PyTorch and torchvision installation version 0.4.1 and 0.2.1, respectively. [pytorch.org](https://pytorch.org/) 8 | - The caffe model we used for [AlexNet](https://drive.google.com/file/d/1wUJTH1Joq2KAgrUDeKJghP1Wf7Q9w4z-/view?usp=sharing) 9 | - PACS dataset ([website](https://dali-dl.github.io/project_iccv2017.html), [dateset](https://drive.google.com/drive/folders/0B6x7gtvErXgfUU1WcGY5SzdwZVk?resourcekey=0-2fvpQY_QSyJf2uIECzqPuQ)) 10 | - Install python requirements 11 | ``` 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## Training and Testing 16 | You can train the model using the following command. 17 | ``` 18 | cd script 19 | bash general.sh 20 | ``` 21 | If you want to train the model without domain generalization (Deep All), you can also use the following command. 22 | ``` 23 | cd script 24 | bash deepall.sh 25 | ``` 26 | 27 | You can set the correct parameter. 28 | - --data-root: the dataset folder path 29 | - --save-root: the folder path for saving the results 30 | - --gpu: the gpu id to run experiments 31 | 32 | ## Citation 33 | If you use this code, please cite the following paper: 34 | 35 | Toshihiko Matsuura and Tatsuya Harada. Domain Generalization Using a Mixture of Multiple Latent Domains. In AAAI, 2020. 36 | ``` 37 | @InProceedings{dg_mmld, 38 | title={Domain Generalization Using a Mixture of Multiple Latent Domains}, 39 | author={Toshihiko Matsuura and Tatsuya Harada}, 40 | booktitle={AAAI}, 41 | year={2020}, 42 | } 43 | -------------------------------------------------------------------------------- /model/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.init as init 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | from model.Discriminator import Discriminator 6 | from torchvision.models import AlexNet 7 | 8 | __all__ = ['AlexNet', 'alexnet'] 9 | 10 | model_urls = { 11 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 12 | } 13 | 14 | def alexnet(num_classes, num_domains=None, pretrained=True): 15 | """AlexNet model architecture from the 16 | `"One weird trick..." `_ paper. 17 | Args: 18 | pretrained (bool): If True, returns a model pre-trained on ImageNet 19 | """ 20 | model = AlexNet() 21 | if pretrained: 22 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) 23 | print('Load pre trained model') 24 | num_ftrs = model.classifier[-1].in_features 25 | model.classifier[-1] = nn.Linear(num_ftrs, num_classes) 26 | nn.init.xavier_uniform_(model.classifier[-1].weight, .1) 27 | nn.init.constant_(model.classifier[-1].bias, 0.) 28 | return model 29 | 30 | class DGalexnet(nn.Module): 31 | def __init__(self, num_classes, num_domains, pretrained=True, grl=True): 32 | super(DGalexnet, self).__init__() 33 | self.num_domains = num_domains 34 | self.base_model = alexnet(num_classes, pretrained=pretrained) 35 | self.discriminator = Discriminator([4096, 1024, 1024, num_domains], grl=grl, reverse=True) 36 | self.feature_layers = nn.Sequential(*list(self.base_model.classifier.children())[:-1]) 37 | self.fc = list(self.base_model.classifier.children())[-1] 38 | 39 | def forward(self, x): 40 | x = self.base_model.features(x) 41 | x = x.view(x.size(0), 256 * 6 * 6) 42 | x = self.feature_layers(x) 43 | output_class = self.fc(x) 44 | output_domain = self.discriminator(x) 45 | return output_class, output_domain 46 | 47 | def features(self, x): 48 | x = self.base_model.features(x) 49 | x = x.view(x.size(0), 256 * 6 * 6) 50 | x = self.feature_layers(x) 51 | return x 52 | 53 | def conv_features(self, x) : 54 | results = [] 55 | for i, model in enumerate(self.base_model.features): 56 | x = model(x) 57 | if i in {4, 7}: 58 | results.append(x) 59 | return results 60 | 61 | def domain_features(self, x): 62 | for i, model in enumerate(self.base_model.features): 63 | x = model(x) 64 | if i == 7: 65 | break 66 | return x.view(x.size(0), -1) -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import resnet18 2 | from model.Discriminator import Discriminator 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | 6 | def resnet(num_classes, num_domains=None, pretrained=True): 7 | model = resnet18(pretrained=pretrained) 8 | num_ftrs = model.fc.in_features 9 | model.fc = nn.Linear(num_ftrs, num_classes) 10 | nn.init.xavier_uniform_(model.fc.weight, .1) 11 | nn.init.constant_(model.fc.bias, 0.) 12 | return model 13 | 14 | class DGresnet(nn.Module): 15 | def __init__(self, num_classes, num_domains, pretrained=True, grl=True): 16 | super(DGresnet, self).__init__() 17 | self.num_domains = num_domains 18 | self.base_model = resnet(num_classes=num_classes, pretrained=pretrained) 19 | self.discriminator = Discriminator([512, 1024, 1024, num_domains], grl=grl, reverse=True) 20 | 21 | def forward(self, x): 22 | x = self.base_model.conv1(x) 23 | x = self.base_model.bn1(x) 24 | x = self.base_model.relu(x) 25 | x = self.base_model.maxpool(x) 26 | 27 | x = self.base_model.layer1(x) 28 | x = self.base_model.layer2(x) 29 | x = self.base_model.layer3(x) 30 | x = self.base_model.layer4(x) 31 | 32 | x = self.base_model.avgpool(x) 33 | x = x.view(x.size(0), -1) 34 | output_class = self.base_model.fc(x) 35 | output_domain = self.discriminator(x) 36 | return output_class, output_domain 37 | 38 | def features(self, x): 39 | x = self.base_model.conv1(x) 40 | x = self.base_model.bn1(x) 41 | x = self.base_model.relu(x) 42 | x = self.base_model.maxpool(x) 43 | 44 | x = self.base_model.layer1(x) 45 | x = self.base_model.layer2(x) 46 | x = self.base_model.layer3(x) 47 | x = self.base_model.layer4(x) 48 | 49 | x = self.base_model.avgpool(x) 50 | x = x.view(x.size(0), -1) 51 | return x 52 | 53 | def conv_features(self, x) : 54 | results = [] 55 | x = self.base_model.conv1(x) 56 | x = self.base_model.bn1(x) 57 | x = self.base_model.relu(x) 58 | # results.append(x) 59 | x = self.base_model.maxpool(x) 60 | x = self.base_model.layer1(x) 61 | results.append(x) 62 | x = self.base_model.layer2(x) 63 | results.append(x) 64 | x = self.base_model.layer3(x) 65 | x = self.base_model.layer4(x) 66 | # results.append(x) 67 | return results 68 | 69 | def domain_features(self, x): 70 | x = self.base_model.conv1(x) 71 | x = self.base_model.bn1(x) 72 | x = self.base_model.relu(x) 73 | x = self.base_model.maxpool(x) 74 | x = self.base_model.layer1(x) 75 | return x.view(x.size(0), -1) -------------------------------------------------------------------------------- /train/general.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from util.util import split_domain 3 | import torch 4 | from numpy.random import * 5 | import numpy as np 6 | from loss.EntropyLoss import HLoss 7 | from loss.MaximumSquareLoss import MaximumSquareLoss 8 | 9 | def train(model, train_data, optimizers, device, epoch, num_epoch, filename, entropy, disc_weight=None, entropy_weight=1.0, grl_weight=1.0): 10 | class_criterion = nn.CrossEntropyLoss() 11 | print(disc_weight) 12 | domain_criterion = nn.CrossEntropyLoss(weight=disc_weight) 13 | if entropy == 'default': 14 | entropy_criterion = HLoss() 15 | else: 16 | entropy_criterion = MaximumSquareLoss() 17 | p = epoch / num_epoch 18 | alpha = (2. / (1. + np.exp(-10 * p)) -1) * grl_weight 19 | beta = (2. / (1. + np.exp(-10 * p)) -1) * entropy_weight 20 | model.discriminator.set_lambd(alpha) 21 | model.train() # Set model to training mode 22 | running_loss_class = 0.0 23 | running_correct_class = 0 24 | running_loss_domain = 0.0 25 | running_correct_domain = 0 26 | running_loss_entropy = 0 27 | # Iterate over data. 28 | for inputs, labels, domains in train_data: 29 | inputs = inputs.to(device) 30 | labels = labels.to(device) 31 | domains = domains.to(device) 32 | # zero the parameter gradients 33 | for optimizer in optimizers: 34 | optimizer.zero_grad() 35 | # forward 36 | output_class, output_domain = model(inputs) 37 | 38 | loss_class = class_criterion(output_class, labels) 39 | loss_domain = domain_criterion(output_domain, domains) 40 | loss_entropy = entropy_criterion(output_class) 41 | _, pred_class = torch.max(output_class, 1) 42 | _, pred_domain = torch.max(output_domain, 1) 43 | 44 | total_loss = loss_class + loss_domain + loss_entropy * beta 45 | total_loss.backward() 46 | for optimizer in optimizers: 47 | optimizer.step() 48 | 49 | running_loss_class += loss_class.item() * inputs.size(0) 50 | running_correct_class += torch.sum(pred_class == labels.data) 51 | running_loss_domain += loss_domain.item() * inputs.size(0) 52 | running_correct_domain += torch.sum(pred_domain == domains.data) 53 | running_loss_entropy += loss_entropy.item() * inputs.size(0) 54 | 55 | epoch_loss_class = running_loss_class / len(train_data.dataset) 56 | epoch_acc_class = running_correct_class.double() / len(train_data.dataset) 57 | epoch_loss_domain = running_loss_domain / len(train_data.dataset) 58 | epoch_acc_domain = running_correct_domain.double() / len(train_data.dataset) 59 | epoch_loss_entropy = running_loss_entropy / len(train_data.dataset) 60 | 61 | log = 'Train: Epoch: {} Alpha: {:.4f} Loss Class: {:.4f} Acc Class: {:.4f}, Loss Domain: {:.4f} Acc Domain: {:.4f} Loss Entropy: {:.4f}'.format(epoch, alpha, epoch_loss_class, epoch_acc_class, epoch_loss_domain, epoch_acc_domain, epoch_loss_entropy) 62 | print(log) 63 | with open(filename, 'a') as f: 64 | f.write(log + '\n') 65 | return model, optimizers -------------------------------------------------------------------------------- /clustering/clustering.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | from PIL import Image 5 | from PIL import ImageFile 6 | import torch 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from sklearn import preprocessing 10 | from sklearn.decomposition import PCA 11 | from sklearn.cluster import KMeans, SpectralClustering, AgglomerativeClustering 12 | from sklearn.mixture import GaussianMixture 13 | 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | 16 | __all__ = ['Kmeans', 'GMM', 'Spectral', 'Agglomerative'] 17 | 18 | 19 | def preprocess_features(npdata, pca_dim=256, whitening=False, L2norm=False): 20 | """Preprocess an array of features. 21 | Args: 22 | npdata (np.array N * ndim): features to preprocess 23 | pca_dim (int): dim of output 24 | Returns: 25 | np.array of dim N * pca: data PCA-reduced, whitened and L2-normalized 26 | """ 27 | _, ndim = npdata.shape 28 | npdata = npdata.astype('float32') 29 | pca = PCA(pca_dim, whiten=whitening) 30 | npdata = pca.fit_transform(npdata) 31 | # L2 normalization 32 | if L2norm: 33 | row_sums = np.linalg.norm(npdata, axis=1) 34 | npdata = npdata / row_sums[:, np.newaxis] 35 | return npdata 36 | 37 | class Clustering: 38 | def __init__(self, k, pca_dim=256, whitening=False, L2norm=False): 39 | self.k = k 40 | self.pca_dim = pca_dim 41 | self.whitening = whitening 42 | self.L2norm = L2norm 43 | 44 | def cluster(self, data, verbose=False): 45 | """Performs k-means clustering. 46 | Args: 47 | x_data (np.array N * dim): data to cluster 48 | """ 49 | # PCA-reducing, whitening and L2-normalization 50 | xb = preprocess_features(data, self.pca_dim, self.whitening, self.L2norm) 51 | # cluster the data 52 | I = self.run_method(xb, self.k) 53 | self.images_lists = [[] for i in range(self.k)] 54 | for i in range(len(data)): 55 | self.images_lists[I[i]].append(i) 56 | return None 57 | 58 | def run_method(): 59 | print('Define each method') 60 | 61 | class Kmeans(Clustering): 62 | def __init__(self, k, pca_dim=256, whitening=False, L2norm=False): 63 | super().__init__(k, pca_dim, whitening, L2norm) 64 | 65 | def run_method(self, x, n_clusters): 66 | kmeans = KMeans(n_clusters=n_clusters) 67 | I = kmeans.fit_predict(x) 68 | return I 69 | 70 | class GMM(Clustering): 71 | def __init__(self, k, pca_dim=256, whitening=False, L2norm=False): 72 | super().__init__(k, pca_dim, whitening, L2norm) 73 | 74 | def run_method(self, x, n_clusters): 75 | kmeans = GaussianMixture(n_clusters=n_clusters) 76 | I = kmeans.fit_predict(x) 77 | return I 78 | 79 | class Spectral(Clustering): 80 | def __init__(self, k, pca_dim=256, whitening=False, L2norm=False): 81 | super().__init__(k, pca_dim, whitening, L2norm) 82 | 83 | def run_method(self, x, n_clusters): 84 | spectral = SpectralClustering(n_clusters=n_clusters) 85 | I = spectral.fit_predict(x) 86 | return I 87 | 88 | class Agglomerative(Clustering): 89 | def __init__(self, k, pca_dim=256, whitening=False, L2norm=False): 90 | super().__init__(k, pca_dim, whitening, L2norm) 91 | 92 | def run_method(self, x, n_clusters): 93 | agg = AgglomerativeClustering(n_clusters=n_clusters) 94 | I = agg.fit_predict(x) 95 | return I -------------------------------------------------------------------------------- /model/caffenet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from model.Discriminator import * 3 | import torch 4 | from torch import nn as nn 5 | 6 | 7 | class Id(nn.Module): 8 | def __init__(self): 9 | super(Id, self).__init__() 10 | 11 | def forward(self, x): 12 | return x 13 | 14 | class AlexNetCaffe(nn.Module): 15 | def __init__(self, num_classes=100, domains=3, dropout=True): 16 | super(AlexNetCaffe, self).__init__() 17 | print("Using Caffe AlexNet") 18 | self.features = nn.Sequential(OrderedDict([ 19 | ("conv1", nn.Conv2d(3, 96, kernel_size=11, stride=4)), 20 | ("relu1", nn.ReLU(inplace=True)), 21 | ("pool1", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)), 22 | ("norm1", nn.LocalResponseNorm(5, 1.e-4, 0.75)), 23 | ("conv2", nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2)), 24 | ("relu2", nn.ReLU(inplace=True)), 25 | ("pool2", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)), 26 | ("norm2", nn.LocalResponseNorm(5, 1.e-4, 0.75)), 27 | ("conv3", nn.Conv2d(256, 384, kernel_size=3, padding=1)), 28 | ("relu3", nn.ReLU(inplace=True)), 29 | ("conv4", nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2)), 30 | ("relu4", nn.ReLU(inplace=True)), 31 | ("conv5", nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2)), 32 | ("relu5", nn.ReLU(inplace=True)), 33 | ("pool5", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)), 34 | ])) 35 | self.classifier = nn.Sequential(OrderedDict([ 36 | ("fc6", nn.Linear(256 * 6 * 6, 4096)), 37 | ("relu6", nn.ReLU(inplace=True)), 38 | ("drop6", nn.Dropout() if dropout else Id()), 39 | ("fc7", nn.Linear(4096, 4096)), 40 | ("relu7", nn.ReLU(inplace=True)), 41 | ("drop7", nn.Dropout() if dropout else Id())])) 42 | 43 | self.class_classifier = nn.Linear(4096, num_classes) 44 | 45 | def forward(self, x, lambda_val=0): 46 | x = self.features(x*57.6) 47 | #57.6 is the magic number needed to bring torch data back to the range of caffe data, based on used std 48 | x = x.view(x.size(0), -1) 49 | x = self.classifier(x) 50 | return self.class_classifier(x) 51 | 52 | def caffenet(num_classes, num_domains=None, pretrained=True): 53 | model = AlexNetCaffe(num_classes) 54 | for m in model.modules(): 55 | if isinstance(m, nn.Linear): 56 | nn.init.xavier_uniform_(m.weight, .1) 57 | nn.init.constant_(m.bias, 0.) 58 | 59 | if pretrained: 60 | state_dict = torch.load("/data/unagi0/matsuura/model/alexnet_caffe.pth.tar") 61 | del state_dict["classifier.fc8.weight"] 62 | del state_dict["classifier.fc8.bias"] 63 | model.load_state_dict(state_dict, strict=False) 64 | return model 65 | 66 | class DGcaffenet(nn.Module): 67 | def __init__(self, num_classes, num_domains, pretrained=True, grl=True): 68 | super(DGcaffenet, self).__init__() 69 | self.num_domains = num_domains 70 | self.base_model = caffenet(num_classes, pretrained=pretrained) 71 | self.discriminator = Discriminator([4096, 1024, 1024, num_domains], grl=grl, reverse=True) 72 | 73 | def forward(self, x): 74 | x = self.base_model.features(x*57.6) 75 | x = x.view(x.size(0), 256 * 6 * 6) 76 | x = self.base_model.classifier(x) 77 | output_class = self.base_model.class_classifier(x) 78 | output_domain = self.discriminator(x) 79 | return output_class, output_domain 80 | 81 | def features(self, x): 82 | x = self.base_model.features(x*57.6) 83 | x = x.view(x.size(0), 256 * 6 * 6) 84 | x = self.base_model.classifier(x) 85 | return x 86 | 87 | def conv_features(self, x) : 88 | results = [] 89 | for i, model in enumerate(self.base_model.features): 90 | if i==0: 91 | x = model(x*57.6) 92 | else: 93 | x = model(x) 94 | if i in {5, 9}: 95 | results.append(x) 96 | return results 97 | 98 | def domain_features(self, x): 99 | for i, model in enumerate(self.base_model.features): 100 | if i == 0: 101 | x = model(x*57.6) 102 | else: 103 | x = model(x) 104 | if i == 5: 105 | x = model(x) 106 | break 107 | return x.view(x.size(0), -1) -------------------------------------------------------------------------------- /dataloader/Dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import sys 3 | import os 4 | from torchvision import transforms 5 | from torchvision.datasets.folder import make_dataset, default_loader 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 9 | 10 | class DG_Dataset(Dataset): 11 | def __init__(self, root_dir, domain, split, get_domain_label=False, get_cluster=False, color_jitter=True, min_scale=0.8): 12 | self.root_dir = root_dir 13 | self.domain = domain 14 | self.split = split 15 | self.get_domain_label = get_domain_label 16 | self.get_cluster = get_cluster 17 | self.color_jitter = color_jitter 18 | self.min_scale = min_scale 19 | self.set_transform(self.split) 20 | self.loader = default_loader 21 | 22 | self.load_dataset() 23 | 24 | def __len__(self): 25 | return len(self.images) 26 | 27 | def __getitem__(self, index): 28 | path, target = self.images[index], self.labels[index] 29 | image = self.loader(path) 30 | image = self.transform(image) 31 | output = [image, target] 32 | 33 | if self.get_domain_label: 34 | domain = np.copy(self.domains[index]) 35 | domain = np.int64(domain) 36 | output.append(domain) 37 | 38 | if self.get_cluster: 39 | cluster = np.copy(self.clusters[index]) 40 | cluster = np.int64(cluster) 41 | output.append(cluster) 42 | 43 | return tuple(output) 44 | 45 | def find_classes(self, dir_name): 46 | if sys.version_info >= (3, 5): 47 | # Faster and available in Python 3.5 and above 48 | classes = [d.name for d in os.scandir(dir_name) if d.is_dir()] 49 | else: 50 | classes = [d for d in os.listdir(dir_name) if os.path.isdir(os.path.join(dir_name, d))] 51 | classes.sort() 52 | class_to_idx = {classes[i]: i for i in range(len(classes))} 53 | return classes, class_to_idx 54 | 55 | def load_dataset(self): 56 | total_samples = [] 57 | self.domains = np.zeros(0) 58 | 59 | classes, class_to_idx = self.find_classes(self.root_dir + self.domain[0] + '/') 60 | self.num_class = len(classes) 61 | for i, item in enumerate(self.domain): 62 | path = self.root_dir + item + '/' 63 | samples = make_dataset(path, class_to_idx, IMG_EXTENSIONS) 64 | total_samples.extend(samples) 65 | self.domains = np.append(self.domains, np.ones(len(samples)) * i) 66 | 67 | self.clusters = np.zeros(len(self.domains), dtype=np.int64) 68 | self.images = [s[0] for s in total_samples] 69 | self.labels = [s[1] for s in total_samples] 70 | 71 | def set_cluster(self, cluster_list): 72 | if len(cluster_list) != len(self.images): 73 | raise ValueError("The length of cluster_list must to be same as self.images") 74 | else: 75 | self.clusters = cluster_list 76 | 77 | def set_domain(self, domain_list): 78 | if len(domain_list) != len(self.images): 79 | raise ValueError("The length of domain_list must to be same as self.images") 80 | else: 81 | self.domains = domain_list 82 | 83 | def set_transform(self, split): 84 | if split == 'train': 85 | if self.color_jitter: 86 | self.transform = transforms.Compose([ 87 | transforms.RandomResizedCrop(224, scale=(self.min_scale, 1.0)), 88 | transforms.RandomHorizontalFlip(), 89 | transforms.ColorJitter(.4, .4, .4, .4), 90 | transforms.ToTensor(), 91 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 92 | ]) 93 | else: 94 | self.transform = transforms.Compose([ 95 | transforms.RandomResizedCrop(224, scale=(self.min_scale, 1.0)), 96 | transforms.RandomHorizontalFlip(), 97 | transforms.ToTensor(), 98 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 99 | ]) 100 | elif split == 'val' or split == 'test': 101 | self.transform = transforms.Compose([ 102 | transforms.Resize((224, 224)), 103 | transforms.ToTensor(), 104 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 105 | ]) 106 | else: 107 | raise Exception('Split must be train or val or test!!') -------------------------------------------------------------------------------- /clustering/domain_split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics.cluster import normalized_mutual_info_score 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from clustering import clustering 6 | from scipy.optimize import linear_sum_assignment 7 | 8 | def calc_mean_std(feat, eps=1e-5): 9 | # eps is a small value added to the variance to avoid divide-by-zero. 10 | size = feat.size() 11 | assert (len(size) == 4) 12 | N, C = size[:2] 13 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 14 | feat_std = feat_var.sqrt().view(N, C) 15 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C) 16 | return feat_mean, feat_std 17 | 18 | def reassign(y_before, y_pred): 19 | assert y_before.size == y_pred.size 20 | D = max(y_before.max(), y_pred.max()) + 1 21 | w = np.zeros((D, D), dtype=np.int64) 22 | for i in range(y_before.size): 23 | w[y_before[i], y_pred[i]] += 1 24 | row_ind, col_ind= linear_sum_assignment(w.max() - w) 25 | return col_ind 26 | 27 | def compute_features(dataloader, model, N, device): 28 | model.eval() 29 | # discard the label information in the dataloader 30 | for i, (input_tensor, _, _) in enumerate(dataloader): 31 | with torch.no_grad(): 32 | input_var = input_tensor.to(device) 33 | aux = model.domain_features(input_var).data.cpu().numpy() 34 | if i == 0: 35 | features = np.zeros((N, aux.shape[1])).astype('float32') 36 | 37 | if i < len(dataloader) - 1: 38 | features[i * dataloader.batch_size: (i + 1) * dataloader.batch_size] = aux.astype('float32') 39 | else: 40 | # special treatment for final batch 41 | features[i * dataloader.batch_size:] = aux.astype('float32') 42 | return features 43 | 44 | def compute_instance_stat(dataloader, model, N, device): 45 | model.eval() 46 | for i, (input_tensor, _, _) in enumerate(dataloader): 47 | with torch.no_grad(): 48 | input_var = input_tensor.to(device) 49 | conv_feats = model.conv_features(input_var) 50 | for j, feats in enumerate(conv_feats): 51 | feat_mean, feat_std = calc_mean_std(feats) 52 | if j == 0: 53 | aux = torch.cat((feat_mean, feat_std), 1).data.cpu().numpy() 54 | else: 55 | aux = np.concatenate((aux, torch.cat((feat_mean, feat_std), 1).data.cpu().numpy()), axis=1) 56 | if i == 0: 57 | features = np.zeros((N, aux.shape[1])).astype('float32') 58 | if i < len(dataloader) - 1: 59 | features[i * dataloader.batch_size: (i + 1) * dataloader.batch_size] = aux.astype('float32') 60 | else: 61 | # special treatment for final batch 62 | features[i * dataloader.batch_size:] = aux.astype('float32') 63 | print(features.shape) 64 | return features 65 | 66 | def arrange_clustering(images_lists): 67 | pseudolabels = [] 68 | image_indexes = [] 69 | for cluster, images in enumerate(images_lists): 70 | image_indexes.extend(images) 71 | pseudolabels.extend([cluster] * len(images)) 72 | indexes = np.argsort(image_indexes) 73 | return np.asarray(pseudolabels)[indexes] 74 | 75 | 76 | def domain_split(dataset, model, device, cluster_before, filename, epoch, nmb_cluster=3, method='Kmeans', pca_dim=256, batchsize=128, num_workers=4, whitening=False, L2norm=False, instance_stat=True): 77 | cluster_method = clustering.__dict__[method](nmb_cluster, pca_dim, whitening, L2norm) 78 | 79 | dataset.set_transform('val') 80 | dataloader = DataLoader(dataset, batch_size=batchsize, shuffle=False, num_workers=num_workers) 81 | 82 | if instance_stat: 83 | features = compute_instance_stat(dataloader, model, len(dataset), device) 84 | else: 85 | features = compute_features(dataloader, model, len(dataset), device) 86 | 87 | clustering_loss = cluster_method.cluster(features, verbose=False) 88 | cluster_list = arrange_clustering(cluster_method.images_lists) 89 | 90 | class_nmi = normalized_mutual_info_score( 91 | cluster_list, dataloader.dataset.labels, average_method='geometric') 92 | domain_nmi = normalized_mutual_info_score( 93 | cluster_list, dataloader.dataset.domains, average_method='geometric') 94 | before_nmi = normalized_mutual_info_score( 95 | cluster_list, cluster_before, average_method='arithmetic') 96 | 97 | log = 'Epoch: {}, NMI against class labels: {:.3f}, domain labels: {:.3f}, previous assignment: {:.3f}'.format(epoch, class_nmi, domain_nmi, before_nmi) 98 | print(log) 99 | if filename: 100 | with open(filename, 'a') as f: 101 | f.write(log + '\n') 102 | 103 | mapping = reassign(cluster_before, cluster_list) 104 | cluster_reassign = [cluster_method.images_lists[mapp] for mapp in mapping] 105 | dataset.set_transform(dataset.split) 106 | return arrange_clustering(cluster_reassign) 107 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from model import alexnet, caffenet 2 | from torch import nn, optim 3 | from torch.optim.lr_scheduler import StepLR, ExponentialLR 4 | from sklearn.manifold import TSNE 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | from copy import deepcopy 9 | from torch.nn import init 10 | from model import caffenet, alexnet, resnet 11 | from dataloader.dataloader import * 12 | from clustering.domain_split import calc_mean_std 13 | from sklearn.decomposition import PCA 14 | from util.scheduler import inv_lr_scheduler 15 | 16 | 17 | def show_images(images, cols = 1, titles = None): 18 | assert((titles is None)or (len(images) == len(titles))) 19 | n_images = len(images) 20 | if titles is None: titles = ['Image (%d)' % i for i in range(1,n_images + 1)] 21 | fig = plt.figure() 22 | for n, (image, title) in enumerate(zip(images, titles)): 23 | a = fig.add_subplot(cols, np.ceil(n_images/float(cols)), n + 1) 24 | if image.ndim == 2: 25 | plt.gray() 26 | plt.imshow(image) 27 | a.set_title(title) 28 | fig.set_size_inches(np.array(fig.get_size_inches()) * n_images) 29 | plt.show() 30 | 31 | def set_parameter_requires_grad(model, feature_extracting): 32 | if feature_extracting: 33 | print('model.features parameters are fixed') 34 | for param in model.parameters(): 35 | param.requires_grad = False 36 | 37 | def split_domain(domains, split_idx, print_domain=True): 38 | source_domain = deepcopy(domains) 39 | target_domain = [source_domain.pop(split_idx)] 40 | if print_domain: 41 | print('Source domain: ', end='') 42 | for domain in source_domain: 43 | print(domain, end=', ') 44 | print('Target domain: ', end='') 45 | for domain in target_domain: 46 | print(domain) 47 | return source_domain, target_domain 48 | 49 | domain_map = { 50 | 'PACS': ['photo', 'art_painting', 'cartoon', 'sketch'], 51 | 'PACS_random_split': ['photo', 'art_painting', 'cartoon', 'sketch'], 52 | 'OfficeHome': ['Art', 'Clipart', 'Product', 'RealWorld'], 53 | 'VLCS': ['Caltech', 'Labelme', 'Pascal', 'Sun'] 54 | } 55 | 56 | def get_domain(name): 57 | if name not in domain_map: 58 | raise ValueError('Name of dataset unknown %s' %name) 59 | return domain_map[name] 60 | 61 | nets_map = { 62 | 'caffenet': {'deepall': caffenet.caffenet, 'general': caffenet.DGcaffenet}, 63 | 'alexnet': {'deepall': alexnet.alexnet, 'general': alexnet.DGalexnet}, 64 | 'resnet': {'deepall': resnet.resnet, 'general': resnet.DGresnet} 65 | } 66 | 67 | def get_model(name, train): 68 | if name not in nets_map: 69 | raise ValueError('Name of network unknown %s' % name) 70 | 71 | def get_network_fn(**kwargs): 72 | return nets_map[name][train](**kwargs) 73 | 74 | return get_network_fn 75 | 76 | def get_model_lr(name, train, model, fc_weight=1.0, disc_weight=1.0): 77 | if name not in nets_map: 78 | raise ValueError('Name of network unknown %s' % name) 79 | if train not in train_map: 80 | raise ValueError('Name of train unknown %s' % name) 81 | 82 | if name == 'alexnet' and train == 'deepall': 83 | return [(model.features, 1.0), (model.classifier[:-1], 1.0), (model.classifier[-1], 1.0* fc_weight)] 84 | elif name == 'caffenet' and train == 'deepall': 85 | return [(model.features, 1.0), (model.classifier, 1.0), (model.class_classifier, 1.0 * fc_weight)] 86 | elif name == 'resnet' and train == 'deepall': 87 | return [(model.conv1, 1.0), (model.bn1, 1.0), (model.layer1, 1.0), (model.layer2, 1.0), (model.layer3, 1.0), 88 | (model.layer4, 1.0), (model.fc, 1.0 * fc_weight)] 89 | 90 | elif name == 'alexnet' and train == 'general': 91 | return [(model.base_model.features, 1.0), (model.feature_layers, 1.0), 92 | (model.fc, 1.0 * fc_weight), (model.discriminator, 1.0 * disc_weight)] 93 | elif name == 'caffenet' and train == 'general': 94 | return [(model.base_model.features, 1.0), (model.base_model.classifier, 1.0), 95 | (model.base_model.class_classifier, 1.0 * fc_weight), (model.discriminator, 1.0 * disc_weight)] 96 | elif name == 'resnet' and train == 'general': 97 | return [(model.base_model.conv1, 1.0), (model.base_model.bn1, 1.0), (model.base_model.layer1, 1.0), 98 | (model.base_model.layer2, 1.0), (model.base_model.layer3, 1.0), (model.base_model.layer4, 1.0), 99 | (model.base_model.fc, 1.0 * fc_weight), (model.discriminator, 1.0 * disc_weight)] 100 | 101 | def get_optimizer(model, init_lr, momentum, weight_decay, feature_fixed=False, nesterov=False, per_layer=False): 102 | if feature_fixed: 103 | params_to_update = [] 104 | for name, param in model.named_parameters(): 105 | if param.requires_grad == True: 106 | params_to_update.append(param) 107 | else: 108 | if per_layer: 109 | if not isinstance(model, list): 110 | raise ValueError('Model must be a list type.') 111 | optimizer = optim.SGD( 112 | [{'params': model_.parameters(), 'lr': init_lr*alpha} for model_, alpha in model], 113 | lr=init_lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov) 114 | 115 | else: 116 | params_to_update = model.parameters() 117 | optimizer = optim.SGD( 118 | params_to_update, lr=init_lr, momentum=momentum, 119 | weight_decay=weight_decay, nesterov=nesterov) 120 | 121 | return optimizer 122 | 123 | schedulers_map = { 124 | 'step': StepLR, 125 | 'exponential': ExponentialLR, 126 | 'inv': inv_lr_scheduler 127 | } 128 | 129 | def get_scheduler(name): 130 | if name not in schedulers_map: 131 | raise ValueError('Name of network unknown %s' % name) 132 | 133 | def get_scheduler_fn(**kwargs): 134 | return schedulers_map[name](**kwargs) 135 | 136 | return get_scheduler_fn 137 | 138 | 139 | def train_to_get_label(train, clustering): 140 | if train == 'deepall': 141 | return [False, False] 142 | elif train == 'general' and clustering == True: 143 | return [False, True] 144 | elif train == 'general' and clustering != True: 145 | return [True, False] 146 | else: 147 | raise ValueError('Name of train unknown %s' % train) 148 | 149 | 150 | from train import deepall, general 151 | train_map = { 152 | 'deepall': deepall.train, 153 | 'general': general.train 154 | } 155 | 156 | def get_train(name): 157 | if name not in train_map: 158 | raise ValueError('Name of train unknown %s' % name) 159 | def get_train_fn(**kwargs): 160 | return train_map[name](**kwargs) 161 | return get_train_fn 162 | 163 | def get_disc_dim(name, clustering, domain_num, clustering_num): 164 | if name == 'deepall': 165 | return None 166 | elif name == 'general' and clustering == True: 167 | return clustering_num 168 | elif name == 'general' and clustering != True: 169 | return domain_num 170 | else: 171 | raise ValueError('Name of train unknown %s' % name) 172 | 173 | def copy_weights(net_from, net_to): 174 | for m_from, m_to in zip(net_from.modules(), net_to.modules()): 175 | if isinstance(m_to, nn.Linear) or isinstance(m_to, nn.Conv2d) or isinstance(m_to, nn.BatchNorm2d): 176 | m_to.weight.data = m_from.weight.data.clone() 177 | if m_to.bias is not None: 178 | m_to.bias.data = m_from.bias.data.clone() 179 | return net_from, net_to -------------------------------------------------------------------------------- /main/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | 4 | from torch.utils.data import DataLoader 5 | import torch 6 | import argparse 7 | import os 8 | from util.util import * 9 | from train.eval import * 10 | from clustering.domain_split import domain_split 11 | from dataloader.dataloader import random_split_dataloader 12 | 13 | if __name__ == '__main__': 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--data-root', default='/data/unagi0/matsuura/PACS/spilit/') 18 | parser.add_argument('--save-root', default='/data/unagi0/matsuura/result/dg_mmld') 19 | parser.add_argument('--result-dir', default='default') 20 | parser.add_argument('--train', default='deepall') 21 | parser.add_argument('--data', default='PACS') 22 | parser.add_argument('--model', default='caffenet') 23 | parser.add_argument('--clustering', action='store_true') 24 | parser.add_argument('--clustering-method', default='Kmeans') 25 | parser.add_argument('--num-clustering', type=int, default=3) 26 | parser.add_argument('--clustering-step', type=int, default=1) 27 | parser.add_argument('--entropy', choices=['default', 'maximum_square']) 28 | 29 | parser.add_argument('--exp-num', type=int, default=0) 30 | parser.add_argument('--gpu', type=int, default=0) 31 | 32 | parser.add_argument('--num-epoch', type=int, default=30) 33 | parser.add_argument('--eval-step', type=int, default=1) 34 | parser.add_argument('--save-step', type=int, default=100) 35 | 36 | parser.add_argument('--batch-size', type=int, default=128) 37 | parser.add_argument('--scheduler', default='step') 38 | parser.add_argument('--lr', type=float, default=0.001) 39 | parser.add_argument('--lr-step', type=int, default=24) 40 | parser.add_argument('--lr-decay-gamma', type=float, default=0.1) 41 | parser.add_argument('--momentum', type=float, default=0.9) 42 | parser.add_argument('--weight-decay', type=float, default=5e-4) 43 | parser.add_argument('--nesterov', action='store_true') 44 | 45 | parser.add_argument('--fc-weight', type=float, default=1.0) 46 | parser.add_argument('--disc-weight', type=float, default=1.0) 47 | parser.add_argument('--entropy-weight', type=float, default=1.0) 48 | parser.add_argument('--grl-weight', type=float, default=1.0) 49 | parser.add_argument('--loss-disc-weight', action='store_true') 50 | 51 | parser.add_argument('--color-jitter', action='store_true') 52 | parser.add_argument('--min-scale', type=float, default=0.8) 53 | 54 | parser.add_argument('--instance-stat', action='store_true') 55 | parser.add_argument('--feature-fixed', action='store_true') 56 | args = parser.parse_args() 57 | 58 | path = args.save_root + args.result_dir 59 | if not os.path.isdir(path): 60 | os.makedirs(path) 61 | os.makedirs(path + '/models') 62 | 63 | with open(path+'/args.txt', 'w') as f: 64 | f.write(str(args)) 65 | 66 | domain = get_domain(args.data) 67 | source_domain, target_domain = split_domain(domain, args.exp_num) 68 | 69 | device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") 70 | get_domain_label, get_cluster = train_to_get_label(args.train, args.clustering) 71 | 72 | source_train, source_val, target_test = random_split_dataloader( 73 | data=args.data, data_root=args.data_root, source_domain=source_domain, target_domain=target_domain, 74 | batch_size=args.batch_size, get_domain_label=get_domain_label, get_cluster=get_cluster, num_workers=4, 75 | color_jitter=args.color_jitter, min_scale=args.min_scale) 76 | 77 | # num_epoch = int(args.num_iteration / len(source_train)) 78 | # lr_step = int(args.lr_step / min([len(domain) for domain in source_train])) 79 | # print(num_epoch) 80 | 81 | num_epoch = args.num_epoch 82 | lr_step = args.lr_step 83 | 84 | disc_dim = get_disc_dim(args.train, args.clustering, len(source_domain), args.num_clustering) 85 | 86 | model = get_model(args.model, args.train)( 87 | num_classes=source_train.dataset.dataset.num_class, num_domains=disc_dim, pretrained=True) 88 | 89 | model = model.to(device) 90 | model_lr = get_model_lr(args.model, args.train, model, fc_weight=args.fc_weight, disc_weight=args.disc_weight) 91 | optimizers = [get_optimizer(model_part, args.lr * alpha, args.momentum, args.weight_decay, 92 | args.feature_fixed, args.nesterov, per_layer=False) for model_part, alpha in model_lr] 93 | 94 | if args.scheduler == 'inv': 95 | schedulers = [get_scheduler(args.scheduler)(optimizer=opt, alpha=10, beta=0.75, total_epoch=num_epoch) 96 | for opt in optimizers] 97 | elif args.scheduler == 'step': 98 | schedulers = [get_scheduler(args.scheduler)(optimizer=opt, step_size=lr_step, gamma=args.lr_decay_gamma) 99 | for opt in optimizers] 100 | else: 101 | raise ValueError('Name of scheduler unknown %s' %args.scheduler) 102 | 103 | best_acc = 0.0 104 | test_acc = 0.0 105 | best_epoch = 0 106 | 107 | for epoch in range(num_epoch): 108 | 109 | print('Epoch: {}/{}, Lr: {:.6f}'.format(epoch, num_epoch-1, optimizers[0].param_groups[0]['lr'])) 110 | print('Temporary Best Accuracy is {:.4f} ({:.4f} at Epoch {})'.format(test_acc, best_acc, best_epoch)) 111 | 112 | dataset = source_train.dataset.dataset 113 | 114 | if args.clustering: 115 | if epoch % args.clustering_step == 0: 116 | pseudo_domain_label = domain_split(dataset, model, device=device, 117 | cluster_before=dataset.clusters, 118 | filename= path+'/nmi.txt', epoch=epoch, 119 | nmb_cluster=args.num_clustering, method=args.clustering_method, 120 | pca_dim=256, whitening=False, L2norm=False, instance_stat=args.instance_stat) 121 | dataset.set_cluster(np.array(pseudo_domain_label)) 122 | 123 | if args.loss_disc_weight: 124 | if args.clustering: 125 | hist = dataset.clusters 126 | else: 127 | hist = dataset.domains 128 | 129 | weight = 1. / np.histogram(hist, bins=model.num_domains)[0] 130 | weight = weight / weight.sum() * model.num_domains 131 | weight = torch.from_numpy(weight).float().to(device) 132 | 133 | else: 134 | weight = None 135 | 136 | model, optimizers = get_train(args.train)( 137 | model=model, train_data=source_train, optimizers=optimizers, device=device, 138 | epoch=epoch, num_epoch=num_epoch, filename=path+'/source_train.txt', entropy=args.entropy, 139 | disc_weight=weight, entropy_weight=args.entropy_weight, grl_weight=args.grl_weight) 140 | 141 | if epoch % args.eval_step == 0: 142 | acc = eval_model(model, source_val, device, epoch, path+'/source_eval.txt') 143 | acc_ = eval_model(model, target_test, device, epoch, path+'/target_test.txt') 144 | 145 | if epoch % args.save_step == 0: 146 | torch.save(model.state_dict(), os.path.join( 147 | path, 'models', 148 | "model_{}.pt".format(epoch))) 149 | 150 | if acc >= best_acc: 151 | best_acc = acc 152 | test_acc = acc_ 153 | best_epoch = epoch 154 | torch.save(model.state_dict(), os.path.join( 155 | path, 'models', 156 | "model_best.pt")) 157 | 158 | for scheduler in schedulers: 159 | scheduler.step() 160 | 161 | best_model = get_model(args.model, args.train)(num_classes=source_train.dataset.dataset.num_class, num_domains=disc_dim, pretrained=False) 162 | best_model.load_state_dict(torch.load(os.path.join( 163 | path, 'models', 164 | "model_best.pt"), map_location=device)) 165 | best_model = best_model.to(device) 166 | test_acc = eval_model(best_model, target_test, device, best_epoch, path+'/target_best.txt') 167 | print('Test Accuracy by the best model on the source domain is {} (at Epoch {})'.format(test_acc, best_epoch)) --------------------------------------------------------------------------------