├── LICENSE ├── README.md ├── _config.yml ├── docs ├── README.md ├── _config.yml ├── method.png └── results.png └── src ├── cmd_options.txt ├── data_manager ├── dataset_read.py ├── datasets.py ├── mnist.py ├── svhn.py └── unaligned_data_loader.py ├── main4.py ├── model ├── build_gen.py ├── svhn2mnist.py ├── syn2gtrsb.py └── usps.py ├── solver.py └── utils └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 seqam-lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **This is the project page for Unsupervised Visual Domain Adaptation:A Deep Max-Margin Gaussian Process Approach. 2 | The work was accepted by CVPR 2019 Oral.** 3 | [[Paper Link]](http://openaccess.thecvf.com/content_CVPR_2019/html/Kim_Unsupervised_Visual_Domain_Adaptation_A_Deep_Max-Margin_Gaussian_Process_Approach_CVPR_2019_paper.html). 4 |
5 | 6 | 7 | ## Citation 8 | If you use this code for your research, please cite our papers (This will be updated when cvpr paper is publicized). 9 | ``` 10 | @article{kim2019unsupervised, 11 | title={Unsupervised Visual Domain Adaptation: A Deep Max-Margin Gaussian Process Approach}, 12 | author={Kim, Minyoung and Sahu, Pritish and Gholami, Behnam and Pavlovic, Vladimir}, 13 | journal={arXiv preprint arXiv:1902.08727}, 14 | year={2019} 15 | } 16 | ``` 17 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-slate 2 | title: Unsupervised Visual Domain Adaptation:A Deep Max-Margin Gaussian Process Approach 3 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | **This is the project page for Unsupervised Visual Domain Adaptation:A Deep Max-Margin Gaussian Process Approach. 2 | The work was accepted by CVPR 2019 Oral.** 3 | [[Paper Link]](http://openaccess.thecvf.com/content_CVPR_2019/html/Kim_Unsupervised_Visual_Domain_Adaptation_A_Deep_Max-Margin_Gaussian_Process_Approach_CVPR_2019_paper.html)[[Youtube Link]](https://youtu.be/OYbiWSM0u8U). 4 |
5 | 6 | ## Abstract 7 | In unsupervised domain adaptation, it is widely known that the target domain error can be provably reduced by having 8 | a shared input representation that makes the source and target domains indistinguishable from each other. Very recently it 9 | has been studied that not just matching the marginal input distributions, but the alignment of output (class) distributions is 10 | also critical. The latter can be achieved by minimizing the maximum discrepancy of predictors (classifiers). In this paper, 11 | we adopt this principle, but propose a more systematic and effective way to achieve hypothesis consistency via Gaussian 12 | processes (GP). The GP allows us to define/induce a hypothesis space of the classifiers from the posterior distribution of the 13 | latent random functions, turning the learning into a simple large-margin posterior separation problem, far easier to solve 14 | than previous approaches based on adversarial minimax optimization. We formulate a learning objective that effectively 15 | pushes the posterior to minimize the maximum discrepancy. This is further shown to be equivalent to maximizing margins 16 | and minimizing uncertainty of the class predictions in the target domain, a well-established principle in classical (semi- 17 | )supervised learning. Empirical results demonstrate that our approach is comparable or superior to the existing methods on 18 | several benchmark domain adaptation datasets. 19 | 20 | ![Method](method.png) 21 |
22 | 23 | ## Results 24 | ![Results](results.png) 25 |
26 | 27 | ## Codes 28 | [[Classification]](https://github.com/seqam-lab/GPDA/tree/master/src) 29 | 30 | ## Citation 31 | If you use this code for your research, please cite our papers (This will be updated when cvpr paper is publicized). 32 | ``` 33 | @article{kim2019unsupervised, 34 | title={Unsupervised Visual Domain Adaptation: A Deep Max-Margin Gaussian Process Approach}, 35 | author={Kim, Minyoung and Sahu, Pritish and Gholami, Behnam and Pavlovic, Vladimir}, 36 | journal={arXiv preprint arXiv:1902.08727}, 37 | year={2019} 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-slate 2 | title: Unsupervised Visual Domain Adaptation:A Deep Max-Margin Gaussian Process Approach 3 | -------------------------------------------------------------------------------- /docs/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seqam-lab/GPDA/1c7b2462f41b8eeb905f0909ff5f59fd0ba94e48/docs/method.png -------------------------------------------------------------------------------- /docs/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seqam-lab/GPDA/1c7b2462f41b8eeb905f0909ff5f59fd0ba94e48/docs/results.png -------------------------------------------------------------------------------- /src/cmd_options.txt: -------------------------------------------------------------------------------- 1 | --num_k 3 --num_kq 3 --lamb_marg_loss 10.0 --max_epoch 2000 --save_model --save_epoch 10 --fix_randomness -------------------------------------------------------------------------------- /src/data_manager/dataset_read.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from data_manager.unaligned_data_loader import UnalignedDataLoader 4 | from data_manager.svhn import load_svhn 5 | from data_manager.mnist import load_mnist 6 | #from datasets.usps import load_usps 7 | #from datasets.gtsrb import load_gtsrb 8 | #from datasets.synth_traffic import load_syntraffic 9 | 10 | ############################################################################### 11 | 12 | def return_dataset(data, scale=False, usps=False, all_use='no'): 13 | 14 | ''' 15 | load a specified dataset 16 | 17 | input: 18 | data = dataset to load (eg, 'svhn', 'mnist'); string 19 | scale = whether to scale up images to (32 x 32) or not (28 x 28) 20 | usps, all_use = whether of not to take subsamples from traning set 21 | 22 | output: 23 | train_image = train images; (ntr x C x H x W) 24 | train_label = {0...9}-valued train labels; ntr-dim 25 | test_image = test images; (nte x C x H x W) 26 | test_label = {0...9}-valued test labels; nte-dim 27 | ''' 28 | 29 | if data == 'svhn': 30 | train_image, train_label, test_image, test_label = load_svhn() 31 | 32 | if data == 'mnist': 33 | train_image, train_label, test_image, test_label = \ 34 | load_mnist( scale=scale, usps=usps, all_use=all_use ) 35 | sys.stdout.write('mnist image shape = '); print(train_image.shape) 36 | 37 | # if data == 'usps': 38 | # train_image, train_label, test_image, test_label = \ 39 | # load_usps(all_use=all_use) 40 | # 41 | # if data == 'synth': 42 | # train_image, train_label, test_image, test_label = \ 43 | # load_syntraffic() 44 | # 45 | # if data == 'gtsrb': 46 | # train_image, train_label, test_image, test_label = load_gtsrb() 47 | 48 | return train_image, train_label, test_image, test_label 49 | 50 | ############################################################################### 51 | 52 | def dataset_read( source, target, batch_size, scale=False, all_use='no' ): 53 | 54 | if source == 'usps' or target == 'usps': 55 | usps = True 56 | else: 57 | usps = False 58 | 59 | S = {}; S_test = {} 60 | T = {}; T_test = {} 61 | 62 | # read source data 63 | train_source, s_label_train, test_source, s_label_test = \ 64 | return_dataset( source, scale=scale, usps=usps, all_use=all_use ) 65 | 66 | # read target data 67 | train_target, t_label_train, test_target, t_label_test = \ 68 | return_dataset( target, scale=scale, usps=usps, all_use=all_use ) 69 | 70 | # prepare source/target data 71 | S['imgs'] = train_source 72 | S['labels'] = s_label_train 73 | T['imgs'] = train_target 74 | T['labels'] = t_label_train 75 | 76 | # test samples for source/target 77 | S_test['imgs'] = test_target 78 | S_test['labels'] = t_label_test 79 | T_test['imgs'] = test_target 80 | T_test['labels'] = t_label_test 81 | 82 | scale = 40 if source == 'synth' else 28 if usps else 32 83 | 84 | # (train) do some image transform and create a minibatch generator 85 | train_loader = UnalignedDataLoader() 86 | train_loader.initialize(S, T, batch_size, batch_size, scale=scale) 87 | dataset = train_loader.load_data() 88 | 89 | # (test) do some image transform and create a minibatch generator 90 | test_loader = UnalignedDataLoader() 91 | test_loader.initialize(S_test, T_test, batch_size, batch_size, scale=scale) 92 | dataset_test = test_loader.load_data() 93 | 94 | return dataset, dataset_test 95 | -------------------------------------------------------------------------------- /src/data_manager/datasets.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import numpy as np 4 | 5 | ############################################################################### 6 | 7 | class Dataset(data.Dataset): 8 | 9 | def __init__(self, data, label, 10 | transform=None,target_transform=None): 11 | self.transform = transform 12 | self.target_transform = target_transform 13 | self.data = data 14 | self.labels = label 15 | 16 | def __getitem__(self, index): 17 | 18 | img, target = self.data[index], self.labels[index] 19 | # doing this so that it is consistent with all other datasets 20 | # to return a PIL Image 21 | # print(img.shape) 22 | if img.shape[0] != 1: 23 | #print(img) 24 | img = Image.fromarray( 25 | np.uint8(np.asarray(img.transpose((1, 2, 0)))) ) 26 | # 27 | elif img.shape[0] == 1: 28 | im = np.uint8(np.asarray(img)) 29 | # print(np.vstack([im,im,im]).shape) 30 | im = np.vstack([im, im, im]).transpose((1, 2, 0)) 31 | img = Image.fromarray(im) 32 | 33 | if self.target_transform is not None: 34 | target = self.target_transform(target) 35 | if self.transform is not None: 36 | img = self.transform(img) 37 | # return img, target 38 | return img, target 39 | def __len__(self): 40 | return len(self.data) 41 | -------------------------------------------------------------------------------- /src/data_manager/mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io import loadmat 3 | 4 | ############################################################################### 5 | 6 | def load_mnist(scale=True, usps=False, all_use='no'): 7 | 8 | ''' 9 | load mnist dataset 10 | 11 | input: 12 | scale = whether to scale up images to (32 x 32) or not (28 x 28) 13 | if scale==True, also duplicate channels to (32 x 32 x 3) 14 | usps, all_use = whether of not to take subsamples from traning set 15 | use 2000 random subsamples from training if usps==True & all_use='no' 16 | use ALL training samples otherwise 17 | 18 | output: 19 | mnist_train = training images; 20 | (55000 x 3 x 32 x 32) or (55000 x 1 x 28 x 28) 21 | train_label = {0...9}-valued training labels; 55000-dim 22 | mnist_test = test images; 23 | (10000 x 3 x 32 x 32) or (10000 x 1 x 28 x 28) 24 | test_label = = {0...9}-valued training labels; 10000-dim 25 | ''' 26 | 27 | mnist_data = loadmat('data/mnist/mnist_data.mat') 28 | # load the following dict composed of: 29 | # mnist_data['train_32', 'test_32'] = (n x 32 x 32) 30 | # mnist_data['train_28', 'test_28'] = (n x 28 x 28 x 1) 31 | # mnist_data['label_train', 'label_test'] = (n x 10) one-hot 32 | 33 | if scale: # scale up and channel-duplicate images to (32 x 32 x 3) 34 | 35 | mnist_train = np.reshape(mnist_data['train_32'], (55000, 32, 32, 1)) 36 | mnist_test = np.reshape(mnist_data['test_32'], (10000, 32, 32, 1)) 37 | 38 | # duplicate channels 39 | mnist_train = np.concatenate( 40 | [mnist_train, mnist_train, mnist_train], 3 ) 41 | mnist_test = np.concatenate( 42 | [mnist_test, mnist_test, mnist_test], 3 ) 43 | 44 | # reshape to (n x C x H x W) format 45 | mnist_train = mnist_train.transpose(0, 3, 1, 2).astype(np.float32) 46 | mnist_test = mnist_test.transpose(0, 3, 1, 2).astype(np.float32) 47 | 48 | else: # use original (28 x 28 x 1) 49 | 50 | mnist_train = mnist_data['train_28'] 51 | mnist_test = mnist_data['test_28'] 52 | 53 | # reshape to (n x C x H x W) format 54 | mnist_train = mnist_train.transpose((0, 3, 1, 2)).astype(np.float32) 55 | mnist_test = mnist_test.transpose((0, 3, 1, 2)).astype(np.float32) 56 | 57 | # labels in one-hot format 58 | mnist_labels_train = mnist_data['label_train'] 59 | mnist_labels_test = mnist_data['label_test'] 60 | 61 | # convert one-hot to 0~9 labels 62 | train_label = np.argmax(mnist_labels_train, axis=1) 63 | test_label = np.argmax(mnist_labels_test, axis=1) 64 | 65 | # randomly shuffle training data 66 | inds = np.random.permutation(mnist_train.shape[0]) 67 | mnist_train = mnist_train[inds] 68 | train_label = train_label[inds] 69 | 70 | # subsample training images 71 | if usps and all_use != 'yes': 72 | mnist_train = mnist_train[:2000] 73 | train_label = train_label[:2000] 74 | 75 | return mnist_train, train_label, mnist_test, test_label 76 | -------------------------------------------------------------------------------- /src/data_manager/svhn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io import loadmat 3 | 4 | from utils.utils import convert_label_10_to_0 5 | 6 | ############################################################################### 7 | 8 | def load_svhn(): 9 | 10 | ''' 11 | load svhn dataset 12 | 13 | input: N/A 14 | 15 | output: 16 | svhn_train_im = training images; (73257 x 3 x 32 x 32) 17 | svhn_label = {0...9}-valued training labels; 73257-dim 18 | svhn_test_im = test images; (26032 x 3 x 32 x 32) 19 | svhn_label_test = {0...9}-valued test labels; 26032-dim 20 | ''' 21 | 22 | svhn_train = loadmat('data/svhn/train_32x32.mat') 23 | svhn_test = loadmat('data/svhn/test_32x32.mat') 24 | svhn_train_im = svhn_train['X'] 25 | svhn_train_im = svhn_train_im.transpose(3, 2, 0, 1).astype(np.float32) 26 | svhn_label = convert_label_10_to_0(svhn_train['y']) 27 | svhn_test_im = svhn_test['X'] 28 | svhn_test_im = svhn_test_im.transpose(3, 2, 0, 1).astype(np.float32) 29 | svhn_label_test = convert_label_10_to_0(svhn_test['y']) 30 | 31 | return svhn_train_im, svhn_label, svhn_test_im, svhn_label_test 32 | -------------------------------------------------------------------------------- /src/data_manager/unaligned_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | import torchnet as tnt 3 | from builtins import object 4 | import torchvision.transforms as transforms 5 | 6 | from data_manager.datasets import Dataset 7 | 8 | ############################################################################### 9 | 10 | class PairedData(object): 11 | 12 | def __init__(self, data_loader_A, data_loader_B, max_dataset_size): 13 | 14 | self.data_loader_A = data_loader_A 15 | self.data_loader_B = data_loader_B 16 | self.stop_A = False 17 | self.stop_B = False 18 | self.max_dataset_size = max_dataset_size 19 | 20 | def __iter__(self): 21 | 22 | self.stop_A = False 23 | self.stop_B = False 24 | self.data_loader_A_iter = iter(self.data_loader_A) 25 | self.data_loader_B_iter = iter(self.data_loader_B) 26 | self.iter = 0 27 | return self 28 | 29 | def __next__(self): 30 | 31 | A, A_paths = None, None 32 | B, B_paths = None, None 33 | try: 34 | A, A_paths = next(self.data_loader_A_iter) 35 | except StopIteration: 36 | if A is None or A_paths is None: 37 | self.stop_A = True 38 | self.data_loader_A_iter = iter(self.data_loader_A) 39 | A, A_paths = next(self.data_loader_A_iter) 40 | 41 | try: 42 | B, B_paths = next(self.data_loader_B_iter) 43 | except StopIteration: 44 | if B is None or B_paths is None: 45 | self.stop_B = True 46 | self.data_loader_B_iter = iter(self.data_loader_B) 47 | B, B_paths = next(self.data_loader_B_iter) 48 | 49 | if (self.stop_A and self.stop_B) or self.iter>self.max_dataset_size: 50 | self.stop_A = False 51 | self.stop_B = False 52 | raise StopIteration() 53 | else: 54 | self.iter += 1 55 | return {'S': A, 'S_label': A_paths, 56 | 'T': B, 'T_label': B_paths} 57 | 58 | ############################################################################### 59 | 60 | class UnalignedDataLoader(): 61 | 62 | def initialize(self, source, target, batch_size1, batch_size2, scale=32): 63 | 64 | transform = transforms.Compose([ 65 | transforms.Scale(scale), 66 | transforms.ToTensor(), 67 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 68 | ]) 69 | 70 | dataset_source = Dataset( 71 | source['imgs'], source['labels'], transform=transform ) 72 | dataset_target = Dataset( 73 | target['imgs'], target['labels'], transform=transform ) 74 | 75 | data_loader_s = torch.utils.data.DataLoader( 76 | dataset_source, 77 | batch_size=batch_size1, 78 | shuffle=True, 79 | num_workers=4 ) 80 | 81 | data_loader_t = torch.utils.data.DataLoader( 82 | dataset_target, 83 | batch_size=batch_size2, 84 | shuffle=True, 85 | num_workers=4 ) 86 | 87 | self.dataset_s = dataset_source 88 | self.dataset_t = dataset_target 89 | self.paired_data = PairedData( 90 | data_loader_s, data_loader_t, float("inf") ) 91 | 92 | def name(self): 93 | return 'UnalignedDataLoader' 94 | 95 | def load_data(self): 96 | return self.paired_data 97 | 98 | def __len__(self): 99 | return min(max(len(self.dataset_s),len(self.dataset_t)), float("inf")) 100 | -------------------------------------------------------------------------------- /src/main4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | 5 | from solver import Solver 6 | 7 | ############################################################################### 8 | 9 | # 10 | # hyperparameters 11 | # 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument( '--nsamps_q', 16 | type=int, default=50, 17 | help='# of samples from variational density q(w) (default: 50)' ) 18 | 19 | parser.add_argument( '--lamb_marg_loss', 20 | type=float, default=10.0, 21 | help='impact of margin loss (default: 10.0)' ) 22 | 23 | parser.add_argument( '--all_use', 24 | type=str, default='no', 25 | help='use all training data? (default: "no")' ) 26 | 27 | parser.add_argument( '--batch-size', 28 | type=int, default=128, 29 | help='input batch size for training (default: 128)' ) 30 | 31 | #parser.add_argument( '--checkpoint_dir', 32 | # type=str, default='checkpoint', 33 | # help='source only or not (default: "checkpoint")' ) 34 | 35 | parser.add_argument( '--eval_only', 36 | action='store_true', default=False, 37 | help='evaluation only option' ) 38 | 39 | parser.add_argument( '--lr', 40 | type=float, default=0.0002, 41 | help='learning rate (default: 0.0002)' ) 42 | 43 | parser.add_argument( '--max_epoch', 44 | type=int, default=200, 45 | help='maximum number of epochs (default: 200)' ) 46 | 47 | parser.add_argument( '--no-cuda', 48 | action='store_true', default=False, 49 | help='disables CUDA training' ) 50 | 51 | parser.add_argument( '--num_k', 52 | type=int, default=4, 53 | help='# gradient descent iterations for phi(G(x)) learning (default: 4)' ) 54 | 55 | parser.add_argument( '--num_kq', 56 | type=int, default=4, 57 | help='# gradient descent iterations for q(w) learning (default: 4)' ) 58 | 59 | #parser.add_argument( '--one_step', 60 | # action='store_true', default=False, 61 | # help='one step training with gradient reversal layer' ) 62 | 63 | parser.add_argument( '--optimizer', 64 | type=str, default='adam', 65 | help='optimizer (default: "adam")' ) 66 | 67 | parser.add_argument( '--resume_epoch', 68 | type=int, default=100, 69 | help='epoch to resume (default: 100)' ) 70 | 71 | parser.add_argument( '--save_epoch', 72 | type=int, default=10, 73 | help='when to restore the model (default: 10)' ) 74 | 75 | parser.add_argument( '--save_model', 76 | action='store_true', default=False, 77 | help='save_model or not' ) 78 | 79 | parser.add_argument( '--seed', 80 | type=int, default=1, 81 | help='random seed (default: 1)' ) 82 | 83 | parser.add_argument( '--source', 84 | type=str, default='svhn', 85 | help='source dataset (default: "svhn")' ) 86 | 87 | parser.add_argument( '--target', 88 | type=str, default='mnist', 89 | help='target dataset (default: "mnist")' ) 90 | 91 | parser.add_argument( '--use_abs_diff', 92 | action='store_true', default=False, 93 | help='use absolute difference value as a measurement' ) 94 | 95 | parser.add_argument( '--fix_randomness', 96 | action='store_true', default=False, 97 | help='fix randomness' ) 98 | 99 | args = parser.parse_args() 100 | 101 | args.cuda = not args.no_cuda and torch.cuda.is_available() 102 | torch.manual_seed(args.seed) 103 | if args.cuda: 104 | torch.cuda.manual_seed(args.seed) 105 | 106 | print(args) 107 | 108 | if args.fix_randomness: 109 | import numpy as np 110 | np.random.seed(10) 111 | torch.backends.cudnn.deterministic = True 112 | 113 | 114 | ############################################################################### 115 | 116 | def main(): 117 | 118 | # make a string that describes the current running setup 119 | num = 0 120 | run_setup_str = \ 121 | '%s2%s_k_%s_kq_%s_lamb_%s' % \ 122 | ( args.source, args.target, args.num_k, args.num_kq, args.lamb_marg_loss) 123 | while os.path.exists('record/%s_run_%s.txt' % (run_setup_str, num)): 124 | num += 1 125 | run_setup_str = '%s_run_%s' % (run_setup_str, num) 126 | # eg, svhn2mnist_k_4_kq_4_lamb_10.0_run_5 127 | 128 | # set file names for records (storing training stats) 129 | record_train = 'record/%s.txt' % (run_setup_str,) 130 | record_test = 'record/%s_test.txt' % (run_setup_str,) 131 | if not os.path.exists('record'): 132 | os.mkdir('record') # create a folder for records if not exist 133 | 134 | # set the checkpoint dir name (storing model params) 135 | checkpoint_dir = 'checkpoint/%s' % (run_setup_str,) 136 | if not os.path.exists('checkpoint'): 137 | os.mkdir('checkpoint') # create a folder if not exist 138 | if not os.path.exists(checkpoint_dir): 139 | os.mkdir(checkpoint_dir) # create a folder if not exist 140 | 141 | #### 142 | 143 | # create a solver: load data, create models (or load existing models), 144 | # and create optimizers 145 | solver = Solver( args, 146 | source = args.source, 147 | target = args.target, 148 | nsamps_q = args.nsamps_q, 149 | lamb_marg_loss = args.lamb_marg_loss, 150 | learning_rate = args.lr, 151 | batch_size = args.batch_size, 152 | optimizer = args.optimizer, 153 | num_k = args.num_k, 154 | num_kq = args.num_kq, 155 | all_use = args.all_use, 156 | checkpoint_dir = checkpoint_dir, 157 | save_epoch = args.save_epoch ) 158 | 159 | # run it (test or training) 160 | if args.eval_only: 161 | solver.test(0) 162 | else: # training 163 | count = 0 164 | for t in range(args.max_epoch): 165 | num = solver.train(t, record_file=record_train) 166 | count += num 167 | if t % 1 == 0: # run it on test data every epoch (and save models) 168 | solver.test( t, record_file=record_test, 169 | save_model=args.save_model ) 170 | if count >= 20000*10: 171 | break 172 | 173 | ############################################################################### 174 | 175 | if __name__ == '__main__': 176 | main() 177 | 178 | -------------------------------------------------------------------------------- /src/model/build_gen.py: -------------------------------------------------------------------------------- 1 | import model.svhn2mnist as svhn2mnist 2 | #import model.usps as usps 3 | #import model.syn2gtrsb as syn2gtrsb 4 | 5 | ############################################################################### 6 | 7 | def PhiGnet(source, target): 8 | if source == 'usps' or target == 'usps': 9 | return usps.PhiGnetwork() 10 | elif source == 'svhn': 11 | return svhn2mnist.PhiGnetwork() 12 | elif source == 'synth': 13 | return syn2gtrsb.PhiGnetwork() 14 | 15 | ############################################################################### 16 | 17 | def QWnet(source, target): 18 | if source == 'usps' or target == 'usps': 19 | return usps.QWnetwork() 20 | elif source == 'svhn': 21 | return svhn2mnist.QWnetwork() 22 | elif source == 'synth': 23 | return syn2gtrsb.QWnetwork() 24 | 25 | ############################################################################### 26 | 27 | #def Generator(source, target): 28 | # if source == 'usps' or target == 'usps': 29 | # return usps.Feature() 30 | # elif source == 'svhn': 31 | # return svhn2mnist.Feature() 32 | # elif source == 'synth': 33 | # return syn2gtrsb.Feature() 34 | 35 | ############################################################################### 36 | 37 | #def Classifier(source, target): 38 | # if source == 'usps' or target == 'usps': 39 | # return usps.Predictor() 40 | # if source == 'svhn': 41 | # return svhn2mnist.Predictor() 42 | # if source == 'synth': 43 | # return syn2gtrsb.Predictor() 44 | -------------------------------------------------------------------------------- /src/model/svhn2mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | #from model.grad_reverse import grad_reverse 6 | 7 | ############################################################################### 8 | 9 | # 10 | # PhiGnetwork retuns u = phi(G(x)) where 11 | # 12 | # x = image 13 | # z = G(x) = exactly Feature() in MCD-DA 14 | # u = phi(z) = the last hidden layer of Predictor() in MCD-DA 15 | # 16 | 17 | class PhiGnetwork(nn.Module): 18 | 19 | def __init__(self): 20 | 21 | super(PhiGnetwork, self).__init__() 22 | 23 | self.g_conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2) 24 | self.g_bn1 = nn.BatchNorm2d(64) 25 | self.g_conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2) 26 | self.g_bn2 = nn.BatchNorm2d(64) 27 | self.g_conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2) 28 | self.g_bn3 = nn.BatchNorm2d(128) 29 | self.g_fc1 = nn.Linear(8192, 3072) 30 | self.g_bn1_fc = nn.BatchNorm1d(3072) 31 | 32 | #self.phi_fc1 = nn.Linear(8192, 3072) 33 | #self.phi_bn1_fc = nn.BatchNorm1d(3072) 34 | self.phi_fc2 = nn.Linear(3072, 2048) 35 | self.phi_bn2_fc = nn.BatchNorm1d(2048) 36 | 37 | self.p = 2048 38 | 39 | def forward(self, x): 40 | 41 | x = F.max_pool2d( F.relu(self.g_bn1(self.g_conv1(x))), 42 | stride=2, kernel_size=3, padding=1 ) 43 | x = F.max_pool2d( F.relu(self.g_bn2(self.g_conv2(x))), 44 | stride=2, kernel_size=3, padding=1 ) 45 | x = F.relu(self.g_bn3(self.g_conv3(x))) 46 | x = x.view(x.size(0), 8192) 47 | x = F.relu(self.g_bn1_fc(self.g_fc1(x))) 48 | z = F.dropout(x, training=self.training) 49 | 50 | u = F.relu(self.phi_bn2_fc(self.phi_fc2(z))) 51 | 52 | return u 53 | 54 | ############################################################################### 55 | 56 | # 57 | # QWnetwork retuns w^m_j = mu_j + sd_j.*eps^m_j, for m=1...M samples from 58 | # q(w) = \prod_{j=1}^K N(w_j; mu_j, diag(sd_j)^2) with dim(w_j) = p 59 | # 60 | # eps = samples from N(0,1); (M x p x K) -- input 61 | # mu = K mean vectors of q(w); (p x K) -- model params 62 | # logsd = K log-stdev vectors of q(w); (p x K) -- model params 63 | # (sd = exp(logsd)) 64 | # 65 | 66 | class QWnetwork(nn.Module): 67 | 68 | def __init__(self): 69 | 70 | super(QWnetwork, self).__init__() 71 | 72 | self.mu = nn.Parameter(0.01*torch.randn(2048, 10)) 73 | self.logsd = nn.Parameter(0.01*torch.randn(2048, 10)) 74 | 75 | def forward(self, eps): 76 | 77 | mu3 = self.mu.unsqueeze(0) # (1 x p x K) 78 | sd3 = torch.exp(self.logsd).unsqueeze(0) # (1 x p x K) 79 | w = mu3 + sd3*eps # (M x p x K) 80 | 81 | return w 82 | 83 | ############################################################################### 84 | 85 | #class Feature(nn.Module): 86 | # 87 | # def __init__(self): 88 | # 89 | # super(Feature, self).__init__() 90 | # 91 | # self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2) 92 | # self.bn1 = nn.BatchNorm2d(64) 93 | # self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2) 94 | # self.bn2 = nn.BatchNorm2d(64) 95 | # self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2) 96 | # self.bn3 = nn.BatchNorm2d(128) 97 | # self.fc1 = nn.Linear(8192, 3072) 98 | # self.bn1_fc = nn.BatchNorm1d(3072) 99 | # 100 | # def forward(self, x): 101 | # 102 | # x = F.max_pool2d( F.relu(self.bn1(self.conv1(x))), 103 | # stride=2, kernel_size=3, padding=1 ) 104 | # x = F.max_pool2d( F.relu(self.bn2(self.conv2(x))), 105 | # stride=2, kernel_size=3, padding=1 ) 106 | # x = F.relu(self.bn3(self.conv3(x))) 107 | # x = x.view(x.size(0), 8192) 108 | # x = F.relu(self.bn1_fc(self.fc1(x))) 109 | # x = F.dropout(x, training=self.training) 110 | # 111 | # return x 112 | 113 | ############################################################################### 114 | 115 | #class Predictor(nn.Module): 116 | # 117 | # def __init__(self, prob=0.5): 118 | # 119 | # super(Predictor, self).__init__() 120 | # 121 | # self.fc1 = nn.Linear(8192, 3072) 122 | # self.bn1_fc = nn.BatchNorm1d(3072) 123 | # self.fc2 = nn.Linear(3072, 2048) 124 | # self.bn2_fc = nn.BatchNorm1d(2048) 125 | # self.fc3 = nn.Linear(2048, 10) 126 | # self.bn_fc3 = nn.BatchNorm1d(10) 127 | # self.prob = prob 128 | # 129 | # def set_lambda(self, lambd): 130 | # 131 | # self.lambd = lambd 132 | # 133 | # def forward(self, x, reverse=False): 134 | # 135 | # if reverse: 136 | # x = grad_reverse(x, self.lambd) 137 | # x = F.relu(self.bn2_fc(self.fc2(x))) 138 | # x = self.fc3(x) 139 | # 140 | # return x 141 | -------------------------------------------------------------------------------- /src/model/syn2gtrsb.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from model.grad_reverse import grad_reverse 5 | 6 | 7 | class Feature(nn.Module): 8 | def __init__(self): 9 | super(Feature, self).__init__() 10 | self.conv1 = nn.Conv2d(3, 96, kernel_size=5, stride=1, padding=2) 11 | self.bn1 = nn.BatchNorm2d(96) 12 | self.conv2 = nn.Conv2d(96, 144, kernel_size=3, stride=1, padding=1) 13 | self.bn2 = nn.BatchNorm2d(144) 14 | self.conv3 = nn.Conv2d(144, 256, kernel_size=5, stride=1, padding=2) 15 | self.bn3 = nn.BatchNorm2d(256) 16 | 17 | def forward(self, x): 18 | x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), stride=2, kernel_size=2, padding=0) 19 | x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), stride=2, kernel_size=2, padding=0) 20 | x = F.max_pool2d(F.relu(self.bn3(self.conv3(x))), stride=2, kernel_size=2, padding=0) 21 | x = x.view(x.size(0), 6400) 22 | return x 23 | 24 | 25 | class Predictor(nn.Module): 26 | def __init__(self): 27 | super(Predictor, self).__init__() 28 | self.fc2 = nn.Linear(6400, 512) 29 | self.bn2_fc = nn.BatchNorm1d(512) 30 | self.fc3 = nn.Linear(512, 43) 31 | self.bn_fc3 = nn.BatchNorm1d(43) 32 | 33 | def set_lambda(self, lambd): 34 | self.lambd = lambd 35 | 36 | def forward(self, x, reverse=False): 37 | if reverse: 38 | x = grad_reverse(x, self.lambd) 39 | x = F.relu(self.bn2_fc(self.fc2(x))) 40 | x = F.dropout(x, training=self.training) 41 | x = self.fc3(x) 42 | return x 43 | -------------------------------------------------------------------------------- /src/model/usps.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from model.grad_reverse import grad_reverse 6 | 7 | 8 | class Feature(nn.Module): 9 | def __init__(self): 10 | super(Feature, self).__init__() 11 | self.conv1 = nn.Conv2d(1, 32, kernel_size=5, stride=1) 12 | self.bn1 = nn.BatchNorm2d(32) 13 | self.conv2 = nn.Conv2d(32, 48, kernel_size=5, stride=1) 14 | self.bn2 = nn.BatchNorm2d(48) 15 | 16 | def forward(self, x): 17 | x = torch.mean(x,1).view(x.size()[0],1,x.size()[2],x.size()[3]) 18 | x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), stride=2, kernel_size=2, dilation=(1, 1)) 19 | x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), stride=2, kernel_size=2, dilation=(1, 1)) 20 | #print(x.size()) 21 | x = x.view(x.size(0), 48*4*4) 22 | return x 23 | 24 | 25 | class Predictor(nn.Module): 26 | def __init__(self, prob=0.5): 27 | super(Predictor, self).__init__() 28 | self.fc1 = nn.Linear(48*4*4, 100) 29 | self.bn1_fc = nn.BatchNorm1d(100) 30 | self.fc2 = nn.Linear(100, 100) 31 | self.bn2_fc = nn.BatchNorm1d(100) 32 | self.fc3 = nn.Linear(100, 10) 33 | self.bn_fc3 = nn.BatchNorm1d(10) 34 | self.prob = prob 35 | 36 | def set_lambda(self, lambd): 37 | self.lambd = lambd 38 | def forward(self, x, reverse=False): 39 | if reverse: 40 | x = grad_reverse(x, self.lambd) 41 | x = F.dropout(x, training=self.training, p=self.prob) 42 | x = F.relu(self.bn1_fc(self.fc1(x))) 43 | x = F.dropout(x, training=self.training, p=self.prob) 44 | x = F.relu(self.bn2_fc(self.fc2(x))) 45 | x = F.dropout(x, training=self.training, p=self.prob) 46 | x = self.fc3(x) 47 | return x 48 | -------------------------------------------------------------------------------- /src/solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | from torch.autograd import Variable 7 | 8 | from model.build_gen import PhiGnet, QWnet 9 | from data_manager.dataset_read import dataset_read 10 | 11 | ############################################################################### 12 | 13 | class Solver(object): 14 | 15 | ######## 16 | def __init__( self, args, batch_size=128, source='svhn', target='mnist', 17 | nsamps_q=50, lamb_marg_loss=10.0, 18 | learning_rate=0.0002, interval=100, optimizer='adam', num_k=4, num_kq=4, 19 | all_use=False, checkpoint_dir=None, save_epoch=10 ): 20 | 21 | # set hyperparameters 22 | self.batch_size = batch_size 23 | self.source = source 24 | self.target = target 25 | self.num_k = num_k 26 | self.num_kq = num_kq 27 | self.checkpoint_dir = checkpoint_dir 28 | self.save_epoch = save_epoch 29 | self.use_abs_diff = args.use_abs_diff 30 | self.all_use = all_use 31 | if self.source == 'svhn': 32 | self.scale = True 33 | else: 34 | self.scale = False 35 | self.lamb_marg_loss = lamb_marg_loss 36 | 37 | # load data, do image transform, and create a mini-batch generator 38 | print('dataset loading') 39 | self.datasets, self.dataset_test = \ 40 | dataset_read( source, target, self.batch_size, 41 | scale=self.scale, all_use=self.all_use ) 42 | print('load finished!') 43 | 44 | if source == 'svhn': 45 | self.Ns = 73257 46 | 47 | # create models 48 | self.phig = PhiGnet(source=source, target=target) 49 | self.qw = QWnet(source=source, target=target) 50 | 51 | # load the previously learned models from files (if evaluations only) 52 | if args.eval_only: 53 | self.phig = torch.load( '%s/model_epoch%s_phig.pt' % 54 | (self.checkpoint_dir , args.resume_epoch) ) 55 | self.qw = torch.load( '%s/model_epoch%s_qw.pt' % 56 | (self.checkpoint_dir, args.resume_epoch) ) 57 | 58 | # move models to GPU 59 | self.phig.cuda() 60 | self.qw.cuda() 61 | 62 | # create optimizer objects (one for each model) 63 | self.set_optimizer(which_opt=optimizer, lr=learning_rate) 64 | 65 | # print stats every interval (default: 100) minibatch iters 66 | self.interval = interval 67 | 68 | self.lr = learning_rate 69 | 70 | # some dimensions 71 | self.p = self.phig.p # dim(phi(G(x))) 72 | self.M = nsamps_q # number of samples from variational density q(w) 73 | 74 | 75 | ######## 76 | def set_optimizer(self, which_opt='momentum', lr=0.001, momentum=0.9): 77 | 78 | if which_opt == 'momentum': 79 | 80 | self.opt_phig = optim.SGD( self.phig.parameters(), 81 | lr=lr, weight_decay=0.0005, momentum=momentum ) 82 | 83 | self.opt_qw = optim.SGD( self.qw.parameters(), 84 | lr=lr, weight_decay=0.0005, momentum=momentum ) 85 | 86 | if which_opt == 'adam': 87 | 88 | self.opt_phig = optim.Adam( self.phig.parameters(), 89 | lr=lr, weight_decay=0.0005 ) 90 | 91 | self.opt_qw = optim.Adam( self.qw.parameters(), 92 | lr=lr, weight_decay=0.0005 ) 93 | 94 | 95 | ######## 96 | def reset_grad(self): 97 | 98 | # zero out all gradients of model params registered in the optimizers 99 | self.opt_phig.zero_grad() 100 | self.opt_qw.zero_grad() 101 | 102 | 103 | ######## 104 | def ent(self, output): 105 | 106 | return -torch.mean(output * torch.log(output + 1e-6)) 107 | 108 | 109 | ######## 110 | def kl_loss(self): 111 | 112 | kl = 0.5 * ( -self.p*10 + 113 | torch.sum( (torch.exp(self.qw.logsd))**2 + self.qw.mu**2 - 114 | 2.0*self.qw.logsd ) 115 | ) 116 | 117 | return kl 118 | 119 | 120 | ######## 121 | def train(self, epoch, record_file=None): 122 | 123 | ''' 124 | train models for one epoch (ie, one pass of whole training data) 125 | ''' 126 | 127 | criterion = nn.CrossEntropyLoss().cuda() 128 | 129 | # turn models into "training" mode 130 | # (required if models contain "BatchNorm"-like layers) 131 | self.phig.train() 132 | self.qw.train() 133 | 134 | torch.cuda.manual_seed(1) 135 | 136 | # for each batch 137 | for batch_idx, data in enumerate(self.datasets): 138 | 139 | img_t = data['T'] 140 | img_s = data['S'] 141 | label_s = data['S_label'] 142 | if img_s.size()[0] < self.batch_size or \ 143 | img_t.size()[0] < self.batch_size: 144 | break 145 | img_s = img_s.cuda() 146 | img_t = img_t.cuda() 147 | # imgs = Variable(torch.cat((img_s, img_t), 0)) 148 | label_s = Variable(label_s.long().cuda()) 149 | img_s = Variable(img_s) 150 | img_t = Variable(img_t) 151 | 152 | # (M x p x K) samples from N(0,1) 153 | eps = Variable(torch.randn(self.M, self.p, 10)) 154 | eps = eps.cuda() 155 | 156 | #### step A: min_{qw} (nll + kl) 157 | 158 | self.reset_grad() 159 | 160 | for i in range(self.num_kq): 161 | 162 | phig_s = self.phig(img_s) # phi(G(xs)) 163 | wsamp = self.qw(eps) # samples from q(w) 164 | 165 | # w'*phi(G(xs)) = (M x B x K) 166 | wphig_s = torch.sum( 167 | wsamp.unsqueeze(1) * phig_s.unsqueeze(0).unsqueeze(3), 168 | dim=2 ) 169 | 170 | # nll loss 171 | loss_nll = criterion( 172 | wphig_s.view(-1,10), label_s.repeat(self.M) ) * self.Ns 173 | 174 | # kl loss 175 | loss_kl = self.kl_loss() 176 | 177 | loss = loss_nll + loss_kl 178 | 179 | # compute gradient of the loss 180 | loss.backward() 181 | 182 | # update models 183 | self.opt_qw.step() 184 | 185 | self.reset_grad() 186 | 187 | #### step B: min_{phig} (nll + kl + marg) 188 | 189 | self.reset_grad() 190 | 191 | for i in range(self.num_k): 192 | 193 | phig_s = self.phig(img_s) # phi(G(xs)) 194 | phig_t = self.phig(img_t) # phi(G(xt)) 195 | wsamp = self.qw(eps) # samples from q(w) 196 | 197 | # w'*phi(G(xs)) = (M x B x K) 198 | wphig_s = torch.sum( 199 | wsamp.unsqueeze(1) * phig_s.unsqueeze(0).unsqueeze(3), 200 | dim=2 ) 201 | 202 | # nll loss 203 | loss_nll = criterion( 204 | wphig_s.view(-1,10), label_s.repeat(self.M) ) * self.Ns 205 | 206 | # kl loss 207 | loss_kl = self.kl_loss() 208 | 209 | # margin loss on target 210 | f_t = torch.mm(phig_t, self.qw.mu) # (B x K) 211 | top2 = torch.topk(f_t, k=2, dim=1)[0] # (B x 2) 212 | # top2[i,0] = max_j f_t[i,j], top2[:,1] = max2_j f_t[i,j] 213 | gap21 = top2[:,1] - top2[:,0] # B-dim 214 | std_f_t = torch.sqrt( 215 | torch.mm(phig_t**2, torch.exp(self.qw.logsd)**2) ) # (B x K) 216 | max_std = torch.max(std_f_t, dim=1)[0] # B-dim 217 | loss_marg = torch.mean( F.relu(1.0 + gap21 + 1.96*max_std) ) 218 | 219 | loss = loss_nll + loss_kl + self.lamb_marg_loss*loss_marg 220 | 221 | # compute gradient of the loss 222 | loss.backward() 223 | 224 | # update models 225 | self.opt_phig.step() 226 | 227 | self.reset_grad() 228 | 229 | #### wrap up 230 | 231 | if batch_idx > 500: 232 | return batch_idx 233 | 234 | if batch_idx % self.interval == 0: 235 | prn_str = ('Train Epoch: %d [batch-idx: %d] ' + \ 236 | 'nll: %.6f, kl: %.6f, marg: %.6f') % \ 237 | ( epoch, batch_idx, loss_nll.item(), loss_kl.item(), 238 | loss_marg.item() ) 239 | print(prn_str) 240 | if record_file: 241 | record = open(record_file, 'a') 242 | record.write('%s\n' % (prn_str,)) 243 | record.close() 244 | 245 | return batch_idx 246 | 247 | 248 | ######## 249 | def test(self, epoch, record_file=None, save_model=False): 250 | 251 | ''' 252 | evaluate the current models on the entire test set 253 | ''' 254 | 255 | criterion = nn.CrossEntropyLoss().cuda() 256 | 257 | # turn models into evaluation mode 258 | self.phig.eval() 259 | self.qw.eval() 260 | 261 | test_loss = 0 # test nll loss 262 | corrects = 0 # number of correct predictions by MAP 263 | size = 0 # total number of test samples 264 | 265 | # turn off autograd feature (no evaluation history tracking) 266 | with torch.no_grad(): 267 | 268 | for batch_idx, data in enumerate(self.dataset_test): 269 | 270 | img = data['T'] 271 | label = data['T_label'] 272 | 273 | img, label = img.cuda(), label.long().cuda() 274 | 275 | #img, label = Variable(img, volatile=True), Variable(label) 276 | img, label = Variable(img), Variable(label) 277 | 278 | # (M x p x K) samples from N(0,1) 279 | #eps = Variable(torch.randn(self.M, self.p, 10)) 280 | #eps = eps.cuda() 281 | 282 | phig = self.phig(img) # phi(G(x)) 283 | wmode = self.qw.mu # mode of q(w) 284 | #wsamp = self.qw(eps) # samples from q(w) 285 | 286 | # w'*phi(G(x)) = (B x K) 287 | output = torch.mm(phig, wmode) 288 | 289 | # w'*phi(G(x)) = (M x B x K) 290 | #wphig = torch.sum( 291 | # wsamp.unsqueeze(1) * phig.unsqueeze(0).unsqueeze(3), dim=2 ) 292 | 293 | # nll loss (equivalent to cross entropy loss) 294 | test_loss += criterion(output, label).item() 295 | 296 | # class prediction 297 | pred = output.data.max(1)[1] # n-dim {0,...,K-1}-valued 298 | # tensor.max(j) returns a list (A, B) where 299 | # A = max of tensor over j-th dim 300 | # B = argmax of tensor over j-th dim 301 | 302 | corrects += pred.eq(label.data).cpu().numpy().sum() 303 | 304 | size += label.data.size()[0] 305 | 306 | test_loss = test_loss / size 307 | 308 | prn_str = ( 'Test set: Average nll loss: %.4f, ' + \ 309 | 'Accuracy: %d/%d (%.4f%%)\n' ) % \ 310 | ( test_loss, corrects, size, 100. * corrects / size ) 311 | print(prn_str) 312 | 313 | # save (append) the test scores/stats to files 314 | if record_file: 315 | record = open(record_file, 'a') 316 | print('recording %s\n' % record_file) 317 | record.write('%s\n' % (prn_str,)) 318 | record.close() 319 | 320 | # save the models as files 321 | if save_model and epoch % self.save_epoch == 0: 322 | torch.save( self.phig, 323 | '%s/model_epoch%s_phig.pt' % (self.checkpoint_dir, epoch) ) 324 | torch.save( self.qw, 325 | '%s/model_epoch%s_qw.pt' % (self.checkpoint_dir, epoch) ) 326 | 327 | 328 | 329 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ############################################################################### 4 | 5 | def weights_init(m): 6 | 7 | classname = m.__class__.__name__ 8 | if classname.find('Conv') != -1: 9 | m.weight.data.normal_(0.0, 0.01) 10 | m.bias.data.normal_(0.0, 0.01) 11 | elif classname.find('BatchNorm') != -1: 12 | m.weight.data.normal_(1.0, 0.01) 13 | m.bias.data.fill_(0) 14 | 15 | ############################################################################### 16 | 17 | def convert_label_10_to_0(labels): 18 | 19 | ''' 20 | convert class label 10 to 0 21 | ''' 22 | 23 | labels2 = np.zeros((len(labels),)) 24 | labels = list(labels) 25 | for i, t in enumerate(labels): 26 | if t == 10: 27 | labels2[i] = 0 28 | else: 29 | labels2[i] = t 30 | 31 | return labels2 32 | --------------------------------------------------------------------------------