├── LICENSE
├── README.md
├── _config.yml
├── docs
├── README.md
├── _config.yml
├── method.png
└── results.png
└── src
├── cmd_options.txt
├── data_manager
├── dataset_read.py
├── datasets.py
├── mnist.py
├── svhn.py
└── unaligned_data_loader.py
├── main4.py
├── model
├── build_gen.py
├── svhn2mnist.py
├── syn2gtrsb.py
└── usps.py
├── solver.py
└── utils
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 seqam-lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | **This is the project page for Unsupervised Visual Domain Adaptation:A Deep Max-Margin Gaussian Process Approach.
2 | The work was accepted by CVPR 2019 Oral.**
3 | [[Paper Link]](http://openaccess.thecvf.com/content_CVPR_2019/html/Kim_Unsupervised_Visual_Domain_Adaptation_A_Deep_Max-Margin_Gaussian_Process_Approach_CVPR_2019_paper.html).
4 |
5 |
6 |
7 | ## Citation
8 | If you use this code for your research, please cite our papers (This will be updated when cvpr paper is publicized).
9 | ```
10 | @article{kim2019unsupervised,
11 | title={Unsupervised Visual Domain Adaptation: A Deep Max-Margin Gaussian Process Approach},
12 | author={Kim, Minyoung and Sahu, Pritish and Gholami, Behnam and Pavlovic, Vladimir},
13 | journal={arXiv preprint arXiv:1902.08727},
14 | year={2019}
15 | }
16 | ```
17 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-slate
2 | title: Unsupervised Visual Domain Adaptation:A Deep Max-Margin Gaussian Process Approach
3 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | **This is the project page for Unsupervised Visual Domain Adaptation:A Deep Max-Margin Gaussian Process Approach.
2 | The work was accepted by CVPR 2019 Oral.**
3 | [[Paper Link]](http://openaccess.thecvf.com/content_CVPR_2019/html/Kim_Unsupervised_Visual_Domain_Adaptation_A_Deep_Max-Margin_Gaussian_Process_Approach_CVPR_2019_paper.html)[[Youtube Link]](https://youtu.be/OYbiWSM0u8U).
4 |
5 |
6 | ## Abstract
7 | In unsupervised domain adaptation, it is widely known that the target domain error can be provably reduced by having
8 | a shared input representation that makes the source and target domains indistinguishable from each other. Very recently it
9 | has been studied that not just matching the marginal input distributions, but the alignment of output (class) distributions is
10 | also critical. The latter can be achieved by minimizing the maximum discrepancy of predictors (classifiers). In this paper,
11 | we adopt this principle, but propose a more systematic and effective way to achieve hypothesis consistency via Gaussian
12 | processes (GP). The GP allows us to define/induce a hypothesis space of the classifiers from the posterior distribution of the
13 | latent random functions, turning the learning into a simple large-margin posterior separation problem, far easier to solve
14 | than previous approaches based on adversarial minimax optimization. We formulate a learning objective that effectively
15 | pushes the posterior to minimize the maximum discrepancy. This is further shown to be equivalent to maximizing margins
16 | and minimizing uncertainty of the class predictions in the target domain, a well-established principle in classical (semi-
17 | )supervised learning. Empirical results demonstrate that our approach is comparable or superior to the existing methods on
18 | several benchmark domain adaptation datasets.
19 |
20 | 
21 |
22 |
23 | ## Results
24 | 
25 |
26 |
27 | ## Codes
28 | [[Classification]](https://github.com/seqam-lab/GPDA/tree/master/src)
29 |
30 | ## Citation
31 | If you use this code for your research, please cite our papers (This will be updated when cvpr paper is publicized).
32 | ```
33 | @article{kim2019unsupervised,
34 | title={Unsupervised Visual Domain Adaptation: A Deep Max-Margin Gaussian Process Approach},
35 | author={Kim, Minyoung and Sahu, Pritish and Gholami, Behnam and Pavlovic, Vladimir},
36 | journal={arXiv preprint arXiv:1902.08727},
37 | year={2019}
38 | }
39 | ```
40 |
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-slate
2 | title: Unsupervised Visual Domain Adaptation:A Deep Max-Margin Gaussian Process Approach
3 |
--------------------------------------------------------------------------------
/docs/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seqam-lab/GPDA/1c7b2462f41b8eeb905f0909ff5f59fd0ba94e48/docs/method.png
--------------------------------------------------------------------------------
/docs/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seqam-lab/GPDA/1c7b2462f41b8eeb905f0909ff5f59fd0ba94e48/docs/results.png
--------------------------------------------------------------------------------
/src/cmd_options.txt:
--------------------------------------------------------------------------------
1 | --num_k 3 --num_kq 3 --lamb_marg_loss 10.0 --max_epoch 2000 --save_model --save_epoch 10 --fix_randomness
--------------------------------------------------------------------------------
/src/data_manager/dataset_read.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from data_manager.unaligned_data_loader import UnalignedDataLoader
4 | from data_manager.svhn import load_svhn
5 | from data_manager.mnist import load_mnist
6 | #from datasets.usps import load_usps
7 | #from datasets.gtsrb import load_gtsrb
8 | #from datasets.synth_traffic import load_syntraffic
9 |
10 | ###############################################################################
11 |
12 | def return_dataset(data, scale=False, usps=False, all_use='no'):
13 |
14 | '''
15 | load a specified dataset
16 |
17 | input:
18 | data = dataset to load (eg, 'svhn', 'mnist'); string
19 | scale = whether to scale up images to (32 x 32) or not (28 x 28)
20 | usps, all_use = whether of not to take subsamples from traning set
21 |
22 | output:
23 | train_image = train images; (ntr x C x H x W)
24 | train_label = {0...9}-valued train labels; ntr-dim
25 | test_image = test images; (nte x C x H x W)
26 | test_label = {0...9}-valued test labels; nte-dim
27 | '''
28 |
29 | if data == 'svhn':
30 | train_image, train_label, test_image, test_label = load_svhn()
31 |
32 | if data == 'mnist':
33 | train_image, train_label, test_image, test_label = \
34 | load_mnist( scale=scale, usps=usps, all_use=all_use )
35 | sys.stdout.write('mnist image shape = '); print(train_image.shape)
36 |
37 | # if data == 'usps':
38 | # train_image, train_label, test_image, test_label = \
39 | # load_usps(all_use=all_use)
40 | #
41 | # if data == 'synth':
42 | # train_image, train_label, test_image, test_label = \
43 | # load_syntraffic()
44 | #
45 | # if data == 'gtsrb':
46 | # train_image, train_label, test_image, test_label = load_gtsrb()
47 |
48 | return train_image, train_label, test_image, test_label
49 |
50 | ###############################################################################
51 |
52 | def dataset_read( source, target, batch_size, scale=False, all_use='no' ):
53 |
54 | if source == 'usps' or target == 'usps':
55 | usps = True
56 | else:
57 | usps = False
58 |
59 | S = {}; S_test = {}
60 | T = {}; T_test = {}
61 |
62 | # read source data
63 | train_source, s_label_train, test_source, s_label_test = \
64 | return_dataset( source, scale=scale, usps=usps, all_use=all_use )
65 |
66 | # read target data
67 | train_target, t_label_train, test_target, t_label_test = \
68 | return_dataset( target, scale=scale, usps=usps, all_use=all_use )
69 |
70 | # prepare source/target data
71 | S['imgs'] = train_source
72 | S['labels'] = s_label_train
73 | T['imgs'] = train_target
74 | T['labels'] = t_label_train
75 |
76 | # test samples for source/target
77 | S_test['imgs'] = test_target
78 | S_test['labels'] = t_label_test
79 | T_test['imgs'] = test_target
80 | T_test['labels'] = t_label_test
81 |
82 | scale = 40 if source == 'synth' else 28 if usps else 32
83 |
84 | # (train) do some image transform and create a minibatch generator
85 | train_loader = UnalignedDataLoader()
86 | train_loader.initialize(S, T, batch_size, batch_size, scale=scale)
87 | dataset = train_loader.load_data()
88 |
89 | # (test) do some image transform and create a minibatch generator
90 | test_loader = UnalignedDataLoader()
91 | test_loader.initialize(S_test, T_test, batch_size, batch_size, scale=scale)
92 | dataset_test = test_loader.load_data()
93 |
94 | return dataset, dataset_test
95 |
--------------------------------------------------------------------------------
/src/data_manager/datasets.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image
3 | import numpy as np
4 |
5 | ###############################################################################
6 |
7 | class Dataset(data.Dataset):
8 |
9 | def __init__(self, data, label,
10 | transform=None,target_transform=None):
11 | self.transform = transform
12 | self.target_transform = target_transform
13 | self.data = data
14 | self.labels = label
15 |
16 | def __getitem__(self, index):
17 |
18 | img, target = self.data[index], self.labels[index]
19 | # doing this so that it is consistent with all other datasets
20 | # to return a PIL Image
21 | # print(img.shape)
22 | if img.shape[0] != 1:
23 | #print(img)
24 | img = Image.fromarray(
25 | np.uint8(np.asarray(img.transpose((1, 2, 0)))) )
26 | #
27 | elif img.shape[0] == 1:
28 | im = np.uint8(np.asarray(img))
29 | # print(np.vstack([im,im,im]).shape)
30 | im = np.vstack([im, im, im]).transpose((1, 2, 0))
31 | img = Image.fromarray(im)
32 |
33 | if self.target_transform is not None:
34 | target = self.target_transform(target)
35 | if self.transform is not None:
36 | img = self.transform(img)
37 | # return img, target
38 | return img, target
39 | def __len__(self):
40 | return len(self.data)
41 |
--------------------------------------------------------------------------------
/src/data_manager/mnist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.io import loadmat
3 |
4 | ###############################################################################
5 |
6 | def load_mnist(scale=True, usps=False, all_use='no'):
7 |
8 | '''
9 | load mnist dataset
10 |
11 | input:
12 | scale = whether to scale up images to (32 x 32) or not (28 x 28)
13 | if scale==True, also duplicate channels to (32 x 32 x 3)
14 | usps, all_use = whether of not to take subsamples from traning set
15 | use 2000 random subsamples from training if usps==True & all_use='no'
16 | use ALL training samples otherwise
17 |
18 | output:
19 | mnist_train = training images;
20 | (55000 x 3 x 32 x 32) or (55000 x 1 x 28 x 28)
21 | train_label = {0...9}-valued training labels; 55000-dim
22 | mnist_test = test images;
23 | (10000 x 3 x 32 x 32) or (10000 x 1 x 28 x 28)
24 | test_label = = {0...9}-valued training labels; 10000-dim
25 | '''
26 |
27 | mnist_data = loadmat('data/mnist/mnist_data.mat')
28 | # load the following dict composed of:
29 | # mnist_data['train_32', 'test_32'] = (n x 32 x 32)
30 | # mnist_data['train_28', 'test_28'] = (n x 28 x 28 x 1)
31 | # mnist_data['label_train', 'label_test'] = (n x 10) one-hot
32 |
33 | if scale: # scale up and channel-duplicate images to (32 x 32 x 3)
34 |
35 | mnist_train = np.reshape(mnist_data['train_32'], (55000, 32, 32, 1))
36 | mnist_test = np.reshape(mnist_data['test_32'], (10000, 32, 32, 1))
37 |
38 | # duplicate channels
39 | mnist_train = np.concatenate(
40 | [mnist_train, mnist_train, mnist_train], 3 )
41 | mnist_test = np.concatenate(
42 | [mnist_test, mnist_test, mnist_test], 3 )
43 |
44 | # reshape to (n x C x H x W) format
45 | mnist_train = mnist_train.transpose(0, 3, 1, 2).astype(np.float32)
46 | mnist_test = mnist_test.transpose(0, 3, 1, 2).astype(np.float32)
47 |
48 | else: # use original (28 x 28 x 1)
49 |
50 | mnist_train = mnist_data['train_28']
51 | mnist_test = mnist_data['test_28']
52 |
53 | # reshape to (n x C x H x W) format
54 | mnist_train = mnist_train.transpose((0, 3, 1, 2)).astype(np.float32)
55 | mnist_test = mnist_test.transpose((0, 3, 1, 2)).astype(np.float32)
56 |
57 | # labels in one-hot format
58 | mnist_labels_train = mnist_data['label_train']
59 | mnist_labels_test = mnist_data['label_test']
60 |
61 | # convert one-hot to 0~9 labels
62 | train_label = np.argmax(mnist_labels_train, axis=1)
63 | test_label = np.argmax(mnist_labels_test, axis=1)
64 |
65 | # randomly shuffle training data
66 | inds = np.random.permutation(mnist_train.shape[0])
67 | mnist_train = mnist_train[inds]
68 | train_label = train_label[inds]
69 |
70 | # subsample training images
71 | if usps and all_use != 'yes':
72 | mnist_train = mnist_train[:2000]
73 | train_label = train_label[:2000]
74 |
75 | return mnist_train, train_label, mnist_test, test_label
76 |
--------------------------------------------------------------------------------
/src/data_manager/svhn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.io import loadmat
3 |
4 | from utils.utils import convert_label_10_to_0
5 |
6 | ###############################################################################
7 |
8 | def load_svhn():
9 |
10 | '''
11 | load svhn dataset
12 |
13 | input: N/A
14 |
15 | output:
16 | svhn_train_im = training images; (73257 x 3 x 32 x 32)
17 | svhn_label = {0...9}-valued training labels; 73257-dim
18 | svhn_test_im = test images; (26032 x 3 x 32 x 32)
19 | svhn_label_test = {0...9}-valued test labels; 26032-dim
20 | '''
21 |
22 | svhn_train = loadmat('data/svhn/train_32x32.mat')
23 | svhn_test = loadmat('data/svhn/test_32x32.mat')
24 | svhn_train_im = svhn_train['X']
25 | svhn_train_im = svhn_train_im.transpose(3, 2, 0, 1).astype(np.float32)
26 | svhn_label = convert_label_10_to_0(svhn_train['y'])
27 | svhn_test_im = svhn_test['X']
28 | svhn_test_im = svhn_test_im.transpose(3, 2, 0, 1).astype(np.float32)
29 | svhn_label_test = convert_label_10_to_0(svhn_test['y'])
30 |
31 | return svhn_train_im, svhn_label, svhn_test_im, svhn_label_test
32 |
--------------------------------------------------------------------------------
/src/data_manager/unaligned_data_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | import torchnet as tnt
3 | from builtins import object
4 | import torchvision.transforms as transforms
5 |
6 | from data_manager.datasets import Dataset
7 |
8 | ###############################################################################
9 |
10 | class PairedData(object):
11 |
12 | def __init__(self, data_loader_A, data_loader_B, max_dataset_size):
13 |
14 | self.data_loader_A = data_loader_A
15 | self.data_loader_B = data_loader_B
16 | self.stop_A = False
17 | self.stop_B = False
18 | self.max_dataset_size = max_dataset_size
19 |
20 | def __iter__(self):
21 |
22 | self.stop_A = False
23 | self.stop_B = False
24 | self.data_loader_A_iter = iter(self.data_loader_A)
25 | self.data_loader_B_iter = iter(self.data_loader_B)
26 | self.iter = 0
27 | return self
28 |
29 | def __next__(self):
30 |
31 | A, A_paths = None, None
32 | B, B_paths = None, None
33 | try:
34 | A, A_paths = next(self.data_loader_A_iter)
35 | except StopIteration:
36 | if A is None or A_paths is None:
37 | self.stop_A = True
38 | self.data_loader_A_iter = iter(self.data_loader_A)
39 | A, A_paths = next(self.data_loader_A_iter)
40 |
41 | try:
42 | B, B_paths = next(self.data_loader_B_iter)
43 | except StopIteration:
44 | if B is None or B_paths is None:
45 | self.stop_B = True
46 | self.data_loader_B_iter = iter(self.data_loader_B)
47 | B, B_paths = next(self.data_loader_B_iter)
48 |
49 | if (self.stop_A and self.stop_B) or self.iter>self.max_dataset_size:
50 | self.stop_A = False
51 | self.stop_B = False
52 | raise StopIteration()
53 | else:
54 | self.iter += 1
55 | return {'S': A, 'S_label': A_paths,
56 | 'T': B, 'T_label': B_paths}
57 |
58 | ###############################################################################
59 |
60 | class UnalignedDataLoader():
61 |
62 | def initialize(self, source, target, batch_size1, batch_size2, scale=32):
63 |
64 | transform = transforms.Compose([
65 | transforms.Scale(scale),
66 | transforms.ToTensor(),
67 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
68 | ])
69 |
70 | dataset_source = Dataset(
71 | source['imgs'], source['labels'], transform=transform )
72 | dataset_target = Dataset(
73 | target['imgs'], target['labels'], transform=transform )
74 |
75 | data_loader_s = torch.utils.data.DataLoader(
76 | dataset_source,
77 | batch_size=batch_size1,
78 | shuffle=True,
79 | num_workers=4 )
80 |
81 | data_loader_t = torch.utils.data.DataLoader(
82 | dataset_target,
83 | batch_size=batch_size2,
84 | shuffle=True,
85 | num_workers=4 )
86 |
87 | self.dataset_s = dataset_source
88 | self.dataset_t = dataset_target
89 | self.paired_data = PairedData(
90 | data_loader_s, data_loader_t, float("inf") )
91 |
92 | def name(self):
93 | return 'UnalignedDataLoader'
94 |
95 | def load_data(self):
96 | return self.paired_data
97 |
98 | def __len__(self):
99 | return min(max(len(self.dataset_s),len(self.dataset_t)), float("inf"))
100 |
--------------------------------------------------------------------------------
/src/main4.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 |
5 | from solver import Solver
6 |
7 | ###############################################################################
8 |
9 | #
10 | # hyperparameters
11 | #
12 |
13 | parser = argparse.ArgumentParser()
14 |
15 | parser.add_argument( '--nsamps_q',
16 | type=int, default=50,
17 | help='# of samples from variational density q(w) (default: 50)' )
18 |
19 | parser.add_argument( '--lamb_marg_loss',
20 | type=float, default=10.0,
21 | help='impact of margin loss (default: 10.0)' )
22 |
23 | parser.add_argument( '--all_use',
24 | type=str, default='no',
25 | help='use all training data? (default: "no")' )
26 |
27 | parser.add_argument( '--batch-size',
28 | type=int, default=128,
29 | help='input batch size for training (default: 128)' )
30 |
31 | #parser.add_argument( '--checkpoint_dir',
32 | # type=str, default='checkpoint',
33 | # help='source only or not (default: "checkpoint")' )
34 |
35 | parser.add_argument( '--eval_only',
36 | action='store_true', default=False,
37 | help='evaluation only option' )
38 |
39 | parser.add_argument( '--lr',
40 | type=float, default=0.0002,
41 | help='learning rate (default: 0.0002)' )
42 |
43 | parser.add_argument( '--max_epoch',
44 | type=int, default=200,
45 | help='maximum number of epochs (default: 200)' )
46 |
47 | parser.add_argument( '--no-cuda',
48 | action='store_true', default=False,
49 | help='disables CUDA training' )
50 |
51 | parser.add_argument( '--num_k',
52 | type=int, default=4,
53 | help='# gradient descent iterations for phi(G(x)) learning (default: 4)' )
54 |
55 | parser.add_argument( '--num_kq',
56 | type=int, default=4,
57 | help='# gradient descent iterations for q(w) learning (default: 4)' )
58 |
59 | #parser.add_argument( '--one_step',
60 | # action='store_true', default=False,
61 | # help='one step training with gradient reversal layer' )
62 |
63 | parser.add_argument( '--optimizer',
64 | type=str, default='adam',
65 | help='optimizer (default: "adam")' )
66 |
67 | parser.add_argument( '--resume_epoch',
68 | type=int, default=100,
69 | help='epoch to resume (default: 100)' )
70 |
71 | parser.add_argument( '--save_epoch',
72 | type=int, default=10,
73 | help='when to restore the model (default: 10)' )
74 |
75 | parser.add_argument( '--save_model',
76 | action='store_true', default=False,
77 | help='save_model or not' )
78 |
79 | parser.add_argument( '--seed',
80 | type=int, default=1,
81 | help='random seed (default: 1)' )
82 |
83 | parser.add_argument( '--source',
84 | type=str, default='svhn',
85 | help='source dataset (default: "svhn")' )
86 |
87 | parser.add_argument( '--target',
88 | type=str, default='mnist',
89 | help='target dataset (default: "mnist")' )
90 |
91 | parser.add_argument( '--use_abs_diff',
92 | action='store_true', default=False,
93 | help='use absolute difference value as a measurement' )
94 |
95 | parser.add_argument( '--fix_randomness',
96 | action='store_true', default=False,
97 | help='fix randomness' )
98 |
99 | args = parser.parse_args()
100 |
101 | args.cuda = not args.no_cuda and torch.cuda.is_available()
102 | torch.manual_seed(args.seed)
103 | if args.cuda:
104 | torch.cuda.manual_seed(args.seed)
105 |
106 | print(args)
107 |
108 | if args.fix_randomness:
109 | import numpy as np
110 | np.random.seed(10)
111 | torch.backends.cudnn.deterministic = True
112 |
113 |
114 | ###############################################################################
115 |
116 | def main():
117 |
118 | # make a string that describes the current running setup
119 | num = 0
120 | run_setup_str = \
121 | '%s2%s_k_%s_kq_%s_lamb_%s' % \
122 | ( args.source, args.target, args.num_k, args.num_kq, args.lamb_marg_loss)
123 | while os.path.exists('record/%s_run_%s.txt' % (run_setup_str, num)):
124 | num += 1
125 | run_setup_str = '%s_run_%s' % (run_setup_str, num)
126 | # eg, svhn2mnist_k_4_kq_4_lamb_10.0_run_5
127 |
128 | # set file names for records (storing training stats)
129 | record_train = 'record/%s.txt' % (run_setup_str,)
130 | record_test = 'record/%s_test.txt' % (run_setup_str,)
131 | if not os.path.exists('record'):
132 | os.mkdir('record') # create a folder for records if not exist
133 |
134 | # set the checkpoint dir name (storing model params)
135 | checkpoint_dir = 'checkpoint/%s' % (run_setup_str,)
136 | if not os.path.exists('checkpoint'):
137 | os.mkdir('checkpoint') # create a folder if not exist
138 | if not os.path.exists(checkpoint_dir):
139 | os.mkdir(checkpoint_dir) # create a folder if not exist
140 |
141 | ####
142 |
143 | # create a solver: load data, create models (or load existing models),
144 | # and create optimizers
145 | solver = Solver( args,
146 | source = args.source,
147 | target = args.target,
148 | nsamps_q = args.nsamps_q,
149 | lamb_marg_loss = args.lamb_marg_loss,
150 | learning_rate = args.lr,
151 | batch_size = args.batch_size,
152 | optimizer = args.optimizer,
153 | num_k = args.num_k,
154 | num_kq = args.num_kq,
155 | all_use = args.all_use,
156 | checkpoint_dir = checkpoint_dir,
157 | save_epoch = args.save_epoch )
158 |
159 | # run it (test or training)
160 | if args.eval_only:
161 | solver.test(0)
162 | else: # training
163 | count = 0
164 | for t in range(args.max_epoch):
165 | num = solver.train(t, record_file=record_train)
166 | count += num
167 | if t % 1 == 0: # run it on test data every epoch (and save models)
168 | solver.test( t, record_file=record_test,
169 | save_model=args.save_model )
170 | if count >= 20000*10:
171 | break
172 |
173 | ###############################################################################
174 |
175 | if __name__ == '__main__':
176 | main()
177 |
178 |
--------------------------------------------------------------------------------
/src/model/build_gen.py:
--------------------------------------------------------------------------------
1 | import model.svhn2mnist as svhn2mnist
2 | #import model.usps as usps
3 | #import model.syn2gtrsb as syn2gtrsb
4 |
5 | ###############################################################################
6 |
7 | def PhiGnet(source, target):
8 | if source == 'usps' or target == 'usps':
9 | return usps.PhiGnetwork()
10 | elif source == 'svhn':
11 | return svhn2mnist.PhiGnetwork()
12 | elif source == 'synth':
13 | return syn2gtrsb.PhiGnetwork()
14 |
15 | ###############################################################################
16 |
17 | def QWnet(source, target):
18 | if source == 'usps' or target == 'usps':
19 | return usps.QWnetwork()
20 | elif source == 'svhn':
21 | return svhn2mnist.QWnetwork()
22 | elif source == 'synth':
23 | return syn2gtrsb.QWnetwork()
24 |
25 | ###############################################################################
26 |
27 | #def Generator(source, target):
28 | # if source == 'usps' or target == 'usps':
29 | # return usps.Feature()
30 | # elif source == 'svhn':
31 | # return svhn2mnist.Feature()
32 | # elif source == 'synth':
33 | # return syn2gtrsb.Feature()
34 |
35 | ###############################################################################
36 |
37 | #def Classifier(source, target):
38 | # if source == 'usps' or target == 'usps':
39 | # return usps.Predictor()
40 | # if source == 'svhn':
41 | # return svhn2mnist.Predictor()
42 | # if source == 'synth':
43 | # return syn2gtrsb.Predictor()
44 |
--------------------------------------------------------------------------------
/src/model/svhn2mnist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | #from model.grad_reverse import grad_reverse
6 |
7 | ###############################################################################
8 |
9 | #
10 | # PhiGnetwork retuns u = phi(G(x)) where
11 | #
12 | # x = image
13 | # z = G(x) = exactly Feature() in MCD-DA
14 | # u = phi(z) = the last hidden layer of Predictor() in MCD-DA
15 | #
16 |
17 | class PhiGnetwork(nn.Module):
18 |
19 | def __init__(self):
20 |
21 | super(PhiGnetwork, self).__init__()
22 |
23 | self.g_conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2)
24 | self.g_bn1 = nn.BatchNorm2d(64)
25 | self.g_conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2)
26 | self.g_bn2 = nn.BatchNorm2d(64)
27 | self.g_conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2)
28 | self.g_bn3 = nn.BatchNorm2d(128)
29 | self.g_fc1 = nn.Linear(8192, 3072)
30 | self.g_bn1_fc = nn.BatchNorm1d(3072)
31 |
32 | #self.phi_fc1 = nn.Linear(8192, 3072)
33 | #self.phi_bn1_fc = nn.BatchNorm1d(3072)
34 | self.phi_fc2 = nn.Linear(3072, 2048)
35 | self.phi_bn2_fc = nn.BatchNorm1d(2048)
36 |
37 | self.p = 2048
38 |
39 | def forward(self, x):
40 |
41 | x = F.max_pool2d( F.relu(self.g_bn1(self.g_conv1(x))),
42 | stride=2, kernel_size=3, padding=1 )
43 | x = F.max_pool2d( F.relu(self.g_bn2(self.g_conv2(x))),
44 | stride=2, kernel_size=3, padding=1 )
45 | x = F.relu(self.g_bn3(self.g_conv3(x)))
46 | x = x.view(x.size(0), 8192)
47 | x = F.relu(self.g_bn1_fc(self.g_fc1(x)))
48 | z = F.dropout(x, training=self.training)
49 |
50 | u = F.relu(self.phi_bn2_fc(self.phi_fc2(z)))
51 |
52 | return u
53 |
54 | ###############################################################################
55 |
56 | #
57 | # QWnetwork retuns w^m_j = mu_j + sd_j.*eps^m_j, for m=1...M samples from
58 | # q(w) = \prod_{j=1}^K N(w_j; mu_j, diag(sd_j)^2) with dim(w_j) = p
59 | #
60 | # eps = samples from N(0,1); (M x p x K) -- input
61 | # mu = K mean vectors of q(w); (p x K) -- model params
62 | # logsd = K log-stdev vectors of q(w); (p x K) -- model params
63 | # (sd = exp(logsd))
64 | #
65 |
66 | class QWnetwork(nn.Module):
67 |
68 | def __init__(self):
69 |
70 | super(QWnetwork, self).__init__()
71 |
72 | self.mu = nn.Parameter(0.01*torch.randn(2048, 10))
73 | self.logsd = nn.Parameter(0.01*torch.randn(2048, 10))
74 |
75 | def forward(self, eps):
76 |
77 | mu3 = self.mu.unsqueeze(0) # (1 x p x K)
78 | sd3 = torch.exp(self.logsd).unsqueeze(0) # (1 x p x K)
79 | w = mu3 + sd3*eps # (M x p x K)
80 |
81 | return w
82 |
83 | ###############################################################################
84 |
85 | #class Feature(nn.Module):
86 | #
87 | # def __init__(self):
88 | #
89 | # super(Feature, self).__init__()
90 | #
91 | # self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2)
92 | # self.bn1 = nn.BatchNorm2d(64)
93 | # self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2)
94 | # self.bn2 = nn.BatchNorm2d(64)
95 | # self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2)
96 | # self.bn3 = nn.BatchNorm2d(128)
97 | # self.fc1 = nn.Linear(8192, 3072)
98 | # self.bn1_fc = nn.BatchNorm1d(3072)
99 | #
100 | # def forward(self, x):
101 | #
102 | # x = F.max_pool2d( F.relu(self.bn1(self.conv1(x))),
103 | # stride=2, kernel_size=3, padding=1 )
104 | # x = F.max_pool2d( F.relu(self.bn2(self.conv2(x))),
105 | # stride=2, kernel_size=3, padding=1 )
106 | # x = F.relu(self.bn3(self.conv3(x)))
107 | # x = x.view(x.size(0), 8192)
108 | # x = F.relu(self.bn1_fc(self.fc1(x)))
109 | # x = F.dropout(x, training=self.training)
110 | #
111 | # return x
112 |
113 | ###############################################################################
114 |
115 | #class Predictor(nn.Module):
116 | #
117 | # def __init__(self, prob=0.5):
118 | #
119 | # super(Predictor, self).__init__()
120 | #
121 | # self.fc1 = nn.Linear(8192, 3072)
122 | # self.bn1_fc = nn.BatchNorm1d(3072)
123 | # self.fc2 = nn.Linear(3072, 2048)
124 | # self.bn2_fc = nn.BatchNorm1d(2048)
125 | # self.fc3 = nn.Linear(2048, 10)
126 | # self.bn_fc3 = nn.BatchNorm1d(10)
127 | # self.prob = prob
128 | #
129 | # def set_lambda(self, lambd):
130 | #
131 | # self.lambd = lambd
132 | #
133 | # def forward(self, x, reverse=False):
134 | #
135 | # if reverse:
136 | # x = grad_reverse(x, self.lambd)
137 | # x = F.relu(self.bn2_fc(self.fc2(x)))
138 | # x = self.fc3(x)
139 | #
140 | # return x
141 |
--------------------------------------------------------------------------------
/src/model/syn2gtrsb.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from model.grad_reverse import grad_reverse
5 |
6 |
7 | class Feature(nn.Module):
8 | def __init__(self):
9 | super(Feature, self).__init__()
10 | self.conv1 = nn.Conv2d(3, 96, kernel_size=5, stride=1, padding=2)
11 | self.bn1 = nn.BatchNorm2d(96)
12 | self.conv2 = nn.Conv2d(96, 144, kernel_size=3, stride=1, padding=1)
13 | self.bn2 = nn.BatchNorm2d(144)
14 | self.conv3 = nn.Conv2d(144, 256, kernel_size=5, stride=1, padding=2)
15 | self.bn3 = nn.BatchNorm2d(256)
16 |
17 | def forward(self, x):
18 | x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), stride=2, kernel_size=2, padding=0)
19 | x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), stride=2, kernel_size=2, padding=0)
20 | x = F.max_pool2d(F.relu(self.bn3(self.conv3(x))), stride=2, kernel_size=2, padding=0)
21 | x = x.view(x.size(0), 6400)
22 | return x
23 |
24 |
25 | class Predictor(nn.Module):
26 | def __init__(self):
27 | super(Predictor, self).__init__()
28 | self.fc2 = nn.Linear(6400, 512)
29 | self.bn2_fc = nn.BatchNorm1d(512)
30 | self.fc3 = nn.Linear(512, 43)
31 | self.bn_fc3 = nn.BatchNorm1d(43)
32 |
33 | def set_lambda(self, lambd):
34 | self.lambd = lambd
35 |
36 | def forward(self, x, reverse=False):
37 | if reverse:
38 | x = grad_reverse(x, self.lambd)
39 | x = F.relu(self.bn2_fc(self.fc2(x)))
40 | x = F.dropout(x, training=self.training)
41 | x = self.fc3(x)
42 | return x
43 |
--------------------------------------------------------------------------------
/src/model/usps.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from model.grad_reverse import grad_reverse
6 |
7 |
8 | class Feature(nn.Module):
9 | def __init__(self):
10 | super(Feature, self).__init__()
11 | self.conv1 = nn.Conv2d(1, 32, kernel_size=5, stride=1)
12 | self.bn1 = nn.BatchNorm2d(32)
13 | self.conv2 = nn.Conv2d(32, 48, kernel_size=5, stride=1)
14 | self.bn2 = nn.BatchNorm2d(48)
15 |
16 | def forward(self, x):
17 | x = torch.mean(x,1).view(x.size()[0],1,x.size()[2],x.size()[3])
18 | x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), stride=2, kernel_size=2, dilation=(1, 1))
19 | x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), stride=2, kernel_size=2, dilation=(1, 1))
20 | #print(x.size())
21 | x = x.view(x.size(0), 48*4*4)
22 | return x
23 |
24 |
25 | class Predictor(nn.Module):
26 | def __init__(self, prob=0.5):
27 | super(Predictor, self).__init__()
28 | self.fc1 = nn.Linear(48*4*4, 100)
29 | self.bn1_fc = nn.BatchNorm1d(100)
30 | self.fc2 = nn.Linear(100, 100)
31 | self.bn2_fc = nn.BatchNorm1d(100)
32 | self.fc3 = nn.Linear(100, 10)
33 | self.bn_fc3 = nn.BatchNorm1d(10)
34 | self.prob = prob
35 |
36 | def set_lambda(self, lambd):
37 | self.lambd = lambd
38 | def forward(self, x, reverse=False):
39 | if reverse:
40 | x = grad_reverse(x, self.lambd)
41 | x = F.dropout(x, training=self.training, p=self.prob)
42 | x = F.relu(self.bn1_fc(self.fc1(x)))
43 | x = F.dropout(x, training=self.training, p=self.prob)
44 | x = F.relu(self.bn2_fc(self.fc2(x)))
45 | x = F.dropout(x, training=self.training, p=self.prob)
46 | x = self.fc3(x)
47 | return x
48 |
--------------------------------------------------------------------------------
/src/solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 |
6 | from torch.autograd import Variable
7 |
8 | from model.build_gen import PhiGnet, QWnet
9 | from data_manager.dataset_read import dataset_read
10 |
11 | ###############################################################################
12 |
13 | class Solver(object):
14 |
15 | ########
16 | def __init__( self, args, batch_size=128, source='svhn', target='mnist',
17 | nsamps_q=50, lamb_marg_loss=10.0,
18 | learning_rate=0.0002, interval=100, optimizer='adam', num_k=4, num_kq=4,
19 | all_use=False, checkpoint_dir=None, save_epoch=10 ):
20 |
21 | # set hyperparameters
22 | self.batch_size = batch_size
23 | self.source = source
24 | self.target = target
25 | self.num_k = num_k
26 | self.num_kq = num_kq
27 | self.checkpoint_dir = checkpoint_dir
28 | self.save_epoch = save_epoch
29 | self.use_abs_diff = args.use_abs_diff
30 | self.all_use = all_use
31 | if self.source == 'svhn':
32 | self.scale = True
33 | else:
34 | self.scale = False
35 | self.lamb_marg_loss = lamb_marg_loss
36 |
37 | # load data, do image transform, and create a mini-batch generator
38 | print('dataset loading')
39 | self.datasets, self.dataset_test = \
40 | dataset_read( source, target, self.batch_size,
41 | scale=self.scale, all_use=self.all_use )
42 | print('load finished!')
43 |
44 | if source == 'svhn':
45 | self.Ns = 73257
46 |
47 | # create models
48 | self.phig = PhiGnet(source=source, target=target)
49 | self.qw = QWnet(source=source, target=target)
50 |
51 | # load the previously learned models from files (if evaluations only)
52 | if args.eval_only:
53 | self.phig = torch.load( '%s/model_epoch%s_phig.pt' %
54 | (self.checkpoint_dir , args.resume_epoch) )
55 | self.qw = torch.load( '%s/model_epoch%s_qw.pt' %
56 | (self.checkpoint_dir, args.resume_epoch) )
57 |
58 | # move models to GPU
59 | self.phig.cuda()
60 | self.qw.cuda()
61 |
62 | # create optimizer objects (one for each model)
63 | self.set_optimizer(which_opt=optimizer, lr=learning_rate)
64 |
65 | # print stats every interval (default: 100) minibatch iters
66 | self.interval = interval
67 |
68 | self.lr = learning_rate
69 |
70 | # some dimensions
71 | self.p = self.phig.p # dim(phi(G(x)))
72 | self.M = nsamps_q # number of samples from variational density q(w)
73 |
74 |
75 | ########
76 | def set_optimizer(self, which_opt='momentum', lr=0.001, momentum=0.9):
77 |
78 | if which_opt == 'momentum':
79 |
80 | self.opt_phig = optim.SGD( self.phig.parameters(),
81 | lr=lr, weight_decay=0.0005, momentum=momentum )
82 |
83 | self.opt_qw = optim.SGD( self.qw.parameters(),
84 | lr=lr, weight_decay=0.0005, momentum=momentum )
85 |
86 | if which_opt == 'adam':
87 |
88 | self.opt_phig = optim.Adam( self.phig.parameters(),
89 | lr=lr, weight_decay=0.0005 )
90 |
91 | self.opt_qw = optim.Adam( self.qw.parameters(),
92 | lr=lr, weight_decay=0.0005 )
93 |
94 |
95 | ########
96 | def reset_grad(self):
97 |
98 | # zero out all gradients of model params registered in the optimizers
99 | self.opt_phig.zero_grad()
100 | self.opt_qw.zero_grad()
101 |
102 |
103 | ########
104 | def ent(self, output):
105 |
106 | return -torch.mean(output * torch.log(output + 1e-6))
107 |
108 |
109 | ########
110 | def kl_loss(self):
111 |
112 | kl = 0.5 * ( -self.p*10 +
113 | torch.sum( (torch.exp(self.qw.logsd))**2 + self.qw.mu**2 -
114 | 2.0*self.qw.logsd )
115 | )
116 |
117 | return kl
118 |
119 |
120 | ########
121 | def train(self, epoch, record_file=None):
122 |
123 | '''
124 | train models for one epoch (ie, one pass of whole training data)
125 | '''
126 |
127 | criterion = nn.CrossEntropyLoss().cuda()
128 |
129 | # turn models into "training" mode
130 | # (required if models contain "BatchNorm"-like layers)
131 | self.phig.train()
132 | self.qw.train()
133 |
134 | torch.cuda.manual_seed(1)
135 |
136 | # for each batch
137 | for batch_idx, data in enumerate(self.datasets):
138 |
139 | img_t = data['T']
140 | img_s = data['S']
141 | label_s = data['S_label']
142 | if img_s.size()[0] < self.batch_size or \
143 | img_t.size()[0] < self.batch_size:
144 | break
145 | img_s = img_s.cuda()
146 | img_t = img_t.cuda()
147 | # imgs = Variable(torch.cat((img_s, img_t), 0))
148 | label_s = Variable(label_s.long().cuda())
149 | img_s = Variable(img_s)
150 | img_t = Variable(img_t)
151 |
152 | # (M x p x K) samples from N(0,1)
153 | eps = Variable(torch.randn(self.M, self.p, 10))
154 | eps = eps.cuda()
155 |
156 | #### step A: min_{qw} (nll + kl)
157 |
158 | self.reset_grad()
159 |
160 | for i in range(self.num_kq):
161 |
162 | phig_s = self.phig(img_s) # phi(G(xs))
163 | wsamp = self.qw(eps) # samples from q(w)
164 |
165 | # w'*phi(G(xs)) = (M x B x K)
166 | wphig_s = torch.sum(
167 | wsamp.unsqueeze(1) * phig_s.unsqueeze(0).unsqueeze(3),
168 | dim=2 )
169 |
170 | # nll loss
171 | loss_nll = criterion(
172 | wphig_s.view(-1,10), label_s.repeat(self.M) ) * self.Ns
173 |
174 | # kl loss
175 | loss_kl = self.kl_loss()
176 |
177 | loss = loss_nll + loss_kl
178 |
179 | # compute gradient of the loss
180 | loss.backward()
181 |
182 | # update models
183 | self.opt_qw.step()
184 |
185 | self.reset_grad()
186 |
187 | #### step B: min_{phig} (nll + kl + marg)
188 |
189 | self.reset_grad()
190 |
191 | for i in range(self.num_k):
192 |
193 | phig_s = self.phig(img_s) # phi(G(xs))
194 | phig_t = self.phig(img_t) # phi(G(xt))
195 | wsamp = self.qw(eps) # samples from q(w)
196 |
197 | # w'*phi(G(xs)) = (M x B x K)
198 | wphig_s = torch.sum(
199 | wsamp.unsqueeze(1) * phig_s.unsqueeze(0).unsqueeze(3),
200 | dim=2 )
201 |
202 | # nll loss
203 | loss_nll = criterion(
204 | wphig_s.view(-1,10), label_s.repeat(self.M) ) * self.Ns
205 |
206 | # kl loss
207 | loss_kl = self.kl_loss()
208 |
209 | # margin loss on target
210 | f_t = torch.mm(phig_t, self.qw.mu) # (B x K)
211 | top2 = torch.topk(f_t, k=2, dim=1)[0] # (B x 2)
212 | # top2[i,0] = max_j f_t[i,j], top2[:,1] = max2_j f_t[i,j]
213 | gap21 = top2[:,1] - top2[:,0] # B-dim
214 | std_f_t = torch.sqrt(
215 | torch.mm(phig_t**2, torch.exp(self.qw.logsd)**2) ) # (B x K)
216 | max_std = torch.max(std_f_t, dim=1)[0] # B-dim
217 | loss_marg = torch.mean( F.relu(1.0 + gap21 + 1.96*max_std) )
218 |
219 | loss = loss_nll + loss_kl + self.lamb_marg_loss*loss_marg
220 |
221 | # compute gradient of the loss
222 | loss.backward()
223 |
224 | # update models
225 | self.opt_phig.step()
226 |
227 | self.reset_grad()
228 |
229 | #### wrap up
230 |
231 | if batch_idx > 500:
232 | return batch_idx
233 |
234 | if batch_idx % self.interval == 0:
235 | prn_str = ('Train Epoch: %d [batch-idx: %d] ' + \
236 | 'nll: %.6f, kl: %.6f, marg: %.6f') % \
237 | ( epoch, batch_idx, loss_nll.item(), loss_kl.item(),
238 | loss_marg.item() )
239 | print(prn_str)
240 | if record_file:
241 | record = open(record_file, 'a')
242 | record.write('%s\n' % (prn_str,))
243 | record.close()
244 |
245 | return batch_idx
246 |
247 |
248 | ########
249 | def test(self, epoch, record_file=None, save_model=False):
250 |
251 | '''
252 | evaluate the current models on the entire test set
253 | '''
254 |
255 | criterion = nn.CrossEntropyLoss().cuda()
256 |
257 | # turn models into evaluation mode
258 | self.phig.eval()
259 | self.qw.eval()
260 |
261 | test_loss = 0 # test nll loss
262 | corrects = 0 # number of correct predictions by MAP
263 | size = 0 # total number of test samples
264 |
265 | # turn off autograd feature (no evaluation history tracking)
266 | with torch.no_grad():
267 |
268 | for batch_idx, data in enumerate(self.dataset_test):
269 |
270 | img = data['T']
271 | label = data['T_label']
272 |
273 | img, label = img.cuda(), label.long().cuda()
274 |
275 | #img, label = Variable(img, volatile=True), Variable(label)
276 | img, label = Variable(img), Variable(label)
277 |
278 | # (M x p x K) samples from N(0,1)
279 | #eps = Variable(torch.randn(self.M, self.p, 10))
280 | #eps = eps.cuda()
281 |
282 | phig = self.phig(img) # phi(G(x))
283 | wmode = self.qw.mu # mode of q(w)
284 | #wsamp = self.qw(eps) # samples from q(w)
285 |
286 | # w'*phi(G(x)) = (B x K)
287 | output = torch.mm(phig, wmode)
288 |
289 | # w'*phi(G(x)) = (M x B x K)
290 | #wphig = torch.sum(
291 | # wsamp.unsqueeze(1) * phig.unsqueeze(0).unsqueeze(3), dim=2 )
292 |
293 | # nll loss (equivalent to cross entropy loss)
294 | test_loss += criterion(output, label).item()
295 |
296 | # class prediction
297 | pred = output.data.max(1)[1] # n-dim {0,...,K-1}-valued
298 | # tensor.max(j) returns a list (A, B) where
299 | # A = max of tensor over j-th dim
300 | # B = argmax of tensor over j-th dim
301 |
302 | corrects += pred.eq(label.data).cpu().numpy().sum()
303 |
304 | size += label.data.size()[0]
305 |
306 | test_loss = test_loss / size
307 |
308 | prn_str = ( 'Test set: Average nll loss: %.4f, ' + \
309 | 'Accuracy: %d/%d (%.4f%%)\n' ) % \
310 | ( test_loss, corrects, size, 100. * corrects / size )
311 | print(prn_str)
312 |
313 | # save (append) the test scores/stats to files
314 | if record_file:
315 | record = open(record_file, 'a')
316 | print('recording %s\n' % record_file)
317 | record.write('%s\n' % (prn_str,))
318 | record.close()
319 |
320 | # save the models as files
321 | if save_model and epoch % self.save_epoch == 0:
322 | torch.save( self.phig,
323 | '%s/model_epoch%s_phig.pt' % (self.checkpoint_dir, epoch) )
324 | torch.save( self.qw,
325 | '%s/model_epoch%s_qw.pt' % (self.checkpoint_dir, epoch) )
326 |
327 |
328 |
329 |
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | ###############################################################################
4 |
5 | def weights_init(m):
6 |
7 | classname = m.__class__.__name__
8 | if classname.find('Conv') != -1:
9 | m.weight.data.normal_(0.0, 0.01)
10 | m.bias.data.normal_(0.0, 0.01)
11 | elif classname.find('BatchNorm') != -1:
12 | m.weight.data.normal_(1.0, 0.01)
13 | m.bias.data.fill_(0)
14 |
15 | ###############################################################################
16 |
17 | def convert_label_10_to_0(labels):
18 |
19 | '''
20 | convert class label 10 to 0
21 | '''
22 |
23 | labels2 = np.zeros((len(labels),))
24 | labels = list(labels)
25 | for i, t in enumerate(labels):
26 | if t == 10:
27 | labels2[i] = 0
28 | else:
29 | labels2[i] = t
30 |
31 | return labels2
32 |
--------------------------------------------------------------------------------