├── images
├── logo.jpg
└── framework.png
├── tensorbayes.tar
├── model
├── alexnet
│ ├── download_weights.sh
│ ├── examples.sh
│ ├── ckpt2npy.py
│ ├── model.py
│ └── finetune.py
├── resnet
│ ├── download_weights.sh
│ ├── examples.sh
│ ├── preprocessor.py
│ ├── finetune.py
│ └── model.py
├── generic_utils.py
├── layers.py
├── dataLoader.py
├── test_da_template_AlexNet_train_feat.py
├── test_da_template_digits.py
├── run_most_AlexNet_finetune.py
├── run_most_AlexNet_train_feat.py
├── test_da_template_AlexNet_finetune.py
├── run_most_digits.py
├── most_AlexNet_train_feat.py
├── most_digits.py
└── most_AlexNet_finetune.py
├── LICENSE
├── tf1.9py3.5.yml
└── README.md
/images/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanrpt/MOST/HEAD/images/logo.jpg
--------------------------------------------------------------------------------
/tensorbayes.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanrpt/MOST/HEAD/tensorbayes.tar
--------------------------------------------------------------------------------
/images/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tuanrpt/MOST/HEAD/images/framework.png
--------------------------------------------------------------------------------
/model/alexnet/download_weights.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # By using curl
4 | curl -O http://www.cs.toronto.edu/~guerzhoy/tf_alexnet/bvlc_alexnet.npy
5 |
6 | # or by using wget
7 | # wget http://www.cs.toronto.edu/~guerzhoy/tf_alexnet/bvlc_alexnet.npy
8 |
--------------------------------------------------------------------------------
/model/alexnet/examples.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | python finetune.py \
4 | --learning_rate "0.00001" \
5 | --train_layers "fc8,fc7,fc6"
6 |
7 | python finetune.py \
8 | --num_epochs 30 \
9 | --multi_scale "228,256" \
10 | --train_layers "fc8,fc7,fc6,conv5,conv4,conv3,conv2,conv1" # full training
11 |
--------------------------------------------------------------------------------
/model/resnet/download_weights.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # By using curl
4 | curl -O https://deniz.co/static/tensorflow-cnn-finetune/ResNet-L50.npy
5 | curl -O https://deniz.co/static/tensorflow-cnn-finetune/ResNet-L101.npy
6 | curl -O https://deniz.co/static/tensorflow-cnn-finetune/ResNet-L152.npy
7 |
8 | # or by using wget
9 | # wget https://www.dropbox.com/s/txb16f1khleyvdh/ResNet-L50.npy?dl=1
10 | # wget https://www.dropbox.com/s/8w10cs6v4rd9616/ResNet-L101.npy?dl=1
11 | # wget https://www.dropbox.com/s/n5inup5a7fi8lom/ResNet-L152.npy?dl=1
12 |
--------------------------------------------------------------------------------
/model/resnet/examples.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | python finetune.py \
4 | --learning_rate "0.00001" \
5 | --train_layers "fc"
6 |
7 | python finetune.py \
8 | --learning_rate "0.00001" \
9 | --train_layers "fc,scale5/block3"
10 |
11 | python finetune.py \
12 | --learning_rate "0.00001" \
13 | --train_layers "fc,scale5/block3,scale5/block2"
14 |
15 | python finetune.py \
16 | --learning_rate "0.00001" \
17 | --train_layers "fc,scale5/block3,scale5/block2,scale5/block1"
18 |
19 | python finetune.py \
20 | --learning_rate "0.00001" \
21 | --multi_scale "225,256" \
22 | --train_layers "fc,scale5"
23 |
24 | python finetune.py \
25 | --learning_rate "0.00001" \
26 | --multi_scale "225,256" \
27 | --train_layers "fc,scale5,scale4/block6"
28 |
29 | python finetune.py \
30 | --learning_rate "0.00001" \
31 | --multi_scale "225,256" \
32 | --train_layers "fc,scale5,scale4/block6,scale4/block5"
33 |
--------------------------------------------------------------------------------
/model/generic_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Tuan Nguyen.
2 | # All rights reserved.
3 |
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import absolute_import
7 |
8 | import time
9 | import numpy as np
10 | from pathlib import Path
11 | import os
12 | _RANDOM_SEED = 6789
13 |
14 |
15 | def model_dir():
16 | cur_dir = Path(os.path.abspath(__file__))
17 | return str(cur_dir.parent.parent)
18 |
19 |
20 | def feat_dir():
21 | cur_dir = Path(os.path.abspath(__file__))
22 | par_dir = cur_dir.parent.parent
23 | return str(par_dir / "features")
24 |
25 | def data_dir():
26 | cur_dir = Path(os.path.abspath(__file__))
27 | par_dir = cur_dir.parent.parent
28 | return str(par_dir / "data")
29 |
30 |
31 | def random_seed():
32 | return _RANDOM_SEED
33 |
34 |
35 | def tuid():
36 | '''
37 | Create a string ID based on current time
38 | :return: a string formatted using current time
39 | '''
40 | random_num = np.random.randint(0, 100)
41 | return time.strftime('%Y-%m-%d_%H.%M.%S') + str(random_num)
42 |
--------------------------------------------------------------------------------
/model/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Tuan Nguyen.
2 | # All rights reserved.
3 |
4 | import tensorflow as tf
5 | from tensorflow.contrib.framework import add_arg_scope
6 |
7 |
8 | @add_arg_scope
9 | def noise(x, std, phase, scope=None, reuse=None):
10 | with tf.name_scope(scope, 'noise'):
11 | eps = tf.random_normal(tf.shape(x), 0.0, std)
12 | output = tf.where(phase, x + eps, x)
13 | return output
14 |
15 |
16 | @add_arg_scope
17 | def leaky_relu(x, a=0.2, name=None):
18 | with tf.name_scope(name, 'leaky_relu'):
19 | return tf.maximum(x, a * x)
20 |
21 | @add_arg_scope
22 | def basic_accuracy(a, b, scope=None):
23 | with tf.name_scope(scope, 'basic_acc'):
24 | a = tf.argmax(a, 1)
25 | b = tf.argmax(b, 1)
26 | eq = tf.cast(tf.equal(a, b), 'float32')
27 | output = tf.reduce_mean(eq)
28 | return output
29 |
30 | @add_arg_scope
31 | def batch_ema_acc(a, b, scope=None):
32 | with tf.name_scope(scope, 'basic_acc'):
33 | a = tf.argmax(a, 1)
34 | b = tf.argmax(b, 1)
35 | output = tf.cast(tf.equal(a, b), 'float32')
36 | return output
37 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Tuan Nguyen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/model/alexnet/ckpt2npy.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import sys
4 | from model import AlexNetModel
5 |
6 |
7 | # Edit just these
8 | FILE_PATH = '/Users/dgurkaynak/Projects/marvel-finetuning/training/alexnet_20171125_124517/checkpoint/model_epoch7.ckpt'
9 | NUM_CLASSES = 26
10 | OUTPUT_FILE = 'alexnet_20171125_124517_epoch7.npy'
11 |
12 |
13 | if __name__ == '__main__':
14 | x = tf.placeholder(tf.float32, [128, 227, 227, 3])
15 | model = AlexNetModel(num_classes=NUM_CLASSES)
16 | model.inference(x)
17 |
18 | saver = tf.train.Saver()
19 | layers = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc6', 'fc7', 'fc8']
20 | data = {
21 | 'conv1': [],
22 | 'conv2': [],
23 | 'conv3': [],
24 | 'conv4': [],
25 | 'conv5': [],
26 | 'fc6': [],
27 | 'fc7': [],
28 | 'fc8': []
29 | }
30 |
31 | with tf.Session() as sess:
32 | saver.restore(sess, FILE_PATH)
33 |
34 | for op_name in layers:
35 | with tf.variable_scope(op_name, reuse = True):
36 | biases_variable = tf.get_variable('biases')
37 | weights_variable = tf.get_variable('weights')
38 | data[op_name].append(sess.run(biases_variable))
39 | data[op_name].append(sess.run(weights_variable))
40 |
41 | np.save(OUTPUT_FILE, data)
42 |
43 |
--------------------------------------------------------------------------------
/tf1.9py3.5.yml:
--------------------------------------------------------------------------------
1 | name: tf1.9py3.5
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - asn1crypto=1.4.0=py_0
7 | - blas=1.0=mkl
8 | - bzip2=1.0.8=h7b6447c_0
9 | - ca-certificates=2020.7.22=0
10 | - cairo=1.14.12=h8948797_3
11 | - certifi=2018.8.24=py35_1
12 | - cffi=1.11.5=py35he75722e_1
13 | - chardet=3.0.4=py35_1
14 | - cryptography=2.3.1=py35hc365091_0
15 | - cycler=0.10.0=py35hc4d5149_0
16 | - dbus=1.13.16=hb2f20db_0
17 | - dill=0.3.2=py_0
18 | - expat=2.2.9=he6710b0_2
19 | - ffmpeg=4.0=hcdf2ecd_0
20 | - fontconfig=2.13.0=h9420a91_0
21 | - freeglut=3.0.0=hf484d3e_5
22 | - freetype=2.10.2=h5ab3b9f_0
23 | - glib=2.63.1=h5a9c865_0
24 | - graphite2=1.3.14=h23475e2_0
25 | - gst-plugins-base=1.14.0=hbbd80ab_1
26 | - gstreamer=1.14.0=hb453b48_1
27 | - harfbuzz=1.8.8=hffaf4a1_0
28 | - hdf5=1.10.2=hba1933b_1
29 | - icu=58.2=he6710b0_3
30 | - idna=2.10=py_0
31 | - intel-openmp=2019.4=243
32 | - jasper=2.0.14=h07fcdf6_1
33 | - jpeg=9b=h024ee3a_2
34 | - kiwisolver=1.0.1=py35hf484d3e_0
35 | - libedit=3.1.20191231=h14c3975_1
36 | - libffi=3.2.1=hd88cf55_4
37 | - libgcc-ng=9.1.0=hdf63c60_0
38 | - libgfortran-ng=7.3.0=hdf63c60_0
39 | - libglu=9.0.0=hf484d3e_1
40 | - libopencv=3.4.2=hb342d67_1
41 | - libopus=1.3.1=h7b6447c_0
42 | - libpng=1.6.37=hbc83047_0
43 | - libprotobuf=3.6.0=hdbcaa40_0
44 | - libstdcxx-ng=9.1.0=hdf63c60_0
45 | - libtiff=4.1.0=h2733197_1
46 | - libuuid=1.0.3=h1bed415_2
47 | - libvpx=1.7.0=h439df22_0
48 | - libxcb=1.14=h7b6447c_0
49 | - libxml2=2.9.10=he19cac6_1
50 | - lz4-c=1.9.2=he6710b0_1
51 | - matplotlib=3.0.0=py35h5429711_0
52 | - mkl=2018.0.3=1
53 | - mkl_fft=1.0.6=py35h7dd41cf_0
54 | - mkl_random=1.0.1=py35h4414c95_1
55 | - ncurses=6.2=he6710b0_1
56 | - numpy=1.15.2=py35h1d66e8a_0
57 | - opencv=3.4.2=py35h6fd60c2_1
58 | - openssl=1.0.2u=h7b6447c_0
59 | - pandas=0.22.0=py35hf484d3e_0
60 | - pcre=8.44=he6710b0_0
61 | - pip=10.0.1=py35_0
62 | - pixman=0.40.0=h7b6447c_0
63 | - protobuf=3.6.0=py35hf484d3e_0
64 | - py-opencv=3.4.2=py35hb342d67_1
65 | - pycparser=2.20=py_2
66 | - pyopenssl=18.0.0=py35_0
67 | - pyparsing=2.4.7=py_0
68 | - pyqt=5.9.2=py35h05f1152_2
69 | - pysocks=1.6.8=py35_0
70 | - python=3.5.6=hc3d631a_0
71 | - python-dateutil=2.8.1=py_0
72 | - pytz=2020.1=py_0
73 | - qt=5.9.6=h8703b6f_2
74 | - readline=7.0=h7b6447c_5
75 | - requests=2.24.0=py_0
76 | - scikit-learn=0.20.0=py35h4989274_1
77 | - sip=4.19.8=py35hf484d3e_0
78 | - six=1.15.0=py_0
79 | - sqlite=3.33.0=h62c20be_0
80 | - tbb=2020.2=hfd86e86_0
81 | - tbb4py=2018.0.5=py35h6bb024c_0
82 | - tk=8.6.10=hbc83047_0
83 | - tornado=5.1.1=py35h7b6447c_0
84 | - urllib3=1.23=py35_0
85 | - wheel=0.35.1=py_0
86 | - xz=5.2.5=h7b6447c_0
87 | - zlib=1.2.11=h7b6447c_3
88 | - zstd=1.4.5=h9ceee32_0
89 | - pip:
90 | - absl-py==0.10.0
91 | - astor==0.8.1
92 | - gast==0.4.0
93 | - grpcio==1.31.0
94 | - h5py==2.10.0
95 | - importlib-metadata==1.7.0
96 | - keras==2.2.5
97 | - keras-applications==1.0.8
98 | - keras-preprocessing==1.1.2
99 | - markdown==3.2.2
100 | - pyyaml==5.3.1
101 | - scipy==1.4.1
102 | - setuptools==39.1.0
103 | - tensorbayes==0.4.0
104 | - tensorboard==1.9.0
105 | - tensorflow-gpu==1.9.0
106 | - termcolor==1.1.0
107 | - werkzeug==1.0.1
108 | - zipp==1.2.0
109 | prefix: /opt/conda/envs/tf1.9py3.5
110 |
--------------------------------------------------------------------------------
/model/resnet/preprocessor.py:
--------------------------------------------------------------------------------
1 | """
2 | Derived from: https://github.com/kratzert/finetune_alexnet_with_tensorflow/
3 | """
4 | import numpy as np
5 | import cv2
6 |
7 |
8 | class BatchPreprocessor(object):
9 |
10 | def __init__(self, dataset_file_path, num_classes, output_size=[227, 227], horizontal_flip=False, shuffle=False,
11 | mean_color=[132.2766, 139.6506, 146.9702], multi_scale=None):
12 | self.num_classes = num_classes
13 | self.output_size = output_size
14 | self.horizontal_flip = horizontal_flip
15 | self.shuffle = shuffle
16 | self.mean_color = mean_color
17 | self.multi_scale = multi_scale
18 |
19 | self.pointer = 0
20 | self.images = []
21 | self.labels = []
22 |
23 | # Read the dataset file
24 | dataset_file = open(dataset_file_path)
25 | lines = dataset_file.readlines()
26 | for line in lines:
27 | items = line.split()
28 | self.images.append(items[0])
29 | self.labels.append(int(items[1]))
30 |
31 | # Shuffle the data
32 | # if self.shuffle:
33 | # self.shuffle_data()
34 |
35 | def shuffle_data(self):
36 | images = self.images[:]
37 | labels = self.labels[:]
38 | self.images = []
39 | self.labels = []
40 |
41 | idx = np.random.permutation(len(labels))
42 | for i in idx:
43 | self.images.append(images[i])
44 | self.labels.append(labels[i])
45 |
46 | def reset_pointer(self):
47 | self.pointer = 0
48 |
49 | # if self.shuffle:
50 | # self.shuffle_data()
51 |
52 | def next_batch(self, batch_size):
53 | if self.shuffle:
54 | self.shuffle_data()
55 | # Get next batch of image (path) and labels
56 | paths = self.images[:batch_size]
57 | labels = self.labels[:batch_size]
58 | else:
59 | # Get next batch of image (path) and labels
60 | paths = self.images[self.pointer:(self.pointer+batch_size)]
61 | labels = self.labels[self.pointer:(self.pointer+batch_size)]
62 |
63 | # Update pointer
64 | self.pointer += batch_size
65 |
66 | # Read images
67 | images = np.ndarray([batch_size, self.output_size[0], self.output_size[1], 3])
68 | for i in range(len(paths)):
69 | img = cv2.imread(paths[i])
70 |
71 | # Flip image at random if flag is selected
72 | if self.horizontal_flip and np.random.random() < 0.5:
73 | img = cv2.flip(img, 1)
74 |
75 | if self.multi_scale is None or len(self.multi_scale) == 0:
76 | # Resize the image for output
77 | img = cv2.resize(img, (self.output_size[0], self.output_size[0]))
78 | img = img.astype(np.float32)
79 | elif isinstance(self.multi_scale, list):
80 | # Resize to random scale
81 | new_size = np.random.randint(self.multi_scale[0], self.multi_scale[1], 1)[0]
82 | img = cv2.resize(img, (new_size, new_size))
83 | img = img.astype(np.float32)
84 |
85 | # random crop at output size
86 | diff_size = new_size - self.output_size[0]
87 | random_offset_x = np.random.randint(0, diff_size, 1)[0]
88 | random_offset_y = np.random.randint(0, diff_size, 1)[0]
89 | img = img[random_offset_x:(random_offset_x+self.output_size[0]),
90 | random_offset_y:(random_offset_y+self.output_size[0])]
91 |
92 | # Subtract mean color
93 | img = img - np.array(self.mean_color)
94 |
95 | images[i] = img
96 |
97 | # Expand labels to one hot encoding
98 | one_hot_labels = np.zeros((batch_size, self.num_classes))
99 | for i in range(len(labels)):
100 | one_hot_labels[i][labels[i]] = 1
101 |
102 | if len(paths) < batch_size:
103 | return images[:len(paths)], one_hot_labels[:len(paths)]
104 |
105 | # Return array of images and labels
106 | return images, one_hot_labels
107 |
--------------------------------------------------------------------------------
/model/dataLoader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Tuan Nguyen.
2 | # All rights reserved.
3 |
4 | import os
5 |
6 | import numpy as np
7 | from scipy.io import loadmat
8 | from keras.utils.np_utils import to_categorical
9 | from generic_utils import random_seed
10 |
11 |
12 | def load_mat_file_single_label(filename):
13 | filename_list = ['mnist', 'stl32', 'synsign', 'gtsrb', 'cifar32', 'usps32']
14 | data = loadmat(filename)
15 | x = data['X']
16 | y = data['y']
17 | if any(fn in filename for fn in filename_list):
18 | if 'mnist32_60_10' not in filename:
19 | y = y[0]
20 | else:
21 | y = np.argmax(y, axis=1)
22 | elif len(y.shape) > 1:
23 | y = np.argmax(y, axis=1)
24 | return x, y
25 |
26 |
27 | def load_mat_office31_AlexNet(filename):
28 | data = loadmat(filename)
29 | # x = data['feas']
30 | x = np.reshape(data['feas'], (-1, 8, 8, 64))
31 | y = data['labels'][0]
32 | return x, y
33 |
34 |
35 | def u2t(x):
36 | """Convert uint8 to [-1, 1] float
37 | """
38 | max_num = 50000
39 | if len(x) > max_num:
40 | y = np.empty_like(x, dtype='float32')
41 | for i in range(len(x) // max_num):
42 | y[i*max_num: (i+1)*max_num] = (x[i*max_num: (i+1)*max_num].astype('float32') / 255) * 2 - 1
43 |
44 | y[(i + 1) * max_num:] = (x[(i + 1) * max_num:].astype('float32') / 255) * 2 - 1
45 | else:
46 | y = (x.astype('float32') / 255) * 2 - 1
47 | return y
48 |
49 |
50 | class DataLoader:
51 | def __init__(self, src_domain=['mnistm'], trg_domain=['mnist'], data_path='./data', data_format='mat',
52 | shuffle_data=False, dataset_name='digits', cast_data=True):
53 | self.num_src_domain = len(src_domain.split(','))
54 | self.src_domain_name = src_domain
55 | self.trg_domain_name = trg_domain
56 | self.data_path = data_path
57 | self.data_format = data_format
58 | self.shuffle_data = shuffle_data
59 | self.dataset_name = dataset_name
60 | self.cast_data = cast_data
61 |
62 | self.src_train = {}
63 | self.trg_train = {}
64 | self.src_test = {}
65 | self.trg_test = {}
66 |
67 | print("Source domains", self.src_domain_name)
68 | print("Target domain", self.trg_domain_name)
69 | self._load_data_train()
70 | self._load_data_test()
71 |
72 | self.data_shape = self.src_train[0][1][0].shape
73 | self.num_domain = len(self.src_train.keys())
74 | self.num_class = self.src_train[0][2].shape[-1]
75 |
76 | def _load_data_train(self, tail_name="_train"):
77 | if not self.src_train:
78 | self.src_train = self._load_file(self.src_domain_name, tail_name, self.shuffle_data)
79 | self.trg_train = self._load_file(self.trg_domain_name, tail_name, self.shuffle_data)
80 |
81 | def _load_data_test(self, tail_name="_test"):
82 | if not self.src_test:
83 | self.src_test = self._load_file(self.src_domain_name, tail_name, self.shuffle_data)
84 | self.trg_test = self._load_file(self.trg_domain_name, tail_name, self.shuffle_data)
85 |
86 | def _load_file(self, name_file=[], tail_name="_train", shuffle_data=False):
87 | data_list = {}
88 | name_file = name_file.split(',')
89 | for idx, s_n in enumerate(name_file):
90 | file_path_train = os.path.join(self.data_path, '{}{}.{}'.format(s_n, tail_name, self.data_format))
91 | if os.path.isfile(file_path_train):
92 | if self.dataset_name == 'digits':
93 | x_train, y_train = load_mat_file_single_label(file_path_train)
94 | elif self.dataset_name == 'office31_AlexNet_feat':
95 | x_train, y_train = load_mat_office31_AlexNet(file_path_train)
96 | if shuffle_data:
97 | x_train, y_train = self.shuffle(x_train, y_train)
98 | if 'mnist32_60_10' not in s_n and self.cast_data:
99 | x_train = u2t(x_train)
100 | data_list.update({idx: [s_n, x_train, to_categorical(y_train)]})
101 | else:
102 | raise('File not found!')
103 | return data_list
104 |
105 | def shuffle(self, x, y=None):
106 | np.random.seed(random_seed())
107 | idx_train = np.random.permutation(x.shape[0])
108 | x = x[idx_train]
109 | if y is not None:
110 | y = y[idx_train]
111 | return x, y
112 |
113 | def onehot2scalar(self, onehot_vectors, axis=1):
114 | return np.argmax(onehot_vectors, axis=axis)
115 |
--------------------------------------------------------------------------------
/model/test_da_template_AlexNet_train_feat.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Tuan Nguyen.
2 | # All rights reserved.
3 |
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import absolute_import
7 |
8 | import os
9 | import sys
10 |
11 | import numpy as np
12 | import tensorflow as tf
13 |
14 | from generic_utils import random_seed
15 | from generic_utils import feat_dir
16 | from dataLoader import DataLoader
17 |
18 |
19 | def test_real_dataset(create_obj_func, src_name=None, trg_name=None):
20 | print('Running {} ...'.format(os.path.basename(__file__)))
21 |
22 | if src_name is None:
23 | if len(sys.argv) > 2:
24 | src_name = sys.argv[2]
25 | else:
26 | raise Exception('Not specify source dataset')
27 | if trg_name is None:
28 | if len(sys.argv) > 3:
29 | trg_name = sys.argv[3]
30 | else:
31 | raise Exception('Not specify trgget dataset')
32 |
33 | np.random.seed(random_seed())
34 | tf.set_random_seed(random_seed())
35 | tf.reset_default_graph()
36 |
37 | print("========== Test on real data ==========")
38 |
39 | users_params = dict()
40 | users_params = parse_arguments(users_params)
41 | data_format = 'mat'
42 |
43 | if 'format' in users_params:
44 | data_format, users_params = extract_param('format', data_format, users_params)
45 |
46 | data_folder = feat_dir()
47 | if len(users_params['data_dir']) != 0:
48 | data_folder = users_params['data_dir']
49 | print("data path", data_folder)
50 |
51 | data_loader = DataLoader(src_domain=src_name,
52 | trg_domain=trg_name,
53 | data_path=data_folder,
54 | data_format=data_format,
55 | dataset_name='office31_AlexNet_feat',
56 | cast_data=users_params['cast_data'])
57 |
58 | assert users_params['batch_size'] % data_loader.num_src_domain == 0
59 |
60 | print('users_params:', users_params)
61 |
62 | learner = create_obj_func(users_params)
63 | learner.dim_src = data_loader.data_shape
64 | learner.dim_trg = data_loader.data_shape
65 |
66 | learner.x_trg_test = data_loader.trg_test[0][0]
67 | learner.y_trg_test = data_loader.trg_test[0][1]
68 |
69 | learner._init(data_loader)
70 | learner._build_model()
71 | learner._fit_loop()
72 |
73 |
74 | def main_func(
75 | create_obj_func,
76 | choice_default=0,
77 | src_name_default='svmguide1',
78 | trg_name_default='svmguide1',
79 | run_exp=False):
80 |
81 | if not run_exp:
82 | choice_lst = [0, 1, 2]
83 | src_name = src_name_default
84 | trg_name = trg_name_default
85 | elif len(sys.argv) > 1:
86 | choice_lst = [int(sys.argv[1])]
87 | src_name = None
88 | trg_name = None
89 | else:
90 | choice_lst = [choice_default]
91 | src_name = src_name_default
92 | trg_name = trg_name_default
93 |
94 | for choice in choice_lst:
95 | if choice == 0:
96 | pass
97 | # add another function here
98 | elif choice == 1:
99 | test_real_dataset(create_obj_func, src_name, trg_name)
100 |
101 |
102 | def parse_arguments(params, as_array=False):
103 | for it in range(4, len(sys.argv), 2):
104 | params[sys.argv[it]] = parse_argument(sys.argv[it + 1], as_array)
105 | return params
106 |
107 |
108 | def parse_argument(string, as_array=False):
109 | try:
110 | result = int(string)
111 | except ValueError:
112 | try:
113 | result = float(string)
114 | except ValueError:
115 | if str.lower(string) == 'true':
116 | result = True
117 | elif str.lower(string) == 'false':
118 | result = False
119 | elif string == "[]":
120 | return []
121 | elif ('|' in string) and ('[' in string) and (']' in string):
122 | result = [float(item) for item in string[1:-1].split('|')]
123 | return result
124 | elif (',' in string) and ('(' in string) and (')' in string):
125 | split = string[1:-1].split(',')
126 | result = float(split[0]) ** np.arange(float(split[1]), float(split[2]), float(split[3]))
127 | return result
128 | else:
129 | result = string
130 |
131 | return [result] if as_array else result
132 |
133 |
134 | def resolve_conflict_params(primary_params, secondary_params):
135 | for key in primary_params.keys():
136 | if key in secondary_params.keys():
137 | del secondary_params[key]
138 | return secondary_params
139 |
140 |
141 | def extract_param(key, value, params_gridsearch, scalar=False):
142 | if key in params_gridsearch.keys():
143 | value = params_gridsearch[key]
144 | del params_gridsearch[key]
145 | if scalar and (value is not None):
146 | value = value[0]
147 | return value, params_gridsearch
148 |
149 |
150 | def dict2string(params):
151 | result = ''
152 | for key, value in params.items():
153 | if type(value) is np.ndarray:
154 | if value.size < 16:
155 | result += key + ': ' + '|'.join('{0:.4f}'.format(x) for x in value.ravel()) + ', '
156 | else:
157 | result += key + ': ' + str(value) + ', '
158 | return '{' + result[:-2] + '}'
159 |
--------------------------------------------------------------------------------
/model/test_da_template_digits.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Tuan Nguyen.
2 | # All rights reserved.
3 |
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import absolute_import
7 |
8 | import os
9 | import sys
10 | from scipy.io import loadmat
11 | import numpy as np
12 | import tensorflow as tf
13 | from generic_utils import random_seed
14 | from generic_utils import feat_dir
15 | from dataLoader import DataLoader
16 |
17 |
18 | def test_real_dataset(create_obj_func, src_name=None, trg_name=None, show=False, block_figure_on_end=False):
19 | print('Running {} ...'.format(os.path.basename(__file__)))
20 |
21 | if src_name is None:
22 | if len(sys.argv) > 2:
23 | src_name = sys.argv[2]
24 | else:
25 | raise Exception('Not specify source dataset')
26 | if trg_name is None:
27 | if len(sys.argv) > 3:
28 | trg_name = sys.argv[3]
29 | else:
30 | raise Exception('Not specify trgget dataset')
31 |
32 | np.random.seed(random_seed())
33 | tf.set_random_seed(random_seed())
34 | tf.reset_default_graph()
35 |
36 | print("========== Test on real data ==========")
37 | users_params = dict()
38 | users_params = parse_arguments(users_params)
39 | data_format = 'mat'
40 |
41 | if 'format' in users_params:
42 | data_format, users_params = extract_param('format', data_format, users_params)
43 |
44 | data_loader = DataLoader(src_domain=src_name,
45 | trg_domain=trg_name,
46 | data_path=feat_dir(),
47 | data_format=data_format,
48 | cast_data=users_params['cast_data'])
49 |
50 | assert users_params['batch_size'] % data_loader.num_src_domain == 0
51 | print('users_params:', users_params)
52 |
53 | learner = create_obj_func(users_params)
54 | learner.dim_src = data_loader.data_shape
55 | learner.dim_trg = data_loader.data_shape
56 |
57 | learner.x_trg_test = data_loader.trg_test[0][0]
58 | learner.y_trg_test = data_loader.trg_test[0][1]
59 | learner._init(data_loader)
60 | learner._build_model()
61 | learner._fit_loop()
62 |
63 |
64 | def main_func(
65 | create_obj_func,
66 | choice_default=0,
67 | src_name_default='svmguide1',
68 | trg_name_default='svmguide1',
69 | run_exp=False,
70 | keep_vars=[],
71 | **kwargs):
72 |
73 | if not run_exp:
74 | choice_lst = [0, 1, 2]
75 | src_name = src_name_default
76 | trg_name = trg_name_default
77 | elif len(sys.argv) > 1:
78 | choice_lst = [int(sys.argv[1])]
79 | src_name = None
80 | trg_name = None
81 | else:
82 | choice_lst = [choice_default]
83 | src_name = src_name_default
84 | trg_name = trg_name_default
85 |
86 | for choice in choice_lst:
87 | if choice == 0:
88 | pass # for synthetic data if possible
89 | elif choice == 1:
90 | test_real_dataset(create_obj_func, src_name, trg_name, show=False, block_figure_on_end=run_exp)
91 |
92 |
93 | def parse_arguments(params, as_array=False):
94 | for it in range(4, len(sys.argv), 2):
95 | params[sys.argv[it]] = parse_argument(sys.argv[it + 1], as_array)
96 | return params
97 |
98 |
99 | def parse_argument(string, as_array=False):
100 | try:
101 | result = int(string)
102 | except ValueError:
103 | try:
104 | result = float(string)
105 | except ValueError:
106 | if str.lower(string) == 'true':
107 | result = True
108 | elif str.lower(string) == 'false':
109 | result = False
110 | elif string == "[]":
111 | return []
112 | elif ('|' in string) and ('[' in string) and (']' in string):
113 | result = [float(item) for item in string[1:-1].split('|')]
114 | return result
115 | elif (',' in string) and ('(' in string) and (')' in string):
116 | split = string[1:-1].split(',')
117 | result = float(split[0]) ** np.arange(float(split[1]), float(split[2]), float(split[3]))
118 | return result
119 | else:
120 | result = string
121 |
122 | return [result] if as_array else result
123 |
124 |
125 | def resolve_conflict_params(primary_params, secondary_params):
126 | for key in primary_params.keys():
127 | if key in secondary_params.keys():
128 | del secondary_params[key]
129 | return secondary_params
130 |
131 |
132 | def extract_param(key, value, params_gridsearch, scalar=False):
133 | if key in params_gridsearch.keys():
134 | value = params_gridsearch[key]
135 | del params_gridsearch[key]
136 | if scalar and (value is not None):
137 | value = value[0]
138 | return value, params_gridsearch
139 |
140 |
141 | def dict2string(params):
142 | result = ''
143 | for key, value in params.items():
144 | if type(value) is np.ndarray:
145 | if value.size < 16:
146 | result += key + ': ' + '|'.join('{0:.4f}'.format(x) for x in value.ravel()) + ', '
147 | else:
148 | result += key + ': ' + str(value) + ', '
149 | return '{' + result[:-2] + '}'
150 |
151 |
152 | def load_mat_file_single_label(filename):
153 | filename_list = ['mnist', 'stl32', 'synsign', 'gtsrb', 'cifar32', 'usps32']
154 | data = loadmat(filename)
155 | x = data['X']
156 | y = data['y']
157 | if any(fn in filename for fn in filename_list):
158 | if 'mnist32_60_10' not in filename and 'mnistg' not in filename:
159 | y = y[0]
160 | else:
161 | y = np.argmax(y, axis=1)
162 | # process one-hot label encoder
163 | elif len(y.shape) > 1:
164 | y = np.argmax(y, axis=1)
165 | return x, y
166 |
167 |
168 | def u2t(x):
169 | """Convert uint8 to [-1, 1] float
170 | """
171 | return x.astype('float32') / 255 * 2 - 1
172 |
--------------------------------------------------------------------------------
/model/alexnet/model.py:
--------------------------------------------------------------------------------
1 | """
2 | Derived from: https://github.com/kratzert/finetune_alexnet_with_tensorflow/
3 | """
4 | import tensorflow as tf
5 | import numpy as np
6 |
7 |
8 | class AlexNetModel(object):
9 |
10 | def __init__(self, num_classes=1000, is_training=False, dropout_keep_prob=0.5):
11 | self.num_classes = num_classes
12 | self.dropout_keep_prob = dropout_keep_prob
13 | self.is_training = is_training
14 |
15 | def inference(self, x, reuse=None, extract_feat=False):
16 | with tf.variable_scope('network', reuse=reuse):
17 | # 1st Layer: Conv (w ReLu) -> Pool -> Lrn
18 | conv1 = conv(x, 11, 11, 96, 4, 4, padding='VALID', name='conv1')
19 | pool1 = max_pool(conv1, 3, 3, 2, 2, padding='VALID', name='pool1')
20 | norm1 = lrn(pool1, 2, 2e-05, 0.75, name='norm1')
21 |
22 | # 2nd Layer: Conv (w ReLu) -> Pool -> Lrn with 2 groups
23 | conv2 = conv(norm1, 5, 5, 256, 1, 1, groups=2, name='conv2')
24 | pool2 = max_pool(conv2, 3, 3, 2, 2, padding='VALID', name ='pool2')
25 | norm2 = lrn(pool2, 2, 2e-05, 0.75, name='norm2')
26 |
27 | # 3rd Layer: Conv (w ReLu)
28 | conv3 = conv(norm2, 3, 3, 384, 1, 1, name='conv3')
29 |
30 | # 4th Layer: Conv (w ReLu) splitted into two groups
31 | conv4 = conv(conv3, 3, 3, 384, 1, 1, groups=2, name='conv4')
32 |
33 | # 5th Layer: Conv (w ReLu) -> Pool splitted into two groups
34 | conv5 = conv(conv4, 3, 3, 256, 1, 1, groups=2, name='conv5')
35 | pool5 = max_pool(conv5, 3, 3, 2, 2, padding='VALID', name='pool5')
36 |
37 | # 6th Layer: Flatten -> FC (w ReLu) -> Dropout
38 | flattened = tf.reshape(pool5, [-1, 6*6*256])
39 | fc6 = fc(flattened, 6*6*256, 4096, name='fc6')
40 |
41 | # if self.is_training:
42 | # fc6 = dropout(fc6, self.dropout_keep_prob)
43 | fc6 = tf.layers.dropout(fc6, rate=1.0-self.dropout_keep_prob, training=self.is_training)
44 |
45 | # 7th Layer: FC (w ReLu) -> Dropout
46 | fc7 = fc(fc6, 4096, 4096, name='fc7')
47 |
48 | # if self.is_training:
49 | # fc7 = dropout(fc7, self.dropout_keep_prob)
50 | fc7 = tf.layers.dropout(fc7, rate=1.0-self.dropout_keep_prob, training=self.is_training)
51 |
52 |
53 | if extract_feat:
54 | return fc7
55 |
56 | # 8th Layer: FC and return unscaled activations (for tf.nn.softmax_cross_entropy_with_logits)
57 | self.score = fc(fc7, 4096, self.num_classes, relu=False, name='fc8')
58 | return self.score
59 |
60 | def loss(self, batch_x, batch_y=None):
61 | y_predict = self.inference(batch_x, training=True)
62 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_predict, labels=batch_y))
63 | return self.loss
64 |
65 | def optimize(self, learning_rate, train_layers=[]):
66 | var_list = [v for v in tf.trainable_variables() if v.name.split('/')[0] in train_layers]
67 | return tf.train.AdamOptimizer(learning_rate).minimize(self.loss, var_list=var_list)
68 |
69 | def load_original_weights(self, session, skip_layers=[]):
70 | weights_dict = np.load('bvlc_alexnet.npy', encoding='bytes', allow_pickle=True).item()
71 |
72 | for op_name in weights_dict:
73 | # if op_name in skip_layers:
74 | # continue
75 |
76 | if op_name == 'fc8' and self.num_classes != 1000:
77 | continue
78 |
79 | with tf.variable_scope('network/' + op_name, reuse=True):
80 | for data in weights_dict[op_name]:
81 | if len(data.shape) == 1:
82 | var = tf.get_variable('biases')
83 | session.run(var.assign(data))
84 | else:
85 | var = tf.get_variable('weights')
86 | session.run(var.assign(data))
87 |
88 |
89 | """
90 | Helper methods
91 | """
92 | def conv(x, filter_height, filter_width, num_filters, stride_y, stride_x, name, padding='SAME', groups=1):
93 | input_channels = int(x.get_shape()[-1])
94 | convolve = lambda i, k: tf.nn.conv2d(i, k, strides=[1, stride_y, stride_x, 1], padding=padding)
95 |
96 | with tf.variable_scope(name) as scope:
97 | weights = tf.get_variable('weights', shape=[filter_height, filter_width, input_channels/groups, num_filters])
98 | biases = tf.get_variable('biases', shape=[num_filters])
99 |
100 | if groups == 1:
101 | conv = convolve(x, weights)
102 | else:
103 | input_groups = tf.split(axis=3, num_or_size_splits=groups, value=x)
104 | weight_groups = tf.split(axis=3, num_or_size_splits=groups, value=weights)
105 | output_groups = [convolve(i, k) for i,k in zip(input_groups, weight_groups)]
106 | conv = tf.concat(axis=3, values=output_groups)
107 |
108 | # original
109 | # bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape().as_list())
110 | bias =tf.nn.bias_add(conv, biases)
111 | relu = tf.nn.relu(bias, name=scope.name)
112 | return relu
113 |
114 | def fc(x, num_in, num_out, name, relu=True):
115 | with tf.variable_scope(name) as scope:
116 | weights = tf.get_variable('weights', shape=[num_in, num_out])
117 | biases = tf.get_variable('biases', [num_out])
118 | act = tf.nn.xw_plus_b(x, weights, biases, name=scope.name)
119 |
120 | if relu == True:
121 | relu = tf.nn.relu(act)
122 | return relu
123 | else:
124 | return act
125 |
126 |
127 | def max_pool(x, filter_height, filter_width, stride_y, stride_x, name, padding='SAME'):
128 | return tf.nn.max_pool(x, ksize=[1, filter_height, filter_width, 1], strides = [1, stride_y, stride_x, 1],
129 | padding = padding, name=name)
130 |
131 | def lrn(x, radius, alpha, beta, name, bias=1.0):
132 | return tf.nn.local_response_normalization(x, depth_radius=radius, alpha=alpha, beta=beta, bias=bias, name=name)
133 |
134 | def dropout(x, keep_prob):
135 | return tf.nn.dropout(x, keep_prob)
136 |
--------------------------------------------------------------------------------
/model/run_most_AlexNet_finetune.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Tuan Nguyen.
2 | # All rights reserved.
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | from most_AlexNet_finetune import MOST
9 |
10 | from layers import noise
11 | from test_da_template_AlexNet_finetune import main_func, resolve_conflict_params
12 |
13 | from tensorflow.python.layers.core import dropout
14 | from tensorbayes.layers import dense, conv2d, avg_pool, max_pool
15 |
16 | import warnings
17 | import os
18 | from generic_utils import tuid, model_dir
19 | import signal
20 | import sys
21 | import time
22 | import datetime
23 | from pprint import pprint
24 |
25 | choice_default = 1
26 | warnings.simplefilter("ignore", category=DeprecationWarning)
27 |
28 | model_name = "MOST-results"
29 | current_time = tuid()
30 |
31 |
32 | def encode_layout(preprocess, training_phase=True, cnn_size='large'):
33 | layout = []
34 | if cnn_size == 'small':
35 | layout = [
36 | (conv2d, (64, 3, 1), {}),
37 | (max_pool, (2, 2), {}),
38 | (dropout, (), dict(training=training_phase)),
39 | (noise, (1,), dict(phase=training_phase)),
40 | ]
41 | elif cnn_size == 'large':
42 | layout = [
43 | (preprocess, (), {}),
44 | (conv2d, (96, 3, 1), {}),
45 | (conv2d, (96, 3, 1), {}),
46 | (conv2d, (96, 3, 1), {}),
47 | (max_pool, (2, 2), {}),
48 | (dropout, (), dict(training=training_phase)),
49 | (noise, (1,), dict(phase=training_phase)),
50 | (conv2d, (192, 3, 1), {}),
51 | (conv2d, (192, 3, 1), {}),
52 | (conv2d, (192, 3, 1), {}),
53 | (max_pool, (2, 2), {}),
54 | (dropout, (), dict(training=training_phase)),
55 | (noise, (1,), dict(phase=training_phase)),
56 | ]
57 | return layout
58 |
59 |
60 | def class_discriminator_layout(num_classes=None, global_pool=True, activation=None, cnn_size='large'):
61 | layout = []
62 | if cnn_size == 'small':
63 | layout = [
64 | (dense, (num_classes,), dict(activation=activation))
65 | ]
66 |
67 | elif cnn_size == 'large':
68 | layout = [
69 | (conv2d, (192, 3, 1), {}),
70 | (conv2d, (192, 3, 1), {}),
71 | (conv2d, (192, 3, 1), {}),
72 | (avg_pool, (), dict(global_pool=global_pool)),
73 | (dense, (num_classes,), dict(activation=activation))
74 | ]
75 | return layout
76 |
77 |
78 | def domain_layout(c):
79 | layout = [
80 | (dense, (c,), dict(activation=None))
81 | ]
82 | return layout
83 |
84 |
85 | def phi_layout(c):
86 | layout = [
87 | (dense, (c,), dict(activation=None))
88 | ]
89 | return layout
90 |
91 |
92 | def create_obj_func(params):
93 | if len(sys.argv) > 1:
94 | my_choice = int(sys.argv[1])
95 | else:
96 | my_choice = choice_default
97 | if my_choice == 0:
98 | default_params = {
99 | }
100 | else:
101 | default_params = {
102 | 'batch_size': 128,
103 | 'learning_rate': 1e-4,
104 | 'num_iters': 80000,
105 | 'src_class_trade_off': 1.0,
106 | 'src_domain_trade_off': '1.0,1.0',
107 | 'ot_trade_off': 0.1,
108 | 'domain_trade_off': 0.1,
109 | 'src_vat_trade_off': 1.0,
110 | 'g_network_trade_off': 1.0,
111 | 'theta': 10.0,
112 | 'mdaot_model_id': '',
113 | 'classify_layout': class_discriminator_layout,
114 | 'encode_layout': encode_layout,
115 | 'domain_layout': domain_layout,
116 | 'phi_layout': phi_layout,
117 | 'log_path': os.path.join(model_dir(), model_name, "logs", "{}".format(current_time)),
118 | 'summary_freq': 400,
119 | 'current_time': current_time,
120 | 'inorm': True,
121 | 'cast_data': False,
122 | 'only_save_final_model': True,
123 | 'cnn_size': 'large',
124 | 'sample_size': 20,
125 | 'data_shift_troff': 10.0,
126 | 'num_classes': 10,
127 | 'multi_scale': '',
128 | 'resnet_depth': 101,
129 | 'train_layers': 'fc7,fc6'
130 | }
131 |
132 | default_params = resolve_conflict_params(params, default_params)
133 |
134 | print('Default parameters:')
135 | pprint(default_params)
136 |
137 | learner = MOST(
138 | **params,
139 | **default_params,
140 | )
141 | return learner
142 |
143 |
144 | def main_test(run_exp=False):
145 | main_func(
146 | create_obj_func,
147 | choice_default=choice_default,
148 | src_name_default='mnist32_60_10',
149 | trg_name_default='mnistm32_60_10',
150 | run_exp=run_exp
151 | )
152 |
153 |
154 | class Logger(object):
155 | def __init__(self):
156 | self.terminal = sys.stdout
157 | self.console_log_path = os.path.join(model_dir(), model_name, "console_output", "{}.txt".format(current_time))
158 | if not os.path.exists(os.path.dirname(self.console_log_path)):
159 | os.makedirs(os.path.dirname(self.console_log_path))
160 | self.log = open(self.console_log_path, 'a')
161 | signal.signal(signal.SIGINT, self.signal_handler)
162 |
163 | def signal_handler(self, sig, frame):
164 | print('You pressed Ctrl+C.')
165 | self.log.close()
166 |
167 | # Remove logfile
168 | os.remove(self.console_log_path)
169 | print('Removed console_output file')
170 | sys.exit(0)
171 |
172 | def write(self, message):
173 | self.terminal.write(message)
174 | self.log.write(message)
175 |
176 | def flush(self):
177 | # this flush method is needed for python 3 compatibility.
178 | # this handles the flush command by doing nothing.
179 | # you might want to specify some extra behavior here.
180 | pass
181 |
182 |
183 | if __name__ == '__main__':
184 | sys.stdout = Logger()
185 | start_time = time.time()
186 | print('Running {} ...'.format(os.path.basename(__file__)))
187 | main_test(run_exp=True)
188 | training_time = time.time() - start_time
189 | print('Total time: %s' % str(datetime.timedelta(seconds=training_time)))
190 | print("============ LOG-ID: %s ============" % current_time)
191 |
--------------------------------------------------------------------------------
/model/run_most_AlexNet_train_feat.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Tuan Nguyen.
2 | # All rights reserved.
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | from most_AlexNet_train_feat import MOST
9 |
10 | from layers import noise
11 | from test_da_template_AlexNet_train_feat import main_func, resolve_conflict_params
12 |
13 | from tensorflow.python.layers.core import dropout
14 | from tensorbayes.layers import dense, conv2d, avg_pool, max_pool
15 |
16 | import warnings
17 | import os
18 | from generic_utils import tuid, model_dir
19 | import signal
20 | import sys
21 | import time
22 | import datetime
23 | from pprint import pprint
24 |
25 | choice_default = 1
26 | warnings.simplefilter("ignore", category=DeprecationWarning)
27 |
28 | model_name = "MOST-results"
29 | current_time = tuid()
30 |
31 |
32 | def encode_layout(preprocess, training_phase=True, cnn_size='large'):
33 | layout = []
34 | if cnn_size == 'small':
35 | layout = [
36 | (conv2d, (64, 3, 1), {}),
37 | (max_pool, (2, 2), {}),
38 | (dropout, (), dict(training=training_phase)),
39 | (noise, (1,), dict(phase=training_phase)),
40 | ]
41 | elif cnn_size == 'large':
42 | layout = [
43 | (preprocess, (), {}),
44 | (conv2d, (96, 3, 1), {}),
45 | (conv2d, (96, 3, 1), {}),
46 | (conv2d, (96, 3, 1), {}),
47 | (max_pool, (2, 2), {}),
48 | (dropout, (), dict(training=training_phase)),
49 | (noise, (1,), dict(phase=training_phase)),
50 | (conv2d, (192, 3, 1), {}),
51 | (conv2d, (192, 3, 1), {}),
52 | (conv2d, (192, 3, 1), {}),
53 | (max_pool, (2, 2), {}),
54 | (dropout, (), dict(training=training_phase)),
55 | (noise, (1,), dict(phase=training_phase)),
56 | ]
57 | return layout
58 |
59 |
60 | def class_discriminator_layout(num_classes=None, global_pool=True, activation=None, cnn_size='large'):
61 | layout = []
62 | if cnn_size == 'small':
63 | layout = [
64 | (dense, (num_classes,), dict(activation=activation))
65 | ]
66 |
67 | elif cnn_size == 'large':
68 | layout = [
69 | (conv2d, (192, 3, 1), {}),
70 | (conv2d, (192, 3, 1), {}),
71 | (conv2d, (192, 3, 1), {}),
72 | (avg_pool, (), dict(global_pool=global_pool)),
73 | (dense, (num_classes,), dict(activation=activation))
74 | ]
75 | return layout
76 |
77 |
78 | def domain_layout(c):
79 | layout = [
80 | (dense, (c,), dict(activation=None))
81 | ]
82 | return layout
83 |
84 |
85 | def phi_layout(c):
86 | layout = [
87 | (dense, (c,), dict(activation=None))
88 | ]
89 | return layout
90 |
91 |
92 | def create_obj_func(params):
93 | if len(sys.argv) > 1:
94 | my_choice = int(sys.argv[1])
95 | else:
96 | my_choice = choice_default
97 | if my_choice == 0:
98 | default_params = {
99 | }
100 | else:
101 | default_params = {
102 | 'batch_size': 128,
103 | 'learning_rate': 1e-4,
104 | 'num_iters': 80000,
105 | 'phase1_iters': 20000,
106 | 'src_class_trade_off': 1.0,
107 | 'src_domain_trade_off': '1.0,1.0',
108 | 'ot_trade_off': 0.1,
109 | 'domain_trade_off': 0.1,
110 | 'src_vat_trade_off': 1.0,
111 | 'trg_vat_troff': 0.1,
112 | 'trg_ent_troff': 0.1,
113 | 'g_network_trade_off': 1.0,
114 | 'mimic_trade_off': 0.1,
115 | 'theta': 10.0,
116 | 'mdaot_model_id': '',
117 | 'classify_layout': class_discriminator_layout,
118 | 'encode_layout': encode_layout,
119 | 'domain_layout': domain_layout,
120 | 'phi_layout': phi_layout,
121 | 'log_path': os.path.join(model_dir(), model_name, "logs", "{}".format(current_time)),
122 | 'summary_freq': 400,
123 | 'current_time': current_time,
124 | 'inorm': True,
125 | 'cast_data': False,
126 | 'only_save_final_model': True,
127 | 'cnn_size': 'large',
128 | 'sample_size': 20,
129 | 'data_shift_troff': 10.0,
130 | 'data_dir': ''
131 | }
132 |
133 | default_params = resolve_conflict_params(params, default_params)
134 |
135 | print('Default parameters:')
136 | pprint(default_params)
137 |
138 | learner = MOST(
139 | **params,
140 | **default_params,
141 | )
142 | return learner
143 |
144 |
145 | def main_test(run_exp=False):
146 | main_func(
147 | create_obj_func,
148 | choice_default=choice_default,
149 | src_name_default='mnist32_60_10',
150 | trg_name_default='mnistm32_60_10',
151 | run_exp=run_exp
152 | )
153 |
154 |
155 | class Logger(object):
156 | def __init__(self):
157 | self.terminal = sys.stdout
158 | self.console_log_path = os.path.join(model_dir(), model_name, "console_output", "{}.txt".format(current_time))
159 | if not os.path.exists(os.path.dirname(self.console_log_path)):
160 | os.makedirs(os.path.dirname(self.console_log_path))
161 | self.log = open(self.console_log_path, 'a')
162 | signal.signal(signal.SIGINT, self.signal_handler)
163 |
164 | def signal_handler(self, sig, frame):
165 | print('You pressed Ctrl+C.')
166 | self.log.close()
167 |
168 | # Remove logfile
169 | os.remove(self.console_log_path)
170 | print('Removed console_output file')
171 | sys.exit(0)
172 |
173 | def write(self, message):
174 | self.terminal.write(message)
175 | self.log.write(message)
176 |
177 | def flush(self):
178 | # this flush method is needed for python 3 compatibility.
179 | # this handles the flush command by doing nothing.
180 | # you might want to specify some extra behavior here.
181 | pass
182 |
183 |
184 | if __name__ == '__main__':
185 | sys.stdout = Logger()
186 | start_time = time.time()
187 | print('Running {} ...'.format(os.path.basename(__file__)))
188 | main_test(run_exp=True)
189 | training_time = time.time() - start_time
190 | print('Total time: %s' % str(datetime.timedelta(seconds=training_time)))
191 | print("============ LOG-ID: %s ============" % current_time)
192 |
--------------------------------------------------------------------------------
/model/test_da_template_AlexNet_finetune.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Tuan Nguyen.
2 | # All rights reserved.
3 |
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import absolute_import
7 |
8 | import os
9 | import sys
10 | import numpy as np
11 | import tensorflow as tf
12 |
13 | from generic_utils import random_seed
14 | from generic_utils import data_dir
15 | from resnet.preprocessor import BatchPreprocessor
16 |
17 |
18 | def test_real_dataset(create_obj_func, src_name=None, trg_name=None):
19 | print('Running {} ...'.format(os.path.basename(__file__)))
20 |
21 | if src_name is None:
22 | if len(sys.argv) > 2:
23 | src_name = sys.argv[2]
24 | else:
25 | raise Exception('Not specify source dataset')
26 | if trg_name is None:
27 | if len(sys.argv) > 3:
28 | trg_name = sys.argv[3]
29 | else:
30 | raise Exception('Not specify trgget dataset')
31 |
32 | np.random.seed(random_seed())
33 | tf.set_random_seed(random_seed())
34 | tf.reset_default_graph()
35 |
36 | print("========== Test on real data ==========")
37 |
38 | users_params = dict()
39 | users_params = parse_arguments(users_params)
40 | data_format = 'mat'
41 |
42 | if 'format' in users_params:
43 | data_format, users_params = extract_param('format', data_format, users_params)
44 |
45 | src_domains = src_name.split(',')
46 | num_src_domain = len(src_domains)
47 | input_size = [227, 227]
48 | n_channels = 3
49 | src_preprocessors = []
50 | dataset_path = os.path.join(data_dir(), 'office31')
51 | multi_scale = list(map(int, users_params['multi_scale'].split(',')))
52 |
53 | for src_domain in src_domains:
54 | file_path_train = os.path.join(dataset_path, '{}_train.txt'.format(src_domain))
55 | src_preprocessor_i = BatchPreprocessor(dataset_file_path=file_path_train,
56 | num_classes=users_params['num_classes'],
57 | output_size=input_size, horizontal_flip=True, shuffle=True,
58 | multi_scale=multi_scale)
59 | src_preprocessors.append(src_preprocessor_i)
60 |
61 | trg_train_preprocessor = BatchPreprocessor(dataset_file_path=os.path.join(dataset_path, '{}_train.txt'.format(trg_name)),
62 | num_classes=users_params['num_classes'], output_size=input_size, horizontal_flip=True, shuffle=True,
63 | multi_scale=multi_scale)
64 |
65 | trg_test_preprocessor = BatchPreprocessor(dataset_file_path=os.path.join(dataset_path, '{}_test.txt'.format(trg_name)),
66 | num_classes=users_params['num_classes'], output_size=input_size)
67 |
68 | assert users_params['batch_size'] % num_src_domain == 0
69 |
70 | print('users_params:', users_params)
71 | print('src_name:', src_name, ', trg_name:', trg_name)
72 | for i in range(len(src_domains)):
73 | print(src_domains[i], len(src_preprocessors[i].labels))
74 | print(trg_name, len(trg_test_preprocessor.labels))
75 |
76 | learner = create_obj_func(users_params)
77 | learner.dim_src = tuple(input_size + [n_channels])
78 | learner.dim_trg = tuple(input_size + [n_channels])
79 |
80 | learner._init(src_preprocessors, trg_train_preprocessor, trg_test_preprocessor, num_src_domain)
81 | learner._build_model()
82 | learner._fit_loop()
83 |
84 |
85 | def main_func(
86 | create_obj_func,
87 | choice_default=0,
88 | src_name_default='svmguide1',
89 | trg_name_default='svmguide1',
90 | run_exp=False):
91 |
92 | if not run_exp:
93 | choice_lst = [0, 1, 2]
94 | src_name = src_name_default
95 | trg_name = trg_name_default
96 | elif len(sys.argv) > 1:
97 | choice_lst = [int(sys.argv[1])]
98 | src_name = None
99 | trg_name = None
100 | else:
101 | choice_lst = [choice_default]
102 | src_name = src_name_default
103 | trg_name = trg_name_default
104 |
105 | for choice in choice_lst:
106 | if choice == 0:
107 | pass
108 | # add another function here
109 | elif choice == 1:
110 | test_real_dataset(create_obj_func, src_name, trg_name)
111 |
112 |
113 | def parse_arguments(params, as_array=False):
114 | for it in range(4, len(sys.argv), 2):
115 | params[sys.argv[it]] = parse_argument(sys.argv[it + 1], as_array)
116 | return params
117 |
118 |
119 | def parse_argument(string, as_array=False):
120 | try:
121 | result = int(string)
122 | except ValueError:
123 | try:
124 | result = float(string)
125 | except ValueError:
126 | if str.lower(string) == 'true':
127 | result = True
128 | elif str.lower(string) == 'false':
129 | result = False
130 | elif string == "[]":
131 | return []
132 | elif ('|' in string) and ('[' in string) and (']' in string):
133 | result = [float(item) for item in string[1:-1].split('|')]
134 | return result
135 | elif (',' in string) and ('(' in string) and (')' in string):
136 | split = string[1:-1].split(',')
137 | result = float(split[0]) ** np.arange(float(split[1]), float(split[2]), float(split[3]))
138 | return result
139 | else:
140 | result = string
141 |
142 | return [result] if as_array else result
143 |
144 |
145 | def resolve_conflict_params(primary_params, secondary_params):
146 | for key in primary_params.keys():
147 | if key in secondary_params.keys():
148 | del secondary_params[key]
149 | return secondary_params
150 |
151 |
152 | def extract_param(key, value, params_gridsearch, scalar=False):
153 | if key in params_gridsearch.keys():
154 | value = params_gridsearch[key]
155 | del params_gridsearch[key]
156 | if scalar and (value is not None):
157 | value = value[0]
158 | return value, params_gridsearch
159 |
160 |
161 | def dict2string(params):
162 | result = ''
163 | for key, value in params.items():
164 | if type(value) is np.ndarray:
165 | if value.size < 16:
166 | result += key + ': ' + '|'.join('{0:.4f}'.format(x) for x in value.ravel()) + ', '
167 | else:
168 | result += key + ': ' + str(value) + ', '
169 | return '{' + result[:-2] + '}'
170 |
--------------------------------------------------------------------------------
/model/run_most_digits.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Tuan Nguyen.
2 | # All rights reserved.
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | from most_digits import MOST
9 | from layers import noise
10 | from test_da_template_digits import main_func, resolve_conflict_params
11 | from tensorflow.python.layers.core import dropout
12 | from tensorbayes.layers import dense, conv2d, avg_pool, max_pool
13 |
14 | import warnings
15 | import os
16 | from generic_utils import tuid, model_dir
17 | import signal
18 | import sys
19 | import time
20 | import datetime
21 |
22 | choice_default = 1
23 | warnings.simplefilter("ignore", category=DeprecationWarning)
24 |
25 | model_name = "MOST-results"
26 | current_time = tuid()
27 |
28 |
29 | # generator
30 | def encode_layout(preprocess, training_phase=True, cnn_size='large'):
31 | layout = []
32 | if cnn_size == 'small':
33 | layout = [
34 | (preprocess, (), {}),
35 | (conv2d, (64, 3, 1), {}),
36 | (conv2d, (64, 3, 1), {}),
37 | (conv2d, (64, 3, 1), {}),
38 | (max_pool, (2, 2), {}),
39 | (dropout, (), dict(training=training_phase)),
40 | (noise, (1,), dict(phase=training_phase)),
41 | (conv2d, (64, 3, 1), {}),
42 | (conv2d, (64, 3, 1), {}),
43 | (conv2d, (64, 3, 1), {}),
44 | (max_pool, (2, 2), {}),
45 | (dropout, (), dict(training=training_phase)),
46 | (noise, (1,), dict(phase=training_phase)),
47 | ]
48 | elif cnn_size == 'large':
49 | layout = [
50 | (preprocess, (), {}),
51 | (conv2d, (96, 3, 1), {}),
52 | (conv2d, (96, 3, 1), {}),
53 | (conv2d, (96, 3, 1), {}),
54 | (max_pool, (2, 2), {}),
55 | (dropout, (), dict(training=training_phase)),
56 | (noise, (1,), dict(phase=training_phase)),
57 | (conv2d, (192, 3, 1), {}),
58 | (conv2d, (192, 3, 1), {}),
59 | (conv2d, (192, 3, 1), {}),
60 | (max_pool, (2, 2), {}),
61 | (dropout, (), dict(training=training_phase)),
62 | (noise, (1,), dict(phase=training_phase)),
63 | ]
64 | return layout
65 |
66 |
67 | # classifier
68 | def class_discriminator_layout(num_classes=None, global_pool=True, activation=None, cnn_size='large'):
69 | layout = []
70 | if cnn_size == 'small':
71 | layout = [
72 | (conv2d, (64, 3, 1), {}),
73 | (conv2d, (64, 3, 1), {}),
74 | (conv2d, (64, 3, 1), {}),
75 | (avg_pool, (), dict(global_pool=global_pool)),
76 | (dense, (num_classes,), dict(activation=activation))
77 | ]
78 |
79 | elif cnn_size == 'large':
80 | layout = [
81 | (conv2d, (192, 3, 1), {}),
82 | (conv2d, (192, 3, 1), {}),
83 | (conv2d, (192, 3, 1), {}),
84 | (avg_pool, (), dict(global_pool=global_pool)),
85 | (dense, (num_classes,), dict(activation=activation))
86 | ]
87 | return layout
88 |
89 |
90 | # discriminator
91 | def domain_layout(c):
92 | layout = [
93 | (dense, (100,), {}),
94 | (dense, (c,), dict(activation=None))
95 | ]
96 | return layout
97 |
98 |
99 | def phi_layout(c):
100 | layout = [
101 | (dense, (100,), {}),
102 | (dense, (c,), dict(activation=None))
103 | ]
104 | return layout
105 |
106 |
107 | def create_obj_func(params):
108 | if len(sys.argv) > 1:
109 | my_choice = int(sys.argv[1])
110 | else:
111 | my_choice = choice_default
112 | if my_choice == 0:
113 | default_params = {
114 | }
115 | else:
116 | default_params = {
117 | 'batch_size': 200,
118 | 'learning_rate': 0.0002,
119 | 'num_iters': 80000,
120 | 'phase1_iters': 20000,
121 | 'src_class_trade_off': 1.0,
122 | 'src_domain_trade_off': '1.0,1.0,1.0,1.0',
123 | 'ot_trade_off': 0.1,
124 | 'domain_trade_off': 1.0,
125 | 'trg_vat_troff': 0.1,
126 | 'trg_ent_troff': 0.1,
127 | 'g_network_trade_off': 1.0,
128 | 'mimic_trade_off': 1.0,
129 | 'theta': 10.0,
130 | 'mdaot_model_id': '',
131 | 'classify_layout': class_discriminator_layout,
132 | 'encode_layout': encode_layout,
133 | 'domain_layout': domain_layout,
134 | 'phi_layout': phi_layout,
135 | 'log_path': os.path.join(model_dir(), model_name, "logs", "{}".format(current_time)),
136 | 'summary_freq': 800,
137 | 'current_time': current_time,
138 | 'inorm': True,
139 | 'cast_data': True,
140 | 'only_save_final_model': True,
141 | 'cnn_size': 'small',
142 | 'sample_size': 20,
143 | 'data_shift_troff': 10.0,
144 | 'lbl_shift_troff': 1.0
145 | }
146 |
147 | default_params = resolve_conflict_params(params, default_params)
148 | learner = MOST(
149 | **params,
150 | **default_params,
151 | )
152 | return learner
153 |
154 |
155 | def main_test(run_exp=False):
156 | main_func(
157 | create_obj_func,
158 | choice_default=choice_default,
159 | src_name_default='mnist32_60_10',
160 | trg_name_default='mnistm32_60_10',
161 | run_exp=run_exp,
162 | freq_predict_display=10,
163 | summary_freq=100,
164 | current_time=current_time,
165 | log_path=os.path.join(model_dir(), model_name, "logs", "{}".format(current_time))
166 | )
167 |
168 |
169 | class Logger(object):
170 | def __init__(self):
171 | self.terminal = sys.stdout
172 | self.console_log_path = os.path.join(model_dir(), model_name, "console_output", "{}.txt".format(current_time))
173 | if not os.path.exists(os.path.dirname(self.console_log_path)):
174 | os.makedirs(os.path.dirname(self.console_log_path))
175 | self.log = open(self.console_log_path, 'a')
176 | signal.signal(signal.SIGINT, self.signal_handler)
177 |
178 | def signal_handler(self, sig, frame):
179 | print('You pressed Ctrl+C.')
180 | self.log.close()
181 |
182 | # Remove logfile
183 | os.remove(self.console_log_path)
184 | print('Removed console_output file')
185 | sys.exit(0)
186 |
187 | def write(self, message):
188 | self.terminal.write(message)
189 | self.log.write(message)
190 |
191 | def flush(self):
192 | # this flush method is needed for python 3 compatibility.
193 | # this handles the flush command by doing nothing.
194 | # you might want to specify some extra behavior here.
195 | pass
196 |
197 |
198 | if __name__ == '__main__':
199 | sys.stdout = Logger()
200 | start_time = time.time()
201 | print('Running {} ...'.format(os.path.basename(__file__)))
202 | main_test(run_exp=True)
203 | training_time = time.time() - start_time
204 | print('Total time: %s' % str(datetime.timedelta(seconds=training_time)))
205 | print("============ LOG-ID: %s ============" % current_time)
206 |
--------------------------------------------------------------------------------
/model/alexnet/finetune.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import numpy as np
3 | import tensorflow as tf
4 | import datetime
5 | from model import AlexNetModel
6 | sys.path.insert(0, '../utils')
7 | from preprocessor import BatchPreprocessor
8 |
9 |
10 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate for adam optimizer')
11 | tf.app.flags.DEFINE_float('dropout_keep_prob', 0.5, 'Dropout keep probability')
12 | tf.app.flags.DEFINE_integer('num_epochs', 10, 'Number of epochs for training')
13 | tf.app.flags.DEFINE_integer('num_classes', 26, 'Number of classes')
14 | tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size')
15 | tf.app.flags.DEFINE_string('train_layers', 'fc8,fc7', 'Finetuning layers, seperated by commas')
16 | tf.app.flags.DEFINE_string('multi_scale', '', 'As preprocessing; scale the image randomly between 2 numbers and crop randomly at network\'s input size')
17 | tf.app.flags.DEFINE_string('training_file', '../data/train.txt', 'Training dataset file')
18 | tf.app.flags.DEFINE_string('val_file', '../data/val.txt', 'Validation dataset file')
19 | tf.app.flags.DEFINE_string('tensorboard_root_dir', '../training', 'Root directory to put the training logs and weights')
20 | tf.app.flags.DEFINE_integer('log_step', 10, 'Logging period in terms of iteration')
21 |
22 | FLAGS = tf.app.flags.FLAGS
23 |
24 |
25 | def main(_):
26 | # Create training directories
27 | now = datetime.datetime.now()
28 | train_dir_name = now.strftime('alexnet_%Y%m%d_%H%M%S')
29 | train_dir = os.path.join(FLAGS.tensorboard_root_dir, train_dir_name)
30 | checkpoint_dir = os.path.join(train_dir, 'checkpoint')
31 | tensorboard_dir = os.path.join(train_dir, 'tensorboard')
32 | tensorboard_train_dir = os.path.join(tensorboard_dir, 'train')
33 | tensorboard_val_dir = os.path.join(tensorboard_dir, 'val')
34 |
35 | if not os.path.isdir(FLAGS.tensorboard_root_dir): os.mkdir(FLAGS.tensorboard_root_dir)
36 | if not os.path.isdir(train_dir): os.mkdir(train_dir)
37 | if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir)
38 | if not os.path.isdir(tensorboard_dir): os.mkdir(tensorboard_dir)
39 | if not os.path.isdir(tensorboard_train_dir): os.mkdir(tensorboard_train_dir)
40 | if not os.path.isdir(tensorboard_val_dir): os.mkdir(tensorboard_val_dir)
41 |
42 | # Write flags to txt
43 | flags_file_path = os.path.join(train_dir, 'flags.txt')
44 | flags_file = open(flags_file_path, 'w')
45 | flags_file.write('learning_rate={}\n'.format(FLAGS.learning_rate))
46 | flags_file.write('dropout_keep_prob={}\n'.format(FLAGS.dropout_keep_prob))
47 | flags_file.write('num_epochs={}\n'.format(FLAGS.num_epochs))
48 | flags_file.write('batch_size={}\n'.format(FLAGS.batch_size))
49 | flags_file.write('train_layers={}\n'.format(FLAGS.train_layers))
50 | flags_file.write('multi_scale={}\n'.format(FLAGS.multi_scale))
51 | flags_file.write('tensorboard_root_dir={}\n'.format(FLAGS.tensorboard_root_dir))
52 | flags_file.write('log_step={}\n'.format(FLAGS.log_step))
53 | flags_file.close()
54 |
55 | # Placeholders
56 | x = tf.placeholder(tf.float32, [FLAGS.batch_size, 227, 227, 3])
57 | y = tf.placeholder(tf.float32, [None, FLAGS.num_classes])
58 | dropout_keep_prob = tf.placeholder(tf.float32)
59 |
60 | # Model
61 | train_layers = FLAGS.train_layers.split(',')
62 | model = AlexNetModel(num_classes=FLAGS.num_classes, dropout_keep_prob=dropout_keep_prob)
63 | loss = model.loss(x, y)
64 | train_op = model.optimize(FLAGS.learning_rate, train_layers)
65 |
66 | # Training accuracy of the model
67 | correct_pred = tf.equal(tf.argmax(model.score, 1), tf.argmax(y, 1))
68 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
69 |
70 | # Summaries
71 | tf.summary.scalar('train_loss', loss)
72 | tf.summary.scalar('train_accuracy', accuracy)
73 | merged_summary = tf.summary.merge_all()
74 |
75 | train_writer = tf.summary.FileWriter(tensorboard_train_dir)
76 | val_writer = tf.summary.FileWriter(tensorboard_val_dir)
77 | saver = tf.train.Saver()
78 |
79 | # Batch preprocessors
80 | multi_scale = FLAGS.multi_scale.split(',')
81 | if len(multi_scale) == 2:
82 | multi_scale = [int(multi_scale[0]), int(multi_scale[1])]
83 | else:
84 | multi_scale = None
85 |
86 | train_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.training_file, num_classes=FLAGS.num_classes,
87 | output_size=[227, 227], horizontal_flip=True, shuffle=True, multi_scale=multi_scale)
88 | val_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.val_file, num_classes=FLAGS.num_classes, output_size=[227, 227])
89 |
90 | # Get the number of training/validation steps per epoch
91 | train_batches_per_epoch = np.floor(len(train_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)
92 | val_batches_per_epoch = np.floor(len(val_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)
93 |
94 |
95 | with tf.Session() as sess:
96 | sess.run(tf.global_variables_initializer())
97 | train_writer.add_graph(sess.graph)
98 |
99 | # Load the pretrained weights
100 | model.load_original_weights(sess, skip_layers=train_layers)
101 |
102 | # Directly restore (your model should be exactly the same with checkpoint)
103 | # saver.restore(sess, "/Users/dgurkaynak/Projects/marvel-training/alexnet64-fc6/model_epoch10.ckpt")
104 |
105 | print("{} Start training...".format(datetime.datetime.now()))
106 | print("{} Open Tensorboard at --logdir {}".format(datetime.datetime.now(), tensorboard_dir))
107 |
108 | for epoch in range(FLAGS.num_epochs):
109 | print("{} Epoch number: {}".format(datetime.datetime.now(), epoch+1))
110 | step = 1
111 |
112 | # Start training
113 | while step < train_batches_per_epoch:
114 | batch_xs, batch_ys = train_preprocessor.next_batch(FLAGS.batch_size)
115 | sess.run(train_op, feed_dict={x: batch_xs, y: batch_ys, dropout_keep_prob: FLAGS.dropout_keep_prob})
116 |
117 | # Logging
118 | if step % FLAGS.log_step == 0:
119 | s = sess.run(merged_summary, feed_dict={x: batch_xs, y: batch_ys, dropout_keep_prob: 1.})
120 | train_writer.add_summary(s, epoch * train_batches_per_epoch + step)
121 |
122 | step += 1
123 |
124 | # Epoch completed, start validation
125 | print("{} Start validation".format(datetime.datetime.now()))
126 | test_acc = 0.
127 | test_count = 0
128 |
129 | for _ in range(val_batches_per_epoch):
130 | batch_tx, batch_ty = val_preprocessor.next_batch(FLAGS.batch_size)
131 | acc = sess.run(accuracy, feed_dict={x: batch_tx, y: batch_ty, dropout_keep_prob: 1.})
132 | test_acc += acc
133 | test_count += 1
134 |
135 | test_acc /= test_count
136 | s = tf.Summary(value=[
137 | tf.Summary.Value(tag="validation_accuracy", simple_value=test_acc)
138 | ])
139 | val_writer.add_summary(s, epoch+1)
140 | print("{} Validation Accuracy = {:.4f}".format(datetime.datetime.now(), test_acc))
141 |
142 | # Reset the dataset pointers
143 | val_preprocessor.reset_pointer()
144 | train_preprocessor.reset_pointer()
145 |
146 | print("{} Saving checkpoint of model...".format(datetime.datetime.now()))
147 |
148 | #save checkpoint of the model
149 | checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch'+str(epoch+1)+'.ckpt')
150 | save_path = saver.save(sess, checkpoint_path)
151 |
152 | print("{} Model checkpoint saved at {}".format(datetime.datetime.now(), checkpoint_path))
153 |
154 | if __name__ == '__main__':
155 | tf.app.run()
156 |
--------------------------------------------------------------------------------
/model/resnet/finetune.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import numpy as np
3 | import tensorflow as tf
4 | import datetime
5 | from model import ResNetModel
6 | sys.path.insert(0, '../utils')
7 | from preprocessor import BatchPreprocessor
8 |
9 |
10 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate for adam optimizer')
11 | tf.app.flags.DEFINE_integer('resnet_depth', 101, 'ResNet architecture to be used: 50, 101 or 152')
12 | tf.app.flags.DEFINE_integer('num_epochs', 100, 'Number of epochs for training')
13 | tf.app.flags.DEFINE_integer('num_classes', 10, 'Number of classes')
14 | tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size')
15 | tf.app.flags.DEFINE_string('train_layers', 'fc', 'Finetuning layers, seperated by commas')
16 | tf.app.flags.DEFINE_string('multi_scale', '', 'As preprocessing; scale the image randomly between 2 numbers and crop randomly at network\'s input size')
17 | tf.app.flags.DEFINE_string('training_file', '../data/office_caltech10/amazon.txt', 'Training dataset file')
18 | tf.app.flags.DEFINE_string('val_file', '../data/office_caltech10/amazon.txt', 'Validation dataset file')
19 | tf.app.flags.DEFINE_string('tensorboard_root_dir', '../training', 'Root directory to put the training logs and weights')
20 | tf.app.flags.DEFINE_integer('log_step', 10, 'Logging period in terms of iteration')
21 |
22 | FLAGS = tf.app.flags.FLAGS
23 |
24 |
25 | def main(_):
26 | # Create training directories
27 | now = datetime.datetime.now()
28 | train_dir_name = now.strftime('resnet_%Y%m%d_%H%M%S')
29 | train_dir = os.path.join(FLAGS.tensorboard_root_dir, train_dir_name)
30 | checkpoint_dir = os.path.join(train_dir, 'checkpoint')
31 | tensorboard_dir = os.path.join(train_dir, 'tensorboard')
32 | tensorboard_train_dir = os.path.join(tensorboard_dir, 'train')
33 | tensorboard_val_dir = os.path.join(tensorboard_dir, 'val')
34 |
35 | if not os.path.isdir(FLAGS.tensorboard_root_dir): os.mkdir(FLAGS.tensorboard_root_dir)
36 | if not os.path.isdir(train_dir): os.mkdir(train_dir)
37 | if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir)
38 | if not os.path.isdir(tensorboard_dir): os.mkdir(tensorboard_dir)
39 | if not os.path.isdir(tensorboard_train_dir): os.mkdir(tensorboard_train_dir)
40 | if not os.path.isdir(tensorboard_val_dir): os.mkdir(tensorboard_val_dir)
41 |
42 | # Write flags to txt
43 | flags_file_path = os.path.join(train_dir, 'flags.txt')
44 | flags_file = open(flags_file_path, 'w')
45 | flags_file.write('learning_rate={}\n'.format(FLAGS.learning_rate))
46 | flags_file.write('resnet_depth={}\n'.format(FLAGS.resnet_depth))
47 | flags_file.write('num_epochs={}\n'.format(FLAGS.num_epochs))
48 | flags_file.write('batch_size={}\n'.format(FLAGS.batch_size))
49 | flags_file.write('train_layers={}\n'.format(FLAGS.train_layers))
50 | flags_file.write('multi_scale={}\n'.format(FLAGS.multi_scale))
51 | flags_file.write('tensorboard_root_dir={}\n'.format(FLAGS.tensorboard_root_dir))
52 | flags_file.write('log_step={}\n'.format(FLAGS.log_step))
53 | flags_file.close()
54 |
55 | # Placeholders
56 | x = tf.placeholder(tf.float32, [FLAGS.batch_size, 224, 224, 3])
57 | y = tf.placeholder(tf.float32, [None, FLAGS.num_classes])
58 | is_training = tf.placeholder('bool', [])
59 |
60 | # Model
61 | train_layers = FLAGS.train_layers.split(',')
62 | model = ResNetModel(is_training, depth=FLAGS.resnet_depth, num_classes=FLAGS.num_classes)
63 | loss = model.loss(x, y)
64 | train_op = model.optimize(FLAGS.learning_rate, train_layers)
65 |
66 | # Training accuracy of the model
67 | correct_pred = tf.equal(tf.argmax(model.prob, 1), tf.argmax(y, 1))
68 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
69 |
70 | # Summaries
71 | tf.summary.scalar('train_loss', loss)
72 | tf.summary.scalar('train_accuracy', accuracy)
73 | merged_summary = tf.summary.merge_all()
74 |
75 | train_writer = tf.summary.FileWriter(tensorboard_train_dir)
76 | val_writer = tf.summary.FileWriter(tensorboard_val_dir)
77 | saver = tf.train.Saver()
78 |
79 | # Batch preprocessors
80 | multi_scale = FLAGS.multi_scale.split(',')
81 | if len(multi_scale) == 2:
82 | multi_scale = [int(multi_scale[0]), int(multi_scale[1])]
83 | else:
84 | multi_scale = None
85 |
86 | train_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.training_file, num_classes=FLAGS.num_classes,
87 | output_size=[224, 224], horizontal_flip=True, shuffle=True, multi_scale=multi_scale)
88 | val_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.val_file, num_classes=FLAGS.num_classes, output_size=[224, 224])
89 |
90 | # Get the number of training/validation steps per epoch
91 | train_batches_per_epoch = np.floor(len(train_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)
92 | val_batches_per_epoch = np.floor(len(val_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)
93 |
94 |
95 | with tf.Session() as sess:
96 | sess.run(tf.global_variables_initializer())
97 | train_writer.add_graph(sess.graph)
98 |
99 | # Load the pretrained weights
100 | model.load_original_weights(sess, skip_layers=train_layers)
101 |
102 | # Directly restore (your model should be exactly the same with checkpoint)
103 | # saver.restore(sess, "/Users/dgurkaynak/Projects/marvel-training/alexnet64-fc6/model_epoch10.ckpt")
104 |
105 | print("{} Start training...".format(datetime.datetime.now()))
106 | print("{} Open Tensorboard at --logdir {}".format(datetime.datetime.now(), tensorboard_dir))
107 |
108 | for epoch in range(FLAGS.num_epochs):
109 | print("{} Epoch number: {}".format(datetime.datetime.now(), epoch+1))
110 | step = 1
111 |
112 | # Start training
113 | while step < train_batches_per_epoch:
114 | batch_xs, batch_ys = train_preprocessor.next_batch(FLAGS.batch_size)
115 | sess.run(train_op, feed_dict={x: batch_xs, y: batch_ys, is_training: True})
116 |
117 | # Logging
118 | if step % FLAGS.log_step == 0:
119 | s = sess.run(merged_summary, feed_dict={x: batch_xs, y: batch_ys, is_training: False})
120 | train_writer.add_summary(s, epoch * train_batches_per_epoch + step)
121 |
122 | step += 1
123 |
124 | # Epoch completed, start validation
125 | print("{} Start validation".format(datetime.datetime.now()))
126 | test_acc = 0.
127 | test_count = 0
128 |
129 | for _ in range(val_batches_per_epoch):
130 | batch_tx, batch_ty = val_preprocessor.next_batch(FLAGS.batch_size)
131 | acc = sess.run(accuracy, feed_dict={x: batch_tx, y: batch_ty, is_training: False})
132 | test_acc += acc
133 | test_count += 1
134 |
135 | test_acc /= test_count
136 | s = tf.Summary(value=[
137 | tf.Summary.Value(tag="validation_accuracy", simple_value=test_acc)
138 | ])
139 | val_writer.add_summary(s, epoch+1)
140 | print("{} Validation Accuracy = {:.4f}".format(datetime.datetime.now(), test_acc))
141 |
142 | # Reset the dataset pointers
143 | val_preprocessor.reset_pointer()
144 | train_preprocessor.reset_pointer()
145 |
146 | print("{} Saving checkpoint of model...".format(datetime.datetime.now()))
147 |
148 | #save checkpoint of the model
149 | checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch'+str(epoch+1)+'.ckpt')
150 | save_path = saver.save(sess, checkpoint_path)
151 |
152 | print("{} Model checkpoint saved at {}".format(datetime.datetime.now(), checkpoint_path))
153 |
154 | if __name__ == '__main__':
155 | tf.app.run()
156 |
--------------------------------------------------------------------------------
/model/resnet/model.py:
--------------------------------------------------------------------------------
1 | """
2 | Derived from: https://github.com/ry/tensorflow-resnet
3 | """
4 | import tensorflow as tf
5 | import numpy as np
6 | from tensorflow.python.ops import control_flow_ops
7 | from tensorflow.python.training import moving_averages
8 |
9 |
10 | NUM_BLOCKS = {
11 | 50: [3, 4, 6, 3],
12 | 101: [3, 4, 23, 3],
13 | 152: [3, 8, 36, 3]
14 | }
15 | CONV_WEIGHT_DECAY = 0.00004
16 | CONV_WEIGHT_STDDEV = 0.1
17 | MOVING_AVERAGE_DECAY = 0.9997
18 | BN_DECAY = MOVING_AVERAGE_DECAY
19 | BN_EPSILON = 0.001
20 | UPDATE_OPS_COLLECTION = 'resnet_update_ops'
21 | FC_WEIGHT_STDDEV = 0.01
22 |
23 |
24 | class ResNetModel(object):
25 |
26 | def __init__(self, is_training, depth=50, num_classes=1000):
27 | self.is_training = is_training
28 | self.num_classes = num_classes
29 | self.depth = depth
30 |
31 | if depth in NUM_BLOCKS:
32 | self.num_blocks = NUM_BLOCKS[depth]
33 | else:
34 | raise ValueError('Depth is not supported; it must be 50, 101 or 152')
35 |
36 | def inference(self, x, reuse=None, extract_feat=False):
37 | # Scale 1
38 | with tf.variable_scope('scale1', reuse=reuse):
39 | s1_conv = conv(x, ksize=7, stride=2, filters_out=64)
40 | s1_bn = bn(s1_conv, is_training=self.is_training)
41 | s1 = tf.nn.relu(s1_bn)
42 |
43 | # Scale 2
44 | with tf.variable_scope('scale2', reuse=reuse):
45 | s2_mp = tf.nn.max_pool(s1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')
46 | s2 = stack(s2_mp, is_training=self.is_training, num_blocks=self.num_blocks[0], stack_stride=1, block_filters_internal=64)
47 |
48 | # Scale 3
49 | with tf.variable_scope('scale3', reuse=reuse):
50 | s3 = stack(s2, is_training=self.is_training, num_blocks=self.num_blocks[1], stack_stride=2, block_filters_internal=128)
51 |
52 | # Scale 4
53 | with tf.variable_scope('scale4', reuse=reuse):
54 | s4 = stack(s3, is_training=self.is_training, num_blocks=self.num_blocks[2], stack_stride=2, block_filters_internal=256)
55 |
56 | # Scale 5
57 | with tf.variable_scope('scale5', reuse=reuse):
58 | s5 = stack(s4, is_training=self.is_training, num_blocks=self.num_blocks[3], stack_stride=2, block_filters_internal=512)
59 |
60 | # post-net
61 | avg_pool = tf.reduce_mean(s5, reduction_indices=[1, 2], name='avg_pool') # (bs/k, 2048)
62 |
63 | if extract_feat:
64 | return avg_pool
65 |
66 | with tf.variable_scope('fc', reuse=reuse):
67 | self.prob = fc(avg_pool, num_units_out=256)
68 | return self.prob
69 | # return avg_pool
70 |
71 | def loss(self, batch_x, batch_y=None):
72 | y_predict = self.inference(batch_x)
73 | # cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y_predict, labels=batch_y)
74 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=y_predict, labels=batch_y)
75 | cross_entropy_mean = tf.reduce_mean(cross_entropy)
76 | regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
77 | self.loss = tf.add_n([cross_entropy_mean] + regularization_losses)
78 | return self.loss
79 |
80 | def optimize(self, learning_rate, other_var_list, train_layers, loss):
81 | # all_other_variables = []
82 | # for var_lst in other_var_list:
83 | # for var in var_lst:
84 | # all_other_variables += var
85 | trainable_var_names = ['weights', 'biases', 'beta', 'gamma']
86 | var_list = [[v for v in tf.trainable_variables() if
87 | v.name.split(':')[0].split('/')[-1] in trainable_var_names and
88 | contains(v.name, train_layers)]] + other_var_list
89 | print('len(var_list)', len(var_list))
90 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss, var_list=var_list)
91 |
92 | ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY)
93 | tf.add_to_collection(UPDATE_OPS_COLLECTION, ema.apply([loss]))
94 |
95 | batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION)
96 | batchnorm_updates_op = tf.group(*batchnorm_updates)
97 |
98 | return tf.group(train_op, batchnorm_updates_op)
99 |
100 | def load_original_weights(self, session, skip_layers=[]):
101 | weights_path = 'ResNet-L{}.npy'.format(self.depth)
102 | weights_dict = np.load(weights_path, encoding='bytes', allow_pickle=True).item()
103 |
104 | for op_name in weights_dict:
105 | parts = op_name.split('/')
106 |
107 | # if contains(op_name, skip_layers):
108 | # continue
109 |
110 | if parts[0] == 'fc' and self.num_classes != 1000:
111 | continue
112 |
113 | full_name = "{}:0".format(op_name)
114 | var = [v for v in tf.global_variables() if v.name == full_name][0] # also assign mean, var of each ResNet layer
115 | session.run(var.assign(weights_dict[op_name]))
116 |
117 |
118 | """
119 | Helper methods
120 | """
121 | def _get_variable(name, shape, initializer, weight_decay=0.0, dtype='float', trainable=True):
122 | "A little wrapper around tf.get_variable to do weight decay"
123 |
124 | if weight_decay > 0:
125 | regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
126 | else:
127 | regularizer = None
128 |
129 | return tf.get_variable(name, shape=shape, initializer=initializer, dtype=dtype, regularizer=regularizer,
130 | trainable=trainable)
131 |
132 | def conv(x, ksize, stride, filters_out):
133 | filters_in = x.get_shape()[-1]
134 | shape = [ksize, ksize, filters_in, filters_out]
135 | initializer = tf.truncated_normal_initializer(stddev=CONV_WEIGHT_STDDEV)
136 | weights = _get_variable('weights', shape=shape, dtype='float', initializer=initializer,
137 | weight_decay=CONV_WEIGHT_DECAY)
138 | return tf.nn.conv2d(x, weights, [1, stride, stride, 1], padding='SAME')
139 |
140 | def bn(x, is_training):
141 | x_shape = x.get_shape()
142 | params_shape = x_shape[-1:]
143 |
144 | axis = list(range(len(x_shape) - 1))
145 |
146 | beta = _get_variable('beta', params_shape, initializer=tf.zeros_initializer())
147 | gamma = _get_variable('gamma', params_shape, initializer=tf.ones_initializer())
148 |
149 | moving_mean = _get_variable('moving_mean', params_shape, initializer=tf.zeros_initializer(), trainable=False)
150 | moving_variance = _get_variable('moving_variance', params_shape, initializer=tf.ones_initializer(), trainable=False)
151 |
152 | # These ops will only be preformed when training.
153 | mean, variance = tf.nn.moments(x, axis)
154 | update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, BN_DECAY)
155 | update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, BN_DECAY)
156 | tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
157 | tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)
158 |
159 | mean, variance = control_flow_ops.cond(
160 | is_training, lambda: (mean, variance),
161 | lambda: (moving_mean, moving_variance))
162 |
163 | return tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)
164 |
165 | def stack(x, is_training, num_blocks, stack_stride, block_filters_internal):
166 | for n in range(num_blocks):
167 | block_stride = stack_stride if n == 0 else 1
168 | with tf.variable_scope('block%d' % (n + 1)):
169 | x = block(x, is_training, block_filters_internal=block_filters_internal, block_stride=block_stride)
170 | return x
171 |
172 |
173 | def block(x, is_training, block_filters_internal, block_stride):
174 | filters_in = x.get_shape()[-1]
175 |
176 | m = 4
177 | filters_out = m * block_filters_internal
178 | shortcut = x
179 |
180 | with tf.variable_scope('a'):
181 | a_conv = conv(x, ksize=1, stride=block_stride, filters_out=block_filters_internal)
182 | a_bn = bn(a_conv, is_training)
183 | a = tf.nn.relu(a_bn)
184 |
185 | with tf.variable_scope('b'):
186 | b_conv = conv(a, ksize=3, stride=1, filters_out=block_filters_internal)
187 | b_bn = bn(b_conv, is_training)
188 | b = tf.nn.relu(b_bn)
189 |
190 | with tf.variable_scope('c'):
191 | c_conv = conv(b, ksize=1, stride=1, filters_out=filters_out)
192 | c = bn(c_conv, is_training)
193 |
194 | with tf.variable_scope('shortcut'):
195 | if filters_out != filters_in or block_stride != 1:
196 | shortcut_conv = conv(x, ksize=1, stride=block_stride, filters_out=filters_out)
197 | shortcut = bn(shortcut_conv, is_training)
198 |
199 | return tf.nn.relu(c + shortcut)
200 |
201 |
202 | def fc(x, num_units_out):
203 | num_units_in = x.get_shape()[1]
204 | weights_initializer = tf.truncated_normal_initializer(stddev=FC_WEIGHT_STDDEV)
205 | weights = _get_variable('weights', shape=[num_units_in, num_units_out], initializer=weights_initializer,
206 | weight_decay=FC_WEIGHT_STDDEV)
207 | biases = _get_variable('biases', shape=[num_units_out], initializer=tf.zeros_initializer())
208 | return tf.nn.xw_plus_b(x, weights, biases)
209 |
210 | def contains(target_str, search_arr):
211 | rv = False
212 |
213 | for search_str in search_arr:
214 | if search_str in target_str:
215 | rv = True
216 | break
217 |
218 | return rv
219 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # MOST: Multi-Source Domain Adaptation via Optimal Transport for Student-Teacher Learning
6 |
7 | 


8 |
9 |
10 | This is the implementation of the paper **[MOST: Multi-Source Domain Adaptation via Optimal Transport for Student-Teacher Learning](https://proceedings.mlr.press/v161/nguyen21a/nguyen21a.pdf)** which has been accepted at UAI 2021.
11 |
12 |
13 |
14 |
15 |
16 | ## A. Setup
17 |
18 | #### **Install manually**
19 |
20 | ```
21 | Python Environment: >= 3.5
22 | Tensorflow: >= 1.9
23 | ```
24 |
25 | #### **Install automatically from YAML file**
26 |
27 | ```
28 | pip install --upgrade pip
29 | conda env create --file tf1.9py3.5.yml
30 | ```
31 |
32 | #### **Install *tensorbayes***
33 |
34 | Please note that tensorbayes 0.4.0 is out of date. Please copy a newer version to the *env* folder (tf1.9py3.5) using **tensorbayes.tar**
35 |
36 | ```
37 | pip install tensorbayes
38 | tar -xvf tensorbayes.tar
39 | cp -rf /tensorbayes/* /opt/conda/envs/tf1.9py3.5/lib/python3.5/site-packages/tensorbayes/
40 | ```
41 |
42 | ## B. Training
43 |
44 | #### 1. Digits-five
45 |
46 | We first navigate to *model* folder, and then run *run_most.py* file as bellow:
47 |
48 | ```python
49 | cd model
50 | ```
51 |
52 | To run on *Digits-five* dataset, in the root folder, please create a new folder named *features*.
53 |
54 | At the next step, user downloads *Digits-five* dataset [here](https://drive.google.com/file/d/1OpoPALgaMdOlkSkJDqf-JwEvRUkyDqkw/view?usp=sharing) and place extracting files to the *features* folder.
55 |
56 | 1. "→ **mm**'' task
57 |
58 | ```python
59 | python run_most_digits.py 1 "mnist32_60_10,usps32,svhn,syn32" mnistm32_60_10 format mat num_iters 80000 phase1_iters 0 summary_freq 800 learning_rate 0.0002 batch_size 200 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0,1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 trg_vat_troff 0.1 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 0.1 cast_data True cnn_size small theta 0.1 sample_size 5
60 | ```
61 |
62 | 2. ''→ **mt**'' task
63 | ```python
64 | python run_most_digits.py 1 "mnistm32_60_10,usps32,svhn,syn32" mnist32_60_10 format mat num_iters 80000 phase1_iters 0 summary_freq 800 learning_rate 0.0002 batch_size 200 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0,1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 trg_vat_troff 0.1 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 1.0 cast_data True cnn_size small theta 0.1 sample_size 5
65 | ```
66 |
67 | 3. ''→ **up**'' task
68 | ```python
69 | python run_most_digits.py 1 "mnistm32_60_10,mnist32_60_10,svhn,syn32" usps32 format mat num_iters 80000 phase1_iters 0 summary_freq 800 learning_rate 0.0002 batch_size 200 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0,1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 trg_vat_troff 0.1 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 1.0 cast_data True cnn_size small theta 0.1 sample_size 5
70 | ```
71 |
72 | 4. ''→ **sv**'' task
73 | ```python
74 | python run_most_digits.py 1 "mnistm32_60_10,mnist32_60_10,usps32,syn32" svhn format mat num_iters 80000 phase1_iters 0 summary_freq 800 learning_rate 0.0002 batch_size 200 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0,1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 trg_vat_troff 0.1 trg_ent_troff 0.0 data_shift_troff 10.0 mimic_trade_off 1.0 cast_data True cnn_size small theta 0.1 sample_size 5
75 | ```
76 |
77 | 5. ''→ **sy**'' task
78 | ```python
79 | python run_most_digits.py 1 "mnistm32_60_10,mnist32_60_10,usps32,svhn" syn32 format mat num_iters 80000 phase1_iters 0 summary_freq 800 learning_rate 0.0002 batch_size 200 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0,1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 trg_vat_troff 0.1 trg_ent_troff 0.0 data_shift_troff 10.0 mimic_trade_off 1.0 cast_data True cnn_size small theta 0.1 sample_size 5
80 | ```
81 |
82 | #### 2. Office-31
83 |
84 | #### Step 1: Train a shallow network using extracted features
85 |
86 | Please download extracted features from AlexNet [here](https://drive.google.com/file/d/1dsrHn4S6lCmlTa4Eg4RAE5JRfZUIxR8G/view?usp=sharing) and save them to the *features* folder.
87 |
88 | 1. ''→ **D**'' task
89 |
90 | ```python
91 | python run_most_AlexNet_train_feat.py 1 "amazon_AlexNet,webcam_AlexNet" dslr_AlexNet format mat num_iters 20000 summary_freq 200 learning_rate 0.0001 inorm True batch_size 62 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 src_vat_trade_off 0.0 trg_vat_troff 0.1 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 0.1 cast_data False cnn_size small theta 0.1 g_network_trade_off 1.0 sample_size 1 num_classes 31 multi_scale "" data_dir ""
92 | ```
93 |
94 | 2. ''→ **W**'' task
95 |
96 | ```python
97 | python run_most_AlexNet_train_feat.py 1 "amazon_AlexNet,dslr_AlexNet" webcam_AlexNet format mat num_iters 20000 summary_freq 200 learning_rate 0.0001 inorm True batch_size 62 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 src_vat_trade_off 0.0 trg_vat_troff 0.1 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 0.1 cast_data False cnn_size small theta 0.1 g_network_trade_off 1.0 sample_size 1 num_classes 31 multi_scale "" data_dir ""
98 | ```
99 |
100 | 3. ''→ **A**'' task
101 |
102 | ```python
103 | python run_most_AlexNet_train_feat.py 1 "dslr_AlexNet,webcam_AlexNet" amazon_AlexNet format mat num_iters 20000 summary_freq 200 learning_rate 0.0001 inorm True batch_size 62 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 src_vat_trade_off 0.0 trg_vat_troff 1.0 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 1.0 cast_data False cnn_size small theta 0.1 g_network_trade_off 1.0 sample_size 1 num_classes 31 multi_scale "" data_dir ""
104 | ```
105 |
106 | We will get a model which is saved in folder *MOST-results/saved-model* together with its unique id, or *LOG-ID* will be printed out at the end of training.
107 |
108 | #### Step 2: Finetune the entire model including AlexNet and the shallow network.
109 |
110 | Some mini-steps should be taken for [finetuning](https://github.com/dgurkaynak/tensorflow-cnn-finetune).
111 |
112 | - Download image data of *Office-31* [here](https://drive.google.com/file/d/1GdbY8GJ-HCp-YrDhDaLC-7BCZ0oxY6G9/view?usp=sharing) and extract them to a new folder named *data* in the root folder.
113 |
114 | - Download pre-trained AlexNet using the following command line, and save it to the *model* folder.
115 |
116 | ```bash
117 | wget http://www.cs.toronto.edu/~guerzhoy/tf_alexnet/bvlc_alexnet.npy
118 | ```
119 |
120 | Finally, please use your model id which is saved at Step 1 and replace ** in the following scripts. The model id is a string ID based on the time of running model at Step 1. It should be, for example, “2021-10-07_02.30.5748”.
121 |
122 | 1. ''→ **D**'' task
123 |
124 | ```python
125 | python run_most_AlexNet_finetune.py 1 "amazon,webcam" dslr format mat num_iters 2000 summary_freq 20 learning_rate 0.0001 inorm True batch_size 62 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 src_vat_trade_off 0.0 trg_vat_troff 0.1 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 0.1 cast_data False cnn_size small theta 0.1 g_network_trade_off 1.0 sample_size 1 num_classes 31 multi_scale "228,256" data_dir "" mdaot_model_id train_layers "fc7,fc6,conv5,conv4,conv3,conv2,conv1"
126 | ```
127 |
128 | 2. ''→ **W**'' task
129 |
130 | ```python
131 | python run_most_AlexNet_finetune.py 1 "amazon,dslr" webcam format mat num_iters 2000 summary_freq 20 learning_rate 0.0001 inorm True batch_size 62 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 src_vat_trade_off 0.0 trg_vat_troff 0.1 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 0.1 cast_data False cnn_size small theta 0.1 g_network_trade_off 1.0 sample_size 1 num_classes 31 multi_scale "228,256" data_dir "" mdaot_model_id train_layers "fc7,fc6,conv5,conv4,conv3,conv2,conv1"
132 | ```
133 |
134 | 3. ''→ **A**'' task
135 |
136 | ```python
137 | python run_most_AlexNet_finetune.py 1 "dslr,webcam" amazon format mat num_iters 5000 summary_freq 50 learning_rate 0.0001 inorm True batch_size 62 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 src_vat_trade_off 0.0 trg_vat_troff 0.1 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 0.1 cast_data False cnn_size small theta 0.1 g_network_trade_off 1.0 sample_size 1 num_classes 31 multi_scale "228,256" data_dir "" mdaot_model_id train_layers "fc7,fc6,conv5,conv4,conv3,conv2,conv1"
138 | ```
139 |
140 | ## C. Results
141 |
142 | #### Table 1: Classification accuracy (%) on Digits-five.
143 |
144 | | Methods | → mm | → mt | → us | → sv | → sy | Avg |
145 | | :-------------: | :--------: | :--------: | :--------: | :--------: | :--------: | :------: |
146 | | MDAN [1] | 69.5 | 98.0 | 92.4 | 69.2 | 87.4 | 83.3 |
147 | | DCTN [2] | 70.5 | 96.2 | 92.8 | 77.6 | 86.8 | 84.8 |
148 | | M3SDA [3] | 72.8 | 98.4 | 96.1 | 81.3 | 89.6 | 87.7 |
149 | | MDDA [4] | 78.6 | 98.8 | 93.9 | 79.3 | 89.7 | 88.1 |
150 | | LtC-MSDA [5] | 85.6 | 99.0 | 98.3 | 83.2 | 93.0 | 91.8 |
151 | | **MOST** (ours) | **91.5** | **99.6** | **98.4** | **90.9** | **96.4** | **95.4** |
152 |
153 | #### Table 2: Classification accuracy (%) on Office-31 using pretrained AlexNet.
154 |
155 | | Methods | → D | → W | → A | Avg |
156 | | :-------------: | :-------: | :-------: | :-------: | :------: |
157 | | MDAN [1] | 99.2 | 95.4 | 55.2 | 83.3 |
158 | | DCTN [2] | 99.6 | 96.9 | 54.9 | 83.8 |
159 | | M3SDA [3] | 99.4 | 96.2 | 55.4 | 83.7 |
160 | | MDDA [4] | 99.2 | 97.1 | 56.2 | 84.2 |
161 | | LtC-MSDA [5] | 99.6 | 97.2 | 56.9 | 84.6 |
162 | | **MOST** (ours) | **100** | **98.7** | **60.6** | **86.4** |
163 |
164 | ## D. Citations
165 |
166 | Please cite the paper if MOST is helpful for your research:
167 |
168 | ```
169 | @InProceedings{tuan2021most,
170 | author = {Nguyen, Tuan and Le, Trung and Zhao, He and Tran, Quan Hung and Nguyen, Truyen and Phung, Dinh},
171 | title = {Most: multi-source domain adaptation via optimal transport for student-teacher learning},
172 | booktitle= {Proceedings of the 37th Conference on Uncertainty in Artificial Intelligence (UAI)},
173 | year = {2021},
174 | abstract = {Multi-source domain adaptation (DA) is more challenging than conventional DA because the knowledge is transferred from several source domains to a target domain. To this end, we propose in this paper a novel model for multi-source DA using the theory of optimal transport and imitation learning. More specifically, our approach consists of two cooperative agents: a teacher classifier and a student classifier. The teacher classifier is a combined expert that leverages knowledge of domain experts that can be theoretically guaranteed to handle perfectly source examples, while the student classifier acting on the target domain tries to imitate the teacher classifier acting on the source domains. Our rigorous theory developed based on optimal transport makes this cross-domain imitation possible and also helps to mitigate not only the data shift but also the label shift, which are inherently thorny issues in DA research. We conduct comprehensive experiments on real-world datasets to demonstrate the merit of our approach and its optimal transport based imitation learning viewpoint. Experimental results show that our proposed method achieves state-of-the-art performance on benchmark datasets for multi-source domain adaptation including Digits-five, Office-Caltech10, and Office-31 to the best of our knowledge.}
175 | }
176 | ```
177 |
178 | ## E. References
179 |
180 | #### Baselines:
181 |
182 | - [1] H. Zhao, S. Zhang, G. Wu, J. M. F. Moura, J. P. Costeira, and G. J Gordon. Adversarial multiple source domain adaptation. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. CesaBianchi, and R. Garnett, editors, Advances in Neural Information Processing Systems 31, pages 8559-8570. Curran Associates, Inc., 2018 .
183 | - [2] R. Xu, Z. Chen, W. Zuo, J. Yan, and L. Lin. Deep cocktail network: Multi-source unsupervised domain adaptation with category shift. In 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 3964-3973, 2018.
184 | - [3] X. Peng, Q. Bai, X. Xia, Z. Huang, K. Saenko, and B. Wang. Moment matching for multi-source domain adaptation. In Proceedings of the IEEE International Conference on Computer Vision, pages 1406-1415, 2019.
185 | - [4] S. Zhao, G. Wang, S. Zhang, Y. Gu, Y. Li, Z. Song, P. Xu, R. Hu, H. Chai, and K. Keutzer. Multi-source distilling domain adaptation. In The Thirty-Fourth AAAI Conference on Artificial Intelligence, AAAI 2020, The Thirty-Second Innovative Applications of Artificial Intelligence Conference, IAAI 2020, The Tenth AAAI Symposium on Educational Advances in Artificial Intelligence, EAAI 2020, New York, NY, USA, February 7-12, 2020, pages 12975-12983. AAAI Press, 2020.
186 | - [5] H. Wang, M. Xu, B. Ni, and W. Zhang. Learning to combine: Knowledge aggregation for multisource domain adaptation. In Computer Vision - ECCV, 2020.
187 |
188 | #### GitHub repositories:
189 |
190 | - Folders *alexnet* and *resnet* are cloned from [Deniz Gurkaynak’s repository](https://github.com/dgurkaynak/tensorflow-cnn-finetune.git)
191 | - Some parts of our code (e.g., VAT, evaluation, …) are rewritten with modifications from [DIRT-T](https://github.com/RuiShu/dirt-t).
192 |
--------------------------------------------------------------------------------
/model/most_AlexNet_train_feat.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Tuan Nguyen.
2 | # All rights reserved.
3 |
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import absolute_import
7 |
8 | import tensorflow as tf
9 | from tensorflow.contrib.framework import arg_scope
10 | from tensorflow.contrib.framework import add_arg_scope
11 | from tensorbayes.layers import dense, conv2d, batch_norm, instance_norm
12 | from tensorbayes.tfutils import softmax_cross_entropy_with_two_logits as softmax_x_entropy_two
13 |
14 | from keras import backend as K
15 | from keras.preprocessing.image import ImageDataGenerator
16 |
17 | from generic_utils import random_seed
18 |
19 | from layers import leaky_relu
20 | import os
21 | from generic_utils import model_dir
22 | import numpy as np
23 | import tensorbayes as tb
24 | from layers import batch_ema_acc
25 | from keras.utils.np_utils import to_categorical
26 |
27 |
28 | def build_block(input_layer, layout, info=1):
29 | x = input_layer
30 | for i in range(0, len(layout)):
31 | with tf.variable_scope('l{:d}'.format(i)):
32 | f, f_args, f_kwargs = layout[i]
33 | x = f(x, *f_args, **f_kwargs)
34 | if info > 1:
35 | print(x)
36 | return x
37 |
38 |
39 | @add_arg_scope
40 | def normalize_perturbation(d, scope=None):
41 | with tf.name_scope(scope, 'norm_pert'):
42 | output = tf.nn.l2_normalize(d, axis=np.arange(1, len(d.shape)))
43 | return output
44 |
45 |
46 | def build_encode_template(
47 | input_layer, training_phase, scope, encode_layout,
48 | reuse=None, internal_update=False, getter=None, inorm=True, cnn_size='large'):
49 | with tf.variable_scope(scope, reuse=reuse, custom_getter=getter):
50 | with arg_scope([leaky_relu], a=0.1), \
51 | arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \
52 | arg_scope([batch_norm], internal_update=internal_update):
53 |
54 | preprocess = instance_norm if inorm else tf.identity
55 |
56 | layout = encode_layout(preprocess=preprocess, training_phase=training_phase, cnn_size=cnn_size)
57 | output_layer = build_block(input_layer, layout)
58 |
59 | return output_layer
60 |
61 |
62 | def build_class_discriminator_template(
63 | input_layer, training_phase, scope, num_classes, class_discriminator_layout,
64 | reuse=None, internal_update=False, getter=None, cnn_size='large'):
65 | with tf.variable_scope(scope, reuse=reuse, custom_getter=getter):
66 | with arg_scope([leaky_relu], a=0.1), \
67 | arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \
68 | arg_scope([batch_norm], internal_update=internal_update):
69 | layout = class_discriminator_layout(num_classes=num_classes, global_pool=True, activation=None,
70 | cnn_size=cnn_size)
71 | output_layer = build_block(input_layer, layout)
72 |
73 | return output_layer
74 |
75 |
76 | def build_domain_discriminator_template(x, domain_layout, c=1, reuse=None, scope='domain_disc'):
77 | with tf.variable_scope(scope, reuse=reuse):
78 | with arg_scope([dense], activation=tf.nn.relu):
79 | layout = domain_layout(c=c)
80 | output_layer = build_block(x, layout)
81 |
82 | return output_layer
83 |
84 |
85 | def build_phi_network_template(x, domain_layout, c=1, reuse=None):
86 | with tf.variable_scope('phi_net', reuse=reuse):
87 | with arg_scope([dense], activation=tf.nn.relu):
88 | layout = domain_layout(c=c)
89 | output_layer = build_block(x, layout)
90 |
91 | return output_layer
92 |
93 |
94 | def get_default_config():
95 | tf_config = tf.ConfigProto()
96 | tf_config.gpu_options.allow_growth = True
97 | tf_config.log_device_placement = False
98 | tf_config.allow_soft_placement = True
99 | return tf_config
100 |
101 |
102 | class MOST():
103 | def __init__(self,
104 | model_name="MOST-results",
105 | learning_rate=0.001,
106 | batch_size=128,
107 | num_iters=80000,
108 | summary_freq=400,
109 | src_class_trade_off=1.0,
110 | src_domain_trade_off='1.0,1.0',
111 | src_vat_trade_off=1.0,
112 | trg_vat_troff=0.1,
113 | trg_ent_troff=0.1,
114 | ot_trade_off=0.1,
115 | domain_trade_off=0.1,
116 | mimic_trade_off=0.1,
117 | encode_layout=None,
118 | classify_layout=None,
119 | domain_layout=None,
120 | phi_layout=None,
121 | current_time='',
122 | inorm=True,
123 | theta=0.1,
124 | g_network_trade_off=1.0,
125 | mdaot_model_id='',
126 | only_save_final_model=True,
127 | cnn_size='large',
128 | sample_size=50,
129 | data_shift_troff=10.0,
130 | train_layers='fc8',
131 | **kwargs):
132 | self.model_name = model_name
133 | self.batch_size = batch_size
134 | self.learning_rate = learning_rate
135 | self.num_iters = num_iters
136 | self.summary_freq = summary_freq
137 | self.src_class_trade_off = src_class_trade_off
138 | self.src_domain_trade_off = [float(item) for item in src_domain_trade_off.split(',')]
139 | self.src_vat_trade_off = src_vat_trade_off
140 | self.trg_vat_troff = trg_vat_troff
141 | self.trg_ent_troff = trg_ent_troff
142 | self.ot_trade_off = ot_trade_off
143 | self.domain_trade_off = domain_trade_off
144 | self.mimic_trade_off = mimic_trade_off
145 |
146 | self.encode_layout = encode_layout
147 | self.classify_layout = classify_layout
148 | self.domain_layout = domain_layout
149 | self.phi_layout = phi_layout
150 |
151 | self.current_time = current_time
152 | self.inorm = inorm
153 |
154 | self.theta = theta
155 | self.g_network_trade_off = g_network_trade_off
156 |
157 | self.mdaot_model_id = mdaot_model_id
158 | self.only_save_final_model = only_save_final_model
159 |
160 | self.cnn_size = cnn_size
161 | self.sample_size = sample_size
162 | self.data_shift_troff = data_shift_troff
163 | self.train_layers = train_layers
164 |
165 | def _init(self, data_loader):
166 | np.random.seed(random_seed())
167 | tf.set_random_seed(random_seed())
168 | tf.reset_default_graph()
169 |
170 | self.tf_graph = tf.get_default_graph()
171 | self.tf_config = get_default_config()
172 | self.tf_session = tf.Session(config=self.tf_config, graph=self.tf_graph)
173 |
174 | self.data_loader = data_loader
175 | self.num_classes = self.data_loader.num_class
176 | self.batch_size_src = self.batch_size // self.data_loader.num_src_domain
177 | assert len(self.src_domain_trade_off) == self.data_loader.num_src_domain
178 |
179 | def _get_variables(self, list_scopes):
180 | variables = []
181 | for scope_name in list_scopes:
182 | variables.append(tf.get_collection('trainable_variables', scope_name))
183 | return variables
184 |
185 | def convert_one_hot(self, y):
186 | y_idx = y.reshape(-1).astype(int) if y is not None else None
187 | y = np.eye(self.num_classes)[y_idx] if y is not None else None
188 | return y
189 |
190 | def _get_scope(self, part_name, side_name, same_network=True):
191 | suffix = ''
192 | if not same_network:
193 | suffix = '/' + side_name
194 | return part_name + suffix
195 |
196 | def _get_teacher_scopes(self):
197 | return ['generator', 'classifier', 'domain_disc']
198 |
199 | def _get_student_primary_scopes(self):
200 | return ['generator', 'c-trg']
201 |
202 | def _get_student_secondary_scopes(self):
203 | return ['phi_net']
204 |
205 | def _build_source_middle(self, x_src, is_reused):
206 | scope_name = self._get_scope('generator', 'src')
207 | if is_reused == 0:
208 | generator_model = build_encode_template(x_src, encode_layout=self.encode_layout,
209 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm, cnn_size=self.cnn_size)
210 | else:
211 | generator_model = build_encode_template(x_src, encode_layout=self.encode_layout,
212 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm,
213 | reuse=True, internal_update=True,
214 | cnn_size=self.cnn_size)
215 | return generator_model
216 |
217 | def _build_target_middle(self, x_trg, reuse=None):
218 | scope_name = 'generator'
219 | return build_encode_template(
220 | x_trg, encode_layout=self.encode_layout,
221 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm,
222 | reuse=reuse, internal_update=True, cnn_size=self.cnn_size
223 | )
224 |
225 | def _build_classifier(self, x, num_classes, ema=None, is_teacher=False):
226 | g_teacher_scope = self._get_scope('generator', 'teacher', same_network=False)
227 | g_x = build_encode_template(
228 | x, encode_layout=self.encode_layout,
229 | scope=g_teacher_scope if is_teacher else 'generator', training_phase=False, inorm=self.inorm,
230 | reuse=False if is_teacher else True, getter=None if is_teacher else tb.tfutils.get_getter(ema),
231 | cnn_size=self.cnn_size
232 | )
233 |
234 | h_teacher_scope = self._get_scope('c-trg', 'teacher', same_network=False)
235 | h_g_x = build_class_discriminator_template(
236 | g_x, training_phase=False, scope=h_teacher_scope if is_teacher else 'c-trg', num_classes=num_classes,
237 | reuse=False if is_teacher else True, class_discriminator_layout=self.classify_layout,
238 | getter=None if is_teacher else tb.tfutils.get_getter(ema), cnn_size=self.cnn_size
239 | )
240 | return h_g_x
241 |
242 | def _build_domain_discriminator(self, x_mid, reuse=None, scope='domain_disc'):
243 | return build_domain_discriminator_template(x_mid, domain_layout=self.domain_layout, c=self.data_loader.num_src_domain, reuse=reuse, scope=scope)
244 |
245 | def _build_phi_network(self, x_mid, reuse=None):
246 | return build_phi_network_template(x_mid, domain_layout=self.phi_layout, c=1, reuse=reuse)
247 |
248 | def _build_class_src_discriminator(self, x_src, num_src_classes, i, reuse=None):
249 | classifier_model = build_class_discriminator_template(
250 | x_src, training_phase=self.is_training, scope='classifier/{}'.format(i), num_classes=num_src_classes,
251 | reuse=reuse, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size
252 | )
253 | return classifier_model
254 |
255 | def _build_class_trg_discriminator(self, x_trg, num_trg_classes):
256 | return build_class_discriminator_template(
257 | x_trg, training_phase=self.is_training, scope='c-trg', num_classes=num_trg_classes,
258 | reuse=False, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size
259 | )
260 |
261 | def perturb_image(self, x, p, num_classes, class_discriminator_layout, encode_layout,
262 | pert='vat', scope=None, radius=3.5, scope_classify=None, scope_encode=None, training_phase=None):
263 | with tf.name_scope(scope, 'perturb_image'):
264 | eps = 1e-6 * normalize_perturbation(tf.random_normal(shape=tf.shape(x)))
265 |
266 | # Predict on randomly perturbed image
267 | x_eps_mid = build_encode_template(
268 | x + eps, encode_layout=encode_layout, scope=scope_encode, training_phase=training_phase, reuse=True,
269 | inorm=self.inorm, cnn_size=self.cnn_size)
270 | x_eps_pred = build_class_discriminator_template(
271 | x_eps_mid, class_discriminator_layout=class_discriminator_layout,
272 | training_phase=training_phase, scope=scope_classify, reuse=True, num_classes=num_classes,
273 | cnn_size=self.cnn_size
274 | )
275 | # eps_p = classifier(x + eps, phase=True, reuse=True)
276 | loss = softmax_x_entropy_two(labels=p, logits=x_eps_pred)
277 |
278 | # Based on perturbed image, get direction of greatest error
279 | eps_adv = tf.gradients(loss, [eps], aggregation_method=2)[0]
280 |
281 | # Use that direction as adversarial perturbation
282 | eps_adv = normalize_perturbation(eps_adv)
283 | x_adv = tf.stop_gradient(x + radius * eps_adv)
284 |
285 | return x_adv
286 |
287 | def vat_loss(self, x, p, num_classes, class_discriminator_layout, encode_layout,
288 | scope=None, scope_classify=None, scope_encode=None, training_phase=None):
289 |
290 | with tf.name_scope(scope, 'smoothing_loss'):
291 | x_adv = self.perturb_image(
292 | x, p, num_classes, class_discriminator_layout=class_discriminator_layout, encode_layout=encode_layout,
293 | scope_classify=scope_classify, scope_encode=scope_encode, training_phase=training_phase)
294 |
295 | x_adv_mid = build_encode_template(
296 | x_adv, encode_layout=encode_layout, scope=scope_encode, training_phase=training_phase, inorm=self.inorm,
297 | reuse=True, cnn_size=self.cnn_size)
298 | x_adv_pred = build_class_discriminator_template(
299 | x_adv_mid, training_phase=training_phase, scope=scope_classify, reuse=True, num_classes=num_classes,
300 | class_discriminator_layout=class_discriminator_layout, cnn_size=self.cnn_size
301 | )
302 | # p_adv = classifier(x_adv, phase=True, reuse=True)
303 | loss = tf.reduce_mean(softmax_x_entropy_two(labels=tf.stop_gradient(p), logits=x_adv_pred))
304 |
305 | return loss
306 |
307 | def _build_vat_loss(self, x, p, num_classes, scope=None, scope_classify=None, scope_encode=None):
308 | return self.vat_loss(
309 | x, p, num_classes,
310 | class_discriminator_layout=self.classify_layout,
311 | encode_layout=self.encode_layout,
312 | scope=scope, scope_classify=scope_classify, scope_encode=scope_encode,
313 | training_phase=self.is_training
314 | )
315 |
316 | def _compute_cosine_similarity(self, x_trg_mid, x_src_mid_all):
317 | x_trg_mid_flatten = tf.layers.Flatten()(x_trg_mid)
318 | x_src_mid_all_flatten = tf.layers.Flatten()(x_src_mid_all)
319 | similarity = tf.reduce_sum(x_trg_mid_flatten[:, tf.newaxis] * x_src_mid_all_flatten, axis=-1)
320 | similarity /= tf.norm(x_trg_mid_flatten[:, tf.newaxis], axis=-1) * tf.norm(x_src_mid_all_flatten, axis=-1)
321 | distance = 1.0 - similarity
322 | return distance
323 |
324 | def _compute_data_shift_loss(self, x_src_mid, x_trg_mid):
325 | x_src_mid_flatten = tf.layers.Flatten()(x_src_mid)
326 | x_trg_mid_flatten = tf.layers.Flatten()(x_trg_mid)
327 |
328 | data_shift_loss = tf.norm(tf.subtract(x_src_mid_flatten, tf.expand_dims(x_trg_mid_flatten, 1)), axis=2)
329 | return data_shift_loss
330 |
331 | def _compute_teacher_hs(self, y_label_trg_output_each_h, y_d_trg_sofmax_output):
332 | y_label_trg_output_each_h = tf.transpose(tf.stack(y_label_trg_output_each_h), perm=[1, 0, 2])
333 | y_d_trg_sofmax_output_multi_y = y_d_trg_sofmax_output
334 | y_d_trg_sofmax_output_multi_y = tf.expand_dims(y_d_trg_sofmax_output_multi_y, axis=-1)
335 | y_d_trg_sofmax_output_multi_y = tf.tile(y_d_trg_sofmax_output_multi_y, [1, 1, self.num_classes])
336 | y_label_trg_output = y_d_trg_sofmax_output_multi_y * y_label_trg_output_each_h
337 | y_label_trg_output = tf.reduce_sum(y_label_trg_output, axis=1)
338 | return y_label_trg_output
339 |
340 | def get_distances(self, a, b, name='L2'):
341 | if name == 'L1':
342 | return tf.reduce_sum(tf.abs(tf.expand_dims(a, 0) - tf.expand_dims(b, 1)), axis=-1)
343 | elif name == 'L2':
344 | return tf.sqrt(tf.reduce_sum(tf.square(tf.expand_dims(a, 0) - tf.expand_dims(b, 1)), axis=-1))
345 | elif name == 'CE':
346 | a_prob = tf.nn.softmax(a)
347 | b_prob = tf.nn.softmax(b)
348 | loss = -tf.reduce_sum(tf.multiply(tf.expand_dims(a_prob, 0), tf.log(tf.expand_dims(b_prob, 1) + 1e-12)), axis=-1)
349 | return loss
350 |
351 | def _build_model(self):
352 | self.x_src_lst = []
353 | self.y_src_lst = []
354 | for i in range(self.data_loader.num_src_domain):
355 | x_src = tf.placeholder(dtype=tf.float32, shape=(None, 8, 8, 64), name='x_src_{}_input'.format(i))
356 | y_src = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes),
357 | name='y_src_{}_input'.format(i))
358 |
359 | self.x_src_lst.append(x_src)
360 | self.y_src_lst.append(y_src)
361 |
362 | self.x_trg = tf.placeholder(dtype=tf.float32, shape=(None, 8, 8, 64), name='x_trg_input')
363 | self.y_trg = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes),
364 | name='y_trg_input')
365 | self.y_src_domain = tf.placeholder(dtype=tf.float32, shape=(None, self.data_loader.num_src_domain),
366 | name='y_src_domain_input')
367 |
368 | T = tb.utils.TensorDict(dict(
369 | x_tmp=tf.placeholder(dtype=tf.float32, shape=(None, 8, 8, 64)),
370 | y_tmp=tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes))
371 | ))
372 |
373 | self.is_training = tf.placeholder(tf.bool, shape=(), name='is_training')
374 |
375 | self.x_src_mid_lst = []
376 | for i in range(self.data_loader.num_src_domain):
377 | x_src_mid = self._build_source_middle(self.x_src_lst[i], is_reused=i)
378 | self.x_src_mid_lst.append(x_src_mid)
379 | self.x_trg_mid = self._build_target_middle(self.x_trg, reuse=True)
380 |
381 | #
382 | self.y_src_logit_lst = []
383 | for i in range(self.data_loader.num_src_domain):
384 | y_src_logit = self._build_class_src_discriminator(self.x_src_mid_lst[i], self.num_classes, i)
385 | self.y_src_logit_lst.append(y_src_logit)
386 | self.y_trg_logit = self._build_class_trg_discriminator(self.x_trg_mid,
387 | self.num_classes)
388 | #
389 |
390 | #
391 | self.src_loss_class_lst = []
392 | self.src_loss_class_sum = tf.constant(0.0)
393 | for i in range(self.data_loader.num_src_domain):
394 | src_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2(
395 | logits=self.y_src_logit_lst[i], labels=self.y_src_lst[i])
396 | src_loss_class = tf.reduce_mean(src_loss_class_detail)
397 | self.src_loss_class_lst.append(self.src_domain_trade_off[i]*src_loss_class)
398 | self.src_loss_class_sum += self.src_domain_trade_off[i]*src_loss_class
399 | #
400 |
401 | #
402 | self.x_src_mid_all = tf.concat(self.x_src_mid_lst, axis=0)
403 | self.y_src_discriminator_logit = self._build_domain_discriminator(self.x_src_mid_all)
404 |
405 | self.src_loss_discriminator_detail = tf.nn.softmax_cross_entropy_with_logits_v2(
406 | logits=self.y_src_discriminator_logit, labels=self.y_src_domain)
407 | self.src_loss_discriminator = tf.reduce_mean(self.src_loss_discriminator_detail)
408 | #
409 |
410 | #
411 | self.y_src_teacher_all = []
412 | for i, bs in zip(range(self.data_loader.num_src_domain),
413 | range(0, self.batch_size_src * self.data_loader.num_src_domain, self.batch_size_src)):
414 | y_src_logit_each_h_lst = []
415 | for j in range(self.data_loader.num_src_domain):
416 | y_src_logit_each_h = self._build_class_src_discriminator(self.x_src_mid_lst[i], self.num_classes,
417 | j, reuse=True)
418 | y_src_logit_each_h_lst.append(y_src_logit_each_h)
419 | y_src_logit_each_h_lst = tf.nn.softmax(tf.convert_to_tensor(y_src_logit_each_h_lst))
420 |
421 | y_src_discriminator_prob = tf.nn.softmax(tf.gather(self.y_src_discriminator_logit,
422 | tf.range(bs, bs + self.batch_size_src,
423 | dtype=tf.int32), axis=0))
424 | y_src_teacher = self._compute_teacher_hs(y_src_logit_each_h_lst, y_src_discriminator_prob)
425 | self.y_src_teacher_all.append(y_src_teacher)
426 | self.y_src_teacher_all = tf.concat(self.y_src_teacher_all, axis=0)
427 | #
428 |
429 | #
430 | y_trg_logit_each_h_lst = []
431 | for j in range(self.data_loader.num_src_domain):
432 | y_trg_logit_each_h = self._build_class_src_discriminator(self.x_trg_mid, self.num_classes,
433 | j, reuse=True)
434 | y_trg_logit_each_h_lst.append(y_trg_logit_each_h)
435 | y_trg_logit_each_h_lst = tf.nn.softmax(tf.convert_to_tensor(y_trg_logit_each_h_lst))
436 | self.y_trg_src_domains_logit = self._build_domain_discriminator(self.x_trg_mid, reuse=True)
437 | y_trg_discriminator_prob = tf.nn.softmax(self.y_trg_src_domains_logit)
438 | self.y_trg_teacher = self._compute_teacher_hs(y_trg_logit_each_h_lst, y_trg_discriminator_prob)
439 | #
440 |
441 | #
442 | self.ht_g_xs = build_class_discriminator_template(
443 | self.x_src_mid_all, training_phase=self.is_training, scope='c-trg', num_classes=self.num_classes,
444 | reuse=True, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size
445 | )
446 | self.mimic_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
447 | logits=self.ht_g_xs, labels=self.y_src_teacher_all)) + \
448 | tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
449 | logits=self.y_trg_logit, labels=self.y_trg_teacher))
450 | #
451 |
452 | #
453 | self.data_shift_loss = self._compute_cosine_similarity(self.x_trg_mid, self.x_src_mid_all)
454 | self.label_shift_loss = self.get_distances(self.y_trg_logit, self.ht_g_xs, 'CE')
455 | self.data_label_shift_loss = self.data_shift_troff*self.data_shift_loss + self.label_shift_loss
456 | self.g_network = tf.reshape(self._build_phi_network(self.x_trg_mid), [-1])
457 | self.exp_term = (- self.data_label_shift_loss + self.g_network) / self.theta
458 | self.g_network_loss = tf.reduce_mean(self.g_network)
459 | self.OT_loss = tf.reduce_mean(
460 | - self.theta * \
461 | (
462 | tf.log(1.0 / self.batch_size) +
463 | tf.reduce_logsumexp(self.exp_term, axis=1)
464 | )
465 | ) + self.g_network_trade_off * self.g_network_loss
466 | #
467 |
468 | #
469 | self.trg_loss_vat = self._build_vat_loss(
470 | self.x_trg, self.y_trg_logit, self.num_classes,
471 | scope_encode=self._get_scope('generator', 'trg'), scope_classify='c-trg'
472 | )
473 | #
474 |
475 | #
476 | self.trg_loss_cond_entropy = tf.reduce_mean(softmax_x_entropy_two(labels=self.y_trg_logit,
477 | logits=self.y_trg_logit))
478 | #
479 |
480 | #
481 | self.src_accuracy_lst = []
482 | for i in range(self.data_loader.num_src_domain):
483 | y_src_pred = tf.argmax(self.y_src_logit_lst[i], 1, output_type=tf.int32)
484 | y_src_sparse = tf.argmax(self.y_src_lst[i], 1, output_type=tf.int32)
485 | src_accuracy = tf.reduce_mean(tf.cast(tf.equal(y_src_sparse, y_src_pred), 'float32'))
486 | self.src_accuracy_lst.append(src_accuracy)
487 | # compute acc for target domain
488 | self.y_trg_pred = tf.argmax(self.y_trg_logit, 1, output_type=tf.int32)
489 | self.y_trg_sparse = tf.argmax(self.y_trg, 1, output_type=tf.int32)
490 | self.trg_accuracy = tf.reduce_mean(tf.cast(tf.equal(self.y_trg_sparse, self.y_trg_pred), 'float32'))
491 | # compute acc for src domain disc
492 | self.y_src_domain_pred = tf.argmax(self.y_src_discriminator_logit, 1, output_type=tf.int32)
493 | self.y_src_domain_sparse = tf.argmax(self.y_src_domain, 1, output_type=tf.int32)
494 | self.src_domain_acc = tf.reduce_mean(tf.cast(tf.equal(self.y_src_domain_sparse, self.y_src_domain_pred), 'float32'))
495 | #
496 |
497 | #
498 | lst_losses = [
499 | (self.src_class_trade_off, self.src_loss_class_sum),
500 | (self.ot_trade_off, self.OT_loss),
501 | (self.domain_trade_off, self.src_loss_discriminator),
502 | (self.trg_vat_troff, self.trg_loss_vat),
503 | (self.trg_ent_troff, self.trg_loss_cond_entropy),
504 | (self.mimic_trade_off, self.mimic_loss)
505 | ]
506 | self.total_loss = tf.constant(0.0)
507 | for trade_off, loss in lst_losses:
508 | self.total_loss += trade_off * loss
509 | #
510 |
511 | #
512 | primary_student_variables = self._get_variables(self._get_student_primary_scopes())
513 | ema = tf.train.ExponentialMovingAverage(decay=0.998)
514 | var_list_for_ema = primary_student_variables[0] + primary_student_variables[1]
515 | ema_op = ema.apply(var_list=var_list_for_ema)
516 | self.ema_p = self._build_classifier(T.x_tmp, self.num_classes, ema)
517 |
518 | self.batch_ema_acc = batch_ema_acc(T.y_tmp, self.ema_p)
519 | self.fn_batch_ema_acc = tb.function(self.tf_session, [T.x_tmp, T.y_tmp], self.batch_ema_acc)
520 |
521 | teacher_variables = self._get_variables(self._get_teacher_scopes())
522 | self.train_student_main = \
523 | tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.total_loss,
524 | var_list=teacher_variables + [primary_student_variables[1]])
525 | self.primary_train_student_op = tf.group(self.train_student_main, ema_op)
526 |
527 | secondary_variables = self._get_variables(self._get_student_secondary_scopes())
528 |
529 | self.secondary_train_student_op = \
530 | tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(-self.OT_loss,
531 | var_list=secondary_variables)
532 | #
533 |
534 | #
535 | tf.summary.scalar('loss/total_loss', self.total_loss)
536 | tf.summary.scalar('loss/W_distance', self.OT_loss)
537 | tf.summary.scalar('loss/src_loss_discriminator', self.src_loss_discriminator)
538 | tf.summary.scalar('loss/data_shift_loss', tf.reduce_mean(self.data_shift_loss))
539 | tf.summary.scalar('loss/label_shift_loss', tf.reduce_mean(self.label_shift_loss))
540 | tf.summary.scalar('loss/data_label_shift_loss', tf.reduce_mean(self.data_label_shift_loss))
541 | tf.summary.scalar('loss/exp_term', tf.reduce_mean(self.exp_term))
542 | tf.summary.histogram('loss/g_batch', self.g_network)
543 | tf.summary.scalar('loss/g_network_loss', self.g_network_loss)
544 |
545 | for i in range(self.data_loader.num_src_domain):
546 | tf.summary.scalar('loss/src_loss_class_{}'.format(i), self.src_loss_class_lst[i])
547 | tf.summary.scalar('acc/src_acc_{}'.format(i), self.src_accuracy_lst[i])
548 | tf.summary.scalar('acc/src_domain_acc', self.src_domain_acc)
549 | tf.summary.scalar('acc/trg_acc', self.trg_accuracy)
550 |
551 | tf.summary.scalar('hyperparameters/learning_rate', self.learning_rate)
552 | tf.summary.scalar('hyperparameters/src_class_trade_off', self.src_class_trade_off)
553 | tf.summary.scalar('hyperparameters/g_network_trade_off', self.g_network_trade_off)
554 | tf.summary.scalar('hyperparameters/domain_trade_off', self.domain_trade_off)
555 | tf.summary.scalar('hyperparameters/src_vat_trade_off', self.src_vat_trade_off)
556 | tf.summary.scalar('hyperparameters/trg_vat_troff', self.trg_vat_troff)
557 | tf.summary.scalar('hyperparameters/trg_ent_troff', self.trg_ent_troff)
558 | self.tf_merged_summaries = tf.summary.merge_all()
559 | #
560 |
561 | def _fit_loop(self):
562 | print('Start training MOST at', os.path.abspath(__file__))
563 | print('============ LOG-ID: %s ============' % self.current_time)
564 |
565 | num_src_samples_lst = []
566 | for k in range(self.data_loader.num_src_domain):
567 | num_src_samples = self.data_loader.src_train[k][2].shape[0]
568 | num_src_samples_lst.append(num_src_samples)
569 |
570 | num_trg_samples = self.data_loader.trg_train[0][1].shape[0]
571 | src_batchsize = self.batch_size // self.data_loader.num_src_domain
572 |
573 | self.tf_session.run(tf.global_variables_initializer())
574 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=101)
575 | self.log_path = os.path.join(model_dir(), self.model_name, "logs",
576 | "{}".format(self.current_time))
577 | self.tf_summary_writer = tf.summary.FileWriter(self.log_path, self.tf_session.graph)
578 |
579 | self.checkpoint_path = os.path.join(model_dir(), self.model_name, "saved-model", "{}".format(self.mdaot_model_id))
580 | check_point = tf.train.get_checkpoint_state(self.checkpoint_path)
581 |
582 | if check_point and tf.train.checkpoint_exists(check_point.model_checkpoint_path):
583 | print("Load model parameters from %s\n" % check_point.model_checkpoint_path)
584 | saver.restore(self.tf_session, check_point.model_checkpoint_path)
585 |
586 | feed_y_src_domain = to_categorical(np.repeat(np.arange(self.data_loader.num_src_domain),
587 | repeats=self.batch_size//self.data_loader.num_src_domain, axis=0))
588 |
589 | for it in range(self.num_iters):
590 | idx_trg_samples = np.random.permutation(num_trg_samples)[:self.batch_size]
591 | feed_data = dict()
592 | for k in range(self.data_loader.num_src_domain):
593 | idx_src_samples = np.random.permutation(num_src_samples_lst[k])[:src_batchsize]
594 | feed_data[self.x_src_lst[k]] = self.data_loader.src_train[k][1][idx_src_samples, :]
595 | feed_data[self.y_src_lst[k]] = self.data_loader.src_train[k][2][idx_src_samples]
596 |
597 | feed_data[self.x_trg] = self.data_loader.trg_train[0][1][idx_trg_samples, :]
598 | feed_data[self.y_trg] = self.data_loader.trg_train[0][2][idx_trg_samples]
599 |
600 | feed_data[self.y_src_domain] = feed_y_src_domain
601 | feed_data[self.is_training] = True
602 |
603 | for i in range(0, 5):
604 | g_idx_trg_samples = np.random.permutation(num_trg_samples)[:self.batch_size]
605 | g_feed_data = dict()
606 | for k in range(self.data_loader.num_src_domain):
607 | g_idx_src_samples = np.random.permutation(num_src_samples_lst[k])[:src_batchsize]
608 | g_feed_data[self.x_src_lst[k]] = self.data_loader.src_train[k][1][g_idx_src_samples, :]
609 | g_feed_data[self.y_src_lst[k]] = self.data_loader.src_train[k][2][g_idx_src_samples]
610 |
611 | g_feed_data[self.x_trg] = self.data_loader.trg_train[0][1][g_idx_trg_samples, :]
612 | g_feed_data[self.y_trg] = self.data_loader.trg_train[0][2][g_idx_trg_samples]
613 | g_feed_data[self.is_training] = True
614 |
615 | _, W_dist = \
616 | self.tf_session.run(
617 | [self.secondary_train_student_op, self.OT_loss],
618 | feed_dict=g_feed_data
619 | )
620 | _, total_loss, src_loss_class_sum, src_loss_class_lst, src_loss_discriminator, src_acc_lst, trg_acc, src_domain_acc, mimic_loss = \
621 | self.tf_session.run(
622 | [self.primary_train_student_op, self.total_loss, self.src_loss_class_sum, self.src_loss_class_lst, self.src_loss_discriminator,
623 | self.src_accuracy_lst, self.trg_accuracy, self.src_domain_acc, self.mimic_loss],
624 | feed_dict=feed_data
625 | )
626 |
627 | if it == 0 or (it + 1) % self.summary_freq == 0:
628 | print(
629 | "iter %d/%d total_loss %.3f; src_loss_class_sum %.3f; W_dist %.3f;\n src_loss_discriminator %.3f, pseudo_lbl_loss %.3f" % (
630 | it + 1, self.num_iters, total_loss, src_loss_class_sum, W_dist,
631 | src_loss_discriminator, mimic_loss))
632 | for k in range(self.data_loader.num_src_domain):
633 | print('src_loss_class_{}: {:.3f} acc {:.2f}'.format(k, src_loss_class_lst[k], src_acc_lst[k]*100))
634 | print("src_domain_disc_acc: %.2f, trg_acc: %.2f;" % (src_domain_acc*100, trg_acc*100))
635 |
636 | summary = self.tf_session.run(self.tf_merged_summaries, feed_dict=feed_data)
637 | self.tf_summary_writer.add_summary(summary, it + 1)
638 | self.tf_summary_writer.flush()
639 |
640 | if it == 0 or (it + 1) % self.summary_freq == 0:
641 | if not self.only_save_final_model:
642 | self.save_trained_model(saver, it + 1)
643 | elif it + 1 == self.num_iters:
644 | self.save_trained_model(saver, it + 1)
645 | if (it + 1) % (self.num_iters // 50) == 0:
646 | self.save_value(step=it + 1)
647 |
648 | def save_trained_model(self, saver, step):
649 | checkpoint_path = os.path.join(model_dir(), self.model_name, "saved-model",
650 | "{}".format(self.current_time))
651 | checkpoint_path = os.path.join(checkpoint_path, "mdaot_" + self.current_time + ".ckpt")
652 |
653 | directory = os.path.dirname(checkpoint_path)
654 | if not os.path.exists(directory):
655 | os.makedirs(directory)
656 | saver.save(self.tf_session, checkpoint_path, global_step=step)
657 |
658 | def save_value(self, step):
659 | ema_acc, summary = self.compute_value(x_full=self.data_loader.trg_test[0][1],
660 | y=self.data_loader.trg_test[0][2], labeler=None)
661 |
662 | self.tf_summary_writer.add_summary(summary, step)
663 | self.tf_summary_writer.flush()
664 |
665 | print_list = ['ema_acc', round(ema_acc * 100, 2)]
666 | print(print_list)
667 |
668 | def compute_value(self, x_full, y, labeler, full=True):
669 | with tb.nputils.FixedSeed(0):
670 | shuffle = np.random.permutation(len(x_full))
671 |
672 | xs = x_full[shuffle]
673 | ys = y[shuffle] if y is not None else None
674 |
675 | if not full:
676 | xs = xs[:1000]
677 | ys = ys[:1000] if ys is not None else None
678 |
679 | n = len(xs)
680 | bs = 200
681 | ema_acc_full = np.ones(n, dtype=float)
682 |
683 | for i in range(0, n, bs):
684 | x = xs[i:i + bs]
685 | y = ys[i:i + bs] if ys is not None else labeler(x)
686 | ema_acc_batch = self.fn_batch_ema_acc(x, y)
687 | ema_acc_full[i:i + bs] = ema_acc_batch
688 |
689 | ema_acc = np.mean(ema_acc_full)
690 | summary = tf.Summary.Value(tag='trg_test/ema_acc', simple_value=ema_acc)
691 | summary = tf.Summary(value=[summary])
692 | return ema_acc, summary
693 |
--------------------------------------------------------------------------------
/model/most_digits.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Tuan Nguyen.
2 | # All rights reserved.
3 |
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import absolute_import
7 |
8 | import tensorflow as tf
9 | from tensorflow.contrib.framework import arg_scope
10 | from tensorflow.contrib.framework import add_arg_scope
11 | from tensorbayes.layers import dense, conv2d, batch_norm, instance_norm
12 | from tensorbayes.tfutils import softmax_cross_entropy_with_two_logits as softmax_x_entropy_two
13 | from keras import backend as K
14 | from generic_utils import random_seed
15 | from layers import leaky_relu
16 | import os
17 | from generic_utils import model_dir
18 | import numpy as np
19 | import tensorbayes as tb
20 | from layers import batch_ema_acc
21 | from keras.utils.np_utils import to_categorical
22 |
23 |
24 | def build_block(input_layer, layout, info=1):
25 | x = input_layer
26 | for i in range(0, len(layout)):
27 | with tf.variable_scope('l{:d}'.format(i)):
28 | f, f_args, f_kwargs = layout[i]
29 | x = f(x, *f_args, **f_kwargs)
30 | if info > 1:
31 | print(x)
32 | return x
33 |
34 |
35 | @add_arg_scope
36 | def normalize_perturbation(d, scope=None):
37 | with tf.name_scope(scope, 'norm_pert'):
38 | output = tf.nn.l2_normalize(d, axis=np.arange(1, len(d.shape)))
39 | return output
40 |
41 |
42 | def build_encode_template(
43 | input_layer, training_phase, scope, encode_layout,
44 | reuse=None, internal_update=False, getter=None, inorm=True, cnn_size='large'):
45 | with tf.variable_scope(scope, reuse=reuse, custom_getter=getter):
46 | with arg_scope([leaky_relu], a=0.1), \
47 | arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \
48 | arg_scope([batch_norm], internal_update=internal_update):
49 | preprocess = instance_norm if inorm else tf.identity
50 |
51 | layout = encode_layout(preprocess=preprocess, training_phase=training_phase, cnn_size=cnn_size)
52 | output_layer = build_block(input_layer, layout)
53 |
54 | return output_layer
55 |
56 |
57 | def build_class_discriminator_template(
58 | input_layer, training_phase, scope, num_classes, class_discriminator_layout,
59 | reuse=None, internal_update=False, getter=None, cnn_size='large'):
60 | with tf.variable_scope(scope, reuse=reuse, custom_getter=getter):
61 | with arg_scope([leaky_relu], a=0.1), \
62 | arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \
63 | arg_scope([batch_norm], internal_update=internal_update):
64 | layout = class_discriminator_layout(num_classes=num_classes, global_pool=True, activation=None,
65 | cnn_size=cnn_size)
66 | output_layer = build_block(input_layer, layout)
67 |
68 | return output_layer
69 |
70 |
71 | def build_domain_discriminator_template(x, domain_layout, c=1, reuse=None, scope='domain_disc'):
72 | with tf.variable_scope(scope, reuse=reuse):
73 | with arg_scope([dense], activation=tf.nn.relu):
74 | layout = domain_layout(c=c)
75 | output_layer = build_block(x, layout)
76 |
77 | return output_layer
78 |
79 |
80 | def build_phi_network_template(x, domain_layout, c=1, reuse=None):
81 | with tf.variable_scope('phi_net', reuse=reuse):
82 | with arg_scope([dense], activation=tf.nn.relu):
83 | layout = domain_layout(c=c)
84 | output_layer = build_block(x, layout)
85 |
86 | return output_layer
87 |
88 |
89 | def get_default_config():
90 | tf_config = tf.ConfigProto()
91 | tf_config.gpu_options.allow_growth = True
92 | tf_config.log_device_placement = False
93 | tf_config.allow_soft_placement = True
94 | return tf_config
95 |
96 |
97 | class MOST():
98 | def __init__(self,
99 | model_name="MOST-results",
100 | learning_rate=0.001,
101 | batch_size=128,
102 | num_iters=80000,
103 | phase1_iters=20000,
104 | summary_freq=400,
105 | src_class_trade_off=1.0,
106 | src_domain_trade_off='1.0,1.0',
107 | trg_vat_troff=0.1,
108 | trg_ent_troff=0.1,
109 | ot_trade_off=0.1,
110 | domain_trade_off=0.1,
111 | mimic_trade_off=0.1,
112 | encode_layout=None,
113 | classify_layout=None,
114 | domain_layout=None,
115 | phi_layout=None,
116 | current_time='',
117 | inorm=True,
118 | theta=0.1,
119 | g_network_trade_off=1.0,
120 | mdaot_model_id='',
121 | only_save_final_model=True,
122 | cnn_size='small',
123 | sample_size=50,
124 | data_shift_troff=10.0,
125 | lbl_shift_troff=1.0,
126 | **kwargs):
127 | self.model_name = model_name
128 | self.batch_size = batch_size
129 | self.learning_rate = learning_rate
130 | self.num_iters = num_iters
131 | self.phase1_iters = phase1_iters
132 | self.summary_freq = summary_freq
133 | self.src_class_trade_off = src_class_trade_off
134 | self.src_domain_trade_off = [float(item) for item in src_domain_trade_off.split(',')]
135 | self.trg_vat_troff = trg_vat_troff
136 | self.trg_ent_troff = trg_ent_troff
137 | self.ot_trade_off = ot_trade_off
138 | self.domain_trade_off = domain_trade_off
139 | self.mimic_trade_off = mimic_trade_off
140 | self.encode_layout = encode_layout
141 | self.classify_layout = classify_layout
142 | self.domain_layout = domain_layout
143 | self.phi_layout = phi_layout
144 | self.current_time = current_time
145 | self.inorm = inorm
146 | self.theta = theta
147 | self.g_network_trade_off = g_network_trade_off
148 | self.mdaot_model_id = mdaot_model_id
149 | self.only_save_final_model = only_save_final_model
150 | self.cnn_size = cnn_size
151 | self.sample_size = sample_size
152 | self.data_shift_troff = data_shift_troff
153 | self.lbl_shift_troff = lbl_shift_troff
154 |
155 | def _init(self, data_loader):
156 | np.random.seed(random_seed())
157 | tf.set_random_seed(random_seed())
158 | tf.reset_default_graph()
159 |
160 | self.tf_graph = tf.get_default_graph()
161 | self.tf_config = get_default_config()
162 | self.tf_session = tf.Session(config=self.tf_config, graph=self.tf_graph)
163 |
164 | self.data_loader = data_loader
165 | self.num_classes = self.data_loader.num_class
166 | self.batch_size_src = self.sample_size * self.num_classes
167 |
168 | assert len(self.src_domain_trade_off) == self.data_loader.num_src_domain
169 | assert self.sample_size * self.num_classes * self.data_loader.num_src_domain == self.batch_size
170 |
171 | def _get_variables(self, list_scopes):
172 | variables = []
173 | for scope_name in list_scopes:
174 | variables.append(tf.get_collection('trainable_variables', scope_name))
175 | return variables
176 |
177 | def _get_scope(self, part_name, side_name, same_network=True):
178 | suffix = ''
179 | if not same_network:
180 | suffix = '/' + side_name
181 | return part_name + suffix
182 |
183 | def _get_teacher_scopes(self):
184 | return ['generator', 'classifier', 'domain_disc']
185 |
186 | def _get_student_primary_scopes(self):
187 | return ['generator', 'c-trg']
188 |
189 | def _get_student_secondary_scopes(self):
190 | return ['phi_net']
191 |
192 | def _build_source_middle(self, x_src, is_reused):
193 | scope_name = self._get_scope('generator', 'src')
194 | if is_reused == 0:
195 | generator_model = build_encode_template(x_src, encode_layout=self.encode_layout,
196 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm,
197 | cnn_size=self.cnn_size)
198 | else:
199 | generator_model = build_encode_template(x_src, encode_layout=self.encode_layout,
200 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm,
201 | reuse=True, internal_update=True,
202 | cnn_size=self.cnn_size)
203 | return generator_model
204 |
205 | def _build_target_middle(self, x_trg, reuse=None):
206 | scope_name = 'generator'
207 | return build_encode_template(
208 | x_trg, encode_layout=self.encode_layout,
209 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm,
210 | reuse=reuse, internal_update=True, cnn_size=self.cnn_size
211 | )
212 |
213 | def _build_classifier(self, x, num_classes, ema=None, is_teacher=False):
214 | g_teacher_scope = self._get_scope('generator', 'teacher', same_network=False)
215 | g_x = build_encode_template(
216 | x, encode_layout=self.encode_layout,
217 | scope=g_teacher_scope if is_teacher else 'generator', training_phase=False, inorm=self.inorm,
218 | reuse=False if is_teacher else True, getter=None if is_teacher else tb.tfutils.get_getter(ema),
219 | cnn_size=self.cnn_size
220 | )
221 |
222 | h_teacher_scope = self._get_scope('c-trg', 'teacher', same_network=False)
223 | h_g_x = build_class_discriminator_template(
224 | g_x, training_phase=False, scope=h_teacher_scope if is_teacher else 'c-trg', num_classes=num_classes,
225 | reuse=False if is_teacher else True, class_discriminator_layout=self.classify_layout,
226 | getter=None if is_teacher else tb.tfutils.get_getter(ema), cnn_size=self.cnn_size
227 | )
228 | return h_g_x
229 |
230 | def _build_domain_discriminator(self, x_mid, reuse=None, scope='domain_disc'):
231 | return build_domain_discriminator_template(x_mid, domain_layout=self.domain_layout,
232 | c=self.data_loader.num_src_domain, reuse=reuse, scope=scope)
233 |
234 | def _build_phi_network(self, x_mid, reuse=None):
235 | return build_phi_network_template(x_mid, domain_layout=self.phi_layout, c=1, reuse=reuse)
236 |
237 | def _build_class_src_discriminator(self, x_src, num_src_classes, i, reuse=None):
238 | classifier_model = build_class_discriminator_template(
239 | x_src, training_phase=self.is_training, scope='classifier/{}'.format(i), num_classes=num_src_classes,
240 | reuse=reuse, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size
241 | )
242 | return classifier_model
243 |
244 | def _build_class_trg_discriminator(self, x_trg, num_trg_classes):
245 | return build_class_discriminator_template(
246 | x_trg, training_phase=self.is_training, scope='c-trg', num_classes=num_trg_classes,
247 | reuse=False, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size
248 | )
249 |
250 | def perturb_image(self, x, p, num_classes, class_discriminator_layout, encode_layout,
251 | pert='vat', scope=None, radius=3.5, scope_classify=None, scope_encode=None, training_phase=None):
252 | with tf.name_scope(scope, 'perturb_image'):
253 | eps = 1e-6 * normalize_perturbation(tf.random_normal(shape=tf.shape(x)))
254 |
255 | # Predict on randomly perturbed image
256 | x_eps_mid = build_encode_template(
257 | x + eps, encode_layout=encode_layout, scope=scope_encode, training_phase=training_phase, reuse=True,
258 | inorm=self.inorm, cnn_size=self.cnn_size)
259 | x_eps_pred = build_class_discriminator_template(
260 | x_eps_mid, class_discriminator_layout=class_discriminator_layout,
261 | training_phase=training_phase, scope=scope_classify, reuse=True, num_classes=num_classes,
262 | cnn_size=self.cnn_size
263 | )
264 | # eps_p = classifier(x + eps, phase=True, reuse=True)
265 | loss = softmax_x_entropy_two(labels=p, logits=x_eps_pred)
266 |
267 | # Based on perturbed image, get direction of greatest error
268 | eps_adv = tf.gradients(loss, [eps], aggregation_method=2)[0]
269 |
270 | # Use that direction as adversarial perturbation
271 | eps_adv = normalize_perturbation(eps_adv)
272 | x_adv = tf.stop_gradient(x + radius * eps_adv)
273 | return x_adv
274 |
275 | def vat_loss(self, x, p, num_classes, class_discriminator_layout, encode_layout,
276 | scope=None, scope_classify=None, scope_encode=None, training_phase=None):
277 | with tf.name_scope(scope, 'smoothing_loss'):
278 | x_adv = self.perturb_image(
279 | x, p, num_classes, class_discriminator_layout=class_discriminator_layout, encode_layout=encode_layout,
280 | scope_classify=scope_classify, scope_encode=scope_encode, training_phase=training_phase)
281 |
282 | x_adv_mid = build_encode_template(
283 | x_adv, encode_layout=encode_layout, scope=scope_encode, training_phase=training_phase, inorm=self.inorm,
284 | reuse=True, cnn_size=self.cnn_size)
285 | x_adv_pred = build_class_discriminator_template(
286 | x_adv_mid, training_phase=training_phase, scope=scope_classify, reuse=True, num_classes=num_classes,
287 | class_discriminator_layout=class_discriminator_layout, cnn_size=self.cnn_size
288 | )
289 | # p_adv = classifier(x_adv, phase=True, reuse=True)
290 | loss = tf.reduce_mean(softmax_x_entropy_two(labels=tf.stop_gradient(p), logits=x_adv_pred))
291 |
292 | return loss
293 |
294 | def _build_vat_loss(self, x, p, num_classes, scope=None, scope_classify=None, scope_encode=None):
295 | return self.vat_loss(
296 | x, p, num_classes,
297 | class_discriminator_layout=self.classify_layout,
298 | encode_layout=self.encode_layout,
299 | scope=scope, scope_classify=scope_classify, scope_encode=scope_encode,
300 | training_phase=self.is_training
301 | )
302 |
303 | def _compute_cosine_similarity(self, x_trg_mid, x_src_mid_all):
304 | x_trg_mid_flatten = tf.layers.Flatten()(x_trg_mid)
305 | x_src_mid_all_flatten = tf.layers.Flatten()(x_src_mid_all)
306 | similarity = tf.reduce_sum(x_trg_mid_flatten[:, tf.newaxis] * x_src_mid_all_flatten, axis=-1)
307 | similarity /= tf.norm(x_trg_mid_flatten[:, tf.newaxis], axis=-1) * tf.norm(x_src_mid_all_flatten, axis=-1)
308 | distance = 1.0 - similarity
309 | return distance
310 |
311 | def _compute_label_shift_loss(self, y_trg_logit, y_src_teacher):
312 | y_trg_logit_rep = K.repeat_elements(tf.expand_dims(y_trg_logit, axis=0), rep=self.batch_size, axis=0)
313 | y_trg_logit_rep = tf.reshape(y_trg_logit_rep, [-1, y_trg_logit_rep.get_shape()[-1]])
314 | y_src_teacher_rep = K.repeat_elements(y_src_teacher, rep=self.batch_size, axis=0)
315 | label_shift_loss = softmax_x_entropy_two(labels=y_src_teacher_rep,
316 | logits=y_trg_logit_rep)
317 |
318 | label_shift_loss = tf.reshape(label_shift_loss, [self.batch_size, self.batch_size])
319 | return label_shift_loss
320 |
321 | def _compute_teacher_hs(self, y_label_trg_output_each_h, y_d_trg_sofmax_output):
322 | y_label_trg_output_each_h = tf.transpose(tf.stack(y_label_trg_output_each_h), perm=[1, 0, 2])
323 | y_d_trg_sofmax_output_multi_y = y_d_trg_sofmax_output
324 | y_d_trg_sofmax_output_multi_y = tf.expand_dims(y_d_trg_sofmax_output_multi_y, axis=-1)
325 | y_d_trg_sofmax_output_multi_y = tf.tile(y_d_trg_sofmax_output_multi_y, [1, 1, self.num_classes])
326 | y_label_trg_output = y_d_trg_sofmax_output_multi_y * y_label_trg_output_each_h
327 | y_label_trg_output = tf.reduce_sum(y_label_trg_output, axis=1)
328 | return y_label_trg_output
329 |
330 | def _build_model(self):
331 | self.x_src_lst = []
332 | self.y_src_lst = []
333 | for i in range(self.data_loader.num_src_domain):
334 | x_src = tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_src,
335 | name='x_src_{}_input'.format(i))
336 | y_src = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes),
337 | name='y_src_{}_input'.format(i))
338 | self.x_src_lst.append(x_src)
339 | self.y_src_lst.append(y_src)
340 |
341 | self.x_trg = tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_trg, name='x_trg_input')
342 | self.y_trg = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes),
343 | name='y_trg_input')
344 | self.y_src_domain = tf.placeholder(dtype=tf.float32, shape=(None, self.data_loader.num_src_domain),
345 | name='y_src_domain_input')
346 |
347 | T = tb.utils.TensorDict(dict(
348 | x_tmp=tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_src),
349 | y_tmp=tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes))
350 | ))
351 |
352 | self.is_training = tf.placeholder(tf.bool, shape=(), name='is_training')
353 |
354 | self.x_src_mid_lst = []
355 | for i in range(self.data_loader.num_src_domain):
356 | x_src_mid = self._build_source_middle(self.x_src_lst[i], is_reused=i)
357 | self.x_src_mid_lst.append(x_src_mid)
358 | self.x_trg_mid = self._build_target_middle(self.x_trg, reuse=True)
359 |
360 | #
361 | self.y_src_logit_lst = []
362 | for i in range(self.data_loader.num_src_domain):
363 | y_src_logit = self._build_class_src_discriminator(self.x_src_mid_lst[i], self.num_classes, i)
364 | self.y_src_logit_lst.append(y_src_logit)
365 |
366 | self.y_trg_logit = self._build_class_trg_discriminator(self.x_trg_mid,
367 | self.num_classes)
368 | #
369 |
370 | #
371 | self.src_loss_class_lst = []
372 | self.src_loss_class_sum = tf.constant(0.0)
373 | for i in range(self.data_loader.num_src_domain):
374 | src_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2(
375 | logits=self.y_src_logit_lst[i], labels=self.y_src_lst[i])
376 | src_loss_class = tf.reduce_mean(src_loss_class_detail)
377 | self.src_loss_class_lst.append(self.src_domain_trade_off[i] * src_loss_class)
378 | self.src_loss_class_sum += self.src_domain_trade_off[i] * src_loss_class
379 | self.trg_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2(
380 | logits=self.y_trg_logit, labels=self.y_trg)
381 | self.trg_loss_class = tf.reduce_mean(self.trg_loss_class_detail)
382 | #
383 |
384 | #
385 | self.x_src_mid_all = tf.concat(self.x_src_mid_lst, axis=0)
386 | self.y_src_discriminator_logit = self._build_domain_discriminator(self.x_src_mid_all)
387 |
388 | self.src_loss_discriminator_detail = tf.nn.softmax_cross_entropy_with_logits_v2(
389 | logits=self.y_src_discriminator_logit, labels=self.y_src_domain)
390 | self.src_loss_discriminator = tf.reduce_mean(self.src_loss_discriminator_detail)
391 | #
392 |
393 | #
394 | self.y_src_teacher_all = []
395 | for i, bs in zip(range(self.data_loader.num_src_domain),
396 | range(0, self.batch_size_src * self.data_loader.num_src_domain, self.batch_size_src)):
397 | y_src_logit_each_h_lst = []
398 | for j in range(self.data_loader.num_src_domain):
399 | y_src_logit_each_h = self._build_class_src_discriminator(self.x_src_mid_lst[i], self.num_classes,
400 | j, reuse=True)
401 | y_src_logit_each_h_lst.append(y_src_logit_each_h)
402 | y_src_logit_each_h_lst = tf.nn.softmax(tf.convert_to_tensor(y_src_logit_each_h_lst))
403 | y_src_discriminator_prob = tf.nn.softmax(tf.gather(self.y_src_discriminator_logit,
404 | tf.range(bs, bs + self.batch_size_src,
405 | dtype=tf.int32), axis=0))
406 | y_src_teacher = self._compute_teacher_hs(y_src_logit_each_h_lst, y_src_discriminator_prob)
407 | self.y_src_teacher_all.append(y_src_teacher)
408 | self.y_src_teacher_all = tf.concat(self.y_src_teacher_all, axis=0)
409 | #
410 |
411 | #
412 | y_trg_logit_each_h_lst = []
413 | for j in range(self.data_loader.num_src_domain):
414 | y_trg_logit_each_h = self._build_class_src_discriminator(self.x_trg_mid, self.num_classes,
415 | j, reuse=True)
416 | y_trg_logit_each_h_lst.append(y_trg_logit_each_h)
417 | y_trg_logit_each_h_lst = tf.nn.softmax(tf.convert_to_tensor(y_trg_logit_each_h_lst))
418 | self.y_trg_src_domains_logit = self._build_domain_discriminator(self.x_trg_mid, reuse=True)
419 | y_trg_discriminator_prob = tf.nn.softmax(self.y_trg_src_domains_logit)
420 | self.y_trg_teacher = self._compute_teacher_hs(y_trg_logit_each_h_lst, y_trg_discriminator_prob)
421 | #
422 |
423 | #
424 | self.ht_g_xs = build_class_discriminator_template(
425 | self.x_src_mid_all, training_phase=self.is_training, scope='c-trg', num_classes=self.num_classes,
426 | reuse=True, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size
427 | )
428 | self.mimic_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
429 | logits=self.ht_g_xs, labels=self.y_src_teacher_all)) + \
430 | tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
431 | logits=self.y_trg_logit, labels=self.y_trg_teacher))
432 | #
433 |
434 | #
435 | self.data_shift_loss = self._compute_cosine_similarity(self.x_trg_mid, self.x_src_mid_all)
436 | self.label_shift_loss = self._compute_label_shift_loss(self.y_trg_logit, self.ht_g_xs)
437 | self.data_label_shift_loss = self.data_shift_troff * self.data_shift_loss + self.lbl_shift_troff * self.label_shift_loss
438 | self.g_network = tf.reshape(self._build_phi_network(self.x_trg_mid), [-1])
439 | self.exp_term = (- self.data_label_shift_loss + self.g_network) / self.theta
440 | self.g_network_loss = tf.reduce_mean(self.g_network)
441 | self.OT_loss = tf.reduce_mean(
442 | - self.theta * \
443 | (
444 | tf.log(1.0 / self.batch_size) +
445 | tf.reduce_logsumexp(self.exp_term, axis=1)
446 | )
447 | ) + self.g_network_trade_off * self.g_network_loss
448 | #
449 |
450 | #
451 | self.trg_loss_vat = self._build_vat_loss(
452 | self.x_trg, self.y_trg_logit, self.num_classes,
453 | scope_encode=self._get_scope('generator', 'trg'), scope_classify='c-trg'
454 | )
455 | #
456 |
457 | #
458 | self.trg_loss_cond_entropy = tf.reduce_mean(softmax_x_entropy_two(labels=self.y_trg_logit,
459 | logits=self.y_trg_logit))
460 | #
461 |
462 | #
463 | self.src_accuracy_lst = []
464 | for i in range(self.data_loader.num_src_domain):
465 | y_src_pred = tf.argmax(self.y_src_logit_lst[i], 1, output_type=tf.int32)
466 | y_src_sparse = tf.argmax(self.y_src_lst[i], 1, output_type=tf.int32)
467 | src_accuracy = tf.reduce_mean(tf.cast(tf.equal(y_src_sparse, y_src_pred), 'float32'))
468 | self.src_accuracy_lst.append(src_accuracy)
469 | # compute acc for target domain
470 | self.y_trg_pred = tf.argmax(self.y_trg_logit, 1, output_type=tf.int32)
471 | self.y_trg_sparse = tf.argmax(self.y_trg, 1, output_type=tf.int32)
472 | self.trg_accuracy = tf.reduce_mean(tf.cast(tf.equal(self.y_trg_sparse, self.y_trg_pred), 'float32'))
473 | # compute acc for src domain disc
474 | self.y_src_domain_pred = tf.argmax(self.y_src_discriminator_logit, 1, output_type=tf.int32)
475 | self.y_src_domain_sparse = tf.argmax(self.y_src_domain, 1, output_type=tf.int32)
476 | self.src_domain_acc = tf.reduce_mean(
477 | tf.cast(tf.equal(self.y_src_domain_sparse, self.y_src_domain_pred), 'float32'))
478 | #
479 |
480 | #
481 | lst_phase1_losses = [
482 | (self.src_class_trade_off, self.src_loss_class_sum),
483 | (self.domain_trade_off, self.src_loss_discriminator),
484 | ]
485 | self.phase1_loss = tf.constant(0.0)
486 | for trade_off, loss in lst_phase1_losses:
487 | self.phase1_loss += trade_off * loss
488 |
489 | lst_phase2_losses = [
490 | (self.src_class_trade_off, self.src_loss_class_sum),
491 | (self.ot_trade_off, self.OT_loss),
492 | (self.domain_trade_off, self.src_loss_discriminator),
493 | (self.trg_vat_troff, self.trg_loss_vat),
494 | (self.trg_ent_troff, self.trg_loss_cond_entropy),
495 | (self.mimic_trade_off, self.mimic_loss)
496 | ]
497 | self.phase2_loss = tf.constant(0.0)
498 | for trade_off, loss in lst_phase2_losses:
499 | self.phase2_loss += trade_off * loss
500 | #
501 |
502 | #
503 | primary_student_variables = self._get_variables(self._get_student_primary_scopes())
504 | ema = tf.train.ExponentialMovingAverage(decay=0.998)
505 | var_list_for_ema = primary_student_variables[0] + primary_student_variables[1]
506 | ema_op = ema.apply(var_list=var_list_for_ema)
507 | self.ema_p = self._build_classifier(T.x_tmp, self.num_classes, ema)
508 | self.batch_ema_acc = batch_ema_acc(T.y_tmp, self.ema_p)
509 | self.fn_batch_ema_acc = tb.function(self.tf_session, [T.x_tmp, T.y_tmp], self.batch_ema_acc)
510 | #
511 |
512 | teacher_variables = self._get_variables(self._get_teacher_scopes())
513 | self.train_teacher = \
514 | tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.phase1_loss,
515 | var_list=teacher_variables)
516 | self.train_student_main = \
517 | tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.phase2_loss,
518 | var_list=teacher_variables + [
519 | primary_student_variables[1]])
520 | self.primary_train_student_op = tf.group(self.train_student_main, ema_op)
521 |
522 | #
523 | secondary_variables = self._get_variables(self._get_student_secondary_scopes())
524 | self.secondary_train_student_op = \
525 | tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(-self.OT_loss,
526 | var_list=secondary_variables)
527 | #
528 |
529 | #
530 | tf.summary.scalar('loss/phase1_loss', self.phase1_loss)
531 | tf.summary.scalar('loss/phase2_loss', self.phase2_loss)
532 | tf.summary.scalar('loss/W_distance', self.OT_loss)
533 | tf.summary.scalar('loss/src_loss_discriminator', self.src_loss_discriminator)
534 |
535 | for i in range(self.data_loader.num_src_domain):
536 | tf.summary.scalar('loss/src_loss_class_{}'.format(i), self.src_loss_class_lst[i])
537 | tf.summary.scalar('acc/src_acc_{}'.format(i), self.src_accuracy_lst[i])
538 | tf.summary.scalar('acc/src_domain_acc', self.src_domain_acc)
539 | tf.summary.scalar('acc/trg_acc', self.trg_accuracy)
540 |
541 | tf.summary.scalar('trg_loss_class', self.trg_loss_class)
542 | tf.summary.scalar('hyperparameters/learning_rate', self.learning_rate)
543 | tf.summary.scalar('hyperparameters/src_class_trade_off', self.src_class_trade_off)
544 | tf.summary.scalar('hyperparameters/g_network_trade_off', self.g_network_trade_off)
545 | tf.summary.scalar('hyperparameters/domain_trade_off', self.domain_trade_off)
546 | tf.summary.scalar('hyperparameters/trg_vat_troff', self.trg_vat_troff)
547 | tf.summary.scalar('hyperparameters/trg_ent_troff', self.trg_ent_troff)
548 | self.tf_merged_summaries = tf.summary.merge_all()
549 | #
550 |
551 | def _fit_loop(self):
552 | print('Start training MOST model at', os.path.abspath(__file__))
553 | print('============ LOG-ID: %s ============' % self.current_time)
554 | num_src_samples_lst = []
555 | for k in range(self.data_loader.num_src_domain):
556 | num_src_samples = self.data_loader.src_train[k][2].shape[0]
557 | num_src_samples_lst.append(num_src_samples)
558 |
559 | num_trg_samples = self.data_loader.trg_train[0][1].shape[0]
560 | src_batchsize = self.batch_size // self.data_loader.num_src_domain
561 |
562 | self.tf_session.run(tf.global_variables_initializer())
563 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=101)
564 | self.log_path = os.path.join(model_dir(), self.model_name, "logs",
565 | "{}".format(self.current_time))
566 | self.tf_summary_writer = tf.summary.FileWriter(self.log_path, self.tf_session.graph)
567 |
568 | self.checkpoint_path = os.path.join(model_dir(), self.model_name, "saved-model",
569 | "{}".format(self.mdaot_model_id))
570 | check_point = tf.train.get_checkpoint_state(self.checkpoint_path)
571 |
572 | if check_point and tf.train.checkpoint_exists(check_point.model_checkpoint_path):
573 | print("Load model parameters from %s\n" % check_point.model_checkpoint_path)
574 | saver.restore(self.tf_session, check_point.model_checkpoint_path)
575 |
576 | feed_y_src_domain = to_categorical(np.repeat(np.arange(self.data_loader.num_src_domain),
577 | repeats=self.sample_size * self.num_classes, axis=0))
578 |
579 | for it in range(self.num_iters):
580 | idx_trg_samples = np.random.permutation(num_trg_samples)[:self.batch_size]
581 | feed_data = dict()
582 | for k in range(self.data_loader.num_src_domain):
583 | idx_src_samples = np.random.permutation(num_src_samples_lst[k])[:src_batchsize]
584 | feed_data[self.x_src_lst[k]] = self.data_loader.src_train[k][1][idx_src_samples, :]
585 | feed_data[self.y_src_lst[k]] = self.data_loader.src_train[k][2][idx_src_samples]
586 | feed_data[self.x_trg] = self.data_loader.trg_train[0][1][idx_trg_samples, :]
587 | feed_data[self.y_trg] = self.data_loader.trg_train[0][2][idx_trg_samples]
588 | feed_data[self.y_src_domain] = feed_y_src_domain
589 | feed_data[self.is_training] = True
590 |
591 | if it < self.phase1_iters:
592 | _, total_loss, W_dist, src_loss_class_sum, src_loss_class_lst, src_loss_discriminator, src_acc_lst, trg_acc, src_domain_acc, mimic_loss = \
593 | self.tf_session.run(
594 | [self.train_teacher, self.phase1_loss, self.OT_loss, self.src_loss_class_sum,
595 | self.src_loss_class_lst, self.src_loss_discriminator,
596 | self.src_accuracy_lst, self.trg_accuracy, self.src_domain_acc,
597 | self.mimic_loss],
598 | feed_dict=feed_data
599 | )
600 | else:
601 | for i in range(0, 5):
602 | g_idx_trg_samples = np.random.permutation(num_trg_samples)[:self.batch_size]
603 | g_feed_data = dict()
604 | for k in range(self.data_loader.num_src_domain):
605 | g_idx_src_samples = np.random.permutation(num_src_samples_lst[k])[:src_batchsize]
606 | g_feed_data[self.x_src_lst[k]] = self.data_loader.src_train[k][1][g_idx_src_samples, :]
607 | g_feed_data[self.y_src_lst[k]] = self.data_loader.src_train[k][2][g_idx_src_samples]
608 |
609 | g_feed_data[self.x_trg] = self.data_loader.trg_train[0][1][g_idx_trg_samples, :]
610 | g_feed_data[self.y_trg] = self.data_loader.trg_train[0][2][g_idx_trg_samples]
611 | g_feed_data[self.is_training] = True
612 |
613 | _, W_dist = \
614 | self.tf_session.run(
615 | [self.secondary_train_student_op, self.OT_loss],
616 | feed_dict=g_feed_data
617 | )
618 | _, total_loss, src_loss_class_sum, src_loss_class_lst, src_loss_discriminator, src_acc_lst, trg_acc, src_domain_acc, mimic_loss = \
619 | self.tf_session.run(
620 | [self.primary_train_student_op, self.phase2_loss, self.src_loss_class_sum,
621 | self.src_loss_class_lst, self.src_loss_discriminator,
622 | self.src_accuracy_lst, self.trg_accuracy, self.src_domain_acc,
623 | self.mimic_loss],
624 | feed_dict=feed_data
625 | )
626 |
627 | if it == 0 or (it + 1) % self.summary_freq == 0:
628 | print(
629 | "iter %d/%d total_loss %.3f; src_loss_class_sum %.3f; W_dist %.3f; \n src_loss_discriminator %.3f, pseudo_lbl_loss %.3f" % (
630 | it + 1, self.num_iters, total_loss, src_loss_class_sum, W_dist,
631 | src_loss_discriminator, mimic_loss))
632 | for k in range(self.data_loader.num_src_domain):
633 | print('src_loss_class_{}: {:.3f} acc {:.2f}'.format(k, src_loss_class_lst[k], src_acc_lst[k] * 100))
634 | print("src_domain_disc_acc: %.2f, trg_acc: %.2f;" % (src_domain_acc * 100, trg_acc * 100))
635 | summary = self.tf_session.run(self.tf_merged_summaries, feed_dict=feed_data)
636 | self.tf_summary_writer.add_summary(summary, it + 1)
637 | self.tf_summary_writer.flush()
638 |
639 | if it == 0 or (it + 1) % self.summary_freq == 0:
640 | if not self.only_save_final_model:
641 | self.save_trained_model(saver, it + 1)
642 | elif it + 1 == self.phase1_iters or it + 1 == self.num_iters:
643 | self.save_trained_model(saver, it + 1)
644 | if it >= self.phase1_iters and (it + 1) % (self.num_iters // 50) == 0:
645 | self.save_value(step=it + 1)
646 |
647 | def save_trained_model(self, saver, step):
648 | checkpoint_path = os.path.join(model_dir(), self.model_name, "saved-model",
649 | "{}".format(self.current_time))
650 | checkpoint_path = os.path.join(checkpoint_path, "mdaot_" + self.current_time + ".ckpt")
651 |
652 | directory = os.path.dirname(checkpoint_path)
653 | if not os.path.exists(directory):
654 | os.makedirs(directory)
655 | saver.save(self.tf_session, checkpoint_path, global_step=step)
656 |
657 | def save_value(self, step):
658 | ema_acc, summary = self.compute_value(x_full=self.data_loader.trg_test[0][1],
659 | y=self.data_loader.trg_test[0][2], labeler=None)
660 | self.tf_summary_writer.add_summary(summary, step)
661 | self.tf_summary_writer.flush()
662 | print_list = ['test_acc', round(ema_acc * 100, 2),]
663 | print(print_list)
664 |
665 | def compute_value(self, x_full, y, labeler, full=True):
666 | with tb.nputils.FixedSeed(0):
667 | shuffle = np.random.permutation(len(x_full))
668 |
669 | xs = x_full[shuffle]
670 | ys = y[shuffle] if y is not None else None
671 | if not full:
672 | xs = xs[:1000]
673 | ys = ys[:1000] if ys is not None else None
674 |
675 | n = len(xs)
676 | bs = 200
677 |
678 | ema_acc_full = np.ones(n, dtype=float)
679 | for i in range(0, n, bs):
680 | x = xs[i:i + bs]
681 | y = ys[i:i + bs] if ys is not None else labeler(x)
682 | ema_acc_batch = self.fn_batch_ema_acc(x, y)
683 | ema_acc_full[i:i + bs] = ema_acc_batch
684 |
685 | ema_acc = np.mean(ema_acc_full)
686 | summary = tf.Summary.Value(tag='trg_test/ema_acc', simple_value=ema_acc)
687 | summary = tf.Summary(value=[summary])
688 | return ema_acc, summary
689 |
--------------------------------------------------------------------------------
/model/most_AlexNet_finetune.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Tuan Nguyen.
2 | # All rights reserved.
3 |
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import absolute_import
7 |
8 | import tensorflow as tf
9 | from tensorflow.contrib.framework import arg_scope
10 | from tensorflow.contrib.framework import add_arg_scope
11 | from tensorbayes.layers import dense, conv2d, batch_norm, instance_norm
12 | from tensorbayes.tfutils import softmax_cross_entropy_with_two_logits as softmax_x_entropy_two
13 |
14 | from keras import backend as K
15 | from keras.preprocessing.image import ImageDataGenerator
16 |
17 | from generic_utils import random_seed
18 |
19 | from layers import leaky_relu
20 | import os
21 | from generic_utils import model_dir
22 | import numpy as np
23 | import tensorbayes as tb
24 | from layers import batch_ema_acc
25 | from keras.utils.np_utils import to_categorical
26 | from alexnet.model import AlexNetModel
27 |
28 |
29 | def build_block(input_layer, layout, info=1):
30 | x = input_layer
31 | for i in range(0, len(layout)):
32 | with tf.variable_scope('l{:d}'.format(i)):
33 | f, f_args, f_kwargs = layout[i]
34 | x = f(x, *f_args, **f_kwargs)
35 | if info > 1:
36 | print(x)
37 | return x
38 |
39 |
40 | @add_arg_scope
41 | def normalize_perturbation(d, scope=None):
42 | with tf.name_scope(scope, 'norm_pert'):
43 | output = tf.nn.l2_normalize(d, axis=np.arange(1, len(d.shape)))
44 | return output
45 |
46 |
47 | def build_encode_template(
48 | input_layer, training_phase, scope, encode_layout,
49 | reuse=None, internal_update=False, getter=None, inorm=True, cnn_size='large'):
50 | with tf.variable_scope(scope, reuse=reuse, custom_getter=getter):
51 | with arg_scope([leaky_relu], a=0.1), \
52 | arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \
53 | arg_scope([batch_norm], internal_update=internal_update):
54 |
55 | preprocess = instance_norm if inorm else tf.identity
56 |
57 | layout = encode_layout(preprocess=preprocess, training_phase=training_phase, cnn_size=cnn_size)
58 | output_layer = build_block(input_layer, layout)
59 |
60 | return output_layer
61 |
62 |
63 | def build_class_discriminator_template(
64 | input_layer, training_phase, scope, num_classes, class_discriminator_layout,
65 | reuse=None, internal_update=False, getter=None, cnn_size='large'):
66 | with tf.variable_scope(scope, reuse=reuse, custom_getter=getter):
67 | with arg_scope([leaky_relu], a=0.1), \
68 | arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \
69 | arg_scope([batch_norm], internal_update=internal_update):
70 | layout = class_discriminator_layout(num_classes=num_classes, global_pool=True, activation=None,
71 | cnn_size=cnn_size)
72 | output_layer = build_block(input_layer, layout)
73 |
74 | return output_layer
75 |
76 |
77 | def build_domain_discriminator_template(x, domain_layout, c=1, reuse=None, scope='domain_disc'):
78 | with tf.variable_scope(scope, reuse=reuse):
79 | with arg_scope([dense], activation=tf.nn.relu):
80 | layout = domain_layout(c=c)
81 | output_layer = build_block(x, layout)
82 |
83 | return output_layer
84 |
85 |
86 | def build_phi_network_template(x, domain_layout, c=1, reuse=None):
87 | with tf.variable_scope('phi_net', reuse=reuse):
88 | with arg_scope([dense], activation=tf.nn.relu):
89 | layout = domain_layout(c=c)
90 | output_layer = build_block(x, layout)
91 |
92 | return output_layer
93 |
94 |
95 | def get_default_config():
96 | tf_config = tf.ConfigProto()
97 | tf_config.gpu_options.allow_growth = True
98 | tf_config.log_device_placement = False
99 | tf_config.allow_soft_placement = True
100 | return tf_config
101 |
102 |
103 | class MOST():
104 | def __init__(self,
105 | model_name="MOST-results",
106 | learning_rate=0.001,
107 | batch_size=128,
108 | num_iters=80000,
109 | summary_freq=400,
110 | src_class_trade_off=1.0,
111 | src_domain_trade_off='1.0,1.0',
112 | src_vat_trade_off=1.0,
113 | trg_vat_troff=0.1,
114 | trg_ent_troff=0.1,
115 | ot_trade_off=0.1,
116 | domain_trade_off=0.1,
117 | mimic_trade_off=0.1,
118 | encode_layout=None,
119 | classify_layout=None,
120 | domain_layout=None,
121 | phi_layout=None,
122 | current_time='',
123 | inorm=True,
124 | theta=0.1,
125 | g_network_trade_off=1.0,
126 | mdaot_model_id='',
127 | only_save_final_model=True,
128 | cnn_size='large',
129 | sample_size=50,
130 | data_shift_troff=10.0,
131 | lbl_shift_troff=1.0,
132 | num_classes=10,
133 | multi_scale='',
134 | resnet_depth=101,
135 | train_layers='fc8,fc7,fc6',
136 | **kwargs):
137 | self.model_name = model_name
138 | self.batch_size = batch_size
139 | self.learning_rate = learning_rate
140 | self.num_iters = num_iters
141 | self.summary_freq = summary_freq
142 | self.src_class_trade_off = src_class_trade_off
143 | self.src_domain_trade_off = [float(item) for item in src_domain_trade_off.split(',')]
144 | self.src_vat_trade_off = src_vat_trade_off
145 | self.trg_vat_troff = trg_vat_troff
146 | self.trg_ent_troff = trg_ent_troff
147 | self.ot_trade_off = ot_trade_off
148 | self.domain_trade_off = domain_trade_off
149 | self.mimic_trade_off = mimic_trade_off
150 |
151 | self.encode_layout = encode_layout
152 | self.classify_layout = classify_layout
153 | self.domain_layout = domain_layout
154 | self.phi_layout = phi_layout
155 |
156 | self.current_time = current_time
157 | self.inorm = inorm
158 |
159 | self.theta = theta
160 | self.g_network_trade_off = g_network_trade_off
161 |
162 | self.mdaot_model_id = mdaot_model_id
163 | self.only_save_final_model = only_save_final_model
164 |
165 | self.cnn_size = cnn_size
166 | self.sample_size = sample_size
167 | self.data_shift_troff = data_shift_troff
168 | self.lbl_shift_troff = lbl_shift_troff
169 |
170 | self.num_classes = num_classes
171 | self.multi_scale = multi_scale
172 | self.resnet_depth = resnet_depth
173 | self.train_layers = train_layers.split(',')
174 |
175 | def _init(self, src_preprocessors, trg_train_preprocessor, trg_test_preprocessor, num_src_domain):
176 | np.random.seed(random_seed())
177 | tf.set_random_seed(random_seed())
178 | tf.reset_default_graph()
179 |
180 | self.tf_graph = tf.get_default_graph()
181 | self.tf_config = get_default_config()
182 | self.tf_session = tf.Session(config=self.tf_config, graph=self.tf_graph)
183 |
184 | self.src_preprocessors = src_preprocessors
185 | self.trg_train_preprocessor = trg_train_preprocessor
186 | self.trg_test_preprocessor = trg_test_preprocessor
187 | self.num_src_domain = num_src_domain
188 | self.batch_size_src = self.sample_size*self.num_classes
189 |
190 | assert len(self.src_domain_trade_off) == self.num_src_domain
191 | assert self.sample_size*self.num_classes*self.num_src_domain == self.batch_size
192 |
193 | def _get_variables(self, list_scopes):
194 | variables = []
195 | for scope_name in list_scopes:
196 | variables.append(tf.get_collection('trainable_variables', scope_name))
197 | return variables
198 |
199 | def convert_one_hot(self, y):
200 | y_idx = y.reshape(-1).astype(int) if y is not None else None
201 | y = np.eye(self.num_classes)[y_idx] if y is not None else None
202 | return y
203 |
204 | def _get_scope(self, part_name, side_name, same_network=True):
205 | suffix = ''
206 | if not same_network:
207 | suffix = '/' + side_name
208 | return part_name + suffix
209 |
210 | def _get_all_g_h(self):
211 | return ['generator', 'classifier']
212 |
213 | def _get_teacher_scopes(self):
214 | return ['generator', 'classifier', 'domain_disc']
215 |
216 | def _get_student_primary_scopes(self):
217 | return ['generator', 'c-trg']
218 |
219 | def _get_student_secondary_scopes(self):
220 | return ['phi_net']
221 |
222 | def _build_source_middle(self, x_src, is_reused):
223 | scope_name = self._get_scope('generator', 'src')
224 | if is_reused == 0:
225 | generator_model = build_encode_template(x_src, encode_layout=self.encode_layout,
226 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm, cnn_size=self.cnn_size)
227 | else:
228 | generator_model = build_encode_template(x_src, encode_layout=self.encode_layout,
229 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm,
230 | reuse=True, internal_update=True,
231 | cnn_size=self.cnn_size)
232 | return generator_model
233 |
234 | def _build_target_middle(self, x_trg, reuse=None):
235 | scope_name = 'generator'
236 | return build_encode_template(
237 | x_trg, encode_layout=self.encode_layout,
238 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm,
239 | reuse=reuse, internal_update=True, cnn_size=self.cnn_size
240 | )
241 |
242 | def _build_classifier(self, x, num_classes, ema=None, is_teacher=False):
243 | g_teacher_scope = self._get_scope('generator', 'teacher', same_network=False)
244 | g_x = build_encode_template(
245 | x, encode_layout=self.encode_layout,
246 | scope=g_teacher_scope if is_teacher else 'generator', training_phase=False, inorm=self.inorm,
247 | reuse=False if is_teacher else True, getter=None if is_teacher else tb.tfutils.get_getter(ema),
248 | cnn_size=self.cnn_size
249 | )
250 |
251 | h_teacher_scope = self._get_scope('c-trg', 'teacher', same_network=False)
252 | h_g_x = build_class_discriminator_template(
253 | g_x, training_phase=False, scope=h_teacher_scope if is_teacher else 'c-trg', num_classes=num_classes,
254 | reuse=False if is_teacher else True, class_discriminator_layout=self.classify_layout,
255 | getter=None if is_teacher else tb.tfutils.get_getter(ema), cnn_size=self.cnn_size
256 | )
257 | return h_g_x
258 |
259 | def _build_domain_discriminator(self, x_mid, reuse=None, scope='domain_disc'):
260 | return build_domain_discriminator_template(x_mid, domain_layout=self.domain_layout, c=self.num_src_domain, reuse=reuse, scope=scope)
261 |
262 | def _build_phi_network(self, x_mid, reuse=None):
263 | return build_phi_network_template(x_mid, domain_layout=self.phi_layout, c=1, reuse=reuse)
264 |
265 | def _build_class_src_discriminator(self, x_src, num_src_classes, i, reuse=None):
266 | classifier_model = build_class_discriminator_template(
267 | x_src, training_phase=self.is_training, scope='classifier/{}'.format(i), num_classes=num_src_classes,
268 | reuse=reuse, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size
269 | )
270 | return classifier_model
271 |
272 | def _build_class_trg_discriminator(self, x_trg, num_trg_classes):
273 | return build_class_discriminator_template(
274 | x_trg, training_phase=self.is_training, scope='c-trg', num_classes=num_trg_classes,
275 | reuse=False, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size
276 | )
277 |
278 | def perturb_image(self, x, p, num_classes, class_discriminator_layout, encode_layout,
279 | pert='vat', scope=None, radius=3.5, scope_classify=None, scope_encode=None, training_phase=None):
280 | with tf.name_scope(scope, 'perturb_image'):
281 | eps = 1e-6 * normalize_perturbation(tf.random_normal(shape=tf.shape(x)))
282 |
283 | # Predict on randomly perturbed image
284 | x_eps_mid = build_encode_template(
285 | x + eps, encode_layout=encode_layout, scope=scope_encode, training_phase=training_phase, reuse=True,
286 | inorm=self.inorm, cnn_size=self.cnn_size)
287 | x_eps_pred = build_class_discriminator_template(
288 | x_eps_mid, class_discriminator_layout=class_discriminator_layout,
289 | training_phase=training_phase, scope=scope_classify, reuse=True, num_classes=num_classes,
290 | cnn_size=self.cnn_size
291 | )
292 | # eps_p = classifier(x + eps, phase=True, reuse=True)
293 | loss = softmax_x_entropy_two(labels=p, logits=x_eps_pred)
294 |
295 | # Based on perturbed image, get direction of greatest error
296 | eps_adv = tf.gradients(loss, [eps], aggregation_method=2)[0]
297 |
298 | # Use that direction as adversarial perturbation
299 | eps_adv = normalize_perturbation(eps_adv)
300 | x_adv = tf.stop_gradient(x + radius * eps_adv)
301 |
302 | return x_adv
303 |
304 | def vat_loss(self, x, p, num_classes, class_discriminator_layout, encode_layout,
305 | scope=None, scope_classify=None, scope_encode=None, training_phase=None):
306 |
307 | with tf.name_scope(scope, 'smoothing_loss'):
308 | x_adv = self.perturb_image(
309 | x, p, num_classes, class_discriminator_layout=class_discriminator_layout, encode_layout=encode_layout,
310 | scope_classify=scope_classify, scope_encode=scope_encode, training_phase=training_phase)
311 |
312 | x_adv_mid = build_encode_template(
313 | x_adv, encode_layout=encode_layout, scope=scope_encode, training_phase=training_phase, inorm=self.inorm,
314 | reuse=True, cnn_size=self.cnn_size)
315 | x_adv_pred = build_class_discriminator_template(
316 | x_adv_mid, training_phase=training_phase, scope=scope_classify, reuse=True, num_classes=num_classes,
317 | class_discriminator_layout=class_discriminator_layout, cnn_size=self.cnn_size
318 | )
319 | # p_adv = classifier(x_adv, phase=True, reuse=True)
320 | loss = tf.reduce_mean(softmax_x_entropy_two(labels=tf.stop_gradient(p), logits=x_adv_pred))
321 |
322 | return loss
323 |
324 | def _build_vat_loss(self, x, p, num_classes, scope=None, scope_classify=None, scope_encode=None):
325 | return self.vat_loss(
326 | x, p, num_classes,
327 | class_discriminator_layout=self.classify_layout,
328 | encode_layout=self.encode_layout,
329 | scope=scope, scope_classify=scope_classify, scope_encode=scope_encode,
330 | training_phase=self.is_training
331 | )
332 |
333 | def _compute_cosine_similarity(self, x_trg_mid, x_src_mid_all):
334 | x_trg_mid_flatten = tf.layers.Flatten()(x_trg_mid)
335 | x_src_mid_all_flatten = tf.layers.Flatten()(x_src_mid_all)
336 | similarity = tf.reduce_sum(x_trg_mid_flatten[:, tf.newaxis] * x_src_mid_all_flatten, axis=-1)
337 | similarity /= tf.norm(x_trg_mid_flatten[:, tf.newaxis], axis=-1) * tf.norm(x_src_mid_all_flatten, axis=-1)
338 | distance = 1.0 - similarity
339 | return distance
340 |
341 | def _compute_data_shift_loss(self, x_src_mid, x_trg_mid):
342 | x_src_mid_flatten = tf.layers.Flatten()(x_src_mid)
343 | x_trg_mid_flatten = tf.layers.Flatten()(x_trg_mid)
344 |
345 | data_shift_loss = tf.norm(tf.subtract(x_src_mid_flatten, tf.expand_dims(x_trg_mid_flatten, 1)), axis=2)
346 | return data_shift_loss
347 |
348 | def _compute_label_shift_loss(self, y_trg_logit, y_src_teacher):
349 | y_trg_logit_rep = K.repeat_elements(tf.expand_dims(y_trg_logit, axis=0), rep=self.batch_size, axis=0)
350 | y_trg_logit_rep = tf.reshape(y_trg_logit_rep, [-1, y_trg_logit_rep.get_shape()[-1]])
351 | y_src_teacher_rep = K.repeat_elements(y_src_teacher, rep=self.batch_size, axis=0)
352 | label_shift_loss = softmax_x_entropy_two(labels=y_src_teacher_rep,
353 | logits=y_trg_logit_rep)
354 |
355 | label_shift_loss = tf.reshape(label_shift_loss, [self.batch_size, self.batch_size])
356 | return label_shift_loss
357 |
358 | def _compute_teacher_hs(self, y_label_trg_output_each_h, y_d_trg_sofmax_output):
359 | y_label_trg_output_each_h = tf.transpose(tf.stack(y_label_trg_output_each_h), perm=[1, 0, 2])
360 | y_d_trg_sofmax_output_multi_y = y_d_trg_sofmax_output
361 | y_d_trg_sofmax_output_multi_y = tf.expand_dims(y_d_trg_sofmax_output_multi_y, axis=-1)
362 | y_d_trg_sofmax_output_multi_y = tf.tile(y_d_trg_sofmax_output_multi_y, [1, 1, self.num_classes])
363 | y_label_trg_output = y_d_trg_sofmax_output_multi_y * y_label_trg_output_each_h
364 | y_label_trg_output = tf.reduce_sum(y_label_trg_output, axis=1)
365 | return y_label_trg_output
366 |
367 | def _build_model(self):
368 | self.x_src_lst = []
369 | self.y_src_lst = []
370 | for i in range(self.num_src_domain):
371 | x_src = tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_src,
372 | name='x_src_{}_input'.format(i))
373 | y_src = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes),
374 | name='y_src_{}_input'.format(i))
375 |
376 | self.x_src_lst.append(x_src)
377 | self.y_src_lst.append(y_src)
378 |
379 | self.x_trg = tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_trg, name='x_trg_input')
380 | self.y_trg = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes),
381 | name='y_trg_input')
382 | self.y_src_domain = tf.placeholder(dtype=tf.float32, shape=(None, self.num_src_domain),
383 | name='y_src_domain_input')
384 |
385 | T = tb.utils.TensorDict(dict(
386 | x_tmp=tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_src),
387 | y_tmp=tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes))
388 | ))
389 |
390 | self.is_training = tf.placeholder(tf.bool, shape=(), name='is_training')
391 | self.alexNet = AlexNetModel(num_classes=self.num_classes, is_training=self.is_training)
392 |
393 | self.x_src_mid_lst = []
394 | for i in range(self.num_src_domain):
395 | if i == 0:
396 | x_src_mid_feat = tf.reshape(self.alexNet.inference(self.x_src_lst[i], extract_feat=True),
397 | (-1, 8, 8, 64))
398 | else:
399 | x_src_mid_feat = tf.reshape(self.alexNet.inference(self.x_src_lst[i], reuse=True, extract_feat=True),
400 | (-1, 8, 8, 64))
401 |
402 | x_src_mid = self._build_source_middle(x_src_mid_feat, is_reused=i)
403 | self.x_src_mid_lst.append(x_src_mid)
404 |
405 | self.x_trg_mid_feat = tf.reshape(self.alexNet.inference(self.x_trg, reuse=True, extract_feat=True), (-1, 8, 8, 64))
406 | self.x_trg_mid = self._build_target_middle(self.x_trg_mid_feat, reuse=True)
407 |
408 | #
409 | self.y_src_logit_lst = []
410 | for i in range(self.num_src_domain):
411 | y_src_logit = self._build_class_src_discriminator(self.x_src_mid_lst[i], self.num_classes, i)
412 | self.y_src_logit_lst.append(y_src_logit)
413 |
414 | self.y_trg_logit = self._build_class_trg_discriminator(self.x_trg_mid,
415 | self.num_classes)
416 | #
417 |
418 | #
419 | self.src_loss_class_lst = []
420 | self.src_loss_class_sum = tf.constant(0.0)
421 | for i in range(self.num_src_domain):
422 | src_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2(
423 | logits=self.y_src_logit_lst[i], labels=self.y_src_lst[i])
424 | src_loss_class = tf.reduce_mean(src_loss_class_detail)
425 | self.src_loss_class_lst.append(self.src_domain_trade_off[i]*src_loss_class)
426 | self.src_loss_class_sum += self.src_domain_trade_off[i]*src_loss_class
427 | #
428 |
429 | #
430 | self.x_src_mid_all = tf.concat(self.x_src_mid_lst, axis=0)
431 | self.y_src_discriminator_logit = self._build_domain_discriminator(self.x_src_mid_all)
432 | self.src_loss_discriminator_detail = tf.nn.softmax_cross_entropy_with_logits_v2(
433 | logits=self.y_src_discriminator_logit, labels=self.y_src_domain)
434 | self.src_loss_discriminator = tf.reduce_mean(self.src_loss_discriminator_detail)
435 | #
436 |
437 | #
438 | self.y_src_teacher_all = []
439 | for i, bs in zip(range(self.num_src_domain),
440 | range(0, self.batch_size_src * self.num_src_domain, self.batch_size_src)):
441 | y_src_logit_each_h_lst = []
442 | for j in range(self.num_src_domain):
443 | y_src_logit_each_h = self._build_class_src_discriminator(self.x_src_mid_lst[i], self.num_classes,
444 | j, reuse=True)
445 | y_src_logit_each_h_lst.append(y_src_logit_each_h)
446 | y_src_logit_each_h_lst = tf.nn.softmax(tf.convert_to_tensor(y_src_logit_each_h_lst))
447 | y_src_discriminator_prob = tf.nn.softmax(tf.gather(self.y_src_discriminator_logit,
448 | tf.range(bs, bs + self.batch_size_src,
449 | dtype=tf.int32), axis=0))
450 | y_src_teacher = self._compute_teacher_hs(y_src_logit_each_h_lst, y_src_discriminator_prob)
451 | self.y_src_teacher_all.append(y_src_teacher)
452 | self.y_src_teacher_all = tf.concat(self.y_src_teacher_all, axis=0)
453 | #
454 |
455 | #
456 | y_trg_logit_each_h_lst = []
457 | for j in range(self.num_src_domain):
458 | y_trg_logit_each_h = self._build_class_src_discriminator(self.x_trg_mid, self.num_classes,
459 | j, reuse=True)
460 | y_trg_logit_each_h_lst.append(y_trg_logit_each_h)
461 | y_trg_logit_each_h_lst = tf.nn.softmax(tf.convert_to_tensor(y_trg_logit_each_h_lst))
462 | self.y_trg_src_domains_logit = self._build_domain_discriminator(self.x_trg_mid, reuse=True)
463 | y_trg_discriminator_prob = tf.nn.softmax(self.y_trg_src_domains_logit)
464 | self.y_trg_teacher = self._compute_teacher_hs(y_trg_logit_each_h_lst, y_trg_discriminator_prob)
465 | #
466 |
467 | #
468 | self.ht_g_xs = build_class_discriminator_template(
469 | self.x_src_mid_all, training_phase=self.is_training, scope='c-trg', num_classes=self.num_classes,
470 | reuse=True, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size
471 | )
472 | self.mimic_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
473 | logits=self.ht_g_xs, labels=self.y_src_teacher_all)) + \
474 | tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
475 | logits=self.y_trg_logit, labels=self.y_trg_teacher))
476 | #
477 |
478 | #
479 | self.data_shift_loss = self._compute_cosine_similarity(self.x_trg_mid, self.x_src_mid_all)
480 | self.label_shift_loss = self._compute_label_shift_loss(self.y_trg_logit, self.ht_g_xs)
481 | self.data_label_shift_loss = self.data_shift_troff*self.data_shift_loss + self.label_shift_loss
482 | self.g_network = tf.reshape(self._build_phi_network(self.x_trg_mid), [-1])
483 | self.exp_term = (- self.data_label_shift_loss + self.g_network) / self.theta
484 | self.g_network_loss = tf.reduce_mean(self.g_network)
485 | self.OT_loss = tf.reduce_mean(
486 | - self.theta * \
487 | (
488 | tf.log(1.0 / self.batch_size) +
489 | tf.reduce_logsumexp(self.exp_term, axis=1)
490 | )
491 | ) + self.g_network_trade_off * self.g_network_loss
492 | #
493 |
494 | #
495 | self.trg_loss_cond_entropy = tf.reduce_mean(softmax_x_entropy_two(labels=self.y_trg_logit,
496 | logits=self.y_trg_logit))
497 | #
498 |
499 | #
500 | self.src_accuracy_lst = []
501 | for i in range(self.num_src_domain):
502 | y_src_pred = tf.argmax(self.y_src_logit_lst[i], 1, output_type=tf.int32)
503 | y_src_sparse = tf.argmax(self.y_src_lst[i], 1, output_type=tf.int32)
504 | src_accuracy = tf.reduce_mean(tf.cast(tf.equal(y_src_sparse, y_src_pred), 'float32'))
505 | self.src_accuracy_lst.append(src_accuracy)
506 | # compute acc for target domain
507 | self.y_trg_pred = tf.argmax(self.y_trg_logit, 1, output_type=tf.int32)
508 | self.y_trg_sparse = tf.argmax(self.y_trg, 1, output_type=tf.int32)
509 | self.trg_accuracy = tf.reduce_mean(tf.cast(tf.equal(self.y_trg_sparse, self.y_trg_pred), 'float32'))
510 | # compute acc for src domain disc
511 | self.y_src_domain_pred = tf.argmax(self.y_src_discriminator_logit, 1, output_type=tf.int32)
512 | self.y_src_domain_sparse = tf.argmax(self.y_src_domain, 1, output_type=tf.int32)
513 | self.src_domain_acc = tf.reduce_mean(tf.cast(tf.equal(self.y_src_domain_sparse, self.y_src_domain_pred), 'float32'))
514 | #
515 |
516 | #
517 | lst_losses = [
518 | (self.src_class_trade_off, self.src_loss_class_sum),
519 | (self.ot_trade_off, self.OT_loss),
520 | (self.domain_trade_off, self.src_loss_discriminator),
521 | (self.trg_ent_troff, self.trg_loss_cond_entropy),
522 | (self.mimic_trade_off, self.mimic_loss)
523 | ]
524 | self.total_loss = tf.constant(0.0)
525 | for trade_off, loss in lst_losses:
526 | self.total_loss += trade_off * loss
527 | #
528 |
529 | #
530 | primary_student_variables = self._get_variables(self._get_student_primary_scopes())
531 |
532 | self.batch_student_acc = batch_ema_acc(self.y_trg, self.y_trg_logit)
533 | self.fn_batch_student_acc = tb.function(self.tf_session, [self.x_trg, self.y_trg, self.is_training], self.batch_student_acc)
534 |
535 | teacher_variables = self._get_variables(self._get_teacher_scopes())
536 | pretrained_var_list = [v for v in tf.trainable_variables() if v.name.split('/')[1] in self.train_layers]
537 | self.train_student_main = \
538 | tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.total_loss,
539 | var_list=teacher_variables + [primary_student_variables[1]] + [pretrained_var_list])
540 | self.primary_train_student_op = tf.group(self.train_student_main)
541 |
542 | secondary_variables = self._get_variables(self._get_student_secondary_scopes())
543 | self.secondary_train_student_op = \
544 | tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(-self.OT_loss,
545 | var_list=secondary_variables)
546 | #
547 |
548 | #
549 | g_h_variables = self._get_variables(self._get_all_g_h())
550 | self.all_g_h_variables = []
551 | for s in g_h_variables:
552 | self.all_g_h_variables += s
553 | #
554 |
555 | #
556 | tf.summary.scalar('loss/total_loss', self.total_loss)
557 | tf.summary.scalar('loss/W_distance', self.OT_loss)
558 | tf.summary.scalar('loss/src_loss_discriminator', self.src_loss_discriminator)
559 | tf.summary.scalar('loss/data_shift_loss', tf.reduce_mean(self.data_shift_loss))
560 | tf.summary.scalar('loss/label_shift_loss', tf.reduce_mean(self.label_shift_loss))
561 | tf.summary.scalar('loss/data_label_shift_loss', tf.reduce_mean(self.data_label_shift_loss))
562 | tf.summary.scalar('loss/exp_term', tf.reduce_mean(self.exp_term))
563 | tf.summary.histogram('loss/g_batch', self.g_network)
564 | tf.summary.scalar('loss/g_network_loss', self.g_network_loss)
565 |
566 | for i in range(self.num_src_domain):
567 | tf.summary.scalar('loss/src_loss_class_{}'.format(i), self.src_loss_class_lst[i])
568 | tf.summary.scalar('acc/src_acc_{}'.format(i), self.src_accuracy_lst[i])
569 | tf.summary.scalar('acc/src_domain_acc', self.src_domain_acc)
570 | tf.summary.scalar('acc/trg_acc', self.trg_accuracy)
571 |
572 | tf.summary.scalar('hyperparameters/learning_rate', self.learning_rate)
573 | tf.summary.scalar('hyperparameters/src_class_trade_off', self.src_class_trade_off)
574 | tf.summary.scalar('hyperparameters/g_network_trade_off', self.g_network_trade_off)
575 | tf.summary.scalar('hyperparameters/domain_trade_off', self.domain_trade_off)
576 | tf.summary.scalar('hyperparameters/src_vat_trade_off', self.src_vat_trade_off)
577 | tf.summary.scalar('hyperparameters/trg_vat_troff', self.trg_vat_troff)
578 | tf.summary.scalar('hyperparameters/trg_ent_troff', self.trg_ent_troff)
579 | self.tf_merged_summaries = tf.summary.merge_all()
580 | #
581 |
582 | def _fit_loop(self):
583 | print('Start training MOST at', os.path.abspath(__file__))
584 | print('============ LOG-ID: %s ============' % self.current_time)
585 |
586 | src_batchsize = self.batch_size // self.num_src_domain
587 | self.tf_session.run(tf.global_variables_initializer())
588 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=101)
589 | self.log_path = os.path.join(model_dir(), self.model_name, "logs",
590 | "{}".format(self.current_time))
591 | self.tf_summary_writer = tf.summary.FileWriter(self.log_path, self.tf_session.graph)
592 |
593 | print("Load pretrained AlexNet")
594 | self.alexNet.load_original_weights(self.tf_session, skip_layers=self.train_layers)
595 |
596 | #
597 | print("Load pre-trained shallow network")
598 | graph = tf.Graph()
599 | with graph.as_default():
600 | with tf.Session() as sess:
601 | checkpoint_path = os.path.join(model_dir(), self.model_name, "saved-model",
602 | "{}".format(self.mdaot_model_id))
603 | check_point = tf.train.get_checkpoint_state(checkpoint_path)
604 | model_checkpoint_path = os.path.join(checkpoint_path, check_point.model_checkpoint_path.split('/')[-1])
605 | saver_old = tf.train.import_meta_graph("{}.meta".format(model_checkpoint_path))
606 | saver_old.restore(sess, model_checkpoint_path)
607 | print("Loaded model parameters from %s\n" % model_checkpoint_path)
608 |
609 | saved_variables = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
610 | saved_values = [sess.run(v) for v in saved_variables]
611 |
612 | g_mean_var = []
613 | h_mean_var = []
614 | g_mean_var_id = [0]
615 | h_mean_var_id = [0, 0]
616 |
617 | for i in g_mean_var_id:
618 | mean = graph.get_tensor_by_name("generator/l{}/conv2d/bn/mean:0".format(i))
619 | var = graph.get_tensor_by_name("generator/l{}/conv2d/bn/var:0".format(i))
620 | mean, var = sess.run([mean, var])
621 | g_mean_var.append([mean, var])
622 |
623 | for i in range(len(h_mean_var_id)):
624 | mean = graph.get_tensor_by_name("classifier/{}/l{}/dense/bn/mean:0".format(i, h_mean_var_id[i]))
625 | var = graph.get_tensor_by_name("classifier/{}/l{}/dense/bn/var:0".format(i, h_mean_var_id[i]))
626 | mean, var = sess.run([mean, var])
627 | h_mean_var.append([mean, var])
628 |
629 | for i in range(len(g_mean_var)):
630 | mean = self.tf_graph.get_tensor_by_name("generator/l{}/conv2d/bn/mean:0".format(g_mean_var_id[i]))
631 | var = self.tf_graph.get_tensor_by_name("generator/l{}/conv2d/bn/var:0".format(g_mean_var_id[i]))
632 | self.tf_session.run(tf.assign(mean, g_mean_var[i][0]))
633 | self.tf_session.run(tf.assign(var, g_mean_var[i][1]))
634 |
635 | for i in range(len(h_mean_var)):
636 | mean = self.tf_graph.get_tensor_by_name(
637 | "classifier/{}/l{}/dense/bn/mean:0".format(i, h_mean_var_id[i]))
638 | var = self.tf_graph.get_tensor_by_name(
639 | "classifier/{}/l{}/dense/bn/var:0".format(i, h_mean_var_id[i]))
640 | self.tf_session.run(tf.assign(mean, h_mean_var[i][0]))
641 | self.tf_session.run(tf.assign(var, h_mean_var[i][1]))
642 |
643 | domain_disc_w = self.tf_graph.get_tensor_by_name('domain_disc/l0/dense/weights:0')
644 | domain_disc_b = self.tf_graph.get_tensor_by_name('domain_disc/l0/dense/biases:0')
645 | phi_net_w = self.tf_graph.get_tensor_by_name('phi_net/l0/dense/weights:0')
646 | phi_net_b = self.tf_graph.get_tensor_by_name('phi_net/l0/dense/biases:0')
647 | self.tf_session.run(tf.assign(domain_disc_w, saved_values[16]))
648 | self.tf_session.run(tf.assign(domain_disc_b, saved_values[17]))
649 | self.tf_session.run(tf.assign(phi_net_w, saved_values[18]))
650 | self.tf_session.run(tf.assign(phi_net_b, saved_values[19]))
651 | for i in range(12):
652 | self.tf_session.run(tf.assign(self.all_g_h_variables[i], saved_values[i]))
653 | #
654 |
655 | feed_y_src_domain = to_categorical(np.repeat(np.arange(self.num_src_domain),
656 | repeats=self.sample_size * self.num_classes, axis=0))
657 |
658 | for it in range(self.num_iters):
659 | feed_data = dict()
660 | for k in range(self.num_src_domain):
661 | feed_data[self.x_src_lst[k]], feed_data[self.y_src_lst[k]] = self.src_preprocessors[k].next_batch(
662 | src_batchsize)
663 |
664 | feed_data[self.x_trg], feed_data[self.y_trg] = self.trg_train_preprocessor.next_batch(self.batch_size)
665 |
666 | feed_data[self.y_src_domain] = feed_y_src_domain
667 | feed_data[self.is_training] = True
668 |
669 | for i in range(0, 5):
670 | g_feed_data = dict()
671 | for k in range(self.num_src_domain):
672 | g_feed_data[self.x_src_lst[k]], g_feed_data[self.y_src_lst[k]] = self.src_preprocessors[k].next_batch(
673 | src_batchsize)
674 | g_feed_data[self.x_trg], g_feed_data[self.y_trg] = self.trg_train_preprocessor.next_batch(self.batch_size)
675 | g_feed_data[self.is_training] = True
676 |
677 | _, W_dist = \
678 | self.tf_session.run(
679 | [self.secondary_train_student_op, self.OT_loss],
680 | feed_dict=g_feed_data
681 | )
682 | _, total_loss, src_loss_class_sum, src_loss_class_lst, src_loss_discriminator, src_acc_lst, trg_acc, src_domain_acc, mimic_loss = \
683 | self.tf_session.run(
684 | [self.primary_train_student_op, self.total_loss, self.src_loss_class_sum, self.src_loss_class_lst, self.src_loss_discriminator,
685 | self.src_accuracy_lst, self.trg_accuracy, self.src_domain_acc, self.mimic_loss],
686 | feed_dict=feed_data
687 | )
688 |
689 | if it == 0 or (it + 1) % self.summary_freq == 0:
690 | print(
691 | "iter %d/%d total_loss %.3f; src_loss_class_sum %.3f; W_dist %.3f;\n src_loss_discriminator %.3f, pseudo_lbl_loss %.3f" % (
692 | it + 1, self.num_iters, total_loss, src_loss_class_sum, W_dist,
693 | src_loss_discriminator, mimic_loss))
694 | for k in range(self.num_src_domain):
695 | print('src_loss_class_{}: {:.3f} acc {:.2f}'.format(k, src_loss_class_lst[k], src_acc_lst[k]*100))
696 | print("src_domain_disc_acc: %.2f, trg_acc: %.2f;" % (src_domain_acc*100, trg_acc*100))
697 |
698 | summary = self.tf_session.run(self.tf_merged_summaries, feed_dict=feed_data)
699 | self.tf_summary_writer.add_summary(summary, it + 1)
700 | self.tf_summary_writer.flush()
701 |
702 | if it == 0 or (it + 1) % self.summary_freq == 0:
703 | if not self.only_save_final_model:
704 | self.save_trained_model(saver, it + 1)
705 | elif it + 1 == self.num_iters:
706 | self.save_trained_model(saver, it + 1)
707 | if (it + 1) % (self.num_iters // 50) == 0:
708 | self.save_value(step=it + 1)
709 |
710 | def save_trained_model(self, saver, step):
711 | checkpoint_path = os.path.join(model_dir(), self.model_name, "saved-model",
712 | "{}".format(self.current_time))
713 | checkpoint_path = os.path.join(checkpoint_path, "mdaot_" + self.current_time + ".ckpt")
714 |
715 | directory = os.path.dirname(checkpoint_path)
716 | if not os.path.exists(directory):
717 | os.makedirs(directory)
718 | saver.save(self.tf_session, checkpoint_path, global_step=step)
719 |
720 | def save_value(self, step):
721 | student_acc, summary = self.compute_value()
722 |
723 | self.tf_summary_writer.add_summary(summary, step)
724 | self.tf_summary_writer.flush()
725 |
726 | print_list = ['test_acc', round(student_acc * 100, 2)]
727 | print(print_list)
728 |
729 | def compute_value(self,):
730 | n = len(self.trg_test_preprocessor.labels)
731 | bs = 200
732 | student_acc_full = np.ones(n, dtype=float)
733 |
734 | for i in range(0, n, bs):
735 | x, y = self.trg_test_preprocessor.next_batch(bs)
736 | student_acc_batch = self.fn_batch_student_acc(x, y, False)
737 | student_acc_full[i:i + bs] = student_acc_batch
738 |
739 | student_acc = np.mean(student_acc_full)
740 | self.trg_test_preprocessor.reset_pointer()
741 |
742 | summary = tf.Summary.Value(tag='trg_test/student_acc', simple_value=student_acc)
743 | summary = tf.Summary(value=[summary])
744 |
745 | return student_acc, summary
746 |
--------------------------------------------------------------------------------