├── .gitignore ├── README.md ├── SFIT ├── datasets │ ├── __init__.py │ └── usps_shot.py ├── loss │ ├── __init__.py │ ├── entropy.py │ ├── js_div.py │ ├── knowledge_distillation.py │ ├── label_smooth.py │ ├── mmd.py │ ├── similarity_preserving.py │ ├── style_loss.py │ └── total_variation.py ├── models │ ├── __init__.py │ ├── classifier_shot.py │ └── cyclegan.py ├── trainers │ ├── __init__.py │ ├── da_trainer.py │ └── sfit_trainer.py └── utils │ ├── logger.py │ ├── meters.py │ └── str2bool.py ├── gpu012.sh ├── train_DA.py └── train_SFIT.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | *.pyc 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | 126 | 127 | 128 | .idea/ 129 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visualizing Adapted Knowledge in Domain Transfer 2 | 3 | ``` 4 | @inproceedings{hou2021visualizing, 5 | title={Visualizing Adapted Knowledge in Domain Transfer}, 6 | author={Hou, Yunzhong and Zheng, Liang}, 7 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 8 | year={2021} 9 | } 10 | ``` 11 | 12 | ## Under construction 13 | 14 | 15 | ## Overview 16 | 17 | This repo dedicates to visualize the learned knowledge in domain adaptation. 18 | To understand the adaptation process, we portray the knowledge difference between the source and target model with image translation, using the source-free image translation (SFIT) method proposed in our [CVPR2021](http://cvpr2021.thecvf.com/) paper *[Visualizing Adapted Knowledge in Domain Transfer](https://arxiv.org/abs/2104.10602)*. 19 | 20 | Specifically, we feed the generated source-style image to the source model, and the original target image to the target model, formulating two branches respectively. 21 | Through update the generated image, we force similar outputs between the two branches. When such requirements are met, the image difference should compensate for and can represent the knowledge difference between models. 22 | 23 | ## Content 24 | - [Dependencies](#dependencies) 25 | - [Data Preparation](#data-preparation) 26 | - [Run the Code](#run-the-code) 27 | * [Train source and target models](#train-source-and-target-models) 28 | * [Visualization](#visualization) 29 | 30 | 31 | ## Dependencies 32 | This code uses the following libraries 33 | - python 3.7+ 34 | - pytorch 1.6+ & torchvision 35 | - numpy 36 | - matplotlib 37 | - pillow 38 | - scikit-learn 39 | 40 | ## Data Preparation 41 | By default, all datasets are in `~/Data/`. We use digits (automatically downloaded), [Office-31](https://people.eecs.berkeley.edu/~jhoffman/domainadapt/), and [VisDA](http://ai.bu.edu/visda-2017/) datasets. 42 | 43 | Your `~/Data/` folder should look like this 44 | ``` 45 | Data 46 | ├── digits/ 47 | │ └── ... 48 | ├── office31/ 49 | │ └── ... 50 | └── visda/ 51 | └── ... 52 | ``` 53 | 54 | ## Run the Code 55 | 56 | ### Train source and target models 57 | Once the data preparation is finished, you can train source and target models using unsupervised domain adaptation (UDA) methods 58 | ```shell script 59 | python train_DA.py -d digits --source svhn --target mnist 60 | ``` 61 | Currently, we support [MMD]() ```--da_setting mmd```, [ADDA]() ```--da_setting adda```, and [SHOT]() ```--da_setting shot```. 62 | 63 | ### Visualization 64 | Based on the trained source and target models, we visualize their knowledge difference via SFIT 65 | ```shell script 66 | python train_SFIT.py -d digits --source svhn --target mnist 67 | ``` 68 | -------------------------------------------------------------------------------- /SFIT/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import SVHN, CIFAR10, ImageFolder, MNIST 2 | from .usps_shot import USPS 3 | -------------------------------------------------------------------------------- /SFIT/datasets/usps_shot.py: -------------------------------------------------------------------------------- 1 | """Dataset setting and data loader for USPS. 2 | Modified from 3 | https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py 4 | """ 5 | 6 | import gzip 7 | import os 8 | import pickle 9 | import urllib 10 | from PIL import Image 11 | 12 | import numpy as np 13 | import torch 14 | import torch.utils.data as data 15 | from torch.utils.data.sampler import WeightedRandomSampler 16 | from torchvision import datasets, transforms 17 | 18 | 19 | class USPS(data.Dataset): 20 | """USPS Dataset. 21 | Args: 22 | root (string): Root directory of dataset where dataset file exist. 23 | train (bool, optional): If True, resample from dataset randomly. 24 | download (bool, optional): If true, downloads the dataset 25 | from the internet and puts it in root directory. 26 | If dataset is already downloaded, it is not downloaded again. 27 | transform (callable, optional): A function/transform that takes in 28 | an PIL image and returns a transformed version. 29 | E.g, ``transforms.RandomCrop`` 30 | """ 31 | 32 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl" 33 | 34 | def __init__(self, root, train=True, transform=None, download=False): 35 | """Init USPS dataset.""" 36 | # init params 37 | self.root = os.path.expanduser(root) 38 | self.filename = "usps_28x28.pkl" 39 | self.train = train 40 | # Num of Train = 7438, Num ot Test 1860 41 | self.transform = transform 42 | self.dataset_size = None 43 | 44 | # download dataset. 45 | if download: 46 | self.download() 47 | if not self._check_exists(): 48 | raise RuntimeError("Dataset not found." + 49 | " You can use download=True to download it") 50 | 51 | self.train_data, self.train_labels = self.load_samples() 52 | if self.train: 53 | total_num_samples = self.train_labels.shape[0] 54 | indices = np.arange(total_num_samples) 55 | self.train_data = self.train_data[indices[0:self.dataset_size], ::] 56 | self.train_labels = self.train_labels[indices[0:self.dataset_size]] 57 | self.train_data *= 255.0 58 | self.train_data = np.squeeze(self.train_data).astype(np.uint8) 59 | 60 | def __getitem__(self, index): 61 | """Get images and target for data loader. 62 | Args: 63 | index (int): Index 64 | Returns: 65 | tuple: (image, target) where target is index of the target class. 66 | """ 67 | img, label = self.train_data[index], self.train_labels[index] 68 | img = Image.fromarray(img, mode='L') 69 | img = img.copy() 70 | if self.transform is not None: 71 | img = self.transform(img) 72 | return img, label.astype("int64") 73 | 74 | def __len__(self): 75 | """Return size of dataset.""" 76 | return len(self.train_data) 77 | 78 | def _check_exists(self): 79 | """Check if dataset is download and in right place.""" 80 | return os.path.exists(os.path.join(self.root, self.filename)) 81 | 82 | def download(self): 83 | """Download dataset.""" 84 | filename = os.path.join(self.root, self.filename) 85 | dirname = os.path.dirname(filename) 86 | if not os.path.isdir(dirname): 87 | os.makedirs(dirname) 88 | if os.path.isfile(filename): 89 | return 90 | print("Download %s to %s" % (self.url, os.path.abspath(filename))) 91 | urllib.request.urlretrieve(self.url, filename) 92 | print("[DONE]") 93 | return 94 | 95 | def load_samples(self): 96 | """Load sample images from dataset.""" 97 | filename = os.path.join(self.root, self.filename) 98 | f = gzip.open(filename, "rb") 99 | data_set = pickle.load(f, encoding="bytes") 100 | f.close() 101 | if self.train: 102 | images = data_set[0][0] 103 | labels = data_set[0][1] 104 | self.dataset_size = labels.shape[0] 105 | else: 106 | images = data_set[1][0] 107 | labels = data_set[1][1] 108 | self.dataset_size = labels.shape[0] 109 | return images, labels 110 | 111 | 112 | class USPS_idx(data.Dataset): 113 | """USPS Dataset. 114 | Args: 115 | root (string): Root directory of dataset where dataset file exist. 116 | train (bool, optional): If True, resample from dataset randomly. 117 | download (bool, optional): If true, downloads the dataset 118 | from the internet and puts it in root directory. 119 | If dataset is already downloaded, it is not downloaded again. 120 | transform (callable, optional): A function/transform that takes in 121 | an PIL image and returns a transformed version. 122 | E.g, ``transforms.RandomCrop`` 123 | """ 124 | 125 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl" 126 | 127 | def __init__(self, root, train=True, transform=None, download=False): 128 | """Init USPS dataset.""" 129 | # init params 130 | self.root = os.path.expanduser(root) 131 | self.filename = "usps_28x28.pkl" 132 | self.train = train 133 | # Num of Train = 7438, Num ot Test 1860 134 | self.transform = transform 135 | self.dataset_size = None 136 | 137 | # download dataset. 138 | if download: 139 | self.download() 140 | if not self._check_exists(): 141 | raise RuntimeError("Dataset not found." + 142 | " You can use download=True to download it") 143 | 144 | self.train_data, self.train_labels = self.load_samples() 145 | if self.train: 146 | total_num_samples = self.train_labels.shape[0] 147 | indices = np.arange(total_num_samples) 148 | self.train_data = self.train_data[indices[0:self.dataset_size], ::] 149 | self.train_labels = self.train_labels[indices[0:self.dataset_size]] 150 | self.train_data *= 255.0 151 | self.train_data = np.squeeze(self.train_data).astype(np.uint8) 152 | 153 | def __getitem__(self, index): 154 | """Get images and target for data loader. 155 | Args: 156 | index (int): Index 157 | Returns: 158 | tuple: (image, target) where target is index of the target class. 159 | """ 160 | img, label = self.train_data[index], self.train_labels[index] 161 | img = Image.fromarray(img, mode='L') 162 | img = img.copy() 163 | if self.transform is not None: 164 | img = self.transform(img) 165 | return img, label.astype("int64"), index 166 | 167 | def __len__(self): 168 | """Return size of dataset.""" 169 | return len(self.train_data) 170 | 171 | def _check_exists(self): 172 | """Check if dataset is download and in right place.""" 173 | return os.path.exists(os.path.join(self.root, self.filename)) 174 | 175 | def download(self): 176 | """Download dataset.""" 177 | filename = os.path.join(self.root, self.filename) 178 | dirname = os.path.dirname(filename) 179 | if not os.path.isdir(dirname): 180 | os.makedirs(dirname) 181 | if os.path.isfile(filename): 182 | return 183 | print("Download %s to %s" % (self.url, os.path.abspath(filename))) 184 | urllib.request.urlretrieve(self.url, filename) 185 | print("[DONE]") 186 | return 187 | 188 | def load_samples(self): 189 | """Load sample images from dataset.""" 190 | filename = os.path.join(self.root, self.filename) 191 | f = gzip.open(filename, "rb") 192 | data_set = pickle.load(f, encoding="bytes") 193 | f.close() 194 | if self.train: 195 | images = data_set[0][0] 196 | labels = data_set[0][1] 197 | self.dataset_size = labels.shape[0] 198 | else: 199 | images = data_set[1][0] 200 | labels = data_set[1][1] 201 | self.dataset_size = labels.shape[0] 202 | return images, labels -------------------------------------------------------------------------------- /SFIT/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .knowledge_distillation import KDLoss 2 | from .label_smooth import LabelSmoothLoss 3 | from .mmd import MMDLoss 4 | from .js_div import JSDivLoss 5 | from .total_variation import TotalVariationLoss 6 | from .entropy import HLoss 7 | from .similarity_preserving import BatchSimLoss, PixelSimLoss, ChannelSimLoss, ChannelSimLoss1D 8 | from .style_loss import StyleLoss 9 | -------------------------------------------------------------------------------- /SFIT/loss/entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class HLoss(nn.Module): 7 | def __init__(self): 8 | super(HLoss, self).__init__() 9 | 10 | def forward(self, x): 11 | b = -F.softmax(x, dim=1) * F.log_softmax(x, dim=1) 12 | b = b.sum(dim=1).mean() 13 | return b 14 | -------------------------------------------------------------------------------- /SFIT/loss/js_div.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class JSDivLoss(nn.Module): 6 | def __init__(self): 7 | super(JSDivLoss, self).__init__() 8 | 9 | def forward(self, p_output, q_output): 10 | # 1/2*KL(p,m) + 1/2*KL(q,m) 11 | p = F.softmax(p_output, dim=1) 12 | q = F.softmax(q_output, dim=1) 13 | log_m = (0.5 * (p + q)).log() 14 | # F.kl_div(x, y) -> F.kl_div(log_q, p) 15 | l_js = 0.5 * (F.kl_div(log_m, p) + F.kl_div(log_m, q)) 16 | return l_js 17 | -------------------------------------------------------------------------------- /SFIT/loss/knowledge_distillation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class KDLoss(nn.Module): 6 | def __init__(self, temperature=1): 7 | super(KDLoss, self).__init__() 8 | self.temperature = temperature 9 | 10 | def forward(self, student_output, teacher_output): 11 | """ 12 | NOTE: the KL Divergence for PyTorch comparing the prob of teacher and log prob of student, 13 | mimicking the prob of ground truth (one-hot) and log prob of network in CE loss 14 | """ 15 | # x -> input -> log(q) 16 | log_q = F.log_softmax(student_output / self.temperature, dim=1) 17 | # y -> target -> p 18 | p = F.softmax(teacher_output / self.temperature, dim=1) 19 | # F.kl_div(x, y) -> F.kl_div(log_q, p) 20 | # l_n = y_n \cdot \left( \log y_n - x_n \right) = p * log(p/q) 21 | l_kl = F.kl_div(log_q, p) # forward KL 22 | return l_kl 23 | -------------------------------------------------------------------------------- /SFIT/loss/label_smooth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LabelSmoothLoss(nn.Module): 7 | def __init__(self, e=0.1): 8 | super(LabelSmoothLoss, self).__init__() 9 | self.e = e 10 | 11 | def forward(self, x, target): 12 | target = torch.zeros_like(x).scatter_(1, target.unsqueeze(1), 1) 13 | smoothed_target = (1 - self.e) * target + self.e / x.size(1) 14 | loss = (- F.log_softmax(x, dim=1) * smoothed_target).sum(dim=1) 15 | return loss.mean() 16 | 17 | 18 | if __name__ == '__main__': 19 | loss = LabelSmoothLoss() 20 | output = torch.randn(64, 10) 21 | label = torch.randint(0, 10, [64]) 22 | loss(output, label) 23 | -------------------------------------------------------------------------------- /SFIT/loss/mmd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class MMDLoss(nn.Module): 7 | def __init__(self, 8 | base=1.0, 9 | sigma_list=[1, 2, 10]): 10 | super(MMDLoss, self).__init__() 11 | 12 | # sigma for MMD 13 | self.base = base 14 | self.sigma_list = sigma_list 15 | self.sigma_list = [sigma / self.base for sigma in self.sigma_list] 16 | 17 | def forward(self, Target, Source): 18 | mmd2_D = mix_rbf_mmd2(Target, Source, self.sigma_list) 19 | mmd2_D = F.relu(mmd2_D) 20 | mmd2_D = torch.sqrt(mmd2_D) 21 | 22 | return mmd2_D 23 | 24 | 25 | min_var_est = 1e-8 26 | 27 | 28 | # Consider linear time MMD with a linear kernel: 29 | # K(f(x), f(y)) = f(x)^Tf(y) 30 | # h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i) 31 | # = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)] 32 | # 33 | # f_of_X: batch_size * k 34 | # f_of_Y: batch_size * k 35 | def linear_mmd2(f_of_X, f_of_Y): 36 | loss = 0.0 37 | delta = f_of_X - f_of_Y 38 | loss = torch.mean((delta[:-1] * delta[1:]).sum(1)) 39 | return loss 40 | 41 | 42 | # Consider linear time MMD with a polynomial kernel: 43 | # K(f(x), f(y)) = (alpha*f(x)^Tf(y) + c)^d 44 | # f_of_X: batch_size * k 45 | # f_of_Y: batch_size * k 46 | def poly_mmd2(f_of_X, f_of_Y, d=2, alpha=1.0, c=2.0): 47 | K_XX = (alpha * (f_of_X[:-1] * f_of_X[1:]).sum(1) + c) 48 | K_XX_mean = torch.mean(K_XX.pow(d)) 49 | 50 | K_YY = (alpha * (f_of_Y[:-1] * f_of_Y[1:]).sum(1) + c) 51 | K_YY_mean = torch.mean(K_YY.pow(d)) 52 | 53 | K_XY = (alpha * (f_of_X[:-1] * f_of_Y[1:]).sum(1) + c) 54 | K_XY_mean = torch.mean(K_XY.pow(d)) 55 | 56 | K_YX = (alpha * (f_of_Y[:-1] * f_of_X[1:]).sum(1) + c) 57 | K_YX_mean = torch.mean(K_YX.pow(d)) 58 | 59 | return K_XX_mean + K_YY_mean - K_XY_mean - K_YX_mean 60 | 61 | 62 | def _mix_rbf_kernel(X, Y, sigma_list): 63 | assert (X.size(0) == Y.size(0)) 64 | m = X.size(0) 65 | 66 | Z = torch.cat((X, Y), 0) 67 | ZZT = torch.mm(Z, Z.t()) 68 | diag_ZZT = torch.diag(ZZT).unsqueeze(1) 69 | Z_norm_sqr = diag_ZZT.expand_as(ZZT) 70 | exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t() 71 | 72 | K = 0.0 73 | for sigma in sigma_list: 74 | gamma = 1.0 / (2 * sigma ** 2) 75 | K += torch.exp(-gamma * exponent) 76 | 77 | return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list) 78 | 79 | 80 | def mix_rbf_mmd2(X, Y, sigma_list, biased=True): 81 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) 82 | # return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) 83 | return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) 84 | 85 | 86 | def mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=True): 87 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) 88 | # return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) 89 | return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) 90 | 91 | 92 | ################################################################################ 93 | # Helper functions to compute variances based on kernel matrices 94 | ################################################################################ 95 | 96 | 97 | def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 98 | m = K_XX.size(0) # assume X, Y are same shape 99 | 100 | # Get the various sums of kernels that we'll use 101 | # Kts drop the diagonal, but we don't need to compute them explicitly 102 | if const_diagonal is not False: 103 | diag_X = diag_Y = const_diagonal 104 | sum_diag_X = sum_diag_Y = m * const_diagonal 105 | else: 106 | diag_X = torch.diag(K_XX) # (m,) 107 | diag_Y = torch.diag(K_YY) # (m,) 108 | sum_diag_X = torch.sum(diag_X) 109 | sum_diag_Y = torch.sum(diag_Y) 110 | 111 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X 112 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y 113 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e 114 | 115 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e 116 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e 117 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e 118 | 119 | if biased: 120 | mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) 121 | + (Kt_YY_sum + sum_diag_Y) / (m * m) 122 | - 2.0 * K_XY_sum / (m * m)) 123 | else: 124 | mmd2 = (Kt_XX_sum / (m * (m - 1)) 125 | + Kt_YY_sum / (m * (m - 1)) 126 | - 2.0 * K_XY_sum / (m * m)) 127 | 128 | return mmd2 129 | 130 | 131 | def _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 132 | mmd2, var_est = _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=const_diagonal, biased=biased) 133 | loss = mmd2 / torch.sqrt(torch.clamp(var_est, min=min_var_est)) 134 | return loss, mmd2, var_est 135 | 136 | 137 | def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 138 | m = K_XX.size(0) # assume X, Y are same shape 139 | 140 | # Get the various sums of kernels that we'll use 141 | # Kts drop the diagonal, but we don't need to compute them explicitly 142 | if const_diagonal is not False: 143 | diag_X = diag_Y = const_diagonal 144 | sum_diag_X = sum_diag_Y = m * const_diagonal 145 | sum_diag2_X = sum_diag2_Y = m * const_diagonal ** 2 146 | else: 147 | diag_X = torch.diag(K_XX) # (m,) 148 | diag_Y = torch.diag(K_YY) # (m,) 149 | sum_diag_X = torch.sum(diag_X) 150 | sum_diag_Y = torch.sum(diag_Y) 151 | sum_diag2_X = diag_X.dot(diag_X) 152 | sum_diag2_Y = diag_Y.dot(diag_Y) 153 | 154 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X 155 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y 156 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e 157 | K_XY_sums_1 = K_XY.sum(dim=1) # K_{XY} * e 158 | 159 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e 160 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e 161 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e 162 | 163 | Kt_XX_2_sum = (K_XX ** 2).sum() - sum_diag2_X # \| \tilde{K}_XX \|_F^2 164 | Kt_YY_2_sum = (K_YY ** 2).sum() - sum_diag2_Y # \| \tilde{K}_YY \|_F^2 165 | K_XY_2_sum = (K_XY ** 2).sum() # \| K_{XY} \|_F^2 166 | 167 | if biased: 168 | mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) 169 | + (Kt_YY_sum + sum_diag_Y) / (m * m) 170 | - 2.0 * K_XY_sum / (m * m)) 171 | else: 172 | mmd2 = (Kt_XX_sum / (m * (m - 1)) 173 | + Kt_YY_sum / (m * (m - 1)) 174 | - 2.0 * K_XY_sum / (m * m)) 175 | 176 | var_est = ( 177 | 2.0 / (m ** 2 * (m - 1.0) ** 2) * ( 178 | 2 * Kt_XX_sums.dot(Kt_XX_sums) - Kt_XX_2_sum + 2 * Kt_YY_sums.dot(Kt_YY_sums) - Kt_YY_2_sum) 179 | - (4.0 * m - 6.0) / (m ** 3 * (m - 1.0) ** 3) * (Kt_XX_sum ** 2 + Kt_YY_sum ** 2) 180 | + 4.0 * (m - 2.0) / (m ** 3 * (m - 1.0) ** 2) * ( 181 | K_XY_sums_1.dot(K_XY_sums_1) + K_XY_sums_0.dot(K_XY_sums_0)) 182 | - 4.0 * (m - 3.0) / (m ** 3 * (m - 1.0) ** 2) * (K_XY_2_sum) - (8 * m - 12) / ( 183 | m ** 5 * (m - 1)) * K_XY_sum ** 2 184 | + 8.0 / (m ** 3 * (m - 1.0)) * ( 185 | 1.0 / m * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum 186 | - Kt_XX_sums.dot(K_XY_sums_1) 187 | - Kt_YY_sums.dot(K_XY_sums_0)) 188 | ) 189 | return mmd2, var_est 190 | -------------------------------------------------------------------------------- /SFIT/loss/similarity_preserving.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BatchSimLoss(nn.Module): 7 | def __init__(self): 8 | super(BatchSimLoss, self).__init__() 9 | 10 | def forward(self, featmap_src_T, featmap_tgt_S): 11 | B, C, H, W = featmap_src_T.shape 12 | f_src, f_tgt = featmap_src_T.view([B, C * H * W]), featmap_tgt_S.view([B, C * H * W]) 13 | A_src, A_tgt = f_src @ f_src.T, f_tgt @ f_tgt.T 14 | A_src, A_tgt = F.normalize(A_src, p=2, dim=1), F.normalize(A_tgt, p=2, dim=1) 15 | loss_batch = torch.norm(A_src - A_tgt) ** 2 / B 16 | return loss_batch 17 | 18 | 19 | class PixelSimLoss(nn.Module): 20 | def __init__(self): 21 | super(PixelSimLoss, self).__init__() 22 | 23 | def forward(self, featmap_src_T, featmap_tgt_S): 24 | B, C, H, W = featmap_src_T.shape 25 | loss_pixel = 0 26 | for b in range(B): 27 | f_src, f_tgt = featmap_src_T[b].view([C, H * W]), featmap_tgt_S[b].view([C, H * W]) 28 | A_src, A_tgt = f_src.T @ f_src, f_tgt.T @ f_tgt 29 | A_src, A_tgt = F.normalize(A_src, p=2, dim=1), F.normalize(A_tgt, p=2, dim=1) 30 | loss_pixel += torch.norm(A_src - A_tgt) ** 2 / (H * W) 31 | loss_pixel /= B 32 | return loss_pixel 33 | 34 | 35 | class ChannelSimLoss(nn.Module): 36 | def __init__(self): 37 | super(ChannelSimLoss, self).__init__() 38 | 39 | def forward(self, featmap_src_T, featmap_tgt_S): 40 | B, C, H, W = featmap_src_T.shape 41 | loss = 0 42 | for b in range(B): 43 | f_src, f_tgt = featmap_src_T[b].view([C, H * W]), featmap_tgt_S[b].view([C, H * W]) 44 | A_src, A_tgt = f_src @ f_src.T, f_tgt @ f_tgt.T 45 | A_src, A_tgt = F.normalize(A_src, p=2, dim=1), F.normalize(A_tgt, p=2, dim=1) 46 | loss += torch.norm(A_src - A_tgt) ** 2 / C 47 | # loss += torch.norm(A_src - A_tgt, p=1) 48 | loss /= B 49 | return loss 50 | 51 | 52 | class ChannelSimLoss1D(nn.Module): 53 | def __init__(self): 54 | super(ChannelSimLoss1D, self).__init__() 55 | 56 | def forward(self, feat_src_T, feat_tgt_S): 57 | B, C = feat_src_T.shape 58 | loss = torch.zeros([]).cuda() 59 | for b in range(B): 60 | f_src, f_tgt = feat_src_T[b].view([C, 1]), feat_tgt_S[b].view([C, 1]) 61 | A_src, A_tgt = f_src @ f_src.T, f_tgt @ f_tgt.T 62 | A_src, A_tgt = F.normalize(A_src, p=2, dim=1), F.normalize(A_tgt, p=2, dim=1) 63 | loss += torch.norm(A_src - A_tgt) ** 2 / C 64 | # loss += torch.norm(A_src - A_tgt, p=1) 65 | loss /= B 66 | return loss 67 | 68 | 69 | if __name__ == '__main__': 70 | from SFIT.loss.style_loss import StyleLoss 71 | 72 | feat1, feat2 = torch.ones([16, 2048, 7, 7]), torch.zeros([16, 2048, 7, 7]) 73 | batch_loss = BatchSimLoss() 74 | l1 = batch_loss(feat1, feat2) 75 | pixel_loss = PixelSimLoss() 76 | l2 = pixel_loss(feat1, feat2) 77 | channel_loss = ChannelSimLoss() 78 | l3 = channel_loss(feat1, feat2) 79 | style_loss = StyleLoss() 80 | l4 = style_loss(feat1, feat2) 81 | 82 | feat1, feat2 = torch.ones([16, 2048]), torch.zeros([16, 2048]) 83 | channel_loss = ChannelSimLoss1D() 84 | l1 = channel_loss(feat1, feat2) 85 | pass 86 | -------------------------------------------------------------------------------- /SFIT/loss/style_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class StyleLoss(nn.Module): 7 | def __init__(self): 8 | super(StyleLoss, self).__init__() 9 | 10 | def forward(self, featmap_src_T, featmap_tgt_S): 11 | B, C, H, W = featmap_src_T.shape 12 | f_src, f_tgt = featmap_src_T.view([B, C, H * W]), featmap_tgt_S.view([B, C, H * W]) 13 | # calculate Gram matrices 14 | A_src, A_tgt = torch.bmm(f_src, f_src.transpose(1, 2)), torch.bmm(f_tgt, f_tgt.transpose(1, 2)) 15 | A_src, A_tgt = A_src / (H * W), A_tgt / (H * W) 16 | loss = F.mse_loss(A_src, A_tgt) 17 | return loss 18 | 19 | 20 | if __name__ == '__main__': 21 | feat1, feat2 = torch.ones([16, 2048, 7, 7]), torch.zeros([16, 2048, 7, 7]) 22 | style_loss = StyleLoss() 23 | l1 = style_loss(feat1, feat2) 24 | pass 25 | -------------------------------------------------------------------------------- /SFIT/loss/total_variation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class TotalVariationLoss(nn.Module): 5 | def __init__(self): 6 | super(TotalVariationLoss, self).__init__() 7 | 8 | def forward(self, image): 9 | # COMPUTE total variation regularization loss 10 | loss_var_l2 = ((image[:, :, :, 1:] - image[:, :, :, :-1]) ** 2).mean() + \ 11 | ((image[:, :, 1:, :] - image[:, :, :-1, :]) ** 2).mean() 12 | 13 | loss_var_l1 = ((image[:, :, :, 1:] - image[:, :, :, :-1]).abs()).mean() + \ 14 | ((image[:, :, 1:, :] - image[:, :, :-1, :]).abs()).mean() 15 | return loss_var_l1, loss_var_l2 16 | -------------------------------------------------------------------------------- /SFIT/models/__init__.py: -------------------------------------------------------------------------------- 1 | # from .alexnet import AlexNet 2 | # from .vgg import VGG16 3 | # from .resnet import ResNet18, ResNet50, ResNet152 4 | # 5 | # custom_factory = { 6 | # 'alexnet': AlexNet, 7 | # 'vgg16': VGG16, 8 | # 'resnet18': ResNet18, 9 | # 'resnet50': ResNet50, 10 | # 'resnet152': ResNet152, 11 | # } 12 | 13 | from torchvision.models.alexnet import AlexNet 14 | from torchvision.models.vgg import vgg16_bn 15 | from torchvision.models.resnet import resnet18, resnet50, resnet152 16 | 17 | torchvision_factory = { 18 | 'alexnet': AlexNet, 19 | 'vgg16': vgg16_bn, 20 | 'resnet18': resnet18, 21 | 'resnet50': resnet50, 22 | 'resnet152': resnet152, 23 | } 24 | 25 | 26 | def names(): 27 | return sorted(torchvision_factory.keys()) 28 | 29 | 30 | def create(name, num_classes, pretrained=False): 31 | """ 32 | Create a model instance. 33 | """ 34 | if name not in torchvision_factory: 35 | raise KeyError("Unknown model:", name) 36 | return torchvision_factory[name](num_classes=num_classes, pretrained=pretrained) 37 | -------------------------------------------------------------------------------- /SFIT/models/classifier_shot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.utils.weight_norm as weight_norm 4 | from torchvision.models.resnet import resnet50, resnet101 5 | 6 | 7 | def init_weights(m): 8 | classname = m.__class__.__name__ 9 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1: 10 | nn.init.kaiming_uniform_(m.weight) 11 | nn.init.zeros_(m.bias) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight, 1.0, 0.02) 14 | nn.init.zeros_(m.bias) 15 | elif classname.find('Linear') != -1: 16 | nn.init.xavier_normal_(m.weight) 17 | if m.bias is not None: 18 | nn.init.zeros_(m.bias) 19 | 20 | 21 | class ClassifierShot(nn.Module): 22 | def __init__(self, num_classes, arch='lenet', bottleneck_dim=256, use_shot=True): 23 | super(ClassifierShot, self).__init__() 24 | 25 | self.arch = arch 26 | dropout = 0.0 27 | if arch == 'lenet': 28 | self.base = LeNetBase() 29 | layer_ids = [2, 6] 30 | dropout = 0.5 31 | elif arch == 'dtn': 32 | self.base = DTNBase() 33 | layer_ids = [3, 7, 11] 34 | dropout = 0.5 35 | elif arch == 'resnet50': 36 | self.base = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2]) 37 | self.base.out_features = 2048 38 | layer_ids = [4, 5, 6, 7] 39 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 40 | elif arch == 'resnet101': 41 | self.base = nn.Sequential(*list(resnet101(pretrained=True).children())[:-2]) 42 | self.base.out_features = 2048 43 | layer_ids = [4, 5, 6, 7] 44 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 45 | else: 46 | raise Exception 47 | self.bottleneck = feat_bootleneck(self.base.out_features, bottleneck_dim, dropout, use_shot) 48 | self.classifier = feat_classifier(num_classes, bottleneck_dim, use_shot) 49 | 50 | # hook for record featmaps 51 | def store_featmap(module, inputs, output): 52 | # input is a tuple of packed inputs 53 | # output is a Tensor 54 | self.featmaps.append(output) 55 | 56 | self.featmaps = [] 57 | for layer_id in layer_ids: 58 | if 'resnet' in arch: 59 | self.base[layer_id].register_forward_hook(store_featmap) 60 | else: 61 | self.base.conv_params[layer_id].register_forward_hook(store_featmap) 62 | if use_shot: 63 | self.bottleneck.bn.register_forward_hook(store_featmap) 64 | else: 65 | self.bottleneck.bottleneck.register_forward_hook(store_featmap) 66 | 67 | def forward(self, x, out_featmaps=False): 68 | self.featmaps = [] 69 | x = self.base(x) 70 | if 'resnet' in self.arch: 71 | x = self.avgpool(x) 72 | x = x.view(x.size(0), -1) 73 | x = self.bottleneck(x) 74 | label = self.classifier(x) 75 | 76 | if out_featmaps: 77 | return (label, self.featmaps) 78 | else: 79 | return label 80 | 81 | 82 | class feat_bootleneck(nn.Module): 83 | def __init__(self, feature_dim, bottleneck_dim=256, dropout=0.5, use_shot=True): 84 | super(feat_bootleneck, self).__init__() 85 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim) 86 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True) 87 | self.dropout = nn.Dropout(p=dropout) 88 | self.bottleneck.apply(init_weights) 89 | self.use_shot = use_shot 90 | 91 | def forward(self, x): 92 | x = self.bottleneck(x) 93 | if self.use_shot: 94 | x = self.bn(x) 95 | x = self.dropout(x) 96 | return x 97 | 98 | 99 | class feat_classifier(nn.Module): 100 | def __init__(self, class_num, bottleneck_dim=256, use_shot=True): 101 | super(feat_classifier, self).__init__() 102 | 103 | self.fc = nn.Linear(bottleneck_dim, class_num, bias=False) 104 | if use_shot: 105 | self.fc = weight_norm(self.fc, name="weight") 106 | self.fc.apply(init_weights) 107 | 108 | def forward(self, x): 109 | x = self.fc(x) 110 | return x 111 | 112 | 113 | class DTNBase(nn.Module): 114 | def __init__(self): 115 | super(DTNBase, self).__init__() 116 | self.conv_params = nn.Sequential(nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2), 117 | nn.BatchNorm2d(64), nn.Dropout2d(0.1), nn.ReLU(), 118 | nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2), 119 | nn.BatchNorm2d(128), nn.Dropout2d(0.3), nn.ReLU(), 120 | nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2), 121 | nn.BatchNorm2d(256), nn.Dropout2d(0.5), nn.ReLU()) 122 | self.out_features = 256 * 4 * 4 123 | 124 | def forward(self, x): 125 | x = self.conv_params(x) 126 | return x 127 | 128 | 129 | class LeNetBase(nn.Module): 130 | def __init__(self): 131 | super(LeNetBase, self).__init__() 132 | self.conv_params = nn.Sequential(nn.Conv2d(1, 20, kernel_size=5), 133 | nn.MaxPool2d(2), nn.ReLU(), 134 | nn.Conv2d(20, 50, kernel_size=5), 135 | nn.Dropout2d(p=0.5), nn.MaxPool2d(2), nn.ReLU(), ) 136 | self.out_features = 50 * 4 * 4 137 | 138 | def forward(self, x): 139 | x = self.conv_params(x) 140 | return x 141 | 142 | 143 | class Discriminator(nn.Module): 144 | def __init__(self, in_channel): 145 | super(Discriminator, self).__init__() 146 | self.model = nn.Sequential(nn.Linear(in_channel, 256), nn.ReLU(), 147 | nn.Linear(256, 256), nn.ReLU(), 148 | nn.Linear(256, 1)) 149 | 150 | def forward(self, x): 151 | x = self.model(x) 152 | return x 153 | 154 | 155 | if __name__ == '__main__': 156 | net = ClassifierShot(10) 157 | img = torch.zeros([64, 1, 28, 28]) 158 | label = net(img) 159 | net = ClassifierShot(10, arch='resnet50') 160 | img = torch.zeros([64, 3, 224, 224]) 161 | label, featmaps = net(img, out_featmaps=True) 162 | pass 163 | -------------------------------------------------------------------------------- /SFIT/models/cyclegan.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | def weights_init_normal(m): 7 | classname = m.__class__.__name__ 8 | if classname.find("Conv") != -1: 9 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 10 | if hasattr(m, "bias") and m.bias is not None: 11 | torch.nn.init.constant_(m.bias.data, 0.0) 12 | elif classname.find("BatchNorm2d") != -1: 13 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | torch.nn.init.constant_(m.bias.data, 0.0) 15 | 16 | 17 | class ResidualBlock(nn.Module): 18 | def __init__(self, in_features): 19 | super(ResidualBlock, self).__init__() 20 | 21 | self.block = nn.Sequential( 22 | nn.ReflectionPad2d(1), 23 | nn.Conv2d(in_features, in_features, 3), 24 | nn.InstanceNorm2d(in_features), 25 | nn.ReLU(inplace=True), 26 | nn.ReflectionPad2d(1), 27 | nn.Conv2d(in_features, in_features, 3), 28 | nn.InstanceNorm2d(in_features), 29 | ) 30 | 31 | def forward(self, x): 32 | return x + self.block(x) 33 | 34 | 35 | class GeneratorResNet(nn.Module): 36 | def __init__(self, num_colors=3, num_residual_blocks=3): 37 | super(GeneratorResNet, self).__init__() 38 | 39 | # Initial convolution block 40 | out_features = 32 41 | model = [ 42 | nn.ReflectionPad2d(3), 43 | nn.Conv2d(num_colors, out_features, 7), 44 | nn.InstanceNorm2d(out_features), 45 | nn.ReLU(inplace=True), 46 | ] 47 | in_features = out_features 48 | 49 | # Downsampling 50 | for _ in range(2): 51 | out_features *= 2 52 | model += [ 53 | nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), 54 | nn.InstanceNorm2d(out_features), 55 | nn.ReLU(inplace=True), 56 | ] 57 | in_features = out_features 58 | 59 | # Residual blocks 60 | for _ in range(num_residual_blocks): 61 | model += [ResidualBlock(out_features)] 62 | 63 | # Upsampling 64 | for _ in range(2): 65 | out_features //= 2 66 | model += [ 67 | nn.Upsample(scale_factor=2), 68 | nn.Conv2d(in_features, out_features, 3, stride=1, padding=1), 69 | nn.InstanceNorm2d(out_features), 70 | nn.ReLU(inplace=True), 71 | ] 72 | in_features = out_features 73 | 74 | # Output layer 75 | model += [nn.ReflectionPad2d(3), nn.Conv2d(out_features, num_colors, 7), nn.Tanh()] 76 | 77 | self.model = nn.Sequential(*model) 78 | self.model.apply(weights_init_normal) 79 | 80 | def forward(self, x): 81 | return self.model(x) 82 | 83 | 84 | def test(): 85 | net = GeneratorResNet(num_colors=3) 86 | real_img = torch.zeros([10, 3, 32, 32]) 87 | gen_img = net(real_img) 88 | pass 89 | 90 | 91 | if __name__ == '__main__': 92 | test() 93 | -------------------------------------------------------------------------------- /SFIT/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .da_trainer import DATrainer 2 | from .sfit_trainer import SFITTrainer 3 | -------------------------------------------------------------------------------- /SFIT/trainers/da_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from sklearn.metrics import confusion_matrix 7 | from SFIT.loss import * 8 | from SFIT.utils.meters import AverageMeter 9 | 10 | 11 | class DATrainer(object): 12 | def __init__(self, net_D, net_S, net_T, logdir, da_setting, source_LSR=True, test_visda=False): 13 | super(DATrainer, self).__init__() 14 | self.net_D = net_D 15 | self.net_S = net_S 16 | self.net_T = net_T 17 | self.CE_loss = nn.CrossEntropyLoss() 18 | self.H_loss = HLoss() 19 | self.LSR_loss = LabelSmoothLoss() 20 | self.D_loss = nn.BCEWithLogitsLoss() 21 | self.MMD_loss = MMDLoss() 22 | self.da_setting = da_setting 23 | self.source_LSR = source_LSR 24 | self.logdir = logdir 25 | self.test_visda = test_visda 26 | 27 | def train_net_S(self, epoch, dataloader, optimizer, scheduler=None, log_interval=1000): 28 | self.net_S.train() 29 | losses, correct, miss = 0, 0, 0 30 | t0 = time.time() 31 | for batch_idx, (data, target) in enumerate(dataloader): 32 | data, target = data.cuda(), target.cuda() 33 | if data.size(0) == 1: 34 | continue 35 | output = self.net_S(data) 36 | pred = torch.argmax(output, 1) 37 | correct += pred.eq(target).sum().item() 38 | miss += target.shape[0] - pred.eq(target).sum().item() 39 | if self.source_LSR and 'shot' in self.da_setting: 40 | loss = self.LSR_loss(output, target) 41 | else: 42 | loss = self.CE_loss(output, target) 43 | 44 | optimizer.zero_grad() 45 | loss.backward() 46 | optimizer.step() 47 | 48 | losses += loss.item() 49 | if scheduler is not None: 50 | if isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingWarmRestarts): 51 | scheduler.step(epoch - 1 + batch_idx / len(dataloader)) 52 | elif isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR): 53 | scheduler.step() 54 | if (batch_idx + 1) % log_interval == 0: 55 | # print(cyclic_scheduler.last_epoch, optimizer.param_groups[0]['lr']) 56 | t1 = time.time() 57 | t_epoch = t1 - t0 58 | print('Train Epoch: {}, Batch:{}, \tLoss: {:.6f}, Prec: {:.1f}%, Time: {:.3f}'.format( 59 | epoch, (batch_idx + 1), losses / (batch_idx + 1), 100. * correct / (correct + miss), t_epoch)) 60 | 61 | t1 = time.time() 62 | t_epoch = t1 - t0 63 | print('Train Epoch: {}, Batch:{}, \tLoss: {:.6f}, Prec: {:.1f}%, Time: {:.3f}'.format( 64 | epoch, len(dataloader), losses / len(dataloader), 100. * correct / (correct + miss), t_epoch)) 65 | 66 | return losses / len(dataloader), correct / (correct + miss) 67 | 68 | def train_net_T(self, epoch, source_loader, target_loader, optimizer_S, optimizer_D, schedulers=None, 69 | log_interval=1000): 70 | # ----------------- 71 | # Train target model 72 | # ----------------- 73 | self.net_T.train() 74 | # self.net_T.bottleneck.eval() 75 | self.net_T.classifier.eval() 76 | 77 | t0 = time.time() 78 | 79 | loss_c, loss_d = torch.zeros([]).cuda(), torch.zeros([]).cuda() 80 | 81 | len_loaders = min(len(source_loader), len(target_loader)) 82 | zip_loaders = zip(source_loader, target_loader) 83 | for batch_idx, ((src_img, src_label), (tgt_img, _)) in enumerate(zip_loaders): 84 | src_img, src_label = src_img.cuda(), src_label.cuda() 85 | tgt_img = tgt_img.cuda() 86 | 87 | # SHOT loss 88 | if 'shot' in self.da_setting: 89 | output_tgt = self.net_T(tgt_img) 90 | # higher conf -> reduce entropy of each image decision 91 | loss_c = self.H_loss(output_tgt) 92 | # more even distribution among classes -> increase entropy of overall class prob 93 | avg_cls = F.softmax(output_tgt, dim=1).mean(dim=0) 94 | loss_d = (-avg_cls * torch.log(avg_cls)).sum() 95 | loss = loss_c - loss_d 96 | else: 97 | # source domain 98 | output_tgt, tgt_feat = self.net_T(tgt_img, True) 99 | if 'adda' in self.da_setting: 100 | # update D 101 | output_src, src_feat = self.net_S(src_img, True) 102 | src_gt_validity = torch.ones([src_img.shape[0], 1], requires_grad=False).cuda() 103 | tgt_gt_validity = torch.zeros([tgt_img.shape[0], 1], requires_grad=False).cuda() 104 | src_validity = self.net_D(src_feat[-1]) 105 | tgt_validity = self.net_D(tgt_feat[-1]) 106 | loss_d = self.D_loss(src_validity, src_gt_validity) + self.D_loss(tgt_validity, tgt_gt_validity) 107 | optimizer_D.zero_grad() 108 | loss_d.backward() 109 | optimizer_D.step() 110 | # update S 111 | output_tgt, tgt_feat = self.net_T(tgt_img, True) 112 | tgt_validity = self.net_D(tgt_feat[-1]) 113 | loss = self.D_loss(tgt_validity, src_gt_validity) 114 | elif 'mmd' in self.da_setting: 115 | output_src, src_feat = self.net_T(src_img, True) 116 | loss_c = self.CE_loss(output_src, src_label) 117 | loss_d = self.MMD_loss(src_feat[-1], tgt_feat[-1]) 118 | # back-prop 119 | loss = loss_c + loss_d 120 | else: 121 | raise Exception 122 | 123 | optimizer_S.zero_grad() 124 | loss.backward() 125 | optimizer_S.step() 126 | 127 | def adjust(scheduler): 128 | if scheduler is not None: 129 | if isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingWarmRestarts): 130 | scheduler.step(epoch - 1 + batch_idx / len_loaders) 131 | elif isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR): 132 | scheduler.step() 133 | 134 | if isinstance(schedulers, list): 135 | for one_scheduler in schedulers: 136 | adjust(one_scheduler) 137 | else: 138 | adjust(schedulers) 139 | 140 | if (batch_idx + 1) % log_interval == 0: 141 | t1 = time.time() 142 | t_epoch = t1 - t0 143 | print('Train Epoch: {}, Batch:{}, S: [c: {:.3f}, d: {:.3f}], Time: {:.3f}'. 144 | format(epoch, (batch_idx + 1), loss_c.item(), loss_d.item(), t_epoch)) 145 | 146 | t1 = time.time() 147 | t_epoch = t1 - t0 148 | print('Train Epoch: {}, Batch:{}, ' 149 | 'S: [c: {:.3f}, d: {:.3f}], Time: {:.3f}'. 150 | format(epoch, len_loaders, loss_c.item(), loss_d.item(), t_epoch)) 151 | 152 | return loss_c.item(), loss_d.item() 153 | 154 | def test_net_S(self, test_loader): 155 | self.net_S.eval() 156 | losses, correct, miss = 0, 0, 0 157 | t0 = time.time() 158 | all_preds, all_labels = [], [] 159 | for batch_idx, (data, target) in enumerate(test_loader): 160 | data, target = data.cuda(), target.cuda() 161 | with torch.no_grad(): 162 | output = self.net_S(data) 163 | pred = torch.argmax(output, 1) 164 | correct += pred.eq(target).sum().item() 165 | miss += target.shape[0] - pred.eq(target).sum().item() 166 | loss = self.CE_loss(output, target) 167 | losses += loss.item() 168 | all_preds.append(pred.cpu()) 169 | all_labels.append(target.cpu()) 170 | 171 | t1 = time.time() 172 | t_epoch = t1 - t0 173 | print('Test, Loss: {:.6f}, Prec: {:.1f}%, Time: {:.3f}'. 174 | format(losses / (len(test_loader) + 1), 100. * correct / (correct + miss), t_epoch)) 175 | 176 | if self.test_visda: 177 | all_preds = torch.cat(all_preds, dim=0) 178 | all_labels = torch.cat(all_labels, dim=0) 179 | matrix = confusion_matrix(all_labels, all_preds) 180 | acc = matrix.diagonal() / matrix.sum(axis=1) * 100 181 | acc_str = ' '.join([str(np.round(i, 2)) for i in acc]) 182 | print(f'visda per class accuracy\n{acc_str}') 183 | print('visda class-averaged accuracy: {:.1f}%'.format(acc.mean())) 184 | return losses / len(test_loader), acc.mean() 185 | 186 | return losses / len(test_loader), correct / (correct + miss) 187 | 188 | def test_net_T(self, test_loader): 189 | self.net_T.eval() 190 | tgt_C_loss = AverageMeter() 191 | correct = 0 192 | t0 = time.time() 193 | all_preds, all_labels = [], [] 194 | for batch_idx, (img, label) in enumerate(test_loader): 195 | img, label = img.cuda(), label.cuda() 196 | with torch.no_grad(): 197 | output = self.net_T(img) 198 | pred_label = torch.argmax(output, 1) 199 | target_C_loss = self.CE_loss(output, label) 200 | tgt_C_loss.update(target_C_loss.item()) 201 | correct += (pred_label == label).sum().item() 202 | all_preds.append(pred_label.cpu()) 203 | all_labels.append(label.cpu()) 204 | 205 | t1 = time.time() 206 | t_epoch = t1 - t0 207 | print('Test, loss: {:.3f}, prec: {:.1f}%, Time: {:.3f}'. 208 | format(tgt_C_loss.avg, 100. * correct / len(test_loader.dataset), t_epoch)) 209 | 210 | if self.test_visda: 211 | all_preds = torch.cat(all_preds, dim=0) 212 | all_labels = torch.cat(all_labels, dim=0) 213 | matrix = confusion_matrix(all_labels, all_preds) 214 | acc = matrix.diagonal() / matrix.sum(axis=1) * 100 215 | acc_str = ' '.join([str(np.round(i, 2)) for i in acc]) 216 | print(f'visda per class accuracy\n{acc_str}') 217 | print('visda class-averaged accuracy: {:.1f}%'.format(acc.mean())) 218 | return tgt_C_loss.avg, acc.mean() 219 | return tgt_C_loss.avg, 100. * correct / len(test_loader.dataset) 220 | -------------------------------------------------------------------------------- /SFIT/trainers/sfit_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from sklearn.cluster import KMeans 3 | from mpl_toolkits.axes_grid1 import make_axes_locatable 4 | import matplotlib.pyplot as plt 5 | import torchvision.transforms as T 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from torchvision.utils import save_image, make_grid 11 | from sklearn.metrics import confusion_matrix 12 | from SFIT.loss import * 13 | from SFIT.utils.meters import AverageMeter 14 | 15 | 16 | class SFITTrainer(object): 17 | def __init__(self, net_G, net_S, net_T, logdir, opts, test_visda=False): 18 | super(SFITTrainer, self).__init__() 19 | self.net_G = net_G 20 | self.net_S = net_S 21 | self.net_T = net_T 22 | self.denorm = lambda x: (x + 1) / 2 23 | self.CE_loss = nn.CrossEntropyLoss() 24 | self.H_loss = HLoss() 25 | self.LSR_loss = LabelSmoothLoss() 26 | self.logdir = logdir 27 | self.test_visda = test_visda 28 | self.D_loss = nn.MSELoss() 29 | self.MMD_loss = MMDLoss() 30 | self.KD_loss = KDLoss(opts.KD_T) 31 | self.JS_loss = JSDivLoss() 32 | self.bn_loss = nn.MSELoss() 33 | self.cyc_loss = nn.L1Loss() 34 | self.id_loss = nn.L1Loss() 35 | self.content_loss = nn.L1Loss() 36 | self.tv_loss = TotalVariationLoss() 37 | self.batch_loss = BatchSimLoss() 38 | self.pixel_loss = PixelSimLoss() 39 | self.style_loss = StyleLoss() 40 | self.channel_loss = ChannelSimLoss() if opts.use_channel else StyleLoss() 41 | self.channel_loss_1d = ChannelSimLoss1D() 42 | self.opts = opts 43 | 44 | def train_net_G(self, epoch, target_loader, optimizer_G, pretrain=False, scheduler=None, log_interval=1000): 45 | # ----------------- 46 | # Train Generator 47 | # ----------------- 48 | self.net_G.train() 49 | self.net_S.eval() 50 | self.net_T.eval() 51 | 52 | def store_mean_var(module, inputs, output): 53 | # input is a tuple of packed inputs 54 | # output is a Tensor 55 | cur_means.append(inputs[0].mean(dim=[0, 2, 3])) 56 | cur_vars.append(inputs[0].var(dim=[0, 2, 3])) 57 | 58 | stat_means, stat_vars = [], [] 59 | cur_means, cur_vars = [], [] 60 | running_means, running_vars = [], [] 61 | handles = [] 62 | use_BN_loss = False 63 | for layer in self.net_S.modules(): 64 | if isinstance(layer, nn.BatchNorm2d): 65 | stat_means.append(layer.running_mean.clone()) 66 | stat_vars.append(layer.running_var.clone()) 67 | handles.append(layer.register_forward_hook(store_mean_var)) 68 | use_BN_loss = True 69 | 70 | t0 = time.time() 71 | 72 | loss_conf, loss_G_BN, loss_cyc, loss_batch, loss_pixel, loss_content, = \ 73 | torch.zeros([]).cuda(), torch.zeros([]).cuda(), torch.zeros([]).cuda(), \ 74 | torch.zeros([]).cuda(), torch.zeros([]).cuda(), torch.zeros([]).cuda() 75 | 76 | loss_avg_kd, loss_avg_bn, loss_avg_channel, loss_avg_conf = \ 77 | AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 78 | 79 | for batch_idx, (real_tgt_img, _) in enumerate(target_loader): 80 | real_tgt_img = real_tgt_img.cuda() 81 | 82 | # real target 83 | with torch.no_grad(): 84 | if self.opts.content_ratio or pretrain: 85 | output_tgt_S, featmaps_tgt_S = self.net_S(real_tgt_img, out_featmaps=True) 86 | output_tgt_T, featmaps_tgt_T = self.net_T(real_tgt_img, out_featmaps=True) 87 | cur_means, cur_vars = [], [] 88 | # fake source 89 | gen_src_img = self.net_G(real_tgt_img) 90 | output_src_S, featmaps_src_S = self.net_S(gen_src_img, out_featmaps=True) 91 | # loss 92 | # style (BN) 93 | if use_BN_loss: 94 | loss_G_BN = 0 95 | for layer_id in range(len(stat_means)): 96 | if layer_id >= len(running_means): 97 | running_means.append(cur_means[layer_id]) 98 | running_vars.append(cur_vars[layer_id]) 99 | else: 100 | running_means[layer_id] = running_means[layer_id].detach() * (1 - self.opts.mAvrgAlpha) + \ 101 | cur_means[layer_id] * self.opts.mAvrgAlpha 102 | running_vars[layer_id] = running_vars[layer_id].detach() * (1 - self.opts.mAvrgAlpha) + \ 103 | cur_vars[layer_id] * self.opts.mAvrgAlpha 104 | loss_G_BN += self.bn_loss(stat_means[layer_id], running_means[layer_id]) + \ 105 | self.bn_loss(torch.sqrt(stat_vars[layer_id]), torch.sqrt(running_vars[layer_id])) 106 | loss_G_BN = loss_G_BN / len(stat_means) 107 | # style 108 | loss_style = torch.zeros([]).cuda() 109 | for layer_id in range(len(featmaps_tgt_T) - 1): 110 | loss_style += self.style_loss(featmaps_src_S[layer_id], featmaps_tgt_T[layer_id]) 111 | # similarity preserving 112 | loss_channel = self.channel_loss(featmaps_src_S[-2], featmaps_tgt_T[-2]) 113 | # loss_channel = self.channel_loss_1d(featmaps_src_S[-1], featmaps_tgt_T[-1]) 114 | loss_batch = self.batch_loss(featmaps_src_S[-2], featmaps_tgt_T[-2]) 115 | loss_pixel = self.pixel_loss(featmaps_src_S[-2], featmaps_tgt_T[-2]) 116 | # kd 117 | loss_kd = self.KD_loss(output_src_S, output_tgt_T) 118 | # content 119 | if self.opts.content_ratio or pretrain: 120 | loss_content = self.content_loss(featmaps_src_S[-2], featmaps_tgt_S[-2]) 121 | else: 122 | loss_content = torch.zeros([]).cuda() 123 | # others 124 | loss_id = self.id_loss(gen_src_img, real_tgt_img) 125 | loss_tv = self.tv_loss(gen_src_img)[1] 126 | loss_activation = -featmaps_src_S[-2].abs().mean() 127 | # SHOT loss 128 | loss_conf = self.H_loss(output_src_S) 129 | avg_cls = torch.nn.functional.softmax(output_src_S, dim=1).mean(dim=0) 130 | loss_div = -(avg_cls * torch.log(avg_cls)).sum() 131 | # co-training 132 | loss_js = self.JS_loss(output_tgt_T, output_src_S) 133 | loss_G = loss_conf * self.opts.conf_ratio - loss_div * self.opts.div_ratio + loss_js * self.opts.js_ratio + \ 134 | loss_content * self.opts.content_ratio + loss_style * self.opts.style_ratio + loss_channel * self.opts.channel_ratio + \ 135 | loss_id * self.opts.id_ratio + loss_kd * self.opts.kd_ratio + loss_tv * self.opts.tv_ratio + \ 136 | loss_batch * self.opts.batch_ratio + loss_pixel * self.opts.pixel_ratio 137 | if use_BN_loss: 138 | loss_G += loss_G_BN * self.opts.bn_ratio 139 | # first train the G as a transparent filter 140 | if pretrain: 141 | loss_G = loss_id + loss_content + self.KD_loss(output_src_S, output_tgt_S) 142 | 143 | optimizer_G.zero_grad() 144 | loss_G.backward() 145 | optimizer_G.step() 146 | 147 | if scheduler is not None: 148 | if isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingWarmRestarts): 149 | scheduler.step(epoch - 1 + batch_idx / len(target_loader)) 150 | elif isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR): 151 | scheduler.step() 152 | 153 | loss_avg_kd.update(loss_kd.item()) 154 | loss_avg_bn.update(loss_G_BN.item()) 155 | loss_avg_channel.update(loss_channel.item()) 156 | loss_avg_conf.update(loss_conf.item()) 157 | 158 | if (batch_idx + 1) % log_interval == 0 or (batch_idx + 1) == len(target_loader): 159 | # print(alpha) 160 | t1 = time.time() 161 | t_epoch = t1 - t0 162 | print(f'Train Epoch: {epoch}, Batch:{batch_idx + 1}, G: [c: {loss_avg_conf.avg:.3f}, ' 163 | f'bn: {loss_avg_bn.avg:.5f}, channel: {loss_avg_channel.avg:.5f}, kd: {loss_avg_kd.avg:.3f}], ' 164 | f'Time: {t_epoch:.3f}') 165 | self.sample_image(real_tgt_img, fname=self.logdir + f'/imgs/{epoch}.png') 166 | # print(f'pixel sim: {loss_pixel.item()}, batch sim: {loss_batch.item()}') 167 | 168 | # remove forward hooks registered in this epoch 169 | for handle in handles: 170 | handle.remove() 171 | 172 | return loss_avg_kd.avg, loss_avg_channel.avg, loss_avg_conf.avg 173 | 174 | def train_net_T(self, epoch, target_loader, optimizer_T, scheduler=None, use_generator=False, log_interval=1000): 175 | # ----------------- 176 | # Train Target Network 177 | # ----------------- 178 | self.net_G.eval() 179 | self.net_S.eval() 180 | self.net_T.train() 181 | self.net_T.classifier.eval() 182 | 183 | t0 = time.time() 184 | 185 | loss_conf, loss_div = torch.zeros([]).cuda(), torch.zeros([]).cuda() 186 | 187 | for batch_idx, (real_tgt_img, _) in enumerate(target_loader): 188 | real_tgt_img = real_tgt_img.cuda() 189 | 190 | # SHOT loss 191 | output_tgt_T, featmaps_tgt_T = self.net_T(real_tgt_img, out_featmaps=True) 192 | # higher conf -> reduce entropy of each image decision 193 | loss_conf = self.H_loss(output_tgt_T) 194 | # more even distribution among classes -> increase entropy of overall class prob 195 | avg_cls = F.softmax(output_tgt_T, dim=1).mean(dim=0) 196 | loss_div = (-avg_cls * torch.log(avg_cls)).sum() 197 | loss_T = loss_conf - loss_div 198 | 199 | # generator 200 | if use_generator: # and (self.opts.T_batch_ratio or self.opts.T_pixel_ratio or self.opts.js_ratio): 201 | with torch.no_grad(): 202 | gen_src_img = self.net_G(real_tgt_img) 203 | output_src_S, featmaps_src_S = self.net_S(gen_src_img, out_featmaps=True) 204 | # output_src_T, featmaps_src_T = self.net_T(gen_src_img, out_featmaps=True) 205 | # reset SHOT loss 206 | pred_label_T = torch.argmax(output_tgt_T, 1) 207 | pred_label_S = torch.argmax(output_src_S, 1) 208 | valid_idx = pred_label_S == pred_label_T 209 | loss_conf = self.H_loss(output_tgt_T[valid_idx]) 210 | # conf = F.softmax(output_src_S, dim=1).max(dim=1)[0] 211 | # valid_idx = conf > self.opts.confidence_thres 212 | # loss_ce = self.CE_loss(output_tgt_T[valid_idx], pred_label_S[valid_idx]) 213 | loss_T = loss_conf - loss_div # + loss_ce 214 | 215 | loss_batch = self.batch_loss(featmaps_src_S[-2], featmaps_tgt_T[-2]) 216 | loss_pixel = self.pixel_loss(featmaps_src_S[-2], featmaps_tgt_T[-2]) 217 | loss_T += loss_batch * self.opts.T_batch_ratio + loss_pixel * self.opts.T_pixel_ratio 218 | # co-training 219 | loss_js = self.JS_loss(output_tgt_T, output_src_S) 220 | loss_T += loss_js * self.opts.js_ratio 221 | 222 | optimizer_T.zero_grad() 223 | loss_T.backward() 224 | optimizer_T.step() 225 | 226 | if scheduler is not None: 227 | if isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingWarmRestarts): 228 | scheduler.step(epoch - 1 + batch_idx / len(target_loader)) 229 | elif isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR): 230 | scheduler.step() 231 | 232 | if (batch_idx + 1) % log_interval == 0 or (batch_idx + 1) == len(target_loader): 233 | # print(alpha) 234 | t1 = time.time() 235 | t_epoch = t1 - t0 236 | print(f'Train Epoch: {epoch}, Batch:{batch_idx + 1}, S: [c: {loss_conf.item():.3f}, ' 237 | f'd: {loss_div.item():.3f}], Time: {t_epoch:.3f}') 238 | 239 | return loss_conf.item(), loss_div.item() 240 | 241 | def test_net_S(self, test_loader, imgs_type=None): 242 | self.net_S.eval() 243 | losses, correct, miss = 0, 0, 0 244 | t0 = time.time() 245 | all_preds, all_labels = [], [] 246 | for batch_idx, (data, target) in enumerate(test_loader): 247 | data, target = data.cuda(), target.cuda() 248 | with torch.no_grad(): 249 | output = self.net_S(data) 250 | pred = torch.argmax(output, 1) 251 | correct += pred.eq(target).sum().item() 252 | miss += target.shape[0] - pred.eq(target).sum().item() 253 | loss = self.CE_loss(output, target) 254 | losses += loss.item() 255 | all_preds.append(pred.cpu()) 256 | all_labels.append(target.cpu()) 257 | 258 | t1 = time.time() 259 | t_epoch = t1 - t0 260 | print('Test, Loss: {:.6f}, Prec: {:.1f}%, Time: {:.3f}'. 261 | format(losses / (len(test_loader) + 1), 100. * correct / (correct + miss), t_epoch)) 262 | if imgs_type is not None: 263 | data, target = next(iter(test_loader)) 264 | self.sample_image(data.cuda(), self.logdir + f'/imgs/{imgs_type}.png', False) 265 | 266 | if self.test_visda: 267 | all_preds = torch.cat(all_preds, dim=0) 268 | all_labels = torch.cat(all_labels, dim=0) 269 | matrix = confusion_matrix(all_labels, all_preds) 270 | acc = matrix.diagonal() / matrix.sum(axis=1) * 100 271 | acc_str = ' '.join([str(np.round(i, 2)) for i in acc]) 272 | print(f'visda per class accuracy\n{acc_str}') 273 | print('visda class-averaged accuracy: {:.1f}%'.format(acc.mean())) 274 | return losses / len(test_loader), acc.mean() 275 | 276 | return losses / len(test_loader), correct / (correct + miss) 277 | 278 | def test(self, test_loader, epoch=None, use_generator=False, visualize=False): 279 | self.net_G.eval() 280 | self.net_S.eval() 281 | self.net_T.eval() 282 | tgt_C_loss = AverageMeter() 283 | correct = 0 284 | correct_valid, valid = 0, 0 285 | t0 = time.time() 286 | all_preds, all_labels = [], [] 287 | for batch_idx, (img, label) in enumerate(test_loader): 288 | img, label = img.cuda(), label.cuda() 289 | with torch.no_grad(): 290 | if use_generator: 291 | output_T, featmaps_tgt_T = self.net_T(img, out_featmaps=True) 292 | pred_label_S = torch.argmax(output_T, 1) 293 | # fake source 294 | img = self.net_G(img) 295 | output, _ = output_S, featmaps_src_S = self.net_S(img, out_featmaps=True) 296 | # channel (pixelsim preserving) 297 | # conf = F.softmax(output, dim=1).max(dim=1)[0] 298 | # valid_idx = conf > self.opts.confidence_thres 299 | pred_label = pred_label_T = torch.argmax(output_S, 1) 300 | valid_idx = pred_label_T == pred_label_S 301 | else: 302 | output = self.net_T(img) 303 | # conf = F.softmax(output, dim=1).max(dim=1)[0] 304 | # valid_idx = conf > self.opts.confidence_thres 305 | pred_label = torch.argmax(output, 1) 306 | valid_idx = torch.ones_like(pred_label).bool() 307 | 308 | target_C_loss = self.CE_loss(output, label) 309 | 310 | tgt_C_loss.update(target_C_loss.item()) 311 | correct += (pred_label == label).sum().item() 312 | correct_valid += (pred_label[valid_idx] == label[valid_idx]).sum().item() 313 | valid += valid_idx.sum().item() 314 | 315 | all_preds.append(pred_label.cpu()) 316 | all_labels.append(label.cpu()) 317 | 318 | t1 = time.time() 319 | t_epoch = t1 - t0 320 | print(f'Test, loss: {tgt_C_loss.avg:.3f}, prec: {100. * correct / len(test_loader.dataset):.1f}%, ' 321 | f'Time: {t_epoch:.3f}') 322 | print(f'valid ratio: {100. * valid / len(test_loader.dataset):.1f}%, ' 323 | f'precision in valid ones: {100. * correct_valid / (valid + 1e-10):.1f}%') 324 | 325 | if epoch is not None and use_generator: 326 | img, _ = next(iter(test_loader)) 327 | self.sample_image(img.cuda(), fname=self.logdir + f'/imgs/{epoch}.png') 328 | 329 | if visualize: 330 | # 2d 331 | # f_src, f_tgt = featmaps_src_S[-2], featmaps_tgt_T[-2] 332 | # B, C, H, W = f_src.shape 333 | # f_src, f_tgt = f_src[0].view(C, -1), f_tgt[0].view(C, -1) 334 | # 1d 335 | f_src, f_tgt = featmaps_src_S[-1], featmaps_tgt_T[-1] 336 | B, C = f_src.shape 337 | f_src, f_tgt = f_src[0].view(C, 1), f_tgt[0].view(C, 1) 338 | indices = np.argsort(KMeans(4).fit_predict(f_src.cpu().detach())) 339 | f_src, f_tgt = f_src[indices], f_tgt[indices] 340 | A_src, A_tgt = f_src @ f_src.T, f_tgt @ f_tgt.T 341 | loss1 = (F.normalize(A_src, p=2, dim=1) - F.normalize(A_tgt, p=2, dim=1)).cpu().detach() 342 | loss2 = (A_src - A_tgt).cpu().detach() 343 | loss1_max, loss2_max = loss1.abs().max(), loss2.abs().max() 344 | 345 | fig, ax = plt.subplots(figsize=(4.5, 5)) 346 | im = ax.imshow(loss1, vmin=-loss1_max, vmax=loss1_max, cmap='seismic') 347 | divider = make_axes_locatable(ax) 348 | cax = divider.new_vertical(size="5%", pad=0.4, pack_start=True) 349 | fig.add_axes(cax) 350 | fig.colorbar(im, cax=cax, orientation="horizontal") 351 | plt.show() 352 | 353 | fig, ax = plt.subplots(figsize=(4.5, 5)) 354 | im = ax.imshow(loss2, vmin=-loss2_max, vmax=loss2_max, cmap='seismic') 355 | divider = make_axes_locatable(ax) 356 | cax = divider.new_vertical(size="5%", pad=0.4, pack_start=True) 357 | fig.add_axes(cax) 358 | fig.colorbar(im, cax=cax, orientation="horizontal") 359 | plt.show() 360 | 361 | if use_generator: 362 | data = [] 363 | # transform = T.Compose([T.Normalize((-1, -1, -1), (2, 2, 2)), T.ToPILImage(), ]) 364 | if self.test_visda: 365 | # indices = [47521, 27974, 32185, 11317] 366 | indices = [43359, 39475, 28118] 367 | else: 368 | indices = [371, 325, 55] 369 | for idx in indices: 370 | img = test_loader.dataset[idx][0] 371 | data.append(img) 372 | data = torch.stack(data, dim=0).cuda() 373 | self.sample_image(data, self.logdir + f'/visualize.png', True, 1) 374 | 375 | if self.test_visda: 376 | all_preds = torch.cat(all_preds, dim=0) 377 | all_labels = torch.cat(all_labels, dim=0) 378 | matrix = confusion_matrix(all_labels, all_preds) 379 | acc = matrix.diagonal() / matrix.sum(axis=1) * 100 380 | acc_str = ' '.join([str(np.round(i, 2)) for i in acc]) 381 | print(f'visda per class accuracy\n{acc_str}') 382 | print('visda class-averaged accuracy: {:.1f}%'.format(acc.mean())) 383 | return tgt_C_loss.avg, acc.mean() 384 | return tgt_C_loss.avg, 100. * correct / len(test_loader.dataset) 385 | 386 | def sample_image(self, real_tgt, fname, use_generator=True, nrow=8): 387 | if use_generator: 388 | """Saves a generated sample from the test set""" 389 | self.net_G.eval() 390 | with torch.no_grad(): 391 | gen_src = self.net_G(real_tgt) 392 | # Arange images along x-axis 393 | real_T = make_grid(self.denorm(real_tgt), nrow=nrow) 394 | gen_S = make_grid(self.denorm(gen_src), nrow=nrow) 395 | # Arange images along y-axis 396 | image_grid = torch.cat((real_T, gen_S), 2) 397 | else: 398 | image_grid = make_grid(self.denorm(real_tgt), nrow=8) 399 | pass 400 | save_image(image_grid, fname, normalize=False) 401 | 402 | 403 | def imshow(tensor): 404 | image = tensor.cpu().clone() # we clone the tensor to not do changes on it 405 | unloader = T.Compose([T.Normalize((-1, -1, -1), (2, 2, 2)), T.ToPILImage(), ]) 406 | image = unloader(image) 407 | plt.imshow(image) 408 | plt.show() 409 | image.save('tmp.png') 410 | -------------------------------------------------------------------------------- /SFIT/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | class Logger(object): 6 | def __init__(self, fpath=None): 7 | self.console = sys.stdout 8 | self.file = None 9 | if fpath is not None: 10 | os.makedirs(os.path.dirname(fpath), exist_ok=True) 11 | self.file = open(fpath, 'w') 12 | 13 | def __del__(self): 14 | self.close() 15 | 16 | def __enter__(self): 17 | pass 18 | 19 | def __exit__(self, *args): 20 | self.close() 21 | 22 | def write(self, msg): 23 | self.console.write(msg) 24 | if self.file is not None: 25 | self.file.write(msg) 26 | 27 | def flush(self): 28 | self.console.flush() 29 | if self.file is not None: 30 | self.file.flush() 31 | os.fsync(self.file.fileno()) 32 | 33 | def close(self): 34 | self.console.close() 35 | if self.file is not None: 36 | self.file.close() 37 | -------------------------------------------------------------------------------- /SFIT/utils/meters.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.val = 0 6 | self.avg = 0 7 | self.sum = 0 8 | self.count = 0 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | -------------------------------------------------------------------------------- /SFIT/utils/str2bool.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def str2bool(v): 5 | if isinstance(v, bool): 6 | return v 7 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 8 | return True 9 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 10 | return False 11 | else: 12 | raise argparse.ArgumentTypeError('Boolean value expected.') 13 | -------------------------------------------------------------------------------- /gpu012.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python train_DA.py -d digits --source svhn --target mnist & 3 | CUDA_VISIBLE_DEVICES=1 python train_DA.py -d digits --source usps --target mnist & 4 | CUDA_VISIBLE_DEVICES=2 python train_DA.py -d digits --source mnist --target usps & 5 | wait 6 | #CUDA_VISIBLE_DEVICES=0 python train_DA.py -d office31 --source amazon --target webcam & 7 | #CUDA_VISIBLE_DEVICES=1 python train_DA.py -d office31 --source dslr --target webcam & 8 | #CUDA_VISIBLE_DEVICES=2 python train_DA.py -d office31 --source webcam --target dslr & 9 | #wait 10 | #CUDA_VISIBLE_DEVICES=0 python train_DA.py -d office31 --source amazon --target dslr & 11 | #CUDA_VISIBLE_DEVICES=1 python train_DA.py -d office31 --source dslr --target amazon & 12 | #CUDA_VISIBLE_DEVICES=2 python train_DA.py -d office31 --source webcam --target amazon & 13 | #wait 14 | #CUDA_VISIBLE_DEVICES=0 python train_DA.py -d visda & 15 | #CUDA_VISIBLE_DEVICES=1 python train_DA.py -d visda --da_setting mmd & 16 | #CUDA_VISIBLE_DEVICES=2 python train_DA.py -d visda --da_setting adda & 17 | #wait 18 | 19 | 20 | CUDA_VISIBLE_DEVICES=0 python train_SFIT.py -d digits --source svhn --target mnist & 21 | CUDA_VISIBLE_DEVICES=1 python train_SFIT.py -d digits --source usps --target mnist & 22 | CUDA_VISIBLE_DEVICES=2 python train_SFIT.py -d digits --source mnist --target usps & 23 | wait 24 | #CUDA_VISIBLE_DEVICES=0 python train_SFIT.py -d office31 --source amazon --target webcam & 25 | #CUDA_VISIBLE_DEVICES=1 python train_SFIT.py -d office31 --source dslr --target webcam & 26 | #CUDA_VISIBLE_DEVICES=2 python train_SFIT.py -d office31 --source webcam --target dslr & 27 | #wait 28 | #CUDA_VISIBLE_DEVICES=0 python train_SFIT.py -d office31 --source amazon --target dslr & 29 | #CUDA_VISIBLE_DEVICES=1 python train_SFIT.py -d office31 --source dslr --target amazon & 30 | #CUDA_VISIBLE_DEVICES=2 python train_SFIT.py -d office31 --source webcam --target amazon & 31 | #wait 32 | #CUDA_VISIBLE_DEVICES=0 python train_SFIT.py -d visda & 33 | #CUDA_VISIBLE_DEVICES=1 python train_SFIT.py -d visda --da_setting mmd & 34 | #CUDA_VISIBLE_DEVICES=2 python train_SFIT.py -d visda --da_setting adda & 35 | #wait -------------------------------------------------------------------------------- /train_DA.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['OMP_NUM_THREADS'] = '1' 4 | import sys 5 | import shutil 6 | from distutils.dir_util import copy_tree 7 | import datetime 8 | from tqdm import tqdm 9 | import argparse 10 | import numpy as np 11 | import random 12 | import torch 13 | import torch.optim as optim 14 | import torchvision.transforms as T 15 | from torch.utils.data import DataLoader 16 | from SFIT import datasets 17 | from SFIT.models.classifier_shot import ClassifierShot, Discriminator 18 | from SFIT.trainers import DATrainer 19 | from SFIT.utils.str2bool import str2bool 20 | from SFIT.utils.logger import Logger 21 | 22 | 23 | def main(args): 24 | # check if in debug mode 25 | gettrace = getattr(sys, 'gettrace', None) 26 | if gettrace(): 27 | print('Hmm, Big Debugger is watching me') 28 | is_debug = True 29 | else: 30 | print('No sys.gettrace') 31 | is_debug = False 32 | 33 | # seed 34 | if args.seed is not None: 35 | np.random.seed(args.seed) 36 | torch.manual_seed(args.seed) 37 | torch.cuda.manual_seed(args.seed) 38 | random.seed(args.seed) 39 | torch.backends.cudnn.deterministic = True 40 | torch.backends.cudnn.benchmark = False 41 | # torch.backends.cudnn.benchmark = True 42 | else: 43 | torch.backends.cudnn.benchmark = True 44 | 45 | # dataset 46 | data_path = os.path.expanduser(f'~/Data/{args.dataset}') 47 | if args.dataset == 'digits': 48 | n_classes = 10 49 | use_src_test = True 50 | args.batch_size = 64 51 | 52 | if args.source == 'svhn' and args.target == 'mnist': 53 | source_trans = T.Compose([T.Resize(32), T.ToTensor(), T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 54 | target_trans = T.Compose([T.Resize(32), T.Lambda(lambda x: x.convert("RGB")), 55 | T.ToTensor(), T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 56 | source_train_dataset = datasets.SVHN(f'{data_path}/svhn', split='train', download=True, 57 | transform=source_trans) 58 | source_test_dataset = datasets.SVHN(f'{data_path}/svhn', split='test', download=True, 59 | transform=source_trans) 60 | target_train_dataset = datasets.MNIST(f'{data_path}/mnist', train=True, download=True, 61 | transform=target_trans) 62 | target_test_dataset = datasets.MNIST(f'{data_path}/mnist', train=False, download=True, 63 | transform=target_trans) 64 | args.arch = 'dtn' 65 | elif args.source == 'usps' and args.target == 'mnist': 66 | source_trans = T.Compose([T.RandomCrop(28, padding=4), T.RandomRotation(10), 67 | T.ToTensor(), T.Normalize([0.5, ], [0.5, ])]) 68 | target_trans = T.Compose([T.ToTensor(), T.Normalize([0.5, ], [0.5, ])]) 69 | source_train_dataset = datasets.USPS(f'{data_path}/usps', train=True, download=True, transform=source_trans) 70 | source_test_dataset = datasets.USPS(f'{data_path}/usps', train=False, download=True, transform=source_trans) 71 | target_train_dataset = datasets.MNIST(f'{data_path}/mnist', train=True, download=True, 72 | transform=target_trans) 73 | target_test_dataset = datasets.MNIST(f'{data_path}/mnist', train=False, download=True, 74 | transform=target_trans) 75 | args.arch = 'lenet' 76 | elif args.source == 'mnist' and args.target == 'usps': 77 | source_trans = T.Compose([T.ToTensor(), T.Normalize([0.5, ], [0.5, ])]) 78 | target_trans = T.Compose([T.ToTensor(), T.Normalize([0.5, ], [0.5, ])]) 79 | source_train_dataset = datasets.MNIST(f'{data_path}/mnist', train=True, download=True, 80 | transform=source_trans) 81 | source_test_dataset = datasets.MNIST(f'{data_path}/mnist', train=False, download=True, 82 | transform=source_trans) 83 | target_train_dataset = datasets.USPS(f'{data_path}/usps', train=True, download=True, transform=target_trans) 84 | target_test_dataset = datasets.USPS(f'{data_path}/usps', train=False, download=True, transform=target_trans) 85 | args.arch = 'lenet' 86 | else: 87 | raise Exception('digits supports mnist, mnistm, usps, svhn') 88 | elif args.dataset == 'office31': 89 | n_classes = 31 90 | use_src_test = False 91 | args.epochs_S = 100 92 | args.epochs_T = 15 93 | if args.arch is None: args.arch = 'resnet50' 94 | train_trans = T.Compose([T.Resize([256, 256]), T.RandomCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), 95 | T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) 96 | test_trans = T.Compose([T.Resize([256, 256]), T.CenterCrop(224), T.ToTensor(), 97 | T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) 98 | 99 | source_train_dataset = datasets.ImageFolder(f'{data_path}/{args.source}/images', transform=train_trans) 100 | source_test_dataset = datasets.ImageFolder(f'{data_path}/{args.source}/images', transform=train_trans) 101 | target_train_dataset = datasets.ImageFolder(f'{data_path}/{args.target}/images', transform=train_trans) 102 | target_test_dataset = datasets.ImageFolder(f'{data_path}/{args.target}/images', transform=test_trans) 103 | elif args.dataset == 'visda': 104 | n_classes = 12 105 | use_src_test = False 106 | args.lr_D *= 0.1 107 | args.lr_S *= 0.1 108 | args.lr_T *= 0.1 109 | args.epochs_S = 10 110 | args.epochs_T = 5 111 | if args.arch is None: args.arch = 'resnet101' 112 | args.source, args.target = 'syn', 'real' 113 | train_trans = T.Compose([T.Resize([256, 256]), T.RandomCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), 114 | T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) 115 | test_trans = T.Compose([T.Resize([256, 256]), T.CenterCrop(224), T.ToTensor(), 116 | T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) 117 | 118 | source_train_dataset = datasets.ImageFolder(f'{data_path}/train', transform=train_trans) 119 | source_test_dataset = datasets.ImageFolder(f'{data_path}/train', transform=train_trans) 120 | target_train_dataset = datasets.ImageFolder(f'{data_path}/validation', transform=train_trans) 121 | target_test_dataset = datasets.ImageFolder(f'{data_path}/validation', transform=test_trans) 122 | else: 123 | raise Exception('please choose dataset from [digits, office31, visda]') 124 | 125 | if 'shot' in args.da_setting: 126 | args.batch_size = 64 127 | 128 | source_train_loader = DataLoader(source_train_dataset, batch_size=args.batch_size, shuffle=True, 129 | num_workers=args.num_workers, drop_last=True) 130 | source_train_loader_64 = DataLoader(source_train_dataset, batch_size=64, shuffle=True, 131 | num_workers=args.num_workers, drop_last=True) 132 | source_test_loader = DataLoader(source_test_dataset, batch_size=args.batch_size, shuffle=False, 133 | num_workers=args.num_workers) 134 | target_train_loader = DataLoader(target_train_dataset, batch_size=args.batch_size, shuffle=True, 135 | num_workers=args.num_workers, drop_last=True) 136 | target_test_loader = DataLoader(target_test_dataset, batch_size=args.batch_size, shuffle=False, 137 | num_workers=args.num_workers) 138 | 139 | logdir = f'logs/{args.da_setting}/{args.dataset}/s_{args.source}/t_{args.target}/' \ 140 | f'{"debug_" if is_debug else ""}{datetime.datetime.today():%Y-%m-%d_%H-%M-%S}/' 141 | print(logdir) 142 | 143 | # logging 144 | if True: 145 | os.makedirs(logdir + 'imgs', exist_ok=True) 146 | copy_tree('./SFIT', logdir + 'scripts/SFIT') 147 | for script in os.listdir('.'): 148 | if script.split('.')[-1] == 'py': 149 | dst_file = os.path.join(logdir, 'scripts', os.path.basename(script)) 150 | shutil.copyfile(script, dst_file) 151 | sys.stdout = Logger(os.path.join(logdir, 'log.txt'), ) 152 | print('Settings:') 153 | print(vars(args)) 154 | 155 | # model 156 | net_D = Discriminator(args.bottleneck_dim).cuda() 157 | net_S = ClassifierShot(n_classes, args.arch, args.bottleneck_dim, 'shot' in args.da_setting).cuda() 158 | net_T = ClassifierShot(n_classes, args.arch, args.bottleneck_dim, 'shot' in args.da_setting).cuda() 159 | 160 | # optimizers 161 | optimizer_D = optim.SGD(net_D.parameters(), lr=args.lr_D, weight_decay=1e-3, momentum=0.9, nesterov=True) 162 | if 'resnet' not in args.arch: 163 | optimizer_S = optim.SGD(net_S.parameters(), lr=args.lr_S, weight_decay=1e-3, momentum=0.9, nesterov=True) 164 | optimizer_T = optim.SGD(list(net_T.base.parameters()), # + list(net_T.bottleneck.parameters()), 165 | lr=args.lr_T, weight_decay=1e-3, momentum=0.9, nesterov=True) 166 | else: 167 | optimizer_S = optim.SGD([{'params': net_S.base.parameters(), 'lr': args.lr_S * 0.1}, 168 | {'params': net_S.bottleneck.parameters()}, 169 | {'params': net_S.classifier.parameters()}], 170 | lr=args.lr_S, weight_decay=1e-3, momentum=0.9, nesterov=True) 171 | optimizer_T = optim.SGD([{'params': net_T.base.parameters(), 'lr': args.lr_T * 0.1}, ], 172 | # {'params': net_T.bottleneck.parameters()}], 173 | lr=args.lr_T, weight_decay=1e-3, momentum=0.9, nesterov=True) 174 | 175 | # schedulers 176 | scheduler_D = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_D, args.epochs_T, 1) 177 | scheduler_S = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_S, args.epochs_S, 1) 178 | scheduler_T = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_T, args.epochs_T, 1) 179 | 180 | trainer = DATrainer(net_D, net_S, net_T, logdir, args.da_setting, args.source_LSR, 181 | args.dataset == 'visda') 182 | 183 | # source model 184 | net_S_fpath = f'logs/{args.da_setting}/{args.dataset}/s_{args.source}/source_model.pth' 185 | if os.path.exists(net_S_fpath) and not args.force_train_S: 186 | print(f'Loading source model at: {net_S_fpath}...') 187 | net_S.load_state_dict(torch.load(net_S_fpath)) 188 | pass 189 | else: 190 | print('Training source model...') 191 | for epoch in tqdm(range(1, args.epochs_S + 1)): 192 | trainer.train_net_S(epoch, source_train_loader_64, optimizer_S, scheduler_S) 193 | if epoch % (max(args.epochs_S // 10, 1)) == 0: 194 | if use_src_test: 195 | print('Testing source model on [source]...') 196 | trainer.test_net_S(source_test_loader) 197 | print('Testing source model on [target]...') 198 | trainer.test_net_S(target_test_loader) 199 | torch.save(net_S.state_dict(), net_S_fpath) 200 | torch.save(net_S.state_dict(), logdir + 'source_model.pth') 201 | print('Testing source model on [source]...') 202 | trainer.test_net_S(source_test_loader) 203 | print('##############################################################') 204 | print('Testing source model on [target]...') 205 | print('##############################################################') 206 | trainer.test_net_S(target_test_loader) 207 | 208 | # target model & discriminator 209 | net_T_fpath = f'logs/{args.da_setting}/{args.dataset}/s_{args.source}/t_{args.target}/target_model.pth' 210 | print(f'Initialize target model with source model...') 211 | net_T.load_state_dict(net_S.state_dict()) 212 | for epoch in tqdm(range(1, args.epochs_T + 1)): 213 | print('Training target model...') 214 | trainer.train_net_T(epoch, source_train_loader, target_train_loader, optimizer_T, optimizer_D, 215 | [scheduler_T, scheduler_D]) 216 | if use_src_test: 217 | print('Testing target model on [source]...') 218 | trainer.test_net_T(source_test_loader) 219 | print('Testing target model on [target]...') 220 | trainer.test_net_T(target_test_loader) 221 | torch.save(net_T.state_dict(), net_T_fpath) 222 | torch.save(net_T.state_dict(), logdir + 'target_model.pth') 223 | print('##############################################################') 224 | print('Testing target model on [target]...') 225 | print('##############################################################') 226 | trainer.test_net_T(target_test_loader) 227 | 228 | 229 | if __name__ == '__main__': 230 | # settings 231 | parser = argparse.ArgumentParser(description='Train SHOT') 232 | parser.add_argument('-d', '--dataset', type=str, default='digits', choices=['digits', 'office31', 'visda']) 233 | parser.add_argument('--source', type=str) 234 | parser.add_argument('--target', type=str) 235 | parser.add_argument('-a', '--arch', type=str, default=None, 236 | choices=['alexnet', 'vgg16', 'resnet18', 'resnet50', 'digits']) 237 | parser.add_argument('--bottleneck_dim', type=int, default=256) 238 | parser.add_argument('-j', '--num_workers', type=int, default=4) 239 | parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', 240 | help='input batch size for training (default: 64)') 241 | parser.add_argument('--force_train_S', action='store_true', default=False) 242 | # source model 243 | parser.add_argument('--source_LSR', type=str2bool, default=True) 244 | # target model 245 | parser.add_argument('--da_setting', type=str, default='shot', choices=['shot', 'mmd', 'adda']) 246 | parser.add_argument('--epochs_S', type=int, default=30, help='number of epochs to train') 247 | parser.add_argument('--epochs_T', type=int, default=30, help='number of epochs to train') 248 | parser.add_argument('--restart', type=float, default=1) 249 | parser.add_argument('--lr_D', type=float, default=1e-3, help='discriminator learning rate') 250 | parser.add_argument('--lr_S', type=float, default=1e-2, help='target model learning rate') 251 | parser.add_argument('--lr_T', type=float, default=1e-2, help='source model learning rate') 252 | parser.add_argument('--seed', type=int, default=None, help='random seed') 253 | args = parser.parse_args() 254 | main(args) 255 | -------------------------------------------------------------------------------- /train_SFIT.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['OMP_NUM_THREADS'] = '1' 4 | import sys 5 | import shutil 6 | from distutils.dir_util import copy_tree 7 | import datetime 8 | from tqdm import tqdm 9 | import argparse 10 | import numpy as np 11 | import random 12 | import torch 13 | import torch.optim as optim 14 | import torchvision.transforms as T 15 | from torch.utils.data import DataLoader 16 | from SFIT import datasets 17 | from SFIT.models.classifier_shot import ClassifierShot 18 | from SFIT.models.cyclegan import GeneratorResNet 19 | from SFIT.trainers import SFITTrainer 20 | from SFIT.utils.str2bool import str2bool 21 | from SFIT.utils.logger import Logger 22 | 23 | 24 | def main(args): 25 | # check if in debug mode 26 | gettrace = getattr(sys, 'gettrace', None) 27 | if gettrace(): 28 | print('Hmm, Big Debugger is watching me') 29 | is_debug = True 30 | else: 31 | print('No sys.gettrace') 32 | is_debug = False 33 | 34 | # seed 35 | if args.seed is not None: 36 | np.random.seed(args.seed) 37 | torch.manual_seed(args.seed) 38 | torch.cuda.manual_seed(args.seed) 39 | random.seed(args.seed) 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | # torch.backends.cudnn.benchmark = True 43 | else: 44 | torch.backends.cudnn.benchmark = True 45 | 46 | # dataset 47 | num_colors = 3 48 | data_path = os.path.expanduser(f'~/Data/{args.dataset}') 49 | if args.dataset == 'digits': 50 | n_classes = 10 51 | use_src_test = True 52 | args.batch_size = 64 53 | args.id_ratio = 3e-2 54 | args.tv_ratio = 3e-2 55 | 56 | if args.source == 'svhn' and args.target == 'mnist': 57 | source_trans = T.Compose([T.Resize(32), T.ToTensor(), T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 58 | target_trans = T.Compose([T.Resize(32), T.Lambda(lambda x: x.convert("RGB")), 59 | T.ToTensor(), T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 60 | source_test_dataset = datasets.SVHN(f'{data_path}/svhn', split='test', download=True, 61 | transform=source_trans) 62 | target_train_dataset = datasets.MNIST(f'{data_path}/mnist', train=True, download=True, 63 | transform=target_trans) 64 | target_test_dataset = datasets.MNIST(f'{data_path}/mnist', train=False, download=True, 65 | transform=target_trans) 66 | args.arch = 'dtn' 67 | elif args.source == 'usps' and args.target == 'mnist': 68 | source_trans = T.Compose([T.RandomCrop(28, padding=4), T.RandomRotation(10), 69 | T.ToTensor(), T.Normalize([0.5, ], [0.5, ])]) 70 | target_trans = T.Compose([T.ToTensor(), T.Normalize([0.5, ], [0.5, ])]) 71 | source_test_dataset = datasets.USPS(f'{data_path}/usps', train=False, download=True, transform=source_trans) 72 | target_train_dataset = datasets.MNIST(f'{data_path}/mnist', train=True, download=True, 73 | transform=target_trans) 74 | target_test_dataset = datasets.MNIST(f'{data_path}/mnist', train=False, download=True, 75 | transform=target_trans) 76 | args.arch = 'lenet' 77 | num_colors = 1 78 | elif args.source == 'mnist' and args.target == 'usps': 79 | source_trans = T.Compose([T.ToTensor(), T.Normalize([0.5, ], [0.5, ])]) 80 | target_trans = T.Compose([T.ToTensor(), T.Normalize([0.5, ], [0.5, ])]) 81 | source_test_dataset = datasets.MNIST(f'{data_path}/mnist', train=False, download=True, 82 | transform=source_trans) 83 | target_train_dataset = datasets.USPS(f'{data_path}/usps', train=True, download=True, transform=target_trans) 84 | target_test_dataset = datasets.USPS(f'{data_path}/usps', train=False, download=True, transform=target_trans) 85 | args.arch = 'lenet' 86 | num_colors = 1 87 | else: 88 | raise Exception('digits supports mnist, mnistm, usps, svhn') 89 | elif args.dataset == 'office31': 90 | n_classes = 31 91 | use_src_test = False 92 | args.epochs_T = 15 93 | args.G_wait = 50 94 | args.epochs_G = 50 95 | if args.arch is None: args.arch = 'resnet50' 96 | train_trans = T.Compose([T.Resize([256, 256]), T.RandomCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), 97 | T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) 98 | test_trans = T.Compose([T.Resize([256, 256]), T.CenterCrop(224), T.ToTensor(), 99 | T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) 100 | 101 | source_test_dataset = datasets.ImageFolder(f'{data_path}/{args.source}/images', transform=train_trans) 102 | target_train_dataset = datasets.ImageFolder(f'{data_path}/{args.target}/images', transform=train_trans) 103 | target_test_dataset = datasets.ImageFolder(f'{data_path}/{args.target}/images', transform=test_trans) 104 | elif args.dataset == 'visda': 105 | n_classes = 12 106 | use_src_test = False 107 | args.lr_T *= 0.1 108 | args.epochs_T = 5 109 | args.G_wait = 5 110 | args.epochs_G = 20 111 | if args.arch is None: args.arch = 'resnet101' 112 | args.source, args.target = 'syn', 'real' 113 | train_trans = T.Compose([T.Resize([256, 256]), T.RandomCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), 114 | T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) 115 | test_trans = T.Compose([T.Resize([256, 256]), T.CenterCrop(224), T.ToTensor(), 116 | T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) 117 | 118 | source_test_dataset = datasets.ImageFolder(f'{data_path}/train', transform=train_trans) 119 | target_train_dataset = datasets.ImageFolder(f'{data_path}/validation', transform=train_trans) 120 | target_test_dataset = datasets.ImageFolder(f'{data_path}/validation', transform=test_trans) 121 | else: 122 | raise Exception('please choose dataset from [digits, office31, visda]') 123 | 124 | source_test_loader = DataLoader(source_test_dataset, batch_size=64, shuffle=True, 125 | num_workers=args.num_workers) 126 | target_train_loader = DataLoader(target_train_dataset, batch_size=args.batch_size, shuffle=True, 127 | num_workers=args.num_workers, drop_last=True) 128 | target_train_loader_32 = DataLoader(target_train_dataset, batch_size=32, shuffle=True, 129 | num_workers=args.num_workers, drop_last=True) 130 | target_test_loader = DataLoader(target_test_dataset, batch_size=64, shuffle=True, 131 | num_workers=args.num_workers) 132 | 133 | args.force_pretrain_G = args.force_pretrain_G or not os.path.exists( 134 | f'logs/SFIT/{args.dataset}/s_{args.source}/t_{args.target}/model_G_transparent.pth') 135 | if args.resume: 136 | splits = args.resume.split('_') 137 | args.da_setting = f'{splits[0]}' 138 | fname = f'{"debug_" if is_debug else ""}{args.da_setting}_R' 139 | # args.force_pretrain_G, args.train_G = False, False 140 | else: 141 | fname = f'{"debug_" if is_debug else ""}{args.da_setting}_' 142 | if args.force_pretrain_G: 143 | fname += 'G0' 144 | if args.train_G: 145 | fname += 'G' 146 | if args.retrain_T: 147 | fname += 'T' 148 | 149 | logdir = f'logs/SFIT/{args.dataset}/s_{args.source}/t_{args.target}/{fname}' \ 150 | f'_conf{args.conf_ratio}_bn{args.bn_ratio}_channel{args.channel_ratio}_content{args.content_ratio}_' \ 151 | f'kd{args.kd_ratio}_{datetime.datetime.today():%Y-%m-%d_%H-%M-%S}/' 152 | print(logdir) 153 | 154 | # logging 155 | if True: 156 | os.makedirs(logdir + 'imgs', exist_ok=True) 157 | copy_tree('./SFIT', logdir + 'scripts/SFIT') 158 | for script in os.listdir('.'): 159 | if script.split('.')[-1] == 'py': 160 | dst_file = os.path.join(logdir, 'scripts', os.path.basename(script)) 161 | shutil.copyfile(script, dst_file) 162 | sys.stdout = Logger(os.path.join(logdir, 'log.txt'), ) 163 | print('Settings:') 164 | print(vars(args)) 165 | 166 | # model 167 | net_G = GeneratorResNet(num_colors=num_colors).cuda() 168 | net_S = ClassifierShot(n_classes, args.arch, args.bottleneck_dim, 'shot' in args.da_setting).cuda() 169 | net_T = ClassifierShot(n_classes, args.arch, args.bottleneck_dim, 'shot' in args.da_setting).cuda() 170 | 171 | optimizer_G = optim.Adam(net_G.parameters(), lr=args.lr_G) 172 | 173 | trainer = SFITTrainer(net_G, net_S, net_T, logdir, args, test_visda=args.dataset == 'visda') 174 | 175 | # source network 176 | fpath = f'logs/{args.da_setting}/{args.dataset}/s_{args.source}/source_model.pth' 177 | if os.path.exists(fpath): 178 | print(f'Loading source network at: {fpath}...') 179 | net_S.load_state_dict(torch.load(fpath)) 180 | pass 181 | else: 182 | raise Exception 183 | print('Testing source network on [source]...') 184 | trainer.test_net_S(source_test_loader, 'src') 185 | print('##############################################################') 186 | print('Testing source network on [target]...') 187 | print('##############################################################') 188 | trainer.test_net_S(target_test_loader, 'tgt') 189 | 190 | # target network 191 | fpath = f'logs/{args.da_setting}/{args.dataset}/s_{args.source}/t_{args.target}/target_model.pth' 192 | if os.path.exists(fpath): 193 | print(f'Loading pre-trained target model at: {fpath}...') 194 | net_T.load_state_dict(torch.load(fpath)) 195 | else: 196 | raise Exception 197 | print('##############################################################') 198 | print('Testing target model on [target]...') 199 | print('##############################################################') 200 | trainer.test(target_test_loader) 201 | 202 | # pre-train generator 203 | fpath = f'logs/SFIT/{args.dataset}/s_{args.source}/t_{args.target}/model_G_transparent.pth' 204 | if not args.force_pretrain_G: 205 | print(f'Load pre-trained Generator at: {fpath}') 206 | net_G.load_state_dict(torch.load(fpath)) 207 | elif args.train_G: 208 | scheduler_G = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_G, args.G_wait, 1) 209 | for epoch in tqdm(range(1, args.G_wait + 1)): 210 | print('Pre-training Generator...') 211 | trainer.train_net_G(epoch, target_train_loader, optimizer_G, pretrain=True, scheduler=scheduler_G) 212 | print('Testing Generator on [target]...') 213 | trainer.test(target_test_loader, epoch, use_generator=True) 214 | torch.save(net_G.state_dict(), fpath) 215 | else: 216 | print('skip pre-training Generator') 217 | pass 218 | print('##############################################################') 219 | print('Testing pre-trained Generator on [target]...') 220 | print('##############################################################') 221 | trainer.test(target_test_loader, use_generator=True) 222 | 223 | # generator 224 | if args.resume: 225 | fpath = f'logs/SFIT/{args.dataset}/s_{args.source}/t_{args.target}/{args.resume}/model_G.pth' 226 | print(f'Load trained Generator at: {fpath}') 227 | net_G.load_state_dict(torch.load(fpath)) 228 | elif args.train_G: 229 | scheduler_G = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_G, args.epochs_G, 1) 230 | for epoch in tqdm(range(1, args.epochs_G + 1)): 231 | print('Training Generator...') 232 | trainer.train_net_G(epoch, target_train_loader, optimizer_G, scheduler=scheduler_G) 233 | print('Testing Generator on [target]...') 234 | trainer.test(target_test_loader, epoch, use_generator=True) 235 | torch.save(net_G.state_dict(), os.path.join(logdir, 'model_G.pth')) 236 | else: 237 | print('skip training Generator') 238 | pass 239 | print('##############################################################') 240 | print('Testing Generator on [target]...') 241 | print('##############################################################') 242 | trainer.test(target_test_loader, use_generator=True) 243 | 244 | # retrain target network 245 | if args.retrain_T: 246 | args.lr_T *= 0.5 247 | if 'resnet' not in args.arch: 248 | optimizer_T = optim.SGD(list(net_T.base.parameters()) + list(net_T.bottleneck.parameters()), 249 | lr=args.lr_T, weight_decay=1e-3, momentum=0.9, nesterov=True) 250 | else: 251 | optimizer_T = optim.SGD([{'params': net_T.base.parameters(), 'lr': args.lr_T * 0.1}, 252 | {'params': net_T.bottleneck.parameters()}], 253 | lr=args.lr_T, weight_decay=1e-3, momentum=0.9, nesterov=True) 254 | scheduler_T = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_T, args.epochs_T, 1) 255 | for epoch in tqdm(range(1, args.epochs_T + 1)): 256 | print('Re-training target model...') 257 | trainer.train_net_T(epoch, target_train_loader_32, optimizer_T, scheduler_T, use_generator=args.train_G) 258 | if use_src_test: 259 | print('Testing re-trained target model on [source]...') 260 | trainer.test(source_test_loader) 261 | print('Testing re-trained target model on [target]...') 262 | trainer.test(target_test_loader) 263 | torch.save(net_T.state_dict(), os.path.join(logdir, 'target_model_retrain.pth')) 264 | print('##############################################################') 265 | print('Testing re-trained target model on [target]...') 266 | print('##############################################################') 267 | trainer.test(target_test_loader) 268 | else: 269 | print('skip re-training target model') 270 | pass 271 | 272 | 273 | if __name__ == '__main__': 274 | # settings 275 | parser = argparse.ArgumentParser(description='Train SFIT') 276 | parser.add_argument('-d', '--dataset', type=str, default='digits', choices=['digits', 'office31', 'visda']) 277 | parser.add_argument('--source', type=str) 278 | parser.add_argument('--target', type=str) 279 | parser.add_argument('-a', '--arch', type=str, default=None, 280 | choices=['alexnet', 'vgg16', 'resnet18', 'resnet50', 'digits']) 281 | parser.add_argument('--bottleneck_dim', type=int, default=256) 282 | parser.add_argument('-j', '--num_workers', type=int, default=4) 283 | parser.add_argument('-b', '--batch-size', type=int, default=16, metavar='N', 284 | help='input batch size for training (default: 64)') 285 | parser.add_argument('--da_setting', type=str, default='shot', choices=['shot', 'mmd', 'adda']) 286 | parser.add_argument('--force_pretrain_G', default=False, action='store_true') 287 | parser.add_argument('--train_G', type=str2bool, default=True) 288 | parser.add_argument('--resume', type=str, default=None) 289 | parser.add_argument('--retrain_T', default=False, action='store_true') 290 | # source model 291 | parser.add_argument('--source_LSR', type=str2bool, default=True) 292 | # generator 293 | parser.add_argument('--mAvrgAlpha', type=float, default=1) 294 | parser.add_argument('--a_ratio', type=float, default=0) 295 | parser.add_argument('--conf_ratio', type=float, default=0) 296 | parser.add_argument('--div_ratio', type=float, default=0) 297 | parser.add_argument('--js_ratio', type=float, default=0) 298 | parser.add_argument('--bn_ratio', type=float, default=0) 299 | parser.add_argument('--style_ratio', type=float, default=0) 300 | parser.add_argument('--channel_ratio', type=float, default=1) 301 | parser.add_argument('--content_ratio', type=float, default=0) 302 | parser.add_argument('--id_ratio', type=float, default=0) 303 | parser.add_argument('--kd_ratio', type=float, default=1) 304 | parser.add_argument('--pixel_ratio', type=float, default=0) 305 | parser.add_argument('--batch_ratio', type=float, default=0) 306 | parser.add_argument('--tv_ratio', type=float, default=0) 307 | parser.add_argument('--use_channel', type=str2bool, default=True) 308 | parser.add_argument('--KD_T', type=float, default=1, 309 | help='>1 to smooth probabilities in divergence loss, or <1 to sharpen them') 310 | # target model 311 | parser.add_argument('--thres_confidence', type=float, default=0.95) 312 | parser.add_argument('--T_pixel_ratio', type=float, default=0) 313 | parser.add_argument('--T_batch_ratio', type=float, default=0) 314 | parser.add_argument('--G_wait', type=int, default=10) 315 | parser.add_argument('--epochs_G', type=int, default=30, help='number of epochs to train') 316 | parser.add_argument('--epochs_T', type=int, default=30, help='number of epochs to train') 317 | parser.add_argument('--restart', type=int, default=1) 318 | parser.add_argument('--lr_G', type=float, default=3e-4, help='generator learning rate') 319 | parser.add_argument('--lr_T', type=float, default=1e-2, help='target model learning rate') 320 | parser.add_argument('--seed', type=int, default=None, help='random seed') 321 | args = parser.parse_args() 322 | main(args) 323 | --------------------------------------------------------------------------------