├── images ├── logo.jpg └── framework.png ├── tensorbayes.tar ├── model ├── alexnet │ ├── download_weights.sh │ ├── examples.sh │ ├── ckpt2npy.py │ ├── model.py │ └── finetune.py ├── resnet │ ├── download_weights.sh │ ├── examples.sh │ ├── preprocessor.py │ ├── finetune.py │ └── model.py ├── generic_utils.py ├── layers.py ├── dataLoader.py ├── test_da_template_AlexNet_train_feat.py ├── test_da_template_digits.py ├── run_most_AlexNet_finetune.py ├── run_most_AlexNet_train_feat.py ├── test_da_template_AlexNet_finetune.py ├── run_most_digits.py ├── most_AlexNet_train_feat.py ├── most_digits.py └── most_AlexNet_finetune.py ├── LICENSE ├── tf1.9py3.5.yml └── README.md /images/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanrpt/MOST/HEAD/images/logo.jpg -------------------------------------------------------------------------------- /tensorbayes.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanrpt/MOST/HEAD/tensorbayes.tar -------------------------------------------------------------------------------- /images/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuanrpt/MOST/HEAD/images/framework.png -------------------------------------------------------------------------------- /model/alexnet/download_weights.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # By using curl 4 | curl -O http://www.cs.toronto.edu/~guerzhoy/tf_alexnet/bvlc_alexnet.npy 5 | 6 | # or by using wget 7 | # wget http://www.cs.toronto.edu/~guerzhoy/tf_alexnet/bvlc_alexnet.npy 8 | -------------------------------------------------------------------------------- /model/alexnet/examples.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python finetune.py \ 4 | --learning_rate "0.00001" \ 5 | --train_layers "fc8,fc7,fc6" 6 | 7 | python finetune.py \ 8 | --num_epochs 30 \ 9 | --multi_scale "228,256" \ 10 | --train_layers "fc8,fc7,fc6,conv5,conv4,conv3,conv2,conv1" # full training 11 | -------------------------------------------------------------------------------- /model/resnet/download_weights.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # By using curl 4 | curl -O https://deniz.co/static/tensorflow-cnn-finetune/ResNet-L50.npy 5 | curl -O https://deniz.co/static/tensorflow-cnn-finetune/ResNet-L101.npy 6 | curl -O https://deniz.co/static/tensorflow-cnn-finetune/ResNet-L152.npy 7 | 8 | # or by using wget 9 | # wget https://www.dropbox.com/s/txb16f1khleyvdh/ResNet-L50.npy?dl=1 10 | # wget https://www.dropbox.com/s/8w10cs6v4rd9616/ResNet-L101.npy?dl=1 11 | # wget https://www.dropbox.com/s/n5inup5a7fi8lom/ResNet-L152.npy?dl=1 12 | -------------------------------------------------------------------------------- /model/resnet/examples.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python finetune.py \ 4 | --learning_rate "0.00001" \ 5 | --train_layers "fc" 6 | 7 | python finetune.py \ 8 | --learning_rate "0.00001" \ 9 | --train_layers "fc,scale5/block3" 10 | 11 | python finetune.py \ 12 | --learning_rate "0.00001" \ 13 | --train_layers "fc,scale5/block3,scale5/block2" 14 | 15 | python finetune.py \ 16 | --learning_rate "0.00001" \ 17 | --train_layers "fc,scale5/block3,scale5/block2,scale5/block1" 18 | 19 | python finetune.py \ 20 | --learning_rate "0.00001" \ 21 | --multi_scale "225,256" \ 22 | --train_layers "fc,scale5" 23 | 24 | python finetune.py \ 25 | --learning_rate "0.00001" \ 26 | --multi_scale "225,256" \ 27 | --train_layers "fc,scale5,scale4/block6" 28 | 29 | python finetune.py \ 30 | --learning_rate "0.00001" \ 31 | --multi_scale "225,256" \ 32 | --train_layers "fc,scale5,scale4/block6,scale4/block5" 33 | -------------------------------------------------------------------------------- /model/generic_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Tuan Nguyen. 2 | # All rights reserved. 3 | 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import absolute_import 7 | 8 | import time 9 | import numpy as np 10 | from pathlib import Path 11 | import os 12 | _RANDOM_SEED = 6789 13 | 14 | 15 | def model_dir(): 16 | cur_dir = Path(os.path.abspath(__file__)) 17 | return str(cur_dir.parent.parent) 18 | 19 | 20 | def feat_dir(): 21 | cur_dir = Path(os.path.abspath(__file__)) 22 | par_dir = cur_dir.parent.parent 23 | return str(par_dir / "features") 24 | 25 | def data_dir(): 26 | cur_dir = Path(os.path.abspath(__file__)) 27 | par_dir = cur_dir.parent.parent 28 | return str(par_dir / "data") 29 | 30 | 31 | def random_seed(): 32 | return _RANDOM_SEED 33 | 34 | 35 | def tuid(): 36 | ''' 37 | Create a string ID based on current time 38 | :return: a string formatted using current time 39 | ''' 40 | random_num = np.random.randint(0, 100) 41 | return time.strftime('%Y-%m-%d_%H.%M.%S') + str(random_num) 42 | -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Tuan Nguyen. 2 | # All rights reserved. 3 | 4 | import tensorflow as tf 5 | from tensorflow.contrib.framework import add_arg_scope 6 | 7 | 8 | @add_arg_scope 9 | def noise(x, std, phase, scope=None, reuse=None): 10 | with tf.name_scope(scope, 'noise'): 11 | eps = tf.random_normal(tf.shape(x), 0.0, std) 12 | output = tf.where(phase, x + eps, x) 13 | return output 14 | 15 | 16 | @add_arg_scope 17 | def leaky_relu(x, a=0.2, name=None): 18 | with tf.name_scope(name, 'leaky_relu'): 19 | return tf.maximum(x, a * x) 20 | 21 | @add_arg_scope 22 | def basic_accuracy(a, b, scope=None): 23 | with tf.name_scope(scope, 'basic_acc'): 24 | a = tf.argmax(a, 1) 25 | b = tf.argmax(b, 1) 26 | eq = tf.cast(tf.equal(a, b), 'float32') 27 | output = tf.reduce_mean(eq) 28 | return output 29 | 30 | @add_arg_scope 31 | def batch_ema_acc(a, b, scope=None): 32 | with tf.name_scope(scope, 'basic_acc'): 33 | a = tf.argmax(a, 1) 34 | b = tf.argmax(b, 1) 35 | output = tf.cast(tf.equal(a, b), 'float32') 36 | return output 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Tuan Nguyen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model/alexnet/ckpt2npy.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import sys 4 | from model import AlexNetModel 5 | 6 | 7 | # Edit just these 8 | FILE_PATH = '/Users/dgurkaynak/Projects/marvel-finetuning/training/alexnet_20171125_124517/checkpoint/model_epoch7.ckpt' 9 | NUM_CLASSES = 26 10 | OUTPUT_FILE = 'alexnet_20171125_124517_epoch7.npy' 11 | 12 | 13 | if __name__ == '__main__': 14 | x = tf.placeholder(tf.float32, [128, 227, 227, 3]) 15 | model = AlexNetModel(num_classes=NUM_CLASSES) 16 | model.inference(x) 17 | 18 | saver = tf.train.Saver() 19 | layers = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc6', 'fc7', 'fc8'] 20 | data = { 21 | 'conv1': [], 22 | 'conv2': [], 23 | 'conv3': [], 24 | 'conv4': [], 25 | 'conv5': [], 26 | 'fc6': [], 27 | 'fc7': [], 28 | 'fc8': [] 29 | } 30 | 31 | with tf.Session() as sess: 32 | saver.restore(sess, FILE_PATH) 33 | 34 | for op_name in layers: 35 | with tf.variable_scope(op_name, reuse = True): 36 | biases_variable = tf.get_variable('biases') 37 | weights_variable = tf.get_variable('weights') 38 | data[op_name].append(sess.run(biases_variable)) 39 | data[op_name].append(sess.run(weights_variable)) 40 | 41 | np.save(OUTPUT_FILE, data) 42 | 43 | -------------------------------------------------------------------------------- /tf1.9py3.5.yml: -------------------------------------------------------------------------------- 1 | name: tf1.9py3.5 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - asn1crypto=1.4.0=py_0 7 | - blas=1.0=mkl 8 | - bzip2=1.0.8=h7b6447c_0 9 | - ca-certificates=2020.7.22=0 10 | - cairo=1.14.12=h8948797_3 11 | - certifi=2018.8.24=py35_1 12 | - cffi=1.11.5=py35he75722e_1 13 | - chardet=3.0.4=py35_1 14 | - cryptography=2.3.1=py35hc365091_0 15 | - cycler=0.10.0=py35hc4d5149_0 16 | - dbus=1.13.16=hb2f20db_0 17 | - dill=0.3.2=py_0 18 | - expat=2.2.9=he6710b0_2 19 | - ffmpeg=4.0=hcdf2ecd_0 20 | - fontconfig=2.13.0=h9420a91_0 21 | - freeglut=3.0.0=hf484d3e_5 22 | - freetype=2.10.2=h5ab3b9f_0 23 | - glib=2.63.1=h5a9c865_0 24 | - graphite2=1.3.14=h23475e2_0 25 | - gst-plugins-base=1.14.0=hbbd80ab_1 26 | - gstreamer=1.14.0=hb453b48_1 27 | - harfbuzz=1.8.8=hffaf4a1_0 28 | - hdf5=1.10.2=hba1933b_1 29 | - icu=58.2=he6710b0_3 30 | - idna=2.10=py_0 31 | - intel-openmp=2019.4=243 32 | - jasper=2.0.14=h07fcdf6_1 33 | - jpeg=9b=h024ee3a_2 34 | - kiwisolver=1.0.1=py35hf484d3e_0 35 | - libedit=3.1.20191231=h14c3975_1 36 | - libffi=3.2.1=hd88cf55_4 37 | - libgcc-ng=9.1.0=hdf63c60_0 38 | - libgfortran-ng=7.3.0=hdf63c60_0 39 | - libglu=9.0.0=hf484d3e_1 40 | - libopencv=3.4.2=hb342d67_1 41 | - libopus=1.3.1=h7b6447c_0 42 | - libpng=1.6.37=hbc83047_0 43 | - libprotobuf=3.6.0=hdbcaa40_0 44 | - libstdcxx-ng=9.1.0=hdf63c60_0 45 | - libtiff=4.1.0=h2733197_1 46 | - libuuid=1.0.3=h1bed415_2 47 | - libvpx=1.7.0=h439df22_0 48 | - libxcb=1.14=h7b6447c_0 49 | - libxml2=2.9.10=he19cac6_1 50 | - lz4-c=1.9.2=he6710b0_1 51 | - matplotlib=3.0.0=py35h5429711_0 52 | - mkl=2018.0.3=1 53 | - mkl_fft=1.0.6=py35h7dd41cf_0 54 | - mkl_random=1.0.1=py35h4414c95_1 55 | - ncurses=6.2=he6710b0_1 56 | - numpy=1.15.2=py35h1d66e8a_0 57 | - opencv=3.4.2=py35h6fd60c2_1 58 | - openssl=1.0.2u=h7b6447c_0 59 | - pandas=0.22.0=py35hf484d3e_0 60 | - pcre=8.44=he6710b0_0 61 | - pip=10.0.1=py35_0 62 | - pixman=0.40.0=h7b6447c_0 63 | - protobuf=3.6.0=py35hf484d3e_0 64 | - py-opencv=3.4.2=py35hb342d67_1 65 | - pycparser=2.20=py_2 66 | - pyopenssl=18.0.0=py35_0 67 | - pyparsing=2.4.7=py_0 68 | - pyqt=5.9.2=py35h05f1152_2 69 | - pysocks=1.6.8=py35_0 70 | - python=3.5.6=hc3d631a_0 71 | - python-dateutil=2.8.1=py_0 72 | - pytz=2020.1=py_0 73 | - qt=5.9.6=h8703b6f_2 74 | - readline=7.0=h7b6447c_5 75 | - requests=2.24.0=py_0 76 | - scikit-learn=0.20.0=py35h4989274_1 77 | - sip=4.19.8=py35hf484d3e_0 78 | - six=1.15.0=py_0 79 | - sqlite=3.33.0=h62c20be_0 80 | - tbb=2020.2=hfd86e86_0 81 | - tbb4py=2018.0.5=py35h6bb024c_0 82 | - tk=8.6.10=hbc83047_0 83 | - tornado=5.1.1=py35h7b6447c_0 84 | - urllib3=1.23=py35_0 85 | - wheel=0.35.1=py_0 86 | - xz=5.2.5=h7b6447c_0 87 | - zlib=1.2.11=h7b6447c_3 88 | - zstd=1.4.5=h9ceee32_0 89 | - pip: 90 | - absl-py==0.10.0 91 | - astor==0.8.1 92 | - gast==0.4.0 93 | - grpcio==1.31.0 94 | - h5py==2.10.0 95 | - importlib-metadata==1.7.0 96 | - keras==2.2.5 97 | - keras-applications==1.0.8 98 | - keras-preprocessing==1.1.2 99 | - markdown==3.2.2 100 | - pyyaml==5.3.1 101 | - scipy==1.4.1 102 | - setuptools==39.1.0 103 | - tensorbayes==0.4.0 104 | - tensorboard==1.9.0 105 | - tensorflow-gpu==1.9.0 106 | - termcolor==1.1.0 107 | - werkzeug==1.0.1 108 | - zipp==1.2.0 109 | prefix: /opt/conda/envs/tf1.9py3.5 110 | -------------------------------------------------------------------------------- /model/resnet/preprocessor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Derived from: https://github.com/kratzert/finetune_alexnet_with_tensorflow/ 3 | """ 4 | import numpy as np 5 | import cv2 6 | 7 | 8 | class BatchPreprocessor(object): 9 | 10 | def __init__(self, dataset_file_path, num_classes, output_size=[227, 227], horizontal_flip=False, shuffle=False, 11 | mean_color=[132.2766, 139.6506, 146.9702], multi_scale=None): 12 | self.num_classes = num_classes 13 | self.output_size = output_size 14 | self.horizontal_flip = horizontal_flip 15 | self.shuffle = shuffle 16 | self.mean_color = mean_color 17 | self.multi_scale = multi_scale 18 | 19 | self.pointer = 0 20 | self.images = [] 21 | self.labels = [] 22 | 23 | # Read the dataset file 24 | dataset_file = open(dataset_file_path) 25 | lines = dataset_file.readlines() 26 | for line in lines: 27 | items = line.split() 28 | self.images.append(items[0]) 29 | self.labels.append(int(items[1])) 30 | 31 | # Shuffle the data 32 | # if self.shuffle: 33 | # self.shuffle_data() 34 | 35 | def shuffle_data(self): 36 | images = self.images[:] 37 | labels = self.labels[:] 38 | self.images = [] 39 | self.labels = [] 40 | 41 | idx = np.random.permutation(len(labels)) 42 | for i in idx: 43 | self.images.append(images[i]) 44 | self.labels.append(labels[i]) 45 | 46 | def reset_pointer(self): 47 | self.pointer = 0 48 | 49 | # if self.shuffle: 50 | # self.shuffle_data() 51 | 52 | def next_batch(self, batch_size): 53 | if self.shuffle: 54 | self.shuffle_data() 55 | # Get next batch of image (path) and labels 56 | paths = self.images[:batch_size] 57 | labels = self.labels[:batch_size] 58 | else: 59 | # Get next batch of image (path) and labels 60 | paths = self.images[self.pointer:(self.pointer+batch_size)] 61 | labels = self.labels[self.pointer:(self.pointer+batch_size)] 62 | 63 | # Update pointer 64 | self.pointer += batch_size 65 | 66 | # Read images 67 | images = np.ndarray([batch_size, self.output_size[0], self.output_size[1], 3]) 68 | for i in range(len(paths)): 69 | img = cv2.imread(paths[i]) 70 | 71 | # Flip image at random if flag is selected 72 | if self.horizontal_flip and np.random.random() < 0.5: 73 | img = cv2.flip(img, 1) 74 | 75 | if self.multi_scale is None or len(self.multi_scale) == 0: 76 | # Resize the image for output 77 | img = cv2.resize(img, (self.output_size[0], self.output_size[0])) 78 | img = img.astype(np.float32) 79 | elif isinstance(self.multi_scale, list): 80 | # Resize to random scale 81 | new_size = np.random.randint(self.multi_scale[0], self.multi_scale[1], 1)[0] 82 | img = cv2.resize(img, (new_size, new_size)) 83 | img = img.astype(np.float32) 84 | 85 | # random crop at output size 86 | diff_size = new_size - self.output_size[0] 87 | random_offset_x = np.random.randint(0, diff_size, 1)[0] 88 | random_offset_y = np.random.randint(0, diff_size, 1)[0] 89 | img = img[random_offset_x:(random_offset_x+self.output_size[0]), 90 | random_offset_y:(random_offset_y+self.output_size[0])] 91 | 92 | # Subtract mean color 93 | img = img - np.array(self.mean_color) 94 | 95 | images[i] = img 96 | 97 | # Expand labels to one hot encoding 98 | one_hot_labels = np.zeros((batch_size, self.num_classes)) 99 | for i in range(len(labels)): 100 | one_hot_labels[i][labels[i]] = 1 101 | 102 | if len(paths) < batch_size: 103 | return images[:len(paths)], one_hot_labels[:len(paths)] 104 | 105 | # Return array of images and labels 106 | return images, one_hot_labels 107 | -------------------------------------------------------------------------------- /model/dataLoader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Tuan Nguyen. 2 | # All rights reserved. 3 | 4 | import os 5 | 6 | import numpy as np 7 | from scipy.io import loadmat 8 | from keras.utils.np_utils import to_categorical 9 | from generic_utils import random_seed 10 | 11 | 12 | def load_mat_file_single_label(filename): 13 | filename_list = ['mnist', 'stl32', 'synsign', 'gtsrb', 'cifar32', 'usps32'] 14 | data = loadmat(filename) 15 | x = data['X'] 16 | y = data['y'] 17 | if any(fn in filename for fn in filename_list): 18 | if 'mnist32_60_10' not in filename: 19 | y = y[0] 20 | else: 21 | y = np.argmax(y, axis=1) 22 | elif len(y.shape) > 1: 23 | y = np.argmax(y, axis=1) 24 | return x, y 25 | 26 | 27 | def load_mat_office31_AlexNet(filename): 28 | data = loadmat(filename) 29 | # x = data['feas'] 30 | x = np.reshape(data['feas'], (-1, 8, 8, 64)) 31 | y = data['labels'][0] 32 | return x, y 33 | 34 | 35 | def u2t(x): 36 | """Convert uint8 to [-1, 1] float 37 | """ 38 | max_num = 50000 39 | if len(x) > max_num: 40 | y = np.empty_like(x, dtype='float32') 41 | for i in range(len(x) // max_num): 42 | y[i*max_num: (i+1)*max_num] = (x[i*max_num: (i+1)*max_num].astype('float32') / 255) * 2 - 1 43 | 44 | y[(i + 1) * max_num:] = (x[(i + 1) * max_num:].astype('float32') / 255) * 2 - 1 45 | else: 46 | y = (x.astype('float32') / 255) * 2 - 1 47 | return y 48 | 49 | 50 | class DataLoader: 51 | def __init__(self, src_domain=['mnistm'], trg_domain=['mnist'], data_path='./data', data_format='mat', 52 | shuffle_data=False, dataset_name='digits', cast_data=True): 53 | self.num_src_domain = len(src_domain.split(',')) 54 | self.src_domain_name = src_domain 55 | self.trg_domain_name = trg_domain 56 | self.data_path = data_path 57 | self.data_format = data_format 58 | self.shuffle_data = shuffle_data 59 | self.dataset_name = dataset_name 60 | self.cast_data = cast_data 61 | 62 | self.src_train = {} 63 | self.trg_train = {} 64 | self.src_test = {} 65 | self.trg_test = {} 66 | 67 | print("Source domains", self.src_domain_name) 68 | print("Target domain", self.trg_domain_name) 69 | self._load_data_train() 70 | self._load_data_test() 71 | 72 | self.data_shape = self.src_train[0][1][0].shape 73 | self.num_domain = len(self.src_train.keys()) 74 | self.num_class = self.src_train[0][2].shape[-1] 75 | 76 | def _load_data_train(self, tail_name="_train"): 77 | if not self.src_train: 78 | self.src_train = self._load_file(self.src_domain_name, tail_name, self.shuffle_data) 79 | self.trg_train = self._load_file(self.trg_domain_name, tail_name, self.shuffle_data) 80 | 81 | def _load_data_test(self, tail_name="_test"): 82 | if not self.src_test: 83 | self.src_test = self._load_file(self.src_domain_name, tail_name, self.shuffle_data) 84 | self.trg_test = self._load_file(self.trg_domain_name, tail_name, self.shuffle_data) 85 | 86 | def _load_file(self, name_file=[], tail_name="_train", shuffle_data=False): 87 | data_list = {} 88 | name_file = name_file.split(',') 89 | for idx, s_n in enumerate(name_file): 90 | file_path_train = os.path.join(self.data_path, '{}{}.{}'.format(s_n, tail_name, self.data_format)) 91 | if os.path.isfile(file_path_train): 92 | if self.dataset_name == 'digits': 93 | x_train, y_train = load_mat_file_single_label(file_path_train) 94 | elif self.dataset_name == 'office31_AlexNet_feat': 95 | x_train, y_train = load_mat_office31_AlexNet(file_path_train) 96 | if shuffle_data: 97 | x_train, y_train = self.shuffle(x_train, y_train) 98 | if 'mnist32_60_10' not in s_n and self.cast_data: 99 | x_train = u2t(x_train) 100 | data_list.update({idx: [s_n, x_train, to_categorical(y_train)]}) 101 | else: 102 | raise('File not found!') 103 | return data_list 104 | 105 | def shuffle(self, x, y=None): 106 | np.random.seed(random_seed()) 107 | idx_train = np.random.permutation(x.shape[0]) 108 | x = x[idx_train] 109 | if y is not None: 110 | y = y[idx_train] 111 | return x, y 112 | 113 | def onehot2scalar(self, onehot_vectors, axis=1): 114 | return np.argmax(onehot_vectors, axis=axis) 115 | -------------------------------------------------------------------------------- /model/test_da_template_AlexNet_train_feat.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Tuan Nguyen. 2 | # All rights reserved. 3 | 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import absolute_import 7 | 8 | import os 9 | import sys 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from generic_utils import random_seed 15 | from generic_utils import feat_dir 16 | from dataLoader import DataLoader 17 | 18 | 19 | def test_real_dataset(create_obj_func, src_name=None, trg_name=None): 20 | print('Running {} ...'.format(os.path.basename(__file__))) 21 | 22 | if src_name is None: 23 | if len(sys.argv) > 2: 24 | src_name = sys.argv[2] 25 | else: 26 | raise Exception('Not specify source dataset') 27 | if trg_name is None: 28 | if len(sys.argv) > 3: 29 | trg_name = sys.argv[3] 30 | else: 31 | raise Exception('Not specify trgget dataset') 32 | 33 | np.random.seed(random_seed()) 34 | tf.set_random_seed(random_seed()) 35 | tf.reset_default_graph() 36 | 37 | print("========== Test on real data ==========") 38 | 39 | users_params = dict() 40 | users_params = parse_arguments(users_params) 41 | data_format = 'mat' 42 | 43 | if 'format' in users_params: 44 | data_format, users_params = extract_param('format', data_format, users_params) 45 | 46 | data_folder = feat_dir() 47 | if len(users_params['data_dir']) != 0: 48 | data_folder = users_params['data_dir'] 49 | print("data path", data_folder) 50 | 51 | data_loader = DataLoader(src_domain=src_name, 52 | trg_domain=trg_name, 53 | data_path=data_folder, 54 | data_format=data_format, 55 | dataset_name='office31_AlexNet_feat', 56 | cast_data=users_params['cast_data']) 57 | 58 | assert users_params['batch_size'] % data_loader.num_src_domain == 0 59 | 60 | print('users_params:', users_params) 61 | 62 | learner = create_obj_func(users_params) 63 | learner.dim_src = data_loader.data_shape 64 | learner.dim_trg = data_loader.data_shape 65 | 66 | learner.x_trg_test = data_loader.trg_test[0][0] 67 | learner.y_trg_test = data_loader.trg_test[0][1] 68 | 69 | learner._init(data_loader) 70 | learner._build_model() 71 | learner._fit_loop() 72 | 73 | 74 | def main_func( 75 | create_obj_func, 76 | choice_default=0, 77 | src_name_default='svmguide1', 78 | trg_name_default='svmguide1', 79 | run_exp=False): 80 | 81 | if not run_exp: 82 | choice_lst = [0, 1, 2] 83 | src_name = src_name_default 84 | trg_name = trg_name_default 85 | elif len(sys.argv) > 1: 86 | choice_lst = [int(sys.argv[1])] 87 | src_name = None 88 | trg_name = None 89 | else: 90 | choice_lst = [choice_default] 91 | src_name = src_name_default 92 | trg_name = trg_name_default 93 | 94 | for choice in choice_lst: 95 | if choice == 0: 96 | pass 97 | # add another function here 98 | elif choice == 1: 99 | test_real_dataset(create_obj_func, src_name, trg_name) 100 | 101 | 102 | def parse_arguments(params, as_array=False): 103 | for it in range(4, len(sys.argv), 2): 104 | params[sys.argv[it]] = parse_argument(sys.argv[it + 1], as_array) 105 | return params 106 | 107 | 108 | def parse_argument(string, as_array=False): 109 | try: 110 | result = int(string) 111 | except ValueError: 112 | try: 113 | result = float(string) 114 | except ValueError: 115 | if str.lower(string) == 'true': 116 | result = True 117 | elif str.lower(string) == 'false': 118 | result = False 119 | elif string == "[]": 120 | return [] 121 | elif ('|' in string) and ('[' in string) and (']' in string): 122 | result = [float(item) for item in string[1:-1].split('|')] 123 | return result 124 | elif (',' in string) and ('(' in string) and (')' in string): 125 | split = string[1:-1].split(',') 126 | result = float(split[0]) ** np.arange(float(split[1]), float(split[2]), float(split[3])) 127 | return result 128 | else: 129 | result = string 130 | 131 | return [result] if as_array else result 132 | 133 | 134 | def resolve_conflict_params(primary_params, secondary_params): 135 | for key in primary_params.keys(): 136 | if key in secondary_params.keys(): 137 | del secondary_params[key] 138 | return secondary_params 139 | 140 | 141 | def extract_param(key, value, params_gridsearch, scalar=False): 142 | if key in params_gridsearch.keys(): 143 | value = params_gridsearch[key] 144 | del params_gridsearch[key] 145 | if scalar and (value is not None): 146 | value = value[0] 147 | return value, params_gridsearch 148 | 149 | 150 | def dict2string(params): 151 | result = '' 152 | for key, value in params.items(): 153 | if type(value) is np.ndarray: 154 | if value.size < 16: 155 | result += key + ': ' + '|'.join('{0:.4f}'.format(x) for x in value.ravel()) + ', ' 156 | else: 157 | result += key + ': ' + str(value) + ', ' 158 | return '{' + result[:-2] + '}' 159 | -------------------------------------------------------------------------------- /model/test_da_template_digits.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Tuan Nguyen. 2 | # All rights reserved. 3 | 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import absolute_import 7 | 8 | import os 9 | import sys 10 | from scipy.io import loadmat 11 | import numpy as np 12 | import tensorflow as tf 13 | from generic_utils import random_seed 14 | from generic_utils import feat_dir 15 | from dataLoader import DataLoader 16 | 17 | 18 | def test_real_dataset(create_obj_func, src_name=None, trg_name=None, show=False, block_figure_on_end=False): 19 | print('Running {} ...'.format(os.path.basename(__file__))) 20 | 21 | if src_name is None: 22 | if len(sys.argv) > 2: 23 | src_name = sys.argv[2] 24 | else: 25 | raise Exception('Not specify source dataset') 26 | if trg_name is None: 27 | if len(sys.argv) > 3: 28 | trg_name = sys.argv[3] 29 | else: 30 | raise Exception('Not specify trgget dataset') 31 | 32 | np.random.seed(random_seed()) 33 | tf.set_random_seed(random_seed()) 34 | tf.reset_default_graph() 35 | 36 | print("========== Test on real data ==========") 37 | users_params = dict() 38 | users_params = parse_arguments(users_params) 39 | data_format = 'mat' 40 | 41 | if 'format' in users_params: 42 | data_format, users_params = extract_param('format', data_format, users_params) 43 | 44 | data_loader = DataLoader(src_domain=src_name, 45 | trg_domain=trg_name, 46 | data_path=feat_dir(), 47 | data_format=data_format, 48 | cast_data=users_params['cast_data']) 49 | 50 | assert users_params['batch_size'] % data_loader.num_src_domain == 0 51 | print('users_params:', users_params) 52 | 53 | learner = create_obj_func(users_params) 54 | learner.dim_src = data_loader.data_shape 55 | learner.dim_trg = data_loader.data_shape 56 | 57 | learner.x_trg_test = data_loader.trg_test[0][0] 58 | learner.y_trg_test = data_loader.trg_test[0][1] 59 | learner._init(data_loader) 60 | learner._build_model() 61 | learner._fit_loop() 62 | 63 | 64 | def main_func( 65 | create_obj_func, 66 | choice_default=0, 67 | src_name_default='svmguide1', 68 | trg_name_default='svmguide1', 69 | run_exp=False, 70 | keep_vars=[], 71 | **kwargs): 72 | 73 | if not run_exp: 74 | choice_lst = [0, 1, 2] 75 | src_name = src_name_default 76 | trg_name = trg_name_default 77 | elif len(sys.argv) > 1: 78 | choice_lst = [int(sys.argv[1])] 79 | src_name = None 80 | trg_name = None 81 | else: 82 | choice_lst = [choice_default] 83 | src_name = src_name_default 84 | trg_name = trg_name_default 85 | 86 | for choice in choice_lst: 87 | if choice == 0: 88 | pass # for synthetic data if possible 89 | elif choice == 1: 90 | test_real_dataset(create_obj_func, src_name, trg_name, show=False, block_figure_on_end=run_exp) 91 | 92 | 93 | def parse_arguments(params, as_array=False): 94 | for it in range(4, len(sys.argv), 2): 95 | params[sys.argv[it]] = parse_argument(sys.argv[it + 1], as_array) 96 | return params 97 | 98 | 99 | def parse_argument(string, as_array=False): 100 | try: 101 | result = int(string) 102 | except ValueError: 103 | try: 104 | result = float(string) 105 | except ValueError: 106 | if str.lower(string) == 'true': 107 | result = True 108 | elif str.lower(string) == 'false': 109 | result = False 110 | elif string == "[]": 111 | return [] 112 | elif ('|' in string) and ('[' in string) and (']' in string): 113 | result = [float(item) for item in string[1:-1].split('|')] 114 | return result 115 | elif (',' in string) and ('(' in string) and (')' in string): 116 | split = string[1:-1].split(',') 117 | result = float(split[0]) ** np.arange(float(split[1]), float(split[2]), float(split[3])) 118 | return result 119 | else: 120 | result = string 121 | 122 | return [result] if as_array else result 123 | 124 | 125 | def resolve_conflict_params(primary_params, secondary_params): 126 | for key in primary_params.keys(): 127 | if key in secondary_params.keys(): 128 | del secondary_params[key] 129 | return secondary_params 130 | 131 | 132 | def extract_param(key, value, params_gridsearch, scalar=False): 133 | if key in params_gridsearch.keys(): 134 | value = params_gridsearch[key] 135 | del params_gridsearch[key] 136 | if scalar and (value is not None): 137 | value = value[0] 138 | return value, params_gridsearch 139 | 140 | 141 | def dict2string(params): 142 | result = '' 143 | for key, value in params.items(): 144 | if type(value) is np.ndarray: 145 | if value.size < 16: 146 | result += key + ': ' + '|'.join('{0:.4f}'.format(x) for x in value.ravel()) + ', ' 147 | else: 148 | result += key + ': ' + str(value) + ', ' 149 | return '{' + result[:-2] + '}' 150 | 151 | 152 | def load_mat_file_single_label(filename): 153 | filename_list = ['mnist', 'stl32', 'synsign', 'gtsrb', 'cifar32', 'usps32'] 154 | data = loadmat(filename) 155 | x = data['X'] 156 | y = data['y'] 157 | if any(fn in filename for fn in filename_list): 158 | if 'mnist32_60_10' not in filename and 'mnistg' not in filename: 159 | y = y[0] 160 | else: 161 | y = np.argmax(y, axis=1) 162 | # process one-hot label encoder 163 | elif len(y.shape) > 1: 164 | y = np.argmax(y, axis=1) 165 | return x, y 166 | 167 | 168 | def u2t(x): 169 | """Convert uint8 to [-1, 1] float 170 | """ 171 | return x.astype('float32') / 255 * 2 - 1 172 | -------------------------------------------------------------------------------- /model/alexnet/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Derived from: https://github.com/kratzert/finetune_alexnet_with_tensorflow/ 3 | """ 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | 8 | class AlexNetModel(object): 9 | 10 | def __init__(self, num_classes=1000, is_training=False, dropout_keep_prob=0.5): 11 | self.num_classes = num_classes 12 | self.dropout_keep_prob = dropout_keep_prob 13 | self.is_training = is_training 14 | 15 | def inference(self, x, reuse=None, extract_feat=False): 16 | with tf.variable_scope('network', reuse=reuse): 17 | # 1st Layer: Conv (w ReLu) -> Pool -> Lrn 18 | conv1 = conv(x, 11, 11, 96, 4, 4, padding='VALID', name='conv1') 19 | pool1 = max_pool(conv1, 3, 3, 2, 2, padding='VALID', name='pool1') 20 | norm1 = lrn(pool1, 2, 2e-05, 0.75, name='norm1') 21 | 22 | # 2nd Layer: Conv (w ReLu) -> Pool -> Lrn with 2 groups 23 | conv2 = conv(norm1, 5, 5, 256, 1, 1, groups=2, name='conv2') 24 | pool2 = max_pool(conv2, 3, 3, 2, 2, padding='VALID', name ='pool2') 25 | norm2 = lrn(pool2, 2, 2e-05, 0.75, name='norm2') 26 | 27 | # 3rd Layer: Conv (w ReLu) 28 | conv3 = conv(norm2, 3, 3, 384, 1, 1, name='conv3') 29 | 30 | # 4th Layer: Conv (w ReLu) splitted into two groups 31 | conv4 = conv(conv3, 3, 3, 384, 1, 1, groups=2, name='conv4') 32 | 33 | # 5th Layer: Conv (w ReLu) -> Pool splitted into two groups 34 | conv5 = conv(conv4, 3, 3, 256, 1, 1, groups=2, name='conv5') 35 | pool5 = max_pool(conv5, 3, 3, 2, 2, padding='VALID', name='pool5') 36 | 37 | # 6th Layer: Flatten -> FC (w ReLu) -> Dropout 38 | flattened = tf.reshape(pool5, [-1, 6*6*256]) 39 | fc6 = fc(flattened, 6*6*256, 4096, name='fc6') 40 | 41 | # if self.is_training: 42 | # fc6 = dropout(fc6, self.dropout_keep_prob) 43 | fc6 = tf.layers.dropout(fc6, rate=1.0-self.dropout_keep_prob, training=self.is_training) 44 | 45 | # 7th Layer: FC (w ReLu) -> Dropout 46 | fc7 = fc(fc6, 4096, 4096, name='fc7') 47 | 48 | # if self.is_training: 49 | # fc7 = dropout(fc7, self.dropout_keep_prob) 50 | fc7 = tf.layers.dropout(fc7, rate=1.0-self.dropout_keep_prob, training=self.is_training) 51 | 52 | 53 | if extract_feat: 54 | return fc7 55 | 56 | # 8th Layer: FC and return unscaled activations (for tf.nn.softmax_cross_entropy_with_logits) 57 | self.score = fc(fc7, 4096, self.num_classes, relu=False, name='fc8') 58 | return self.score 59 | 60 | def loss(self, batch_x, batch_y=None): 61 | y_predict = self.inference(batch_x, training=True) 62 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_predict, labels=batch_y)) 63 | return self.loss 64 | 65 | def optimize(self, learning_rate, train_layers=[]): 66 | var_list = [v for v in tf.trainable_variables() if v.name.split('/')[0] in train_layers] 67 | return tf.train.AdamOptimizer(learning_rate).minimize(self.loss, var_list=var_list) 68 | 69 | def load_original_weights(self, session, skip_layers=[]): 70 | weights_dict = np.load('bvlc_alexnet.npy', encoding='bytes', allow_pickle=True).item() 71 | 72 | for op_name in weights_dict: 73 | # if op_name in skip_layers: 74 | # continue 75 | 76 | if op_name == 'fc8' and self.num_classes != 1000: 77 | continue 78 | 79 | with tf.variable_scope('network/' + op_name, reuse=True): 80 | for data in weights_dict[op_name]: 81 | if len(data.shape) == 1: 82 | var = tf.get_variable('biases') 83 | session.run(var.assign(data)) 84 | else: 85 | var = tf.get_variable('weights') 86 | session.run(var.assign(data)) 87 | 88 | 89 | """ 90 | Helper methods 91 | """ 92 | def conv(x, filter_height, filter_width, num_filters, stride_y, stride_x, name, padding='SAME', groups=1): 93 | input_channels = int(x.get_shape()[-1]) 94 | convolve = lambda i, k: tf.nn.conv2d(i, k, strides=[1, stride_y, stride_x, 1], padding=padding) 95 | 96 | with tf.variable_scope(name) as scope: 97 | weights = tf.get_variable('weights', shape=[filter_height, filter_width, input_channels/groups, num_filters]) 98 | biases = tf.get_variable('biases', shape=[num_filters]) 99 | 100 | if groups == 1: 101 | conv = convolve(x, weights) 102 | else: 103 | input_groups = tf.split(axis=3, num_or_size_splits=groups, value=x) 104 | weight_groups = tf.split(axis=3, num_or_size_splits=groups, value=weights) 105 | output_groups = [convolve(i, k) for i,k in zip(input_groups, weight_groups)] 106 | conv = tf.concat(axis=3, values=output_groups) 107 | 108 | # original 109 | # bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape().as_list()) 110 | bias =tf.nn.bias_add(conv, biases) 111 | relu = tf.nn.relu(bias, name=scope.name) 112 | return relu 113 | 114 | def fc(x, num_in, num_out, name, relu=True): 115 | with tf.variable_scope(name) as scope: 116 | weights = tf.get_variable('weights', shape=[num_in, num_out]) 117 | biases = tf.get_variable('biases', [num_out]) 118 | act = tf.nn.xw_plus_b(x, weights, biases, name=scope.name) 119 | 120 | if relu == True: 121 | relu = tf.nn.relu(act) 122 | return relu 123 | else: 124 | return act 125 | 126 | 127 | def max_pool(x, filter_height, filter_width, stride_y, stride_x, name, padding='SAME'): 128 | return tf.nn.max_pool(x, ksize=[1, filter_height, filter_width, 1], strides = [1, stride_y, stride_x, 1], 129 | padding = padding, name=name) 130 | 131 | def lrn(x, radius, alpha, beta, name, bias=1.0): 132 | return tf.nn.local_response_normalization(x, depth_radius=radius, alpha=alpha, beta=beta, bias=bias, name=name) 133 | 134 | def dropout(x, keep_prob): 135 | return tf.nn.dropout(x, keep_prob) 136 | -------------------------------------------------------------------------------- /model/run_most_AlexNet_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Tuan Nguyen. 2 | # All rights reserved. 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | from most_AlexNet_finetune import MOST 9 | 10 | from layers import noise 11 | from test_da_template_AlexNet_finetune import main_func, resolve_conflict_params 12 | 13 | from tensorflow.python.layers.core import dropout 14 | from tensorbayes.layers import dense, conv2d, avg_pool, max_pool 15 | 16 | import warnings 17 | import os 18 | from generic_utils import tuid, model_dir 19 | import signal 20 | import sys 21 | import time 22 | import datetime 23 | from pprint import pprint 24 | 25 | choice_default = 1 26 | warnings.simplefilter("ignore", category=DeprecationWarning) 27 | 28 | model_name = "MOST-results" 29 | current_time = tuid() 30 | 31 | 32 | def encode_layout(preprocess, training_phase=True, cnn_size='large'): 33 | layout = [] 34 | if cnn_size == 'small': 35 | layout = [ 36 | (conv2d, (64, 3, 1), {}), 37 | (max_pool, (2, 2), {}), 38 | (dropout, (), dict(training=training_phase)), 39 | (noise, (1,), dict(phase=training_phase)), 40 | ] 41 | elif cnn_size == 'large': 42 | layout = [ 43 | (preprocess, (), {}), 44 | (conv2d, (96, 3, 1), {}), 45 | (conv2d, (96, 3, 1), {}), 46 | (conv2d, (96, 3, 1), {}), 47 | (max_pool, (2, 2), {}), 48 | (dropout, (), dict(training=training_phase)), 49 | (noise, (1,), dict(phase=training_phase)), 50 | (conv2d, (192, 3, 1), {}), 51 | (conv2d, (192, 3, 1), {}), 52 | (conv2d, (192, 3, 1), {}), 53 | (max_pool, (2, 2), {}), 54 | (dropout, (), dict(training=training_phase)), 55 | (noise, (1,), dict(phase=training_phase)), 56 | ] 57 | return layout 58 | 59 | 60 | def class_discriminator_layout(num_classes=None, global_pool=True, activation=None, cnn_size='large'): 61 | layout = [] 62 | if cnn_size == 'small': 63 | layout = [ 64 | (dense, (num_classes,), dict(activation=activation)) 65 | ] 66 | 67 | elif cnn_size == 'large': 68 | layout = [ 69 | (conv2d, (192, 3, 1), {}), 70 | (conv2d, (192, 3, 1), {}), 71 | (conv2d, (192, 3, 1), {}), 72 | (avg_pool, (), dict(global_pool=global_pool)), 73 | (dense, (num_classes,), dict(activation=activation)) 74 | ] 75 | return layout 76 | 77 | 78 | def domain_layout(c): 79 | layout = [ 80 | (dense, (c,), dict(activation=None)) 81 | ] 82 | return layout 83 | 84 | 85 | def phi_layout(c): 86 | layout = [ 87 | (dense, (c,), dict(activation=None)) 88 | ] 89 | return layout 90 | 91 | 92 | def create_obj_func(params): 93 | if len(sys.argv) > 1: 94 | my_choice = int(sys.argv[1]) 95 | else: 96 | my_choice = choice_default 97 | if my_choice == 0: 98 | default_params = { 99 | } 100 | else: 101 | default_params = { 102 | 'batch_size': 128, 103 | 'learning_rate': 1e-4, 104 | 'num_iters': 80000, 105 | 'src_class_trade_off': 1.0, 106 | 'src_domain_trade_off': '1.0,1.0', 107 | 'ot_trade_off': 0.1, 108 | 'domain_trade_off': 0.1, 109 | 'src_vat_trade_off': 1.0, 110 | 'g_network_trade_off': 1.0, 111 | 'theta': 10.0, 112 | 'mdaot_model_id': '', 113 | 'classify_layout': class_discriminator_layout, 114 | 'encode_layout': encode_layout, 115 | 'domain_layout': domain_layout, 116 | 'phi_layout': phi_layout, 117 | 'log_path': os.path.join(model_dir(), model_name, "logs", "{}".format(current_time)), 118 | 'summary_freq': 400, 119 | 'current_time': current_time, 120 | 'inorm': True, 121 | 'cast_data': False, 122 | 'only_save_final_model': True, 123 | 'cnn_size': 'large', 124 | 'sample_size': 20, 125 | 'data_shift_troff': 10.0, 126 | 'num_classes': 10, 127 | 'multi_scale': '', 128 | 'resnet_depth': 101, 129 | 'train_layers': 'fc7,fc6' 130 | } 131 | 132 | default_params = resolve_conflict_params(params, default_params) 133 | 134 | print('Default parameters:') 135 | pprint(default_params) 136 | 137 | learner = MOST( 138 | **params, 139 | **default_params, 140 | ) 141 | return learner 142 | 143 | 144 | def main_test(run_exp=False): 145 | main_func( 146 | create_obj_func, 147 | choice_default=choice_default, 148 | src_name_default='mnist32_60_10', 149 | trg_name_default='mnistm32_60_10', 150 | run_exp=run_exp 151 | ) 152 | 153 | 154 | class Logger(object): 155 | def __init__(self): 156 | self.terminal = sys.stdout 157 | self.console_log_path = os.path.join(model_dir(), model_name, "console_output", "{}.txt".format(current_time)) 158 | if not os.path.exists(os.path.dirname(self.console_log_path)): 159 | os.makedirs(os.path.dirname(self.console_log_path)) 160 | self.log = open(self.console_log_path, 'a') 161 | signal.signal(signal.SIGINT, self.signal_handler) 162 | 163 | def signal_handler(self, sig, frame): 164 | print('You pressed Ctrl+C.') 165 | self.log.close() 166 | 167 | # Remove logfile 168 | os.remove(self.console_log_path) 169 | print('Removed console_output file') 170 | sys.exit(0) 171 | 172 | def write(self, message): 173 | self.terminal.write(message) 174 | self.log.write(message) 175 | 176 | def flush(self): 177 | # this flush method is needed for python 3 compatibility. 178 | # this handles the flush command by doing nothing. 179 | # you might want to specify some extra behavior here. 180 | pass 181 | 182 | 183 | if __name__ == '__main__': 184 | sys.stdout = Logger() 185 | start_time = time.time() 186 | print('Running {} ...'.format(os.path.basename(__file__))) 187 | main_test(run_exp=True) 188 | training_time = time.time() - start_time 189 | print('Total time: %s' % str(datetime.timedelta(seconds=training_time))) 190 | print("============ LOG-ID: %s ============" % current_time) 191 | -------------------------------------------------------------------------------- /model/run_most_AlexNet_train_feat.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Tuan Nguyen. 2 | # All rights reserved. 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | from most_AlexNet_train_feat import MOST 9 | 10 | from layers import noise 11 | from test_da_template_AlexNet_train_feat import main_func, resolve_conflict_params 12 | 13 | from tensorflow.python.layers.core import dropout 14 | from tensorbayes.layers import dense, conv2d, avg_pool, max_pool 15 | 16 | import warnings 17 | import os 18 | from generic_utils import tuid, model_dir 19 | import signal 20 | import sys 21 | import time 22 | import datetime 23 | from pprint import pprint 24 | 25 | choice_default = 1 26 | warnings.simplefilter("ignore", category=DeprecationWarning) 27 | 28 | model_name = "MOST-results" 29 | current_time = tuid() 30 | 31 | 32 | def encode_layout(preprocess, training_phase=True, cnn_size='large'): 33 | layout = [] 34 | if cnn_size == 'small': 35 | layout = [ 36 | (conv2d, (64, 3, 1), {}), 37 | (max_pool, (2, 2), {}), 38 | (dropout, (), dict(training=training_phase)), 39 | (noise, (1,), dict(phase=training_phase)), 40 | ] 41 | elif cnn_size == 'large': 42 | layout = [ 43 | (preprocess, (), {}), 44 | (conv2d, (96, 3, 1), {}), 45 | (conv2d, (96, 3, 1), {}), 46 | (conv2d, (96, 3, 1), {}), 47 | (max_pool, (2, 2), {}), 48 | (dropout, (), dict(training=training_phase)), 49 | (noise, (1,), dict(phase=training_phase)), 50 | (conv2d, (192, 3, 1), {}), 51 | (conv2d, (192, 3, 1), {}), 52 | (conv2d, (192, 3, 1), {}), 53 | (max_pool, (2, 2), {}), 54 | (dropout, (), dict(training=training_phase)), 55 | (noise, (1,), dict(phase=training_phase)), 56 | ] 57 | return layout 58 | 59 | 60 | def class_discriminator_layout(num_classes=None, global_pool=True, activation=None, cnn_size='large'): 61 | layout = [] 62 | if cnn_size == 'small': 63 | layout = [ 64 | (dense, (num_classes,), dict(activation=activation)) 65 | ] 66 | 67 | elif cnn_size == 'large': 68 | layout = [ 69 | (conv2d, (192, 3, 1), {}), 70 | (conv2d, (192, 3, 1), {}), 71 | (conv2d, (192, 3, 1), {}), 72 | (avg_pool, (), dict(global_pool=global_pool)), 73 | (dense, (num_classes,), dict(activation=activation)) 74 | ] 75 | return layout 76 | 77 | 78 | def domain_layout(c): 79 | layout = [ 80 | (dense, (c,), dict(activation=None)) 81 | ] 82 | return layout 83 | 84 | 85 | def phi_layout(c): 86 | layout = [ 87 | (dense, (c,), dict(activation=None)) 88 | ] 89 | return layout 90 | 91 | 92 | def create_obj_func(params): 93 | if len(sys.argv) > 1: 94 | my_choice = int(sys.argv[1]) 95 | else: 96 | my_choice = choice_default 97 | if my_choice == 0: 98 | default_params = { 99 | } 100 | else: 101 | default_params = { 102 | 'batch_size': 128, 103 | 'learning_rate': 1e-4, 104 | 'num_iters': 80000, 105 | 'phase1_iters': 20000, 106 | 'src_class_trade_off': 1.0, 107 | 'src_domain_trade_off': '1.0,1.0', 108 | 'ot_trade_off': 0.1, 109 | 'domain_trade_off': 0.1, 110 | 'src_vat_trade_off': 1.0, 111 | 'trg_vat_troff': 0.1, 112 | 'trg_ent_troff': 0.1, 113 | 'g_network_trade_off': 1.0, 114 | 'mimic_trade_off': 0.1, 115 | 'theta': 10.0, 116 | 'mdaot_model_id': '', 117 | 'classify_layout': class_discriminator_layout, 118 | 'encode_layout': encode_layout, 119 | 'domain_layout': domain_layout, 120 | 'phi_layout': phi_layout, 121 | 'log_path': os.path.join(model_dir(), model_name, "logs", "{}".format(current_time)), 122 | 'summary_freq': 400, 123 | 'current_time': current_time, 124 | 'inorm': True, 125 | 'cast_data': False, 126 | 'only_save_final_model': True, 127 | 'cnn_size': 'large', 128 | 'sample_size': 20, 129 | 'data_shift_troff': 10.0, 130 | 'data_dir': '' 131 | } 132 | 133 | default_params = resolve_conflict_params(params, default_params) 134 | 135 | print('Default parameters:') 136 | pprint(default_params) 137 | 138 | learner = MOST( 139 | **params, 140 | **default_params, 141 | ) 142 | return learner 143 | 144 | 145 | def main_test(run_exp=False): 146 | main_func( 147 | create_obj_func, 148 | choice_default=choice_default, 149 | src_name_default='mnist32_60_10', 150 | trg_name_default='mnistm32_60_10', 151 | run_exp=run_exp 152 | ) 153 | 154 | 155 | class Logger(object): 156 | def __init__(self): 157 | self.terminal = sys.stdout 158 | self.console_log_path = os.path.join(model_dir(), model_name, "console_output", "{}.txt".format(current_time)) 159 | if not os.path.exists(os.path.dirname(self.console_log_path)): 160 | os.makedirs(os.path.dirname(self.console_log_path)) 161 | self.log = open(self.console_log_path, 'a') 162 | signal.signal(signal.SIGINT, self.signal_handler) 163 | 164 | def signal_handler(self, sig, frame): 165 | print('You pressed Ctrl+C.') 166 | self.log.close() 167 | 168 | # Remove logfile 169 | os.remove(self.console_log_path) 170 | print('Removed console_output file') 171 | sys.exit(0) 172 | 173 | def write(self, message): 174 | self.terminal.write(message) 175 | self.log.write(message) 176 | 177 | def flush(self): 178 | # this flush method is needed for python 3 compatibility. 179 | # this handles the flush command by doing nothing. 180 | # you might want to specify some extra behavior here. 181 | pass 182 | 183 | 184 | if __name__ == '__main__': 185 | sys.stdout = Logger() 186 | start_time = time.time() 187 | print('Running {} ...'.format(os.path.basename(__file__))) 188 | main_test(run_exp=True) 189 | training_time = time.time() - start_time 190 | print('Total time: %s' % str(datetime.timedelta(seconds=training_time))) 191 | print("============ LOG-ID: %s ============" % current_time) 192 | -------------------------------------------------------------------------------- /model/test_da_template_AlexNet_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Tuan Nguyen. 2 | # All rights reserved. 3 | 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import absolute_import 7 | 8 | import os 9 | import sys 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | from generic_utils import random_seed 14 | from generic_utils import data_dir 15 | from resnet.preprocessor import BatchPreprocessor 16 | 17 | 18 | def test_real_dataset(create_obj_func, src_name=None, trg_name=None): 19 | print('Running {} ...'.format(os.path.basename(__file__))) 20 | 21 | if src_name is None: 22 | if len(sys.argv) > 2: 23 | src_name = sys.argv[2] 24 | else: 25 | raise Exception('Not specify source dataset') 26 | if trg_name is None: 27 | if len(sys.argv) > 3: 28 | trg_name = sys.argv[3] 29 | else: 30 | raise Exception('Not specify trgget dataset') 31 | 32 | np.random.seed(random_seed()) 33 | tf.set_random_seed(random_seed()) 34 | tf.reset_default_graph() 35 | 36 | print("========== Test on real data ==========") 37 | 38 | users_params = dict() 39 | users_params = parse_arguments(users_params) 40 | data_format = 'mat' 41 | 42 | if 'format' in users_params: 43 | data_format, users_params = extract_param('format', data_format, users_params) 44 | 45 | src_domains = src_name.split(',') 46 | num_src_domain = len(src_domains) 47 | input_size = [227, 227] 48 | n_channels = 3 49 | src_preprocessors = [] 50 | dataset_path = os.path.join(data_dir(), 'office31') 51 | multi_scale = list(map(int, users_params['multi_scale'].split(','))) 52 | 53 | for src_domain in src_domains: 54 | file_path_train = os.path.join(dataset_path, '{}_train.txt'.format(src_domain)) 55 | src_preprocessor_i = BatchPreprocessor(dataset_file_path=file_path_train, 56 | num_classes=users_params['num_classes'], 57 | output_size=input_size, horizontal_flip=True, shuffle=True, 58 | multi_scale=multi_scale) 59 | src_preprocessors.append(src_preprocessor_i) 60 | 61 | trg_train_preprocessor = BatchPreprocessor(dataset_file_path=os.path.join(dataset_path, '{}_train.txt'.format(trg_name)), 62 | num_classes=users_params['num_classes'], output_size=input_size, horizontal_flip=True, shuffle=True, 63 | multi_scale=multi_scale) 64 | 65 | trg_test_preprocessor = BatchPreprocessor(dataset_file_path=os.path.join(dataset_path, '{}_test.txt'.format(trg_name)), 66 | num_classes=users_params['num_classes'], output_size=input_size) 67 | 68 | assert users_params['batch_size'] % num_src_domain == 0 69 | 70 | print('users_params:', users_params) 71 | print('src_name:', src_name, ', trg_name:', trg_name) 72 | for i in range(len(src_domains)): 73 | print(src_domains[i], len(src_preprocessors[i].labels)) 74 | print(trg_name, len(trg_test_preprocessor.labels)) 75 | 76 | learner = create_obj_func(users_params) 77 | learner.dim_src = tuple(input_size + [n_channels]) 78 | learner.dim_trg = tuple(input_size + [n_channels]) 79 | 80 | learner._init(src_preprocessors, trg_train_preprocessor, trg_test_preprocessor, num_src_domain) 81 | learner._build_model() 82 | learner._fit_loop() 83 | 84 | 85 | def main_func( 86 | create_obj_func, 87 | choice_default=0, 88 | src_name_default='svmguide1', 89 | trg_name_default='svmguide1', 90 | run_exp=False): 91 | 92 | if not run_exp: 93 | choice_lst = [0, 1, 2] 94 | src_name = src_name_default 95 | trg_name = trg_name_default 96 | elif len(sys.argv) > 1: 97 | choice_lst = [int(sys.argv[1])] 98 | src_name = None 99 | trg_name = None 100 | else: 101 | choice_lst = [choice_default] 102 | src_name = src_name_default 103 | trg_name = trg_name_default 104 | 105 | for choice in choice_lst: 106 | if choice == 0: 107 | pass 108 | # add another function here 109 | elif choice == 1: 110 | test_real_dataset(create_obj_func, src_name, trg_name) 111 | 112 | 113 | def parse_arguments(params, as_array=False): 114 | for it in range(4, len(sys.argv), 2): 115 | params[sys.argv[it]] = parse_argument(sys.argv[it + 1], as_array) 116 | return params 117 | 118 | 119 | def parse_argument(string, as_array=False): 120 | try: 121 | result = int(string) 122 | except ValueError: 123 | try: 124 | result = float(string) 125 | except ValueError: 126 | if str.lower(string) == 'true': 127 | result = True 128 | elif str.lower(string) == 'false': 129 | result = False 130 | elif string == "[]": 131 | return [] 132 | elif ('|' in string) and ('[' in string) and (']' in string): 133 | result = [float(item) for item in string[1:-1].split('|')] 134 | return result 135 | elif (',' in string) and ('(' in string) and (')' in string): 136 | split = string[1:-1].split(',') 137 | result = float(split[0]) ** np.arange(float(split[1]), float(split[2]), float(split[3])) 138 | return result 139 | else: 140 | result = string 141 | 142 | return [result] if as_array else result 143 | 144 | 145 | def resolve_conflict_params(primary_params, secondary_params): 146 | for key in primary_params.keys(): 147 | if key in secondary_params.keys(): 148 | del secondary_params[key] 149 | return secondary_params 150 | 151 | 152 | def extract_param(key, value, params_gridsearch, scalar=False): 153 | if key in params_gridsearch.keys(): 154 | value = params_gridsearch[key] 155 | del params_gridsearch[key] 156 | if scalar and (value is not None): 157 | value = value[0] 158 | return value, params_gridsearch 159 | 160 | 161 | def dict2string(params): 162 | result = '' 163 | for key, value in params.items(): 164 | if type(value) is np.ndarray: 165 | if value.size < 16: 166 | result += key + ': ' + '|'.join('{0:.4f}'.format(x) for x in value.ravel()) + ', ' 167 | else: 168 | result += key + ': ' + str(value) + ', ' 169 | return '{' + result[:-2] + '}' 170 | -------------------------------------------------------------------------------- /model/run_most_digits.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Tuan Nguyen. 2 | # All rights reserved. 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | from most_digits import MOST 9 | from layers import noise 10 | from test_da_template_digits import main_func, resolve_conflict_params 11 | from tensorflow.python.layers.core import dropout 12 | from tensorbayes.layers import dense, conv2d, avg_pool, max_pool 13 | 14 | import warnings 15 | import os 16 | from generic_utils import tuid, model_dir 17 | import signal 18 | import sys 19 | import time 20 | import datetime 21 | 22 | choice_default = 1 23 | warnings.simplefilter("ignore", category=DeprecationWarning) 24 | 25 | model_name = "MOST-results" 26 | current_time = tuid() 27 | 28 | 29 | # generator 30 | def encode_layout(preprocess, training_phase=True, cnn_size='large'): 31 | layout = [] 32 | if cnn_size == 'small': 33 | layout = [ 34 | (preprocess, (), {}), 35 | (conv2d, (64, 3, 1), {}), 36 | (conv2d, (64, 3, 1), {}), 37 | (conv2d, (64, 3, 1), {}), 38 | (max_pool, (2, 2), {}), 39 | (dropout, (), dict(training=training_phase)), 40 | (noise, (1,), dict(phase=training_phase)), 41 | (conv2d, (64, 3, 1), {}), 42 | (conv2d, (64, 3, 1), {}), 43 | (conv2d, (64, 3, 1), {}), 44 | (max_pool, (2, 2), {}), 45 | (dropout, (), dict(training=training_phase)), 46 | (noise, (1,), dict(phase=training_phase)), 47 | ] 48 | elif cnn_size == 'large': 49 | layout = [ 50 | (preprocess, (), {}), 51 | (conv2d, (96, 3, 1), {}), 52 | (conv2d, (96, 3, 1), {}), 53 | (conv2d, (96, 3, 1), {}), 54 | (max_pool, (2, 2), {}), 55 | (dropout, (), dict(training=training_phase)), 56 | (noise, (1,), dict(phase=training_phase)), 57 | (conv2d, (192, 3, 1), {}), 58 | (conv2d, (192, 3, 1), {}), 59 | (conv2d, (192, 3, 1), {}), 60 | (max_pool, (2, 2), {}), 61 | (dropout, (), dict(training=training_phase)), 62 | (noise, (1,), dict(phase=training_phase)), 63 | ] 64 | return layout 65 | 66 | 67 | # classifier 68 | def class_discriminator_layout(num_classes=None, global_pool=True, activation=None, cnn_size='large'): 69 | layout = [] 70 | if cnn_size == 'small': 71 | layout = [ 72 | (conv2d, (64, 3, 1), {}), 73 | (conv2d, (64, 3, 1), {}), 74 | (conv2d, (64, 3, 1), {}), 75 | (avg_pool, (), dict(global_pool=global_pool)), 76 | (dense, (num_classes,), dict(activation=activation)) 77 | ] 78 | 79 | elif cnn_size == 'large': 80 | layout = [ 81 | (conv2d, (192, 3, 1), {}), 82 | (conv2d, (192, 3, 1), {}), 83 | (conv2d, (192, 3, 1), {}), 84 | (avg_pool, (), dict(global_pool=global_pool)), 85 | (dense, (num_classes,), dict(activation=activation)) 86 | ] 87 | return layout 88 | 89 | 90 | # discriminator 91 | def domain_layout(c): 92 | layout = [ 93 | (dense, (100,), {}), 94 | (dense, (c,), dict(activation=None)) 95 | ] 96 | return layout 97 | 98 | 99 | def phi_layout(c): 100 | layout = [ 101 | (dense, (100,), {}), 102 | (dense, (c,), dict(activation=None)) 103 | ] 104 | return layout 105 | 106 | 107 | def create_obj_func(params): 108 | if len(sys.argv) > 1: 109 | my_choice = int(sys.argv[1]) 110 | else: 111 | my_choice = choice_default 112 | if my_choice == 0: 113 | default_params = { 114 | } 115 | else: 116 | default_params = { 117 | 'batch_size': 200, 118 | 'learning_rate': 0.0002, 119 | 'num_iters': 80000, 120 | 'phase1_iters': 20000, 121 | 'src_class_trade_off': 1.0, 122 | 'src_domain_trade_off': '1.0,1.0,1.0,1.0', 123 | 'ot_trade_off': 0.1, 124 | 'domain_trade_off': 1.0, 125 | 'trg_vat_troff': 0.1, 126 | 'trg_ent_troff': 0.1, 127 | 'g_network_trade_off': 1.0, 128 | 'mimic_trade_off': 1.0, 129 | 'theta': 10.0, 130 | 'mdaot_model_id': '', 131 | 'classify_layout': class_discriminator_layout, 132 | 'encode_layout': encode_layout, 133 | 'domain_layout': domain_layout, 134 | 'phi_layout': phi_layout, 135 | 'log_path': os.path.join(model_dir(), model_name, "logs", "{}".format(current_time)), 136 | 'summary_freq': 800, 137 | 'current_time': current_time, 138 | 'inorm': True, 139 | 'cast_data': True, 140 | 'only_save_final_model': True, 141 | 'cnn_size': 'small', 142 | 'sample_size': 20, 143 | 'data_shift_troff': 10.0, 144 | 'lbl_shift_troff': 1.0 145 | } 146 | 147 | default_params = resolve_conflict_params(params, default_params) 148 | learner = MOST( 149 | **params, 150 | **default_params, 151 | ) 152 | return learner 153 | 154 | 155 | def main_test(run_exp=False): 156 | main_func( 157 | create_obj_func, 158 | choice_default=choice_default, 159 | src_name_default='mnist32_60_10', 160 | trg_name_default='mnistm32_60_10', 161 | run_exp=run_exp, 162 | freq_predict_display=10, 163 | summary_freq=100, 164 | current_time=current_time, 165 | log_path=os.path.join(model_dir(), model_name, "logs", "{}".format(current_time)) 166 | ) 167 | 168 | 169 | class Logger(object): 170 | def __init__(self): 171 | self.terminal = sys.stdout 172 | self.console_log_path = os.path.join(model_dir(), model_name, "console_output", "{}.txt".format(current_time)) 173 | if not os.path.exists(os.path.dirname(self.console_log_path)): 174 | os.makedirs(os.path.dirname(self.console_log_path)) 175 | self.log = open(self.console_log_path, 'a') 176 | signal.signal(signal.SIGINT, self.signal_handler) 177 | 178 | def signal_handler(self, sig, frame): 179 | print('You pressed Ctrl+C.') 180 | self.log.close() 181 | 182 | # Remove logfile 183 | os.remove(self.console_log_path) 184 | print('Removed console_output file') 185 | sys.exit(0) 186 | 187 | def write(self, message): 188 | self.terminal.write(message) 189 | self.log.write(message) 190 | 191 | def flush(self): 192 | # this flush method is needed for python 3 compatibility. 193 | # this handles the flush command by doing nothing. 194 | # you might want to specify some extra behavior here. 195 | pass 196 | 197 | 198 | if __name__ == '__main__': 199 | sys.stdout = Logger() 200 | start_time = time.time() 201 | print('Running {} ...'.format(os.path.basename(__file__))) 202 | main_test(run_exp=True) 203 | training_time = time.time() - start_time 204 | print('Total time: %s' % str(datetime.timedelta(seconds=training_time))) 205 | print("============ LOG-ID: %s ============" % current_time) 206 | -------------------------------------------------------------------------------- /model/alexnet/finetune.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import tensorflow as tf 4 | import datetime 5 | from model import AlexNetModel 6 | sys.path.insert(0, '../utils') 7 | from preprocessor import BatchPreprocessor 8 | 9 | 10 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate for adam optimizer') 11 | tf.app.flags.DEFINE_float('dropout_keep_prob', 0.5, 'Dropout keep probability') 12 | tf.app.flags.DEFINE_integer('num_epochs', 10, 'Number of epochs for training') 13 | tf.app.flags.DEFINE_integer('num_classes', 26, 'Number of classes') 14 | tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size') 15 | tf.app.flags.DEFINE_string('train_layers', 'fc8,fc7', 'Finetuning layers, seperated by commas') 16 | tf.app.flags.DEFINE_string('multi_scale', '', 'As preprocessing; scale the image randomly between 2 numbers and crop randomly at network\'s input size') 17 | tf.app.flags.DEFINE_string('training_file', '../data/train.txt', 'Training dataset file') 18 | tf.app.flags.DEFINE_string('val_file', '../data/val.txt', 'Validation dataset file') 19 | tf.app.flags.DEFINE_string('tensorboard_root_dir', '../training', 'Root directory to put the training logs and weights') 20 | tf.app.flags.DEFINE_integer('log_step', 10, 'Logging period in terms of iteration') 21 | 22 | FLAGS = tf.app.flags.FLAGS 23 | 24 | 25 | def main(_): 26 | # Create training directories 27 | now = datetime.datetime.now() 28 | train_dir_name = now.strftime('alexnet_%Y%m%d_%H%M%S') 29 | train_dir = os.path.join(FLAGS.tensorboard_root_dir, train_dir_name) 30 | checkpoint_dir = os.path.join(train_dir, 'checkpoint') 31 | tensorboard_dir = os.path.join(train_dir, 'tensorboard') 32 | tensorboard_train_dir = os.path.join(tensorboard_dir, 'train') 33 | tensorboard_val_dir = os.path.join(tensorboard_dir, 'val') 34 | 35 | if not os.path.isdir(FLAGS.tensorboard_root_dir): os.mkdir(FLAGS.tensorboard_root_dir) 36 | if not os.path.isdir(train_dir): os.mkdir(train_dir) 37 | if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir) 38 | if not os.path.isdir(tensorboard_dir): os.mkdir(tensorboard_dir) 39 | if not os.path.isdir(tensorboard_train_dir): os.mkdir(tensorboard_train_dir) 40 | if not os.path.isdir(tensorboard_val_dir): os.mkdir(tensorboard_val_dir) 41 | 42 | # Write flags to txt 43 | flags_file_path = os.path.join(train_dir, 'flags.txt') 44 | flags_file = open(flags_file_path, 'w') 45 | flags_file.write('learning_rate={}\n'.format(FLAGS.learning_rate)) 46 | flags_file.write('dropout_keep_prob={}\n'.format(FLAGS.dropout_keep_prob)) 47 | flags_file.write('num_epochs={}\n'.format(FLAGS.num_epochs)) 48 | flags_file.write('batch_size={}\n'.format(FLAGS.batch_size)) 49 | flags_file.write('train_layers={}\n'.format(FLAGS.train_layers)) 50 | flags_file.write('multi_scale={}\n'.format(FLAGS.multi_scale)) 51 | flags_file.write('tensorboard_root_dir={}\n'.format(FLAGS.tensorboard_root_dir)) 52 | flags_file.write('log_step={}\n'.format(FLAGS.log_step)) 53 | flags_file.close() 54 | 55 | # Placeholders 56 | x = tf.placeholder(tf.float32, [FLAGS.batch_size, 227, 227, 3]) 57 | y = tf.placeholder(tf.float32, [None, FLAGS.num_classes]) 58 | dropout_keep_prob = tf.placeholder(tf.float32) 59 | 60 | # Model 61 | train_layers = FLAGS.train_layers.split(',') 62 | model = AlexNetModel(num_classes=FLAGS.num_classes, dropout_keep_prob=dropout_keep_prob) 63 | loss = model.loss(x, y) 64 | train_op = model.optimize(FLAGS.learning_rate, train_layers) 65 | 66 | # Training accuracy of the model 67 | correct_pred = tf.equal(tf.argmax(model.score, 1), tf.argmax(y, 1)) 68 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 69 | 70 | # Summaries 71 | tf.summary.scalar('train_loss', loss) 72 | tf.summary.scalar('train_accuracy', accuracy) 73 | merged_summary = tf.summary.merge_all() 74 | 75 | train_writer = tf.summary.FileWriter(tensorboard_train_dir) 76 | val_writer = tf.summary.FileWriter(tensorboard_val_dir) 77 | saver = tf.train.Saver() 78 | 79 | # Batch preprocessors 80 | multi_scale = FLAGS.multi_scale.split(',') 81 | if len(multi_scale) == 2: 82 | multi_scale = [int(multi_scale[0]), int(multi_scale[1])] 83 | else: 84 | multi_scale = None 85 | 86 | train_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.training_file, num_classes=FLAGS.num_classes, 87 | output_size=[227, 227], horizontal_flip=True, shuffle=True, multi_scale=multi_scale) 88 | val_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.val_file, num_classes=FLAGS.num_classes, output_size=[227, 227]) 89 | 90 | # Get the number of training/validation steps per epoch 91 | train_batches_per_epoch = np.floor(len(train_preprocessor.labels) / FLAGS.batch_size).astype(np.int16) 92 | val_batches_per_epoch = np.floor(len(val_preprocessor.labels) / FLAGS.batch_size).astype(np.int16) 93 | 94 | 95 | with tf.Session() as sess: 96 | sess.run(tf.global_variables_initializer()) 97 | train_writer.add_graph(sess.graph) 98 | 99 | # Load the pretrained weights 100 | model.load_original_weights(sess, skip_layers=train_layers) 101 | 102 | # Directly restore (your model should be exactly the same with checkpoint) 103 | # saver.restore(sess, "/Users/dgurkaynak/Projects/marvel-training/alexnet64-fc6/model_epoch10.ckpt") 104 | 105 | print("{} Start training...".format(datetime.datetime.now())) 106 | print("{} Open Tensorboard at --logdir {}".format(datetime.datetime.now(), tensorboard_dir)) 107 | 108 | for epoch in range(FLAGS.num_epochs): 109 | print("{} Epoch number: {}".format(datetime.datetime.now(), epoch+1)) 110 | step = 1 111 | 112 | # Start training 113 | while step < train_batches_per_epoch: 114 | batch_xs, batch_ys = train_preprocessor.next_batch(FLAGS.batch_size) 115 | sess.run(train_op, feed_dict={x: batch_xs, y: batch_ys, dropout_keep_prob: FLAGS.dropout_keep_prob}) 116 | 117 | # Logging 118 | if step % FLAGS.log_step == 0: 119 | s = sess.run(merged_summary, feed_dict={x: batch_xs, y: batch_ys, dropout_keep_prob: 1.}) 120 | train_writer.add_summary(s, epoch * train_batches_per_epoch + step) 121 | 122 | step += 1 123 | 124 | # Epoch completed, start validation 125 | print("{} Start validation".format(datetime.datetime.now())) 126 | test_acc = 0. 127 | test_count = 0 128 | 129 | for _ in range(val_batches_per_epoch): 130 | batch_tx, batch_ty = val_preprocessor.next_batch(FLAGS.batch_size) 131 | acc = sess.run(accuracy, feed_dict={x: batch_tx, y: batch_ty, dropout_keep_prob: 1.}) 132 | test_acc += acc 133 | test_count += 1 134 | 135 | test_acc /= test_count 136 | s = tf.Summary(value=[ 137 | tf.Summary.Value(tag="validation_accuracy", simple_value=test_acc) 138 | ]) 139 | val_writer.add_summary(s, epoch+1) 140 | print("{} Validation Accuracy = {:.4f}".format(datetime.datetime.now(), test_acc)) 141 | 142 | # Reset the dataset pointers 143 | val_preprocessor.reset_pointer() 144 | train_preprocessor.reset_pointer() 145 | 146 | print("{} Saving checkpoint of model...".format(datetime.datetime.now())) 147 | 148 | #save checkpoint of the model 149 | checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch'+str(epoch+1)+'.ckpt') 150 | save_path = saver.save(sess, checkpoint_path) 151 | 152 | print("{} Model checkpoint saved at {}".format(datetime.datetime.now(), checkpoint_path)) 153 | 154 | if __name__ == '__main__': 155 | tf.app.run() 156 | -------------------------------------------------------------------------------- /model/resnet/finetune.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import tensorflow as tf 4 | import datetime 5 | from model import ResNetModel 6 | sys.path.insert(0, '../utils') 7 | from preprocessor import BatchPreprocessor 8 | 9 | 10 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate for adam optimizer') 11 | tf.app.flags.DEFINE_integer('resnet_depth', 101, 'ResNet architecture to be used: 50, 101 or 152') 12 | tf.app.flags.DEFINE_integer('num_epochs', 100, 'Number of epochs for training') 13 | tf.app.flags.DEFINE_integer('num_classes', 10, 'Number of classes') 14 | tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size') 15 | tf.app.flags.DEFINE_string('train_layers', 'fc', 'Finetuning layers, seperated by commas') 16 | tf.app.flags.DEFINE_string('multi_scale', '', 'As preprocessing; scale the image randomly between 2 numbers and crop randomly at network\'s input size') 17 | tf.app.flags.DEFINE_string('training_file', '../data/office_caltech10/amazon.txt', 'Training dataset file') 18 | tf.app.flags.DEFINE_string('val_file', '../data/office_caltech10/amazon.txt', 'Validation dataset file') 19 | tf.app.flags.DEFINE_string('tensorboard_root_dir', '../training', 'Root directory to put the training logs and weights') 20 | tf.app.flags.DEFINE_integer('log_step', 10, 'Logging period in terms of iteration') 21 | 22 | FLAGS = tf.app.flags.FLAGS 23 | 24 | 25 | def main(_): 26 | # Create training directories 27 | now = datetime.datetime.now() 28 | train_dir_name = now.strftime('resnet_%Y%m%d_%H%M%S') 29 | train_dir = os.path.join(FLAGS.tensorboard_root_dir, train_dir_name) 30 | checkpoint_dir = os.path.join(train_dir, 'checkpoint') 31 | tensorboard_dir = os.path.join(train_dir, 'tensorboard') 32 | tensorboard_train_dir = os.path.join(tensorboard_dir, 'train') 33 | tensorboard_val_dir = os.path.join(tensorboard_dir, 'val') 34 | 35 | if not os.path.isdir(FLAGS.tensorboard_root_dir): os.mkdir(FLAGS.tensorboard_root_dir) 36 | if not os.path.isdir(train_dir): os.mkdir(train_dir) 37 | if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir) 38 | if not os.path.isdir(tensorboard_dir): os.mkdir(tensorboard_dir) 39 | if not os.path.isdir(tensorboard_train_dir): os.mkdir(tensorboard_train_dir) 40 | if not os.path.isdir(tensorboard_val_dir): os.mkdir(tensorboard_val_dir) 41 | 42 | # Write flags to txt 43 | flags_file_path = os.path.join(train_dir, 'flags.txt') 44 | flags_file = open(flags_file_path, 'w') 45 | flags_file.write('learning_rate={}\n'.format(FLAGS.learning_rate)) 46 | flags_file.write('resnet_depth={}\n'.format(FLAGS.resnet_depth)) 47 | flags_file.write('num_epochs={}\n'.format(FLAGS.num_epochs)) 48 | flags_file.write('batch_size={}\n'.format(FLAGS.batch_size)) 49 | flags_file.write('train_layers={}\n'.format(FLAGS.train_layers)) 50 | flags_file.write('multi_scale={}\n'.format(FLAGS.multi_scale)) 51 | flags_file.write('tensorboard_root_dir={}\n'.format(FLAGS.tensorboard_root_dir)) 52 | flags_file.write('log_step={}\n'.format(FLAGS.log_step)) 53 | flags_file.close() 54 | 55 | # Placeholders 56 | x = tf.placeholder(tf.float32, [FLAGS.batch_size, 224, 224, 3]) 57 | y = tf.placeholder(tf.float32, [None, FLAGS.num_classes]) 58 | is_training = tf.placeholder('bool', []) 59 | 60 | # Model 61 | train_layers = FLAGS.train_layers.split(',') 62 | model = ResNetModel(is_training, depth=FLAGS.resnet_depth, num_classes=FLAGS.num_classes) 63 | loss = model.loss(x, y) 64 | train_op = model.optimize(FLAGS.learning_rate, train_layers) 65 | 66 | # Training accuracy of the model 67 | correct_pred = tf.equal(tf.argmax(model.prob, 1), tf.argmax(y, 1)) 68 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 69 | 70 | # Summaries 71 | tf.summary.scalar('train_loss', loss) 72 | tf.summary.scalar('train_accuracy', accuracy) 73 | merged_summary = tf.summary.merge_all() 74 | 75 | train_writer = tf.summary.FileWriter(tensorboard_train_dir) 76 | val_writer = tf.summary.FileWriter(tensorboard_val_dir) 77 | saver = tf.train.Saver() 78 | 79 | # Batch preprocessors 80 | multi_scale = FLAGS.multi_scale.split(',') 81 | if len(multi_scale) == 2: 82 | multi_scale = [int(multi_scale[0]), int(multi_scale[1])] 83 | else: 84 | multi_scale = None 85 | 86 | train_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.training_file, num_classes=FLAGS.num_classes, 87 | output_size=[224, 224], horizontal_flip=True, shuffle=True, multi_scale=multi_scale) 88 | val_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.val_file, num_classes=FLAGS.num_classes, output_size=[224, 224]) 89 | 90 | # Get the number of training/validation steps per epoch 91 | train_batches_per_epoch = np.floor(len(train_preprocessor.labels) / FLAGS.batch_size).astype(np.int16) 92 | val_batches_per_epoch = np.floor(len(val_preprocessor.labels) / FLAGS.batch_size).astype(np.int16) 93 | 94 | 95 | with tf.Session() as sess: 96 | sess.run(tf.global_variables_initializer()) 97 | train_writer.add_graph(sess.graph) 98 | 99 | # Load the pretrained weights 100 | model.load_original_weights(sess, skip_layers=train_layers) 101 | 102 | # Directly restore (your model should be exactly the same with checkpoint) 103 | # saver.restore(sess, "/Users/dgurkaynak/Projects/marvel-training/alexnet64-fc6/model_epoch10.ckpt") 104 | 105 | print("{} Start training...".format(datetime.datetime.now())) 106 | print("{} Open Tensorboard at --logdir {}".format(datetime.datetime.now(), tensorboard_dir)) 107 | 108 | for epoch in range(FLAGS.num_epochs): 109 | print("{} Epoch number: {}".format(datetime.datetime.now(), epoch+1)) 110 | step = 1 111 | 112 | # Start training 113 | while step < train_batches_per_epoch: 114 | batch_xs, batch_ys = train_preprocessor.next_batch(FLAGS.batch_size) 115 | sess.run(train_op, feed_dict={x: batch_xs, y: batch_ys, is_training: True}) 116 | 117 | # Logging 118 | if step % FLAGS.log_step == 0: 119 | s = sess.run(merged_summary, feed_dict={x: batch_xs, y: batch_ys, is_training: False}) 120 | train_writer.add_summary(s, epoch * train_batches_per_epoch + step) 121 | 122 | step += 1 123 | 124 | # Epoch completed, start validation 125 | print("{} Start validation".format(datetime.datetime.now())) 126 | test_acc = 0. 127 | test_count = 0 128 | 129 | for _ in range(val_batches_per_epoch): 130 | batch_tx, batch_ty = val_preprocessor.next_batch(FLAGS.batch_size) 131 | acc = sess.run(accuracy, feed_dict={x: batch_tx, y: batch_ty, is_training: False}) 132 | test_acc += acc 133 | test_count += 1 134 | 135 | test_acc /= test_count 136 | s = tf.Summary(value=[ 137 | tf.Summary.Value(tag="validation_accuracy", simple_value=test_acc) 138 | ]) 139 | val_writer.add_summary(s, epoch+1) 140 | print("{} Validation Accuracy = {:.4f}".format(datetime.datetime.now(), test_acc)) 141 | 142 | # Reset the dataset pointers 143 | val_preprocessor.reset_pointer() 144 | train_preprocessor.reset_pointer() 145 | 146 | print("{} Saving checkpoint of model...".format(datetime.datetime.now())) 147 | 148 | #save checkpoint of the model 149 | checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch'+str(epoch+1)+'.ckpt') 150 | save_path = saver.save(sess, checkpoint_path) 151 | 152 | print("{} Model checkpoint saved at {}".format(datetime.datetime.now(), checkpoint_path)) 153 | 154 | if __name__ == '__main__': 155 | tf.app.run() 156 | -------------------------------------------------------------------------------- /model/resnet/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Derived from: https://github.com/ry/tensorflow-resnet 3 | """ 4 | import tensorflow as tf 5 | import numpy as np 6 | from tensorflow.python.ops import control_flow_ops 7 | from tensorflow.python.training import moving_averages 8 | 9 | 10 | NUM_BLOCKS = { 11 | 50: [3, 4, 6, 3], 12 | 101: [3, 4, 23, 3], 13 | 152: [3, 8, 36, 3] 14 | } 15 | CONV_WEIGHT_DECAY = 0.00004 16 | CONV_WEIGHT_STDDEV = 0.1 17 | MOVING_AVERAGE_DECAY = 0.9997 18 | BN_DECAY = MOVING_AVERAGE_DECAY 19 | BN_EPSILON = 0.001 20 | UPDATE_OPS_COLLECTION = 'resnet_update_ops' 21 | FC_WEIGHT_STDDEV = 0.01 22 | 23 | 24 | class ResNetModel(object): 25 | 26 | def __init__(self, is_training, depth=50, num_classes=1000): 27 | self.is_training = is_training 28 | self.num_classes = num_classes 29 | self.depth = depth 30 | 31 | if depth in NUM_BLOCKS: 32 | self.num_blocks = NUM_BLOCKS[depth] 33 | else: 34 | raise ValueError('Depth is not supported; it must be 50, 101 or 152') 35 | 36 | def inference(self, x, reuse=None, extract_feat=False): 37 | # Scale 1 38 | with tf.variable_scope('scale1', reuse=reuse): 39 | s1_conv = conv(x, ksize=7, stride=2, filters_out=64) 40 | s1_bn = bn(s1_conv, is_training=self.is_training) 41 | s1 = tf.nn.relu(s1_bn) 42 | 43 | # Scale 2 44 | with tf.variable_scope('scale2', reuse=reuse): 45 | s2_mp = tf.nn.max_pool(s1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME') 46 | s2 = stack(s2_mp, is_training=self.is_training, num_blocks=self.num_blocks[0], stack_stride=1, block_filters_internal=64) 47 | 48 | # Scale 3 49 | with tf.variable_scope('scale3', reuse=reuse): 50 | s3 = stack(s2, is_training=self.is_training, num_blocks=self.num_blocks[1], stack_stride=2, block_filters_internal=128) 51 | 52 | # Scale 4 53 | with tf.variable_scope('scale4', reuse=reuse): 54 | s4 = stack(s3, is_training=self.is_training, num_blocks=self.num_blocks[2], stack_stride=2, block_filters_internal=256) 55 | 56 | # Scale 5 57 | with tf.variable_scope('scale5', reuse=reuse): 58 | s5 = stack(s4, is_training=self.is_training, num_blocks=self.num_blocks[3], stack_stride=2, block_filters_internal=512) 59 | 60 | # post-net 61 | avg_pool = tf.reduce_mean(s5, reduction_indices=[1, 2], name='avg_pool') # (bs/k, 2048) 62 | 63 | if extract_feat: 64 | return avg_pool 65 | 66 | with tf.variable_scope('fc', reuse=reuse): 67 | self.prob = fc(avg_pool, num_units_out=256) 68 | return self.prob 69 | # return avg_pool 70 | 71 | def loss(self, batch_x, batch_y=None): 72 | y_predict = self.inference(batch_x) 73 | # cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y_predict, labels=batch_y) 74 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=y_predict, labels=batch_y) 75 | cross_entropy_mean = tf.reduce_mean(cross_entropy) 76 | regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 77 | self.loss = tf.add_n([cross_entropy_mean] + regularization_losses) 78 | return self.loss 79 | 80 | def optimize(self, learning_rate, other_var_list, train_layers, loss): 81 | # all_other_variables = [] 82 | # for var_lst in other_var_list: 83 | # for var in var_lst: 84 | # all_other_variables += var 85 | trainable_var_names = ['weights', 'biases', 'beta', 'gamma'] 86 | var_list = [[v for v in tf.trainable_variables() if 87 | v.name.split(':')[0].split('/')[-1] in trainable_var_names and 88 | contains(v.name, train_layers)]] + other_var_list 89 | print('len(var_list)', len(var_list)) 90 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss, var_list=var_list) 91 | 92 | ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY) 93 | tf.add_to_collection(UPDATE_OPS_COLLECTION, ema.apply([loss])) 94 | 95 | batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION) 96 | batchnorm_updates_op = tf.group(*batchnorm_updates) 97 | 98 | return tf.group(train_op, batchnorm_updates_op) 99 | 100 | def load_original_weights(self, session, skip_layers=[]): 101 | weights_path = 'ResNet-L{}.npy'.format(self.depth) 102 | weights_dict = np.load(weights_path, encoding='bytes', allow_pickle=True).item() 103 | 104 | for op_name in weights_dict: 105 | parts = op_name.split('/') 106 | 107 | # if contains(op_name, skip_layers): 108 | # continue 109 | 110 | if parts[0] == 'fc' and self.num_classes != 1000: 111 | continue 112 | 113 | full_name = "{}:0".format(op_name) 114 | var = [v for v in tf.global_variables() if v.name == full_name][0] # also assign mean, var of each ResNet layer 115 | session.run(var.assign(weights_dict[op_name])) 116 | 117 | 118 | """ 119 | Helper methods 120 | """ 121 | def _get_variable(name, shape, initializer, weight_decay=0.0, dtype='float', trainable=True): 122 | "A little wrapper around tf.get_variable to do weight decay" 123 | 124 | if weight_decay > 0: 125 | regularizer = tf.contrib.layers.l2_regularizer(weight_decay) 126 | else: 127 | regularizer = None 128 | 129 | return tf.get_variable(name, shape=shape, initializer=initializer, dtype=dtype, regularizer=regularizer, 130 | trainable=trainable) 131 | 132 | def conv(x, ksize, stride, filters_out): 133 | filters_in = x.get_shape()[-1] 134 | shape = [ksize, ksize, filters_in, filters_out] 135 | initializer = tf.truncated_normal_initializer(stddev=CONV_WEIGHT_STDDEV) 136 | weights = _get_variable('weights', shape=shape, dtype='float', initializer=initializer, 137 | weight_decay=CONV_WEIGHT_DECAY) 138 | return tf.nn.conv2d(x, weights, [1, stride, stride, 1], padding='SAME') 139 | 140 | def bn(x, is_training): 141 | x_shape = x.get_shape() 142 | params_shape = x_shape[-1:] 143 | 144 | axis = list(range(len(x_shape) - 1)) 145 | 146 | beta = _get_variable('beta', params_shape, initializer=tf.zeros_initializer()) 147 | gamma = _get_variable('gamma', params_shape, initializer=tf.ones_initializer()) 148 | 149 | moving_mean = _get_variable('moving_mean', params_shape, initializer=tf.zeros_initializer(), trainable=False) 150 | moving_variance = _get_variable('moving_variance', params_shape, initializer=tf.ones_initializer(), trainable=False) 151 | 152 | # These ops will only be preformed when training. 153 | mean, variance = tf.nn.moments(x, axis) 154 | update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, BN_DECAY) 155 | update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, BN_DECAY) 156 | tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean) 157 | tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance) 158 | 159 | mean, variance = control_flow_ops.cond( 160 | is_training, lambda: (mean, variance), 161 | lambda: (moving_mean, moving_variance)) 162 | 163 | return tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON) 164 | 165 | def stack(x, is_training, num_blocks, stack_stride, block_filters_internal): 166 | for n in range(num_blocks): 167 | block_stride = stack_stride if n == 0 else 1 168 | with tf.variable_scope('block%d' % (n + 1)): 169 | x = block(x, is_training, block_filters_internal=block_filters_internal, block_stride=block_stride) 170 | return x 171 | 172 | 173 | def block(x, is_training, block_filters_internal, block_stride): 174 | filters_in = x.get_shape()[-1] 175 | 176 | m = 4 177 | filters_out = m * block_filters_internal 178 | shortcut = x 179 | 180 | with tf.variable_scope('a'): 181 | a_conv = conv(x, ksize=1, stride=block_stride, filters_out=block_filters_internal) 182 | a_bn = bn(a_conv, is_training) 183 | a = tf.nn.relu(a_bn) 184 | 185 | with tf.variable_scope('b'): 186 | b_conv = conv(a, ksize=3, stride=1, filters_out=block_filters_internal) 187 | b_bn = bn(b_conv, is_training) 188 | b = tf.nn.relu(b_bn) 189 | 190 | with tf.variable_scope('c'): 191 | c_conv = conv(b, ksize=1, stride=1, filters_out=filters_out) 192 | c = bn(c_conv, is_training) 193 | 194 | with tf.variable_scope('shortcut'): 195 | if filters_out != filters_in or block_stride != 1: 196 | shortcut_conv = conv(x, ksize=1, stride=block_stride, filters_out=filters_out) 197 | shortcut = bn(shortcut_conv, is_training) 198 | 199 | return tf.nn.relu(c + shortcut) 200 | 201 | 202 | def fc(x, num_units_out): 203 | num_units_in = x.get_shape()[1] 204 | weights_initializer = tf.truncated_normal_initializer(stddev=FC_WEIGHT_STDDEV) 205 | weights = _get_variable('weights', shape=[num_units_in, num_units_out], initializer=weights_initializer, 206 | weight_decay=FC_WEIGHT_STDDEV) 207 | biases = _get_variable('biases', shape=[num_units_out], initializer=tf.zeros_initializer()) 208 | return tf.nn.xw_plus_b(x, weights, biases) 209 | 210 | def contains(target_str, search_arr): 211 | rv = False 212 | 213 | for search_str in search_arr: 214 | if search_str in target_str: 215 | rv = True 216 | break 217 | 218 | return rv 219 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | # MOST: Multi-Source Domain Adaptation via Optimal Transport for Student-Teacher Learning 6 | 7 | GitHub top languageGitHub last commitGitHub repo sizeGitHub license 8 | 9 | 10 | This is the implementation of the paper **[MOST: Multi-Source Domain Adaptation via Optimal Transport for Student-Teacher Learning](https://proceedings.mlr.press/v161/nguyen21a/nguyen21a.pdf)** which has been accepted at UAI 2021. 11 | 12 |

13 | 14 |

15 | 16 | ## A. Setup 17 | 18 | #### **Install manually** 19 | 20 | ``` 21 | Python Environment: >= 3.5 22 | Tensorflow: >= 1.9 23 | ``` 24 | 25 | #### **Install automatically from YAML file** 26 | 27 | ``` 28 | pip install --upgrade pip 29 | conda env create --file tf1.9py3.5.yml 30 | ``` 31 | 32 | #### **Install *tensorbayes*** 33 | 34 | Please note that tensorbayes 0.4.0 is out of date. Please copy a newer version to the *env* folder (tf1.9py3.5) using **tensorbayes.tar** 35 | 36 | ``` 37 | pip install tensorbayes 38 | tar -xvf tensorbayes.tar 39 | cp -rf /tensorbayes/* /opt/conda/envs/tf1.9py3.5/lib/python3.5/site-packages/tensorbayes/ 40 | ``` 41 | 42 | ## B. Training 43 | 44 | #### 1. Digits-five 45 | 46 | We first navigate to *model* folder, and then run *run_most.py* file as bellow: 47 | 48 | ```python 49 | cd model 50 | ``` 51 | 52 | To run on *Digits-five* dataset, in the root folder, please create a new folder named *features*. 53 | 54 | At the next step, user downloads *Digits-five* dataset [here](https://drive.google.com/file/d/1OpoPALgaMdOlkSkJDqf-JwEvRUkyDqkw/view?usp=sharing) and place extracting files to the *features* folder. 55 | 56 | 1. "→ **mm**'' task 57 | 58 | ```python 59 | python run_most_digits.py 1 "mnist32_60_10,usps32,svhn,syn32" mnistm32_60_10 format mat num_iters 80000 phase1_iters 0 summary_freq 800 learning_rate 0.0002 batch_size 200 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0,1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 trg_vat_troff 0.1 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 0.1 cast_data True cnn_size small theta 0.1 sample_size 5 60 | ``` 61 | 62 | 2. ''→ **mt**'' task 63 | ```python 64 | python run_most_digits.py 1 "mnistm32_60_10,usps32,svhn,syn32" mnist32_60_10 format mat num_iters 80000 phase1_iters 0 summary_freq 800 learning_rate 0.0002 batch_size 200 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0,1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 trg_vat_troff 0.1 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 1.0 cast_data True cnn_size small theta 0.1 sample_size 5 65 | ``` 66 | 67 | 3. ''→ **up**'' task 68 | ```python 69 | python run_most_digits.py 1 "mnistm32_60_10,mnist32_60_10,svhn,syn32" usps32 format mat num_iters 80000 phase1_iters 0 summary_freq 800 learning_rate 0.0002 batch_size 200 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0,1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 trg_vat_troff 0.1 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 1.0 cast_data True cnn_size small theta 0.1 sample_size 5 70 | ``` 71 | 72 | 4. ''→ **sv**'' task 73 | ```python 74 | python run_most_digits.py 1 "mnistm32_60_10,mnist32_60_10,usps32,syn32" svhn format mat num_iters 80000 phase1_iters 0 summary_freq 800 learning_rate 0.0002 batch_size 200 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0,1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 trg_vat_troff 0.1 trg_ent_troff 0.0 data_shift_troff 10.0 mimic_trade_off 1.0 cast_data True cnn_size small theta 0.1 sample_size 5 75 | ``` 76 | 77 | 5. ''→ **sy**'' task 78 | ```python 79 | python run_most_digits.py 1 "mnistm32_60_10,mnist32_60_10,usps32,svhn" syn32 format mat num_iters 80000 phase1_iters 0 summary_freq 800 learning_rate 0.0002 batch_size 200 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0,1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 trg_vat_troff 0.1 trg_ent_troff 0.0 data_shift_troff 10.0 mimic_trade_off 1.0 cast_data True cnn_size small theta 0.1 sample_size 5 80 | ``` 81 | 82 | #### 2. Office-31 83 | 84 | #### Step 1: Train a shallow network using extracted features 85 | 86 | Please download extracted features from AlexNet [here](https://drive.google.com/file/d/1dsrHn4S6lCmlTa4Eg4RAE5JRfZUIxR8G/view?usp=sharing) and save them to the *features* folder. 87 | 88 | 1. ''→ **D**'' task 89 | 90 | ```python 91 | python run_most_AlexNet_train_feat.py 1 "amazon_AlexNet,webcam_AlexNet" dslr_AlexNet format mat num_iters 20000 summary_freq 200 learning_rate 0.0001 inorm True batch_size 62 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 src_vat_trade_off 0.0 trg_vat_troff 0.1 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 0.1 cast_data False cnn_size small theta 0.1 g_network_trade_off 1.0 sample_size 1 num_classes 31 multi_scale "" data_dir "" 92 | ``` 93 | 94 | 2. ''→ **W**'' task 95 | 96 | ```python 97 | python run_most_AlexNet_train_feat.py 1 "amazon_AlexNet,dslr_AlexNet" webcam_AlexNet format mat num_iters 20000 summary_freq 200 learning_rate 0.0001 inorm True batch_size 62 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 src_vat_trade_off 0.0 trg_vat_troff 0.1 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 0.1 cast_data False cnn_size small theta 0.1 g_network_trade_off 1.0 sample_size 1 num_classes 31 multi_scale "" data_dir "" 98 | ``` 99 | 100 | 3. ''→ **A**'' task 101 | 102 | ```python 103 | python run_most_AlexNet_train_feat.py 1 "dslr_AlexNet,webcam_AlexNet" amazon_AlexNet format mat num_iters 20000 summary_freq 200 learning_rate 0.0001 inorm True batch_size 62 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 src_vat_trade_off 0.0 trg_vat_troff 1.0 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 1.0 cast_data False cnn_size small theta 0.1 g_network_trade_off 1.0 sample_size 1 num_classes 31 multi_scale "" data_dir "" 104 | ``` 105 | 106 | We will get a model which is saved in folder *MOST-results/saved-model* together with its unique id, or *LOG-ID* will be printed out at the end of training. 107 | 108 | #### Step 2: Finetune the entire model including AlexNet and the shallow network. 109 | 110 | Some mini-steps should be taken for [finetuning](https://github.com/dgurkaynak/tensorflow-cnn-finetune). 111 | 112 | - Download image data of *Office-31* [here](https://drive.google.com/file/d/1GdbY8GJ-HCp-YrDhDaLC-7BCZ0oxY6G9/view?usp=sharing) and extract them to a new folder named *data* in the root folder. 113 | 114 | - Download pre-trained AlexNet using the following command line, and save it to the *model* folder. 115 | 116 | ```bash 117 | wget http://www.cs.toronto.edu/~guerzhoy/tf_alexnet/bvlc_alexnet.npy 118 | ``` 119 | 120 | Finally, please use your model id which is saved at Step 1 and replace ** in the following scripts. The model id is a string ID based on the time of running model at Step 1. It should be, for example, “2021-10-07_02.30.5748”. 121 | 122 | 1. ''→ **D**'' task 123 | 124 | ```python 125 | python run_most_AlexNet_finetune.py 1 "amazon,webcam" dslr format mat num_iters 2000 summary_freq 20 learning_rate 0.0001 inorm True batch_size 62 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 src_vat_trade_off 0.0 trg_vat_troff 0.1 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 0.1 cast_data False cnn_size small theta 0.1 g_network_trade_off 1.0 sample_size 1 num_classes 31 multi_scale "228,256" data_dir "" mdaot_model_id train_layers "fc7,fc6,conv5,conv4,conv3,conv2,conv1" 126 | ``` 127 | 128 | 2. ''→ **W**'' task 129 | 130 | ```python 131 | python run_most_AlexNet_finetune.py 1 "amazon,dslr" webcam format mat num_iters 2000 summary_freq 20 learning_rate 0.0001 inorm True batch_size 62 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 src_vat_trade_off 0.0 trg_vat_troff 0.1 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 0.1 cast_data False cnn_size small theta 0.1 g_network_trade_off 1.0 sample_size 1 num_classes 31 multi_scale "228,256" data_dir "" mdaot_model_id train_layers "fc7,fc6,conv5,conv4,conv3,conv2,conv1" 132 | ``` 133 | 134 | 3. ''→ **A**'' task 135 | 136 | ```python 137 | python run_most_AlexNet_finetune.py 1 "dslr,webcam" amazon format mat num_iters 5000 summary_freq 50 learning_rate 0.0001 inorm True batch_size 62 src_class_trade_off 1.0 src_domain_trade_off "1.0,1.0" ot_trade_off 0.1 domain_trade_off 1.0 src_vat_trade_off 0.0 trg_vat_troff 0.1 trg_ent_troff 0.1 data_shift_troff 10.0 mimic_trade_off 0.1 cast_data False cnn_size small theta 0.1 g_network_trade_off 1.0 sample_size 1 num_classes 31 multi_scale "228,256" data_dir "" mdaot_model_id train_layers "fc7,fc6,conv5,conv4,conv3,conv2,conv1" 138 | ``` 139 | 140 | ## C. Results 141 | 142 | #### Table 1: Classification accuracy (%) on Digits-five. 143 | 144 | | Methods | → mm | → mt | → us | → sv | → sy | Avg | 145 | | :-------------: | :--------: | :--------: | :--------: | :--------: | :--------: | :------: | 146 | | MDAN [1] | 69.5 | 98.0 | 92.4 | 69.2 | 87.4 | 83.3 | 147 | | DCTN [2] | 70.5 | 96.2 | 92.8 | 77.6 | 86.8 | 84.8 | 148 | | M3SDA [3] | 72.8 | 98.4 | 96.1 | 81.3 | 89.6 | 87.7 | 149 | | MDDA [4] | 78.6 | 98.8 | 93.9 | 79.3 | 89.7 | 88.1 | 150 | | LtC-MSDA [5] | 85.6 | 99.0 | 98.3 | 83.2 | 93.0 | 91.8 | 151 | | **MOST** (ours) | **91.5** | **99.6** | **98.4** | **90.9** | **96.4** | **95.4** | 152 | 153 | #### Table 2: Classification accuracy (%) on Office-31 using pretrained AlexNet. 154 | 155 | | Methods | → D | → W | → A | Avg | 156 | | :-------------: | :-------: | :-------: | :-------: | :------: | 157 | | MDAN [1] | 99.2 | 95.4 | 55.2 | 83.3 | 158 | | DCTN [2] | 99.6 | 96.9 | 54.9 | 83.8 | 159 | | M3SDA [3] | 99.4 | 96.2 | 55.4 | 83.7 | 160 | | MDDA [4] | 99.2 | 97.1 | 56.2 | 84.2 | 161 | | LtC-MSDA [5] | 99.6 | 97.2 | 56.9 | 84.6 | 162 | | **MOST** (ours) | **100** | **98.7** | **60.6** | **86.4** | 163 | 164 | ## D. Citations 165 | 166 | Please cite the paper if MOST is helpful for your research: 167 | 168 | ``` 169 | @InProceedings{tuan2021most, 170 | author = {Nguyen, Tuan and Le, Trung and Zhao, He and Tran, Quan Hung and Nguyen, Truyen and Phung, Dinh}, 171 | title = {Most: multi-source domain adaptation via optimal transport for student-teacher learning}, 172 | booktitle= {Proceedings of the 37th Conference on Uncertainty in Artificial Intelligence (UAI)}, 173 | year = {2021}, 174 | abstract = {Multi-source domain adaptation (DA) is more challenging than conventional DA because the knowledge is transferred from several source domains to a target domain. To this end, we propose in this paper a novel model for multi-source DA using the theory of optimal transport and imitation learning. More specifically, our approach consists of two cooperative agents: a teacher classifier and a student classifier. The teacher classifier is a combined expert that leverages knowledge of domain experts that can be theoretically guaranteed to handle perfectly source examples, while the student classifier acting on the target domain tries to imitate the teacher classifier acting on the source domains. Our rigorous theory developed based on optimal transport makes this cross-domain imitation possible and also helps to mitigate not only the data shift but also the label shift, which are inherently thorny issues in DA research. We conduct comprehensive experiments on real-world datasets to demonstrate the merit of our approach and its optimal transport based imitation learning viewpoint. Experimental results show that our proposed method achieves state-of-the-art performance on benchmark datasets for multi-source domain adaptation including Digits-five, Office-Caltech10, and Office-31 to the best of our knowledge.} 175 | } 176 | ``` 177 | 178 | ## E. References 179 | 180 | #### Baselines: 181 | 182 | - [1] H. Zhao, S. Zhang, G. Wu, J. M. F. Moura, J. P. Costeira, and G. J Gordon. Adversarial multiple source domain adaptation. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. CesaBianchi, and R. Garnett, editors, Advances in Neural Information Processing Systems 31, pages 8559-8570. Curran Associates, Inc., 2018 . 183 | - [2] R. Xu, Z. Chen, W. Zuo, J. Yan, and L. Lin. Deep cocktail network: Multi-source unsupervised domain adaptation with category shift. In 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 3964-3973, 2018. 184 | - [3] X. Peng, Q. Bai, X. Xia, Z. Huang, K. Saenko, and B. Wang. Moment matching for multi-source domain adaptation. In Proceedings of the IEEE International Conference on Computer Vision, pages 1406-1415, 2019. 185 | - [4] S. Zhao, G. Wang, S. Zhang, Y. Gu, Y. Li, Z. Song, P. Xu, R. Hu, H. Chai, and K. Keutzer. Multi-source distilling domain adaptation. In The Thirty-Fourth AAAI Conference on Artificial Intelligence, AAAI 2020, The Thirty-Second Innovative Applications of Artificial Intelligence Conference, IAAI 2020, The Tenth AAAI Symposium on Educational Advances in Artificial Intelligence, EAAI 2020, New York, NY, USA, February 7-12, 2020, pages 12975-12983. AAAI Press, 2020. 186 | - [5] H. Wang, M. Xu, B. Ni, and W. Zhang. Learning to combine: Knowledge aggregation for multisource domain adaptation. In Computer Vision - ECCV, 2020. 187 | 188 | #### GitHub repositories: 189 | 190 | - Folders *alexnet* and *resnet* are cloned from [Deniz Gurkaynak’s repository](https://github.com/dgurkaynak/tensorflow-cnn-finetune.git) 191 | - Some parts of our code (e.g., VAT, evaluation, …) are rewritten with modifications from [DIRT-T](https://github.com/RuiShu/dirt-t). 192 | -------------------------------------------------------------------------------- /model/most_AlexNet_train_feat.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Tuan Nguyen. 2 | # All rights reserved. 3 | 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import absolute_import 7 | 8 | import tensorflow as tf 9 | from tensorflow.contrib.framework import arg_scope 10 | from tensorflow.contrib.framework import add_arg_scope 11 | from tensorbayes.layers import dense, conv2d, batch_norm, instance_norm 12 | from tensorbayes.tfutils import softmax_cross_entropy_with_two_logits as softmax_x_entropy_two 13 | 14 | from keras import backend as K 15 | from keras.preprocessing.image import ImageDataGenerator 16 | 17 | from generic_utils import random_seed 18 | 19 | from layers import leaky_relu 20 | import os 21 | from generic_utils import model_dir 22 | import numpy as np 23 | import tensorbayes as tb 24 | from layers import batch_ema_acc 25 | from keras.utils.np_utils import to_categorical 26 | 27 | 28 | def build_block(input_layer, layout, info=1): 29 | x = input_layer 30 | for i in range(0, len(layout)): 31 | with tf.variable_scope('l{:d}'.format(i)): 32 | f, f_args, f_kwargs = layout[i] 33 | x = f(x, *f_args, **f_kwargs) 34 | if info > 1: 35 | print(x) 36 | return x 37 | 38 | 39 | @add_arg_scope 40 | def normalize_perturbation(d, scope=None): 41 | with tf.name_scope(scope, 'norm_pert'): 42 | output = tf.nn.l2_normalize(d, axis=np.arange(1, len(d.shape))) 43 | return output 44 | 45 | 46 | def build_encode_template( 47 | input_layer, training_phase, scope, encode_layout, 48 | reuse=None, internal_update=False, getter=None, inorm=True, cnn_size='large'): 49 | with tf.variable_scope(scope, reuse=reuse, custom_getter=getter): 50 | with arg_scope([leaky_relu], a=0.1), \ 51 | arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \ 52 | arg_scope([batch_norm], internal_update=internal_update): 53 | 54 | preprocess = instance_norm if inorm else tf.identity 55 | 56 | layout = encode_layout(preprocess=preprocess, training_phase=training_phase, cnn_size=cnn_size) 57 | output_layer = build_block(input_layer, layout) 58 | 59 | return output_layer 60 | 61 | 62 | def build_class_discriminator_template( 63 | input_layer, training_phase, scope, num_classes, class_discriminator_layout, 64 | reuse=None, internal_update=False, getter=None, cnn_size='large'): 65 | with tf.variable_scope(scope, reuse=reuse, custom_getter=getter): 66 | with arg_scope([leaky_relu], a=0.1), \ 67 | arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \ 68 | arg_scope([batch_norm], internal_update=internal_update): 69 | layout = class_discriminator_layout(num_classes=num_classes, global_pool=True, activation=None, 70 | cnn_size=cnn_size) 71 | output_layer = build_block(input_layer, layout) 72 | 73 | return output_layer 74 | 75 | 76 | def build_domain_discriminator_template(x, domain_layout, c=1, reuse=None, scope='domain_disc'): 77 | with tf.variable_scope(scope, reuse=reuse): 78 | with arg_scope([dense], activation=tf.nn.relu): 79 | layout = domain_layout(c=c) 80 | output_layer = build_block(x, layout) 81 | 82 | return output_layer 83 | 84 | 85 | def build_phi_network_template(x, domain_layout, c=1, reuse=None): 86 | with tf.variable_scope('phi_net', reuse=reuse): 87 | with arg_scope([dense], activation=tf.nn.relu): 88 | layout = domain_layout(c=c) 89 | output_layer = build_block(x, layout) 90 | 91 | return output_layer 92 | 93 | 94 | def get_default_config(): 95 | tf_config = tf.ConfigProto() 96 | tf_config.gpu_options.allow_growth = True 97 | tf_config.log_device_placement = False 98 | tf_config.allow_soft_placement = True 99 | return tf_config 100 | 101 | 102 | class MOST(): 103 | def __init__(self, 104 | model_name="MOST-results", 105 | learning_rate=0.001, 106 | batch_size=128, 107 | num_iters=80000, 108 | summary_freq=400, 109 | src_class_trade_off=1.0, 110 | src_domain_trade_off='1.0,1.0', 111 | src_vat_trade_off=1.0, 112 | trg_vat_troff=0.1, 113 | trg_ent_troff=0.1, 114 | ot_trade_off=0.1, 115 | domain_trade_off=0.1, 116 | mimic_trade_off=0.1, 117 | encode_layout=None, 118 | classify_layout=None, 119 | domain_layout=None, 120 | phi_layout=None, 121 | current_time='', 122 | inorm=True, 123 | theta=0.1, 124 | g_network_trade_off=1.0, 125 | mdaot_model_id='', 126 | only_save_final_model=True, 127 | cnn_size='large', 128 | sample_size=50, 129 | data_shift_troff=10.0, 130 | train_layers='fc8', 131 | **kwargs): 132 | self.model_name = model_name 133 | self.batch_size = batch_size 134 | self.learning_rate = learning_rate 135 | self.num_iters = num_iters 136 | self.summary_freq = summary_freq 137 | self.src_class_trade_off = src_class_trade_off 138 | self.src_domain_trade_off = [float(item) for item in src_domain_trade_off.split(',')] 139 | self.src_vat_trade_off = src_vat_trade_off 140 | self.trg_vat_troff = trg_vat_troff 141 | self.trg_ent_troff = trg_ent_troff 142 | self.ot_trade_off = ot_trade_off 143 | self.domain_trade_off = domain_trade_off 144 | self.mimic_trade_off = mimic_trade_off 145 | 146 | self.encode_layout = encode_layout 147 | self.classify_layout = classify_layout 148 | self.domain_layout = domain_layout 149 | self.phi_layout = phi_layout 150 | 151 | self.current_time = current_time 152 | self.inorm = inorm 153 | 154 | self.theta = theta 155 | self.g_network_trade_off = g_network_trade_off 156 | 157 | self.mdaot_model_id = mdaot_model_id 158 | self.only_save_final_model = only_save_final_model 159 | 160 | self.cnn_size = cnn_size 161 | self.sample_size = sample_size 162 | self.data_shift_troff = data_shift_troff 163 | self.train_layers = train_layers 164 | 165 | def _init(self, data_loader): 166 | np.random.seed(random_seed()) 167 | tf.set_random_seed(random_seed()) 168 | tf.reset_default_graph() 169 | 170 | self.tf_graph = tf.get_default_graph() 171 | self.tf_config = get_default_config() 172 | self.tf_session = tf.Session(config=self.tf_config, graph=self.tf_graph) 173 | 174 | self.data_loader = data_loader 175 | self.num_classes = self.data_loader.num_class 176 | self.batch_size_src = self.batch_size // self.data_loader.num_src_domain 177 | assert len(self.src_domain_trade_off) == self.data_loader.num_src_domain 178 | 179 | def _get_variables(self, list_scopes): 180 | variables = [] 181 | for scope_name in list_scopes: 182 | variables.append(tf.get_collection('trainable_variables', scope_name)) 183 | return variables 184 | 185 | def convert_one_hot(self, y): 186 | y_idx = y.reshape(-1).astype(int) if y is not None else None 187 | y = np.eye(self.num_classes)[y_idx] if y is not None else None 188 | return y 189 | 190 | def _get_scope(self, part_name, side_name, same_network=True): 191 | suffix = '' 192 | if not same_network: 193 | suffix = '/' + side_name 194 | return part_name + suffix 195 | 196 | def _get_teacher_scopes(self): 197 | return ['generator', 'classifier', 'domain_disc'] 198 | 199 | def _get_student_primary_scopes(self): 200 | return ['generator', 'c-trg'] 201 | 202 | def _get_student_secondary_scopes(self): 203 | return ['phi_net'] 204 | 205 | def _build_source_middle(self, x_src, is_reused): 206 | scope_name = self._get_scope('generator', 'src') 207 | if is_reused == 0: 208 | generator_model = build_encode_template(x_src, encode_layout=self.encode_layout, 209 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm, cnn_size=self.cnn_size) 210 | else: 211 | generator_model = build_encode_template(x_src, encode_layout=self.encode_layout, 212 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm, 213 | reuse=True, internal_update=True, 214 | cnn_size=self.cnn_size) 215 | return generator_model 216 | 217 | def _build_target_middle(self, x_trg, reuse=None): 218 | scope_name = 'generator' 219 | return build_encode_template( 220 | x_trg, encode_layout=self.encode_layout, 221 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm, 222 | reuse=reuse, internal_update=True, cnn_size=self.cnn_size 223 | ) 224 | 225 | def _build_classifier(self, x, num_classes, ema=None, is_teacher=False): 226 | g_teacher_scope = self._get_scope('generator', 'teacher', same_network=False) 227 | g_x = build_encode_template( 228 | x, encode_layout=self.encode_layout, 229 | scope=g_teacher_scope if is_teacher else 'generator', training_phase=False, inorm=self.inorm, 230 | reuse=False if is_teacher else True, getter=None if is_teacher else tb.tfutils.get_getter(ema), 231 | cnn_size=self.cnn_size 232 | ) 233 | 234 | h_teacher_scope = self._get_scope('c-trg', 'teacher', same_network=False) 235 | h_g_x = build_class_discriminator_template( 236 | g_x, training_phase=False, scope=h_teacher_scope if is_teacher else 'c-trg', num_classes=num_classes, 237 | reuse=False if is_teacher else True, class_discriminator_layout=self.classify_layout, 238 | getter=None if is_teacher else tb.tfutils.get_getter(ema), cnn_size=self.cnn_size 239 | ) 240 | return h_g_x 241 | 242 | def _build_domain_discriminator(self, x_mid, reuse=None, scope='domain_disc'): 243 | return build_domain_discriminator_template(x_mid, domain_layout=self.domain_layout, c=self.data_loader.num_src_domain, reuse=reuse, scope=scope) 244 | 245 | def _build_phi_network(self, x_mid, reuse=None): 246 | return build_phi_network_template(x_mid, domain_layout=self.phi_layout, c=1, reuse=reuse) 247 | 248 | def _build_class_src_discriminator(self, x_src, num_src_classes, i, reuse=None): 249 | classifier_model = build_class_discriminator_template( 250 | x_src, training_phase=self.is_training, scope='classifier/{}'.format(i), num_classes=num_src_classes, 251 | reuse=reuse, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size 252 | ) 253 | return classifier_model 254 | 255 | def _build_class_trg_discriminator(self, x_trg, num_trg_classes): 256 | return build_class_discriminator_template( 257 | x_trg, training_phase=self.is_training, scope='c-trg', num_classes=num_trg_classes, 258 | reuse=False, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size 259 | ) 260 | 261 | def perturb_image(self, x, p, num_classes, class_discriminator_layout, encode_layout, 262 | pert='vat', scope=None, radius=3.5, scope_classify=None, scope_encode=None, training_phase=None): 263 | with tf.name_scope(scope, 'perturb_image'): 264 | eps = 1e-6 * normalize_perturbation(tf.random_normal(shape=tf.shape(x))) 265 | 266 | # Predict on randomly perturbed image 267 | x_eps_mid = build_encode_template( 268 | x + eps, encode_layout=encode_layout, scope=scope_encode, training_phase=training_phase, reuse=True, 269 | inorm=self.inorm, cnn_size=self.cnn_size) 270 | x_eps_pred = build_class_discriminator_template( 271 | x_eps_mid, class_discriminator_layout=class_discriminator_layout, 272 | training_phase=training_phase, scope=scope_classify, reuse=True, num_classes=num_classes, 273 | cnn_size=self.cnn_size 274 | ) 275 | # eps_p = classifier(x + eps, phase=True, reuse=True) 276 | loss = softmax_x_entropy_two(labels=p, logits=x_eps_pred) 277 | 278 | # Based on perturbed image, get direction of greatest error 279 | eps_adv = tf.gradients(loss, [eps], aggregation_method=2)[0] 280 | 281 | # Use that direction as adversarial perturbation 282 | eps_adv = normalize_perturbation(eps_adv) 283 | x_adv = tf.stop_gradient(x + radius * eps_adv) 284 | 285 | return x_adv 286 | 287 | def vat_loss(self, x, p, num_classes, class_discriminator_layout, encode_layout, 288 | scope=None, scope_classify=None, scope_encode=None, training_phase=None): 289 | 290 | with tf.name_scope(scope, 'smoothing_loss'): 291 | x_adv = self.perturb_image( 292 | x, p, num_classes, class_discriminator_layout=class_discriminator_layout, encode_layout=encode_layout, 293 | scope_classify=scope_classify, scope_encode=scope_encode, training_phase=training_phase) 294 | 295 | x_adv_mid = build_encode_template( 296 | x_adv, encode_layout=encode_layout, scope=scope_encode, training_phase=training_phase, inorm=self.inorm, 297 | reuse=True, cnn_size=self.cnn_size) 298 | x_adv_pred = build_class_discriminator_template( 299 | x_adv_mid, training_phase=training_phase, scope=scope_classify, reuse=True, num_classes=num_classes, 300 | class_discriminator_layout=class_discriminator_layout, cnn_size=self.cnn_size 301 | ) 302 | # p_adv = classifier(x_adv, phase=True, reuse=True) 303 | loss = tf.reduce_mean(softmax_x_entropy_two(labels=tf.stop_gradient(p), logits=x_adv_pred)) 304 | 305 | return loss 306 | 307 | def _build_vat_loss(self, x, p, num_classes, scope=None, scope_classify=None, scope_encode=None): 308 | return self.vat_loss( 309 | x, p, num_classes, 310 | class_discriminator_layout=self.classify_layout, 311 | encode_layout=self.encode_layout, 312 | scope=scope, scope_classify=scope_classify, scope_encode=scope_encode, 313 | training_phase=self.is_training 314 | ) 315 | 316 | def _compute_cosine_similarity(self, x_trg_mid, x_src_mid_all): 317 | x_trg_mid_flatten = tf.layers.Flatten()(x_trg_mid) 318 | x_src_mid_all_flatten = tf.layers.Flatten()(x_src_mid_all) 319 | similarity = tf.reduce_sum(x_trg_mid_flatten[:, tf.newaxis] * x_src_mid_all_flatten, axis=-1) 320 | similarity /= tf.norm(x_trg_mid_flatten[:, tf.newaxis], axis=-1) * tf.norm(x_src_mid_all_flatten, axis=-1) 321 | distance = 1.0 - similarity 322 | return distance 323 | 324 | def _compute_data_shift_loss(self, x_src_mid, x_trg_mid): 325 | x_src_mid_flatten = tf.layers.Flatten()(x_src_mid) 326 | x_trg_mid_flatten = tf.layers.Flatten()(x_trg_mid) 327 | 328 | data_shift_loss = tf.norm(tf.subtract(x_src_mid_flatten, tf.expand_dims(x_trg_mid_flatten, 1)), axis=2) 329 | return data_shift_loss 330 | 331 | def _compute_teacher_hs(self, y_label_trg_output_each_h, y_d_trg_sofmax_output): 332 | y_label_trg_output_each_h = tf.transpose(tf.stack(y_label_trg_output_each_h), perm=[1, 0, 2]) 333 | y_d_trg_sofmax_output_multi_y = y_d_trg_sofmax_output 334 | y_d_trg_sofmax_output_multi_y = tf.expand_dims(y_d_trg_sofmax_output_multi_y, axis=-1) 335 | y_d_trg_sofmax_output_multi_y = tf.tile(y_d_trg_sofmax_output_multi_y, [1, 1, self.num_classes]) 336 | y_label_trg_output = y_d_trg_sofmax_output_multi_y * y_label_trg_output_each_h 337 | y_label_trg_output = tf.reduce_sum(y_label_trg_output, axis=1) 338 | return y_label_trg_output 339 | 340 | def get_distances(self, a, b, name='L2'): 341 | if name == 'L1': 342 | return tf.reduce_sum(tf.abs(tf.expand_dims(a, 0) - tf.expand_dims(b, 1)), axis=-1) 343 | elif name == 'L2': 344 | return tf.sqrt(tf.reduce_sum(tf.square(tf.expand_dims(a, 0) - tf.expand_dims(b, 1)), axis=-1)) 345 | elif name == 'CE': 346 | a_prob = tf.nn.softmax(a) 347 | b_prob = tf.nn.softmax(b) 348 | loss = -tf.reduce_sum(tf.multiply(tf.expand_dims(a_prob, 0), tf.log(tf.expand_dims(b_prob, 1) + 1e-12)), axis=-1) 349 | return loss 350 | 351 | def _build_model(self): 352 | self.x_src_lst = [] 353 | self.y_src_lst = [] 354 | for i in range(self.data_loader.num_src_domain): 355 | x_src = tf.placeholder(dtype=tf.float32, shape=(None, 8, 8, 64), name='x_src_{}_input'.format(i)) 356 | y_src = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes), 357 | name='y_src_{}_input'.format(i)) 358 | 359 | self.x_src_lst.append(x_src) 360 | self.y_src_lst.append(y_src) 361 | 362 | self.x_trg = tf.placeholder(dtype=tf.float32, shape=(None, 8, 8, 64), name='x_trg_input') 363 | self.y_trg = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes), 364 | name='y_trg_input') 365 | self.y_src_domain = tf.placeholder(dtype=tf.float32, shape=(None, self.data_loader.num_src_domain), 366 | name='y_src_domain_input') 367 | 368 | T = tb.utils.TensorDict(dict( 369 | x_tmp=tf.placeholder(dtype=tf.float32, shape=(None, 8, 8, 64)), 370 | y_tmp=tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes)) 371 | )) 372 | 373 | self.is_training = tf.placeholder(tf.bool, shape=(), name='is_training') 374 | 375 | self.x_src_mid_lst = [] 376 | for i in range(self.data_loader.num_src_domain): 377 | x_src_mid = self._build_source_middle(self.x_src_lst[i], is_reused=i) 378 | self.x_src_mid_lst.append(x_src_mid) 379 | self.x_trg_mid = self._build_target_middle(self.x_trg, reuse=True) 380 | 381 | # 382 | self.y_src_logit_lst = [] 383 | for i in range(self.data_loader.num_src_domain): 384 | y_src_logit = self._build_class_src_discriminator(self.x_src_mid_lst[i], self.num_classes, i) 385 | self.y_src_logit_lst.append(y_src_logit) 386 | self.y_trg_logit = self._build_class_trg_discriminator(self.x_trg_mid, 387 | self.num_classes) 388 | # 389 | 390 | # 391 | self.src_loss_class_lst = [] 392 | self.src_loss_class_sum = tf.constant(0.0) 393 | for i in range(self.data_loader.num_src_domain): 394 | src_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2( 395 | logits=self.y_src_logit_lst[i], labels=self.y_src_lst[i]) 396 | src_loss_class = tf.reduce_mean(src_loss_class_detail) 397 | self.src_loss_class_lst.append(self.src_domain_trade_off[i]*src_loss_class) 398 | self.src_loss_class_sum += self.src_domain_trade_off[i]*src_loss_class 399 | # 400 | 401 | # 402 | self.x_src_mid_all = tf.concat(self.x_src_mid_lst, axis=0) 403 | self.y_src_discriminator_logit = self._build_domain_discriminator(self.x_src_mid_all) 404 | 405 | self.src_loss_discriminator_detail = tf.nn.softmax_cross_entropy_with_logits_v2( 406 | logits=self.y_src_discriminator_logit, labels=self.y_src_domain) 407 | self.src_loss_discriminator = tf.reduce_mean(self.src_loss_discriminator_detail) 408 | # 409 | 410 | # 411 | self.y_src_teacher_all = [] 412 | for i, bs in zip(range(self.data_loader.num_src_domain), 413 | range(0, self.batch_size_src * self.data_loader.num_src_domain, self.batch_size_src)): 414 | y_src_logit_each_h_lst = [] 415 | for j in range(self.data_loader.num_src_domain): 416 | y_src_logit_each_h = self._build_class_src_discriminator(self.x_src_mid_lst[i], self.num_classes, 417 | j, reuse=True) 418 | y_src_logit_each_h_lst.append(y_src_logit_each_h) 419 | y_src_logit_each_h_lst = tf.nn.softmax(tf.convert_to_tensor(y_src_logit_each_h_lst)) 420 | 421 | y_src_discriminator_prob = tf.nn.softmax(tf.gather(self.y_src_discriminator_logit, 422 | tf.range(bs, bs + self.batch_size_src, 423 | dtype=tf.int32), axis=0)) 424 | y_src_teacher = self._compute_teacher_hs(y_src_logit_each_h_lst, y_src_discriminator_prob) 425 | self.y_src_teacher_all.append(y_src_teacher) 426 | self.y_src_teacher_all = tf.concat(self.y_src_teacher_all, axis=0) 427 | # 428 | 429 | # 430 | y_trg_logit_each_h_lst = [] 431 | for j in range(self.data_loader.num_src_domain): 432 | y_trg_logit_each_h = self._build_class_src_discriminator(self.x_trg_mid, self.num_classes, 433 | j, reuse=True) 434 | y_trg_logit_each_h_lst.append(y_trg_logit_each_h) 435 | y_trg_logit_each_h_lst = tf.nn.softmax(tf.convert_to_tensor(y_trg_logit_each_h_lst)) 436 | self.y_trg_src_domains_logit = self._build_domain_discriminator(self.x_trg_mid, reuse=True) 437 | y_trg_discriminator_prob = tf.nn.softmax(self.y_trg_src_domains_logit) 438 | self.y_trg_teacher = self._compute_teacher_hs(y_trg_logit_each_h_lst, y_trg_discriminator_prob) 439 | # 440 | 441 | # 442 | self.ht_g_xs = build_class_discriminator_template( 443 | self.x_src_mid_all, training_phase=self.is_training, scope='c-trg', num_classes=self.num_classes, 444 | reuse=True, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size 445 | ) 446 | self.mimic_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2( 447 | logits=self.ht_g_xs, labels=self.y_src_teacher_all)) + \ 448 | tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2( 449 | logits=self.y_trg_logit, labels=self.y_trg_teacher)) 450 | # 451 | 452 | # 453 | self.data_shift_loss = self._compute_cosine_similarity(self.x_trg_mid, self.x_src_mid_all) 454 | self.label_shift_loss = self.get_distances(self.y_trg_logit, self.ht_g_xs, 'CE') 455 | self.data_label_shift_loss = self.data_shift_troff*self.data_shift_loss + self.label_shift_loss 456 | self.g_network = tf.reshape(self._build_phi_network(self.x_trg_mid), [-1]) 457 | self.exp_term = (- self.data_label_shift_loss + self.g_network) / self.theta 458 | self.g_network_loss = tf.reduce_mean(self.g_network) 459 | self.OT_loss = tf.reduce_mean( 460 | - self.theta * \ 461 | ( 462 | tf.log(1.0 / self.batch_size) + 463 | tf.reduce_logsumexp(self.exp_term, axis=1) 464 | ) 465 | ) + self.g_network_trade_off * self.g_network_loss 466 | # 467 | 468 | # 469 | self.trg_loss_vat = self._build_vat_loss( 470 | self.x_trg, self.y_trg_logit, self.num_classes, 471 | scope_encode=self._get_scope('generator', 'trg'), scope_classify='c-trg' 472 | ) 473 | # 474 | 475 | # 476 | self.trg_loss_cond_entropy = tf.reduce_mean(softmax_x_entropy_two(labels=self.y_trg_logit, 477 | logits=self.y_trg_logit)) 478 | # 479 | 480 | # 481 | self.src_accuracy_lst = [] 482 | for i in range(self.data_loader.num_src_domain): 483 | y_src_pred = tf.argmax(self.y_src_logit_lst[i], 1, output_type=tf.int32) 484 | y_src_sparse = tf.argmax(self.y_src_lst[i], 1, output_type=tf.int32) 485 | src_accuracy = tf.reduce_mean(tf.cast(tf.equal(y_src_sparse, y_src_pred), 'float32')) 486 | self.src_accuracy_lst.append(src_accuracy) 487 | # compute acc for target domain 488 | self.y_trg_pred = tf.argmax(self.y_trg_logit, 1, output_type=tf.int32) 489 | self.y_trg_sparse = tf.argmax(self.y_trg, 1, output_type=tf.int32) 490 | self.trg_accuracy = tf.reduce_mean(tf.cast(tf.equal(self.y_trg_sparse, self.y_trg_pred), 'float32')) 491 | # compute acc for src domain disc 492 | self.y_src_domain_pred = tf.argmax(self.y_src_discriminator_logit, 1, output_type=tf.int32) 493 | self.y_src_domain_sparse = tf.argmax(self.y_src_domain, 1, output_type=tf.int32) 494 | self.src_domain_acc = tf.reduce_mean(tf.cast(tf.equal(self.y_src_domain_sparse, self.y_src_domain_pred), 'float32')) 495 | # 496 | 497 | # 498 | lst_losses = [ 499 | (self.src_class_trade_off, self.src_loss_class_sum), 500 | (self.ot_trade_off, self.OT_loss), 501 | (self.domain_trade_off, self.src_loss_discriminator), 502 | (self.trg_vat_troff, self.trg_loss_vat), 503 | (self.trg_ent_troff, self.trg_loss_cond_entropy), 504 | (self.mimic_trade_off, self.mimic_loss) 505 | ] 506 | self.total_loss = tf.constant(0.0) 507 | for trade_off, loss in lst_losses: 508 | self.total_loss += trade_off * loss 509 | # 510 | 511 | # 512 | primary_student_variables = self._get_variables(self._get_student_primary_scopes()) 513 | ema = tf.train.ExponentialMovingAverage(decay=0.998) 514 | var_list_for_ema = primary_student_variables[0] + primary_student_variables[1] 515 | ema_op = ema.apply(var_list=var_list_for_ema) 516 | self.ema_p = self._build_classifier(T.x_tmp, self.num_classes, ema) 517 | 518 | self.batch_ema_acc = batch_ema_acc(T.y_tmp, self.ema_p) 519 | self.fn_batch_ema_acc = tb.function(self.tf_session, [T.x_tmp, T.y_tmp], self.batch_ema_acc) 520 | 521 | teacher_variables = self._get_variables(self._get_teacher_scopes()) 522 | self.train_student_main = \ 523 | tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.total_loss, 524 | var_list=teacher_variables + [primary_student_variables[1]]) 525 | self.primary_train_student_op = tf.group(self.train_student_main, ema_op) 526 | 527 | secondary_variables = self._get_variables(self._get_student_secondary_scopes()) 528 | 529 | self.secondary_train_student_op = \ 530 | tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(-self.OT_loss, 531 | var_list=secondary_variables) 532 | # 533 | 534 | # 535 | tf.summary.scalar('loss/total_loss', self.total_loss) 536 | tf.summary.scalar('loss/W_distance', self.OT_loss) 537 | tf.summary.scalar('loss/src_loss_discriminator', self.src_loss_discriminator) 538 | tf.summary.scalar('loss/data_shift_loss', tf.reduce_mean(self.data_shift_loss)) 539 | tf.summary.scalar('loss/label_shift_loss', tf.reduce_mean(self.label_shift_loss)) 540 | tf.summary.scalar('loss/data_label_shift_loss', tf.reduce_mean(self.data_label_shift_loss)) 541 | tf.summary.scalar('loss/exp_term', tf.reduce_mean(self.exp_term)) 542 | tf.summary.histogram('loss/g_batch', self.g_network) 543 | tf.summary.scalar('loss/g_network_loss', self.g_network_loss) 544 | 545 | for i in range(self.data_loader.num_src_domain): 546 | tf.summary.scalar('loss/src_loss_class_{}'.format(i), self.src_loss_class_lst[i]) 547 | tf.summary.scalar('acc/src_acc_{}'.format(i), self.src_accuracy_lst[i]) 548 | tf.summary.scalar('acc/src_domain_acc', self.src_domain_acc) 549 | tf.summary.scalar('acc/trg_acc', self.trg_accuracy) 550 | 551 | tf.summary.scalar('hyperparameters/learning_rate', self.learning_rate) 552 | tf.summary.scalar('hyperparameters/src_class_trade_off', self.src_class_trade_off) 553 | tf.summary.scalar('hyperparameters/g_network_trade_off', self.g_network_trade_off) 554 | tf.summary.scalar('hyperparameters/domain_trade_off', self.domain_trade_off) 555 | tf.summary.scalar('hyperparameters/src_vat_trade_off', self.src_vat_trade_off) 556 | tf.summary.scalar('hyperparameters/trg_vat_troff', self.trg_vat_troff) 557 | tf.summary.scalar('hyperparameters/trg_ent_troff', self.trg_ent_troff) 558 | self.tf_merged_summaries = tf.summary.merge_all() 559 | # 560 | 561 | def _fit_loop(self): 562 | print('Start training MOST at', os.path.abspath(__file__)) 563 | print('============ LOG-ID: %s ============' % self.current_time) 564 | 565 | num_src_samples_lst = [] 566 | for k in range(self.data_loader.num_src_domain): 567 | num_src_samples = self.data_loader.src_train[k][2].shape[0] 568 | num_src_samples_lst.append(num_src_samples) 569 | 570 | num_trg_samples = self.data_loader.trg_train[0][1].shape[0] 571 | src_batchsize = self.batch_size // self.data_loader.num_src_domain 572 | 573 | self.tf_session.run(tf.global_variables_initializer()) 574 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=101) 575 | self.log_path = os.path.join(model_dir(), self.model_name, "logs", 576 | "{}".format(self.current_time)) 577 | self.tf_summary_writer = tf.summary.FileWriter(self.log_path, self.tf_session.graph) 578 | 579 | self.checkpoint_path = os.path.join(model_dir(), self.model_name, "saved-model", "{}".format(self.mdaot_model_id)) 580 | check_point = tf.train.get_checkpoint_state(self.checkpoint_path) 581 | 582 | if check_point and tf.train.checkpoint_exists(check_point.model_checkpoint_path): 583 | print("Load model parameters from %s\n" % check_point.model_checkpoint_path) 584 | saver.restore(self.tf_session, check_point.model_checkpoint_path) 585 | 586 | feed_y_src_domain = to_categorical(np.repeat(np.arange(self.data_loader.num_src_domain), 587 | repeats=self.batch_size//self.data_loader.num_src_domain, axis=0)) 588 | 589 | for it in range(self.num_iters): 590 | idx_trg_samples = np.random.permutation(num_trg_samples)[:self.batch_size] 591 | feed_data = dict() 592 | for k in range(self.data_loader.num_src_domain): 593 | idx_src_samples = np.random.permutation(num_src_samples_lst[k])[:src_batchsize] 594 | feed_data[self.x_src_lst[k]] = self.data_loader.src_train[k][1][idx_src_samples, :] 595 | feed_data[self.y_src_lst[k]] = self.data_loader.src_train[k][2][idx_src_samples] 596 | 597 | feed_data[self.x_trg] = self.data_loader.trg_train[0][1][idx_trg_samples, :] 598 | feed_data[self.y_trg] = self.data_loader.trg_train[0][2][idx_trg_samples] 599 | 600 | feed_data[self.y_src_domain] = feed_y_src_domain 601 | feed_data[self.is_training] = True 602 | 603 | for i in range(0, 5): 604 | g_idx_trg_samples = np.random.permutation(num_trg_samples)[:self.batch_size] 605 | g_feed_data = dict() 606 | for k in range(self.data_loader.num_src_domain): 607 | g_idx_src_samples = np.random.permutation(num_src_samples_lst[k])[:src_batchsize] 608 | g_feed_data[self.x_src_lst[k]] = self.data_loader.src_train[k][1][g_idx_src_samples, :] 609 | g_feed_data[self.y_src_lst[k]] = self.data_loader.src_train[k][2][g_idx_src_samples] 610 | 611 | g_feed_data[self.x_trg] = self.data_loader.trg_train[0][1][g_idx_trg_samples, :] 612 | g_feed_data[self.y_trg] = self.data_loader.trg_train[0][2][g_idx_trg_samples] 613 | g_feed_data[self.is_training] = True 614 | 615 | _, W_dist = \ 616 | self.tf_session.run( 617 | [self.secondary_train_student_op, self.OT_loss], 618 | feed_dict=g_feed_data 619 | ) 620 | _, total_loss, src_loss_class_sum, src_loss_class_lst, src_loss_discriminator, src_acc_lst, trg_acc, src_domain_acc, mimic_loss = \ 621 | self.tf_session.run( 622 | [self.primary_train_student_op, self.total_loss, self.src_loss_class_sum, self.src_loss_class_lst, self.src_loss_discriminator, 623 | self.src_accuracy_lst, self.trg_accuracy, self.src_domain_acc, self.mimic_loss], 624 | feed_dict=feed_data 625 | ) 626 | 627 | if it == 0 or (it + 1) % self.summary_freq == 0: 628 | print( 629 | "iter %d/%d total_loss %.3f; src_loss_class_sum %.3f; W_dist %.3f;\n src_loss_discriminator %.3f, pseudo_lbl_loss %.3f" % ( 630 | it + 1, self.num_iters, total_loss, src_loss_class_sum, W_dist, 631 | src_loss_discriminator, mimic_loss)) 632 | for k in range(self.data_loader.num_src_domain): 633 | print('src_loss_class_{}: {:.3f} acc {:.2f}'.format(k, src_loss_class_lst[k], src_acc_lst[k]*100)) 634 | print("src_domain_disc_acc: %.2f, trg_acc: %.2f;" % (src_domain_acc*100, trg_acc*100)) 635 | 636 | summary = self.tf_session.run(self.tf_merged_summaries, feed_dict=feed_data) 637 | self.tf_summary_writer.add_summary(summary, it + 1) 638 | self.tf_summary_writer.flush() 639 | 640 | if it == 0 or (it + 1) % self.summary_freq == 0: 641 | if not self.only_save_final_model: 642 | self.save_trained_model(saver, it + 1) 643 | elif it + 1 == self.num_iters: 644 | self.save_trained_model(saver, it + 1) 645 | if (it + 1) % (self.num_iters // 50) == 0: 646 | self.save_value(step=it + 1) 647 | 648 | def save_trained_model(self, saver, step): 649 | checkpoint_path = os.path.join(model_dir(), self.model_name, "saved-model", 650 | "{}".format(self.current_time)) 651 | checkpoint_path = os.path.join(checkpoint_path, "mdaot_" + self.current_time + ".ckpt") 652 | 653 | directory = os.path.dirname(checkpoint_path) 654 | if not os.path.exists(directory): 655 | os.makedirs(directory) 656 | saver.save(self.tf_session, checkpoint_path, global_step=step) 657 | 658 | def save_value(self, step): 659 | ema_acc, summary = self.compute_value(x_full=self.data_loader.trg_test[0][1], 660 | y=self.data_loader.trg_test[0][2], labeler=None) 661 | 662 | self.tf_summary_writer.add_summary(summary, step) 663 | self.tf_summary_writer.flush() 664 | 665 | print_list = ['ema_acc', round(ema_acc * 100, 2)] 666 | print(print_list) 667 | 668 | def compute_value(self, x_full, y, labeler, full=True): 669 | with tb.nputils.FixedSeed(0): 670 | shuffle = np.random.permutation(len(x_full)) 671 | 672 | xs = x_full[shuffle] 673 | ys = y[shuffle] if y is not None else None 674 | 675 | if not full: 676 | xs = xs[:1000] 677 | ys = ys[:1000] if ys is not None else None 678 | 679 | n = len(xs) 680 | bs = 200 681 | ema_acc_full = np.ones(n, dtype=float) 682 | 683 | for i in range(0, n, bs): 684 | x = xs[i:i + bs] 685 | y = ys[i:i + bs] if ys is not None else labeler(x) 686 | ema_acc_batch = self.fn_batch_ema_acc(x, y) 687 | ema_acc_full[i:i + bs] = ema_acc_batch 688 | 689 | ema_acc = np.mean(ema_acc_full) 690 | summary = tf.Summary.Value(tag='trg_test/ema_acc', simple_value=ema_acc) 691 | summary = tf.Summary(value=[summary]) 692 | return ema_acc, summary 693 | -------------------------------------------------------------------------------- /model/most_digits.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Tuan Nguyen. 2 | # All rights reserved. 3 | 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import absolute_import 7 | 8 | import tensorflow as tf 9 | from tensorflow.contrib.framework import arg_scope 10 | from tensorflow.contrib.framework import add_arg_scope 11 | from tensorbayes.layers import dense, conv2d, batch_norm, instance_norm 12 | from tensorbayes.tfutils import softmax_cross_entropy_with_two_logits as softmax_x_entropy_two 13 | from keras import backend as K 14 | from generic_utils import random_seed 15 | from layers import leaky_relu 16 | import os 17 | from generic_utils import model_dir 18 | import numpy as np 19 | import tensorbayes as tb 20 | from layers import batch_ema_acc 21 | from keras.utils.np_utils import to_categorical 22 | 23 | 24 | def build_block(input_layer, layout, info=1): 25 | x = input_layer 26 | for i in range(0, len(layout)): 27 | with tf.variable_scope('l{:d}'.format(i)): 28 | f, f_args, f_kwargs = layout[i] 29 | x = f(x, *f_args, **f_kwargs) 30 | if info > 1: 31 | print(x) 32 | return x 33 | 34 | 35 | @add_arg_scope 36 | def normalize_perturbation(d, scope=None): 37 | with tf.name_scope(scope, 'norm_pert'): 38 | output = tf.nn.l2_normalize(d, axis=np.arange(1, len(d.shape))) 39 | return output 40 | 41 | 42 | def build_encode_template( 43 | input_layer, training_phase, scope, encode_layout, 44 | reuse=None, internal_update=False, getter=None, inorm=True, cnn_size='large'): 45 | with tf.variable_scope(scope, reuse=reuse, custom_getter=getter): 46 | with arg_scope([leaky_relu], a=0.1), \ 47 | arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \ 48 | arg_scope([batch_norm], internal_update=internal_update): 49 | preprocess = instance_norm if inorm else tf.identity 50 | 51 | layout = encode_layout(preprocess=preprocess, training_phase=training_phase, cnn_size=cnn_size) 52 | output_layer = build_block(input_layer, layout) 53 | 54 | return output_layer 55 | 56 | 57 | def build_class_discriminator_template( 58 | input_layer, training_phase, scope, num_classes, class_discriminator_layout, 59 | reuse=None, internal_update=False, getter=None, cnn_size='large'): 60 | with tf.variable_scope(scope, reuse=reuse, custom_getter=getter): 61 | with arg_scope([leaky_relu], a=0.1), \ 62 | arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \ 63 | arg_scope([batch_norm], internal_update=internal_update): 64 | layout = class_discriminator_layout(num_classes=num_classes, global_pool=True, activation=None, 65 | cnn_size=cnn_size) 66 | output_layer = build_block(input_layer, layout) 67 | 68 | return output_layer 69 | 70 | 71 | def build_domain_discriminator_template(x, domain_layout, c=1, reuse=None, scope='domain_disc'): 72 | with tf.variable_scope(scope, reuse=reuse): 73 | with arg_scope([dense], activation=tf.nn.relu): 74 | layout = domain_layout(c=c) 75 | output_layer = build_block(x, layout) 76 | 77 | return output_layer 78 | 79 | 80 | def build_phi_network_template(x, domain_layout, c=1, reuse=None): 81 | with tf.variable_scope('phi_net', reuse=reuse): 82 | with arg_scope([dense], activation=tf.nn.relu): 83 | layout = domain_layout(c=c) 84 | output_layer = build_block(x, layout) 85 | 86 | return output_layer 87 | 88 | 89 | def get_default_config(): 90 | tf_config = tf.ConfigProto() 91 | tf_config.gpu_options.allow_growth = True 92 | tf_config.log_device_placement = False 93 | tf_config.allow_soft_placement = True 94 | return tf_config 95 | 96 | 97 | class MOST(): 98 | def __init__(self, 99 | model_name="MOST-results", 100 | learning_rate=0.001, 101 | batch_size=128, 102 | num_iters=80000, 103 | phase1_iters=20000, 104 | summary_freq=400, 105 | src_class_trade_off=1.0, 106 | src_domain_trade_off='1.0,1.0', 107 | trg_vat_troff=0.1, 108 | trg_ent_troff=0.1, 109 | ot_trade_off=0.1, 110 | domain_trade_off=0.1, 111 | mimic_trade_off=0.1, 112 | encode_layout=None, 113 | classify_layout=None, 114 | domain_layout=None, 115 | phi_layout=None, 116 | current_time='', 117 | inorm=True, 118 | theta=0.1, 119 | g_network_trade_off=1.0, 120 | mdaot_model_id='', 121 | only_save_final_model=True, 122 | cnn_size='small', 123 | sample_size=50, 124 | data_shift_troff=10.0, 125 | lbl_shift_troff=1.0, 126 | **kwargs): 127 | self.model_name = model_name 128 | self.batch_size = batch_size 129 | self.learning_rate = learning_rate 130 | self.num_iters = num_iters 131 | self.phase1_iters = phase1_iters 132 | self.summary_freq = summary_freq 133 | self.src_class_trade_off = src_class_trade_off 134 | self.src_domain_trade_off = [float(item) for item in src_domain_trade_off.split(',')] 135 | self.trg_vat_troff = trg_vat_troff 136 | self.trg_ent_troff = trg_ent_troff 137 | self.ot_trade_off = ot_trade_off 138 | self.domain_trade_off = domain_trade_off 139 | self.mimic_trade_off = mimic_trade_off 140 | self.encode_layout = encode_layout 141 | self.classify_layout = classify_layout 142 | self.domain_layout = domain_layout 143 | self.phi_layout = phi_layout 144 | self.current_time = current_time 145 | self.inorm = inorm 146 | self.theta = theta 147 | self.g_network_trade_off = g_network_trade_off 148 | self.mdaot_model_id = mdaot_model_id 149 | self.only_save_final_model = only_save_final_model 150 | self.cnn_size = cnn_size 151 | self.sample_size = sample_size 152 | self.data_shift_troff = data_shift_troff 153 | self.lbl_shift_troff = lbl_shift_troff 154 | 155 | def _init(self, data_loader): 156 | np.random.seed(random_seed()) 157 | tf.set_random_seed(random_seed()) 158 | tf.reset_default_graph() 159 | 160 | self.tf_graph = tf.get_default_graph() 161 | self.tf_config = get_default_config() 162 | self.tf_session = tf.Session(config=self.tf_config, graph=self.tf_graph) 163 | 164 | self.data_loader = data_loader 165 | self.num_classes = self.data_loader.num_class 166 | self.batch_size_src = self.sample_size * self.num_classes 167 | 168 | assert len(self.src_domain_trade_off) == self.data_loader.num_src_domain 169 | assert self.sample_size * self.num_classes * self.data_loader.num_src_domain == self.batch_size 170 | 171 | def _get_variables(self, list_scopes): 172 | variables = [] 173 | for scope_name in list_scopes: 174 | variables.append(tf.get_collection('trainable_variables', scope_name)) 175 | return variables 176 | 177 | def _get_scope(self, part_name, side_name, same_network=True): 178 | suffix = '' 179 | if not same_network: 180 | suffix = '/' + side_name 181 | return part_name + suffix 182 | 183 | def _get_teacher_scopes(self): 184 | return ['generator', 'classifier', 'domain_disc'] 185 | 186 | def _get_student_primary_scopes(self): 187 | return ['generator', 'c-trg'] 188 | 189 | def _get_student_secondary_scopes(self): 190 | return ['phi_net'] 191 | 192 | def _build_source_middle(self, x_src, is_reused): 193 | scope_name = self._get_scope('generator', 'src') 194 | if is_reused == 0: 195 | generator_model = build_encode_template(x_src, encode_layout=self.encode_layout, 196 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm, 197 | cnn_size=self.cnn_size) 198 | else: 199 | generator_model = build_encode_template(x_src, encode_layout=self.encode_layout, 200 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm, 201 | reuse=True, internal_update=True, 202 | cnn_size=self.cnn_size) 203 | return generator_model 204 | 205 | def _build_target_middle(self, x_trg, reuse=None): 206 | scope_name = 'generator' 207 | return build_encode_template( 208 | x_trg, encode_layout=self.encode_layout, 209 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm, 210 | reuse=reuse, internal_update=True, cnn_size=self.cnn_size 211 | ) 212 | 213 | def _build_classifier(self, x, num_classes, ema=None, is_teacher=False): 214 | g_teacher_scope = self._get_scope('generator', 'teacher', same_network=False) 215 | g_x = build_encode_template( 216 | x, encode_layout=self.encode_layout, 217 | scope=g_teacher_scope if is_teacher else 'generator', training_phase=False, inorm=self.inorm, 218 | reuse=False if is_teacher else True, getter=None if is_teacher else tb.tfutils.get_getter(ema), 219 | cnn_size=self.cnn_size 220 | ) 221 | 222 | h_teacher_scope = self._get_scope('c-trg', 'teacher', same_network=False) 223 | h_g_x = build_class_discriminator_template( 224 | g_x, training_phase=False, scope=h_teacher_scope if is_teacher else 'c-trg', num_classes=num_classes, 225 | reuse=False if is_teacher else True, class_discriminator_layout=self.classify_layout, 226 | getter=None if is_teacher else tb.tfutils.get_getter(ema), cnn_size=self.cnn_size 227 | ) 228 | return h_g_x 229 | 230 | def _build_domain_discriminator(self, x_mid, reuse=None, scope='domain_disc'): 231 | return build_domain_discriminator_template(x_mid, domain_layout=self.domain_layout, 232 | c=self.data_loader.num_src_domain, reuse=reuse, scope=scope) 233 | 234 | def _build_phi_network(self, x_mid, reuse=None): 235 | return build_phi_network_template(x_mid, domain_layout=self.phi_layout, c=1, reuse=reuse) 236 | 237 | def _build_class_src_discriminator(self, x_src, num_src_classes, i, reuse=None): 238 | classifier_model = build_class_discriminator_template( 239 | x_src, training_phase=self.is_training, scope='classifier/{}'.format(i), num_classes=num_src_classes, 240 | reuse=reuse, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size 241 | ) 242 | return classifier_model 243 | 244 | def _build_class_trg_discriminator(self, x_trg, num_trg_classes): 245 | return build_class_discriminator_template( 246 | x_trg, training_phase=self.is_training, scope='c-trg', num_classes=num_trg_classes, 247 | reuse=False, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size 248 | ) 249 | 250 | def perturb_image(self, x, p, num_classes, class_discriminator_layout, encode_layout, 251 | pert='vat', scope=None, radius=3.5, scope_classify=None, scope_encode=None, training_phase=None): 252 | with tf.name_scope(scope, 'perturb_image'): 253 | eps = 1e-6 * normalize_perturbation(tf.random_normal(shape=tf.shape(x))) 254 | 255 | # Predict on randomly perturbed image 256 | x_eps_mid = build_encode_template( 257 | x + eps, encode_layout=encode_layout, scope=scope_encode, training_phase=training_phase, reuse=True, 258 | inorm=self.inorm, cnn_size=self.cnn_size) 259 | x_eps_pred = build_class_discriminator_template( 260 | x_eps_mid, class_discriminator_layout=class_discriminator_layout, 261 | training_phase=training_phase, scope=scope_classify, reuse=True, num_classes=num_classes, 262 | cnn_size=self.cnn_size 263 | ) 264 | # eps_p = classifier(x + eps, phase=True, reuse=True) 265 | loss = softmax_x_entropy_two(labels=p, logits=x_eps_pred) 266 | 267 | # Based on perturbed image, get direction of greatest error 268 | eps_adv = tf.gradients(loss, [eps], aggregation_method=2)[0] 269 | 270 | # Use that direction as adversarial perturbation 271 | eps_adv = normalize_perturbation(eps_adv) 272 | x_adv = tf.stop_gradient(x + radius * eps_adv) 273 | return x_adv 274 | 275 | def vat_loss(self, x, p, num_classes, class_discriminator_layout, encode_layout, 276 | scope=None, scope_classify=None, scope_encode=None, training_phase=None): 277 | with tf.name_scope(scope, 'smoothing_loss'): 278 | x_adv = self.perturb_image( 279 | x, p, num_classes, class_discriminator_layout=class_discriminator_layout, encode_layout=encode_layout, 280 | scope_classify=scope_classify, scope_encode=scope_encode, training_phase=training_phase) 281 | 282 | x_adv_mid = build_encode_template( 283 | x_adv, encode_layout=encode_layout, scope=scope_encode, training_phase=training_phase, inorm=self.inorm, 284 | reuse=True, cnn_size=self.cnn_size) 285 | x_adv_pred = build_class_discriminator_template( 286 | x_adv_mid, training_phase=training_phase, scope=scope_classify, reuse=True, num_classes=num_classes, 287 | class_discriminator_layout=class_discriminator_layout, cnn_size=self.cnn_size 288 | ) 289 | # p_adv = classifier(x_adv, phase=True, reuse=True) 290 | loss = tf.reduce_mean(softmax_x_entropy_two(labels=tf.stop_gradient(p), logits=x_adv_pred)) 291 | 292 | return loss 293 | 294 | def _build_vat_loss(self, x, p, num_classes, scope=None, scope_classify=None, scope_encode=None): 295 | return self.vat_loss( 296 | x, p, num_classes, 297 | class_discriminator_layout=self.classify_layout, 298 | encode_layout=self.encode_layout, 299 | scope=scope, scope_classify=scope_classify, scope_encode=scope_encode, 300 | training_phase=self.is_training 301 | ) 302 | 303 | def _compute_cosine_similarity(self, x_trg_mid, x_src_mid_all): 304 | x_trg_mid_flatten = tf.layers.Flatten()(x_trg_mid) 305 | x_src_mid_all_flatten = tf.layers.Flatten()(x_src_mid_all) 306 | similarity = tf.reduce_sum(x_trg_mid_flatten[:, tf.newaxis] * x_src_mid_all_flatten, axis=-1) 307 | similarity /= tf.norm(x_trg_mid_flatten[:, tf.newaxis], axis=-1) * tf.norm(x_src_mid_all_flatten, axis=-1) 308 | distance = 1.0 - similarity 309 | return distance 310 | 311 | def _compute_label_shift_loss(self, y_trg_logit, y_src_teacher): 312 | y_trg_logit_rep = K.repeat_elements(tf.expand_dims(y_trg_logit, axis=0), rep=self.batch_size, axis=0) 313 | y_trg_logit_rep = tf.reshape(y_trg_logit_rep, [-1, y_trg_logit_rep.get_shape()[-1]]) 314 | y_src_teacher_rep = K.repeat_elements(y_src_teacher, rep=self.batch_size, axis=0) 315 | label_shift_loss = softmax_x_entropy_two(labels=y_src_teacher_rep, 316 | logits=y_trg_logit_rep) 317 | 318 | label_shift_loss = tf.reshape(label_shift_loss, [self.batch_size, self.batch_size]) 319 | return label_shift_loss 320 | 321 | def _compute_teacher_hs(self, y_label_trg_output_each_h, y_d_trg_sofmax_output): 322 | y_label_trg_output_each_h = tf.transpose(tf.stack(y_label_trg_output_each_h), perm=[1, 0, 2]) 323 | y_d_trg_sofmax_output_multi_y = y_d_trg_sofmax_output 324 | y_d_trg_sofmax_output_multi_y = tf.expand_dims(y_d_trg_sofmax_output_multi_y, axis=-1) 325 | y_d_trg_sofmax_output_multi_y = tf.tile(y_d_trg_sofmax_output_multi_y, [1, 1, self.num_classes]) 326 | y_label_trg_output = y_d_trg_sofmax_output_multi_y * y_label_trg_output_each_h 327 | y_label_trg_output = tf.reduce_sum(y_label_trg_output, axis=1) 328 | return y_label_trg_output 329 | 330 | def _build_model(self): 331 | self.x_src_lst = [] 332 | self.y_src_lst = [] 333 | for i in range(self.data_loader.num_src_domain): 334 | x_src = tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_src, 335 | name='x_src_{}_input'.format(i)) 336 | y_src = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes), 337 | name='y_src_{}_input'.format(i)) 338 | self.x_src_lst.append(x_src) 339 | self.y_src_lst.append(y_src) 340 | 341 | self.x_trg = tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_trg, name='x_trg_input') 342 | self.y_trg = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes), 343 | name='y_trg_input') 344 | self.y_src_domain = tf.placeholder(dtype=tf.float32, shape=(None, self.data_loader.num_src_domain), 345 | name='y_src_domain_input') 346 | 347 | T = tb.utils.TensorDict(dict( 348 | x_tmp=tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_src), 349 | y_tmp=tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes)) 350 | )) 351 | 352 | self.is_training = tf.placeholder(tf.bool, shape=(), name='is_training') 353 | 354 | self.x_src_mid_lst = [] 355 | for i in range(self.data_loader.num_src_domain): 356 | x_src_mid = self._build_source_middle(self.x_src_lst[i], is_reused=i) 357 | self.x_src_mid_lst.append(x_src_mid) 358 | self.x_trg_mid = self._build_target_middle(self.x_trg, reuse=True) 359 | 360 | # 361 | self.y_src_logit_lst = [] 362 | for i in range(self.data_loader.num_src_domain): 363 | y_src_logit = self._build_class_src_discriminator(self.x_src_mid_lst[i], self.num_classes, i) 364 | self.y_src_logit_lst.append(y_src_logit) 365 | 366 | self.y_trg_logit = self._build_class_trg_discriminator(self.x_trg_mid, 367 | self.num_classes) 368 | # 369 | 370 | # 371 | self.src_loss_class_lst = [] 372 | self.src_loss_class_sum = tf.constant(0.0) 373 | for i in range(self.data_loader.num_src_domain): 374 | src_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2( 375 | logits=self.y_src_logit_lst[i], labels=self.y_src_lst[i]) 376 | src_loss_class = tf.reduce_mean(src_loss_class_detail) 377 | self.src_loss_class_lst.append(self.src_domain_trade_off[i] * src_loss_class) 378 | self.src_loss_class_sum += self.src_domain_trade_off[i] * src_loss_class 379 | self.trg_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2( 380 | logits=self.y_trg_logit, labels=self.y_trg) 381 | self.trg_loss_class = tf.reduce_mean(self.trg_loss_class_detail) 382 | # 383 | 384 | # 385 | self.x_src_mid_all = tf.concat(self.x_src_mid_lst, axis=0) 386 | self.y_src_discriminator_logit = self._build_domain_discriminator(self.x_src_mid_all) 387 | 388 | self.src_loss_discriminator_detail = tf.nn.softmax_cross_entropy_with_logits_v2( 389 | logits=self.y_src_discriminator_logit, labels=self.y_src_domain) 390 | self.src_loss_discriminator = tf.reduce_mean(self.src_loss_discriminator_detail) 391 | # 392 | 393 | # 394 | self.y_src_teacher_all = [] 395 | for i, bs in zip(range(self.data_loader.num_src_domain), 396 | range(0, self.batch_size_src * self.data_loader.num_src_domain, self.batch_size_src)): 397 | y_src_logit_each_h_lst = [] 398 | for j in range(self.data_loader.num_src_domain): 399 | y_src_logit_each_h = self._build_class_src_discriminator(self.x_src_mid_lst[i], self.num_classes, 400 | j, reuse=True) 401 | y_src_logit_each_h_lst.append(y_src_logit_each_h) 402 | y_src_logit_each_h_lst = tf.nn.softmax(tf.convert_to_tensor(y_src_logit_each_h_lst)) 403 | y_src_discriminator_prob = tf.nn.softmax(tf.gather(self.y_src_discriminator_logit, 404 | tf.range(bs, bs + self.batch_size_src, 405 | dtype=tf.int32), axis=0)) 406 | y_src_teacher = self._compute_teacher_hs(y_src_logit_each_h_lst, y_src_discriminator_prob) 407 | self.y_src_teacher_all.append(y_src_teacher) 408 | self.y_src_teacher_all = tf.concat(self.y_src_teacher_all, axis=0) 409 | # 410 | 411 | # 412 | y_trg_logit_each_h_lst = [] 413 | for j in range(self.data_loader.num_src_domain): 414 | y_trg_logit_each_h = self._build_class_src_discriminator(self.x_trg_mid, self.num_classes, 415 | j, reuse=True) 416 | y_trg_logit_each_h_lst.append(y_trg_logit_each_h) 417 | y_trg_logit_each_h_lst = tf.nn.softmax(tf.convert_to_tensor(y_trg_logit_each_h_lst)) 418 | self.y_trg_src_domains_logit = self._build_domain_discriminator(self.x_trg_mid, reuse=True) 419 | y_trg_discriminator_prob = tf.nn.softmax(self.y_trg_src_domains_logit) 420 | self.y_trg_teacher = self._compute_teacher_hs(y_trg_logit_each_h_lst, y_trg_discriminator_prob) 421 | # 422 | 423 | # 424 | self.ht_g_xs = build_class_discriminator_template( 425 | self.x_src_mid_all, training_phase=self.is_training, scope='c-trg', num_classes=self.num_classes, 426 | reuse=True, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size 427 | ) 428 | self.mimic_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2( 429 | logits=self.ht_g_xs, labels=self.y_src_teacher_all)) + \ 430 | tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2( 431 | logits=self.y_trg_logit, labels=self.y_trg_teacher)) 432 | # 433 | 434 | # 435 | self.data_shift_loss = self._compute_cosine_similarity(self.x_trg_mid, self.x_src_mid_all) 436 | self.label_shift_loss = self._compute_label_shift_loss(self.y_trg_logit, self.ht_g_xs) 437 | self.data_label_shift_loss = self.data_shift_troff * self.data_shift_loss + self.lbl_shift_troff * self.label_shift_loss 438 | self.g_network = tf.reshape(self._build_phi_network(self.x_trg_mid), [-1]) 439 | self.exp_term = (- self.data_label_shift_loss + self.g_network) / self.theta 440 | self.g_network_loss = tf.reduce_mean(self.g_network) 441 | self.OT_loss = tf.reduce_mean( 442 | - self.theta * \ 443 | ( 444 | tf.log(1.0 / self.batch_size) + 445 | tf.reduce_logsumexp(self.exp_term, axis=1) 446 | ) 447 | ) + self.g_network_trade_off * self.g_network_loss 448 | # 449 | 450 | # 451 | self.trg_loss_vat = self._build_vat_loss( 452 | self.x_trg, self.y_trg_logit, self.num_classes, 453 | scope_encode=self._get_scope('generator', 'trg'), scope_classify='c-trg' 454 | ) 455 | # 456 | 457 | # 458 | self.trg_loss_cond_entropy = tf.reduce_mean(softmax_x_entropy_two(labels=self.y_trg_logit, 459 | logits=self.y_trg_logit)) 460 | # 461 | 462 | # 463 | self.src_accuracy_lst = [] 464 | for i in range(self.data_loader.num_src_domain): 465 | y_src_pred = tf.argmax(self.y_src_logit_lst[i], 1, output_type=tf.int32) 466 | y_src_sparse = tf.argmax(self.y_src_lst[i], 1, output_type=tf.int32) 467 | src_accuracy = tf.reduce_mean(tf.cast(tf.equal(y_src_sparse, y_src_pred), 'float32')) 468 | self.src_accuracy_lst.append(src_accuracy) 469 | # compute acc for target domain 470 | self.y_trg_pred = tf.argmax(self.y_trg_logit, 1, output_type=tf.int32) 471 | self.y_trg_sparse = tf.argmax(self.y_trg, 1, output_type=tf.int32) 472 | self.trg_accuracy = tf.reduce_mean(tf.cast(tf.equal(self.y_trg_sparse, self.y_trg_pred), 'float32')) 473 | # compute acc for src domain disc 474 | self.y_src_domain_pred = tf.argmax(self.y_src_discriminator_logit, 1, output_type=tf.int32) 475 | self.y_src_domain_sparse = tf.argmax(self.y_src_domain, 1, output_type=tf.int32) 476 | self.src_domain_acc = tf.reduce_mean( 477 | tf.cast(tf.equal(self.y_src_domain_sparse, self.y_src_domain_pred), 'float32')) 478 | # 479 | 480 | # 481 | lst_phase1_losses = [ 482 | (self.src_class_trade_off, self.src_loss_class_sum), 483 | (self.domain_trade_off, self.src_loss_discriminator), 484 | ] 485 | self.phase1_loss = tf.constant(0.0) 486 | for trade_off, loss in lst_phase1_losses: 487 | self.phase1_loss += trade_off * loss 488 | 489 | lst_phase2_losses = [ 490 | (self.src_class_trade_off, self.src_loss_class_sum), 491 | (self.ot_trade_off, self.OT_loss), 492 | (self.domain_trade_off, self.src_loss_discriminator), 493 | (self.trg_vat_troff, self.trg_loss_vat), 494 | (self.trg_ent_troff, self.trg_loss_cond_entropy), 495 | (self.mimic_trade_off, self.mimic_loss) 496 | ] 497 | self.phase2_loss = tf.constant(0.0) 498 | for trade_off, loss in lst_phase2_losses: 499 | self.phase2_loss += trade_off * loss 500 | # 501 | 502 | # 503 | primary_student_variables = self._get_variables(self._get_student_primary_scopes()) 504 | ema = tf.train.ExponentialMovingAverage(decay=0.998) 505 | var_list_for_ema = primary_student_variables[0] + primary_student_variables[1] 506 | ema_op = ema.apply(var_list=var_list_for_ema) 507 | self.ema_p = self._build_classifier(T.x_tmp, self.num_classes, ema) 508 | self.batch_ema_acc = batch_ema_acc(T.y_tmp, self.ema_p) 509 | self.fn_batch_ema_acc = tb.function(self.tf_session, [T.x_tmp, T.y_tmp], self.batch_ema_acc) 510 | # 511 | 512 | teacher_variables = self._get_variables(self._get_teacher_scopes()) 513 | self.train_teacher = \ 514 | tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.phase1_loss, 515 | var_list=teacher_variables) 516 | self.train_student_main = \ 517 | tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.phase2_loss, 518 | var_list=teacher_variables + [ 519 | primary_student_variables[1]]) 520 | self.primary_train_student_op = tf.group(self.train_student_main, ema_op) 521 | 522 | # 523 | secondary_variables = self._get_variables(self._get_student_secondary_scopes()) 524 | self.secondary_train_student_op = \ 525 | tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(-self.OT_loss, 526 | var_list=secondary_variables) 527 | # 528 | 529 | # 530 | tf.summary.scalar('loss/phase1_loss', self.phase1_loss) 531 | tf.summary.scalar('loss/phase2_loss', self.phase2_loss) 532 | tf.summary.scalar('loss/W_distance', self.OT_loss) 533 | tf.summary.scalar('loss/src_loss_discriminator', self.src_loss_discriminator) 534 | 535 | for i in range(self.data_loader.num_src_domain): 536 | tf.summary.scalar('loss/src_loss_class_{}'.format(i), self.src_loss_class_lst[i]) 537 | tf.summary.scalar('acc/src_acc_{}'.format(i), self.src_accuracy_lst[i]) 538 | tf.summary.scalar('acc/src_domain_acc', self.src_domain_acc) 539 | tf.summary.scalar('acc/trg_acc', self.trg_accuracy) 540 | 541 | tf.summary.scalar('trg_loss_class', self.trg_loss_class) 542 | tf.summary.scalar('hyperparameters/learning_rate', self.learning_rate) 543 | tf.summary.scalar('hyperparameters/src_class_trade_off', self.src_class_trade_off) 544 | tf.summary.scalar('hyperparameters/g_network_trade_off', self.g_network_trade_off) 545 | tf.summary.scalar('hyperparameters/domain_trade_off', self.domain_trade_off) 546 | tf.summary.scalar('hyperparameters/trg_vat_troff', self.trg_vat_troff) 547 | tf.summary.scalar('hyperparameters/trg_ent_troff', self.trg_ent_troff) 548 | self.tf_merged_summaries = tf.summary.merge_all() 549 | # 550 | 551 | def _fit_loop(self): 552 | print('Start training MOST model at', os.path.abspath(__file__)) 553 | print('============ LOG-ID: %s ============' % self.current_time) 554 | num_src_samples_lst = [] 555 | for k in range(self.data_loader.num_src_domain): 556 | num_src_samples = self.data_loader.src_train[k][2].shape[0] 557 | num_src_samples_lst.append(num_src_samples) 558 | 559 | num_trg_samples = self.data_loader.trg_train[0][1].shape[0] 560 | src_batchsize = self.batch_size // self.data_loader.num_src_domain 561 | 562 | self.tf_session.run(tf.global_variables_initializer()) 563 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=101) 564 | self.log_path = os.path.join(model_dir(), self.model_name, "logs", 565 | "{}".format(self.current_time)) 566 | self.tf_summary_writer = tf.summary.FileWriter(self.log_path, self.tf_session.graph) 567 | 568 | self.checkpoint_path = os.path.join(model_dir(), self.model_name, "saved-model", 569 | "{}".format(self.mdaot_model_id)) 570 | check_point = tf.train.get_checkpoint_state(self.checkpoint_path) 571 | 572 | if check_point and tf.train.checkpoint_exists(check_point.model_checkpoint_path): 573 | print("Load model parameters from %s\n" % check_point.model_checkpoint_path) 574 | saver.restore(self.tf_session, check_point.model_checkpoint_path) 575 | 576 | feed_y_src_domain = to_categorical(np.repeat(np.arange(self.data_loader.num_src_domain), 577 | repeats=self.sample_size * self.num_classes, axis=0)) 578 | 579 | for it in range(self.num_iters): 580 | idx_trg_samples = np.random.permutation(num_trg_samples)[:self.batch_size] 581 | feed_data = dict() 582 | for k in range(self.data_loader.num_src_domain): 583 | idx_src_samples = np.random.permutation(num_src_samples_lst[k])[:src_batchsize] 584 | feed_data[self.x_src_lst[k]] = self.data_loader.src_train[k][1][idx_src_samples, :] 585 | feed_data[self.y_src_lst[k]] = self.data_loader.src_train[k][2][idx_src_samples] 586 | feed_data[self.x_trg] = self.data_loader.trg_train[0][1][idx_trg_samples, :] 587 | feed_data[self.y_trg] = self.data_loader.trg_train[0][2][idx_trg_samples] 588 | feed_data[self.y_src_domain] = feed_y_src_domain 589 | feed_data[self.is_training] = True 590 | 591 | if it < self.phase1_iters: 592 | _, total_loss, W_dist, src_loss_class_sum, src_loss_class_lst, src_loss_discriminator, src_acc_lst, trg_acc, src_domain_acc, mimic_loss = \ 593 | self.tf_session.run( 594 | [self.train_teacher, self.phase1_loss, self.OT_loss, self.src_loss_class_sum, 595 | self.src_loss_class_lst, self.src_loss_discriminator, 596 | self.src_accuracy_lst, self.trg_accuracy, self.src_domain_acc, 597 | self.mimic_loss], 598 | feed_dict=feed_data 599 | ) 600 | else: 601 | for i in range(0, 5): 602 | g_idx_trg_samples = np.random.permutation(num_trg_samples)[:self.batch_size] 603 | g_feed_data = dict() 604 | for k in range(self.data_loader.num_src_domain): 605 | g_idx_src_samples = np.random.permutation(num_src_samples_lst[k])[:src_batchsize] 606 | g_feed_data[self.x_src_lst[k]] = self.data_loader.src_train[k][1][g_idx_src_samples, :] 607 | g_feed_data[self.y_src_lst[k]] = self.data_loader.src_train[k][2][g_idx_src_samples] 608 | 609 | g_feed_data[self.x_trg] = self.data_loader.trg_train[0][1][g_idx_trg_samples, :] 610 | g_feed_data[self.y_trg] = self.data_loader.trg_train[0][2][g_idx_trg_samples] 611 | g_feed_data[self.is_training] = True 612 | 613 | _, W_dist = \ 614 | self.tf_session.run( 615 | [self.secondary_train_student_op, self.OT_loss], 616 | feed_dict=g_feed_data 617 | ) 618 | _, total_loss, src_loss_class_sum, src_loss_class_lst, src_loss_discriminator, src_acc_lst, trg_acc, src_domain_acc, mimic_loss = \ 619 | self.tf_session.run( 620 | [self.primary_train_student_op, self.phase2_loss, self.src_loss_class_sum, 621 | self.src_loss_class_lst, self.src_loss_discriminator, 622 | self.src_accuracy_lst, self.trg_accuracy, self.src_domain_acc, 623 | self.mimic_loss], 624 | feed_dict=feed_data 625 | ) 626 | 627 | if it == 0 or (it + 1) % self.summary_freq == 0: 628 | print( 629 | "iter %d/%d total_loss %.3f; src_loss_class_sum %.3f; W_dist %.3f; \n src_loss_discriminator %.3f, pseudo_lbl_loss %.3f" % ( 630 | it + 1, self.num_iters, total_loss, src_loss_class_sum, W_dist, 631 | src_loss_discriminator, mimic_loss)) 632 | for k in range(self.data_loader.num_src_domain): 633 | print('src_loss_class_{}: {:.3f} acc {:.2f}'.format(k, src_loss_class_lst[k], src_acc_lst[k] * 100)) 634 | print("src_domain_disc_acc: %.2f, trg_acc: %.2f;" % (src_domain_acc * 100, trg_acc * 100)) 635 | summary = self.tf_session.run(self.tf_merged_summaries, feed_dict=feed_data) 636 | self.tf_summary_writer.add_summary(summary, it + 1) 637 | self.tf_summary_writer.flush() 638 | 639 | if it == 0 or (it + 1) % self.summary_freq == 0: 640 | if not self.only_save_final_model: 641 | self.save_trained_model(saver, it + 1) 642 | elif it + 1 == self.phase1_iters or it + 1 == self.num_iters: 643 | self.save_trained_model(saver, it + 1) 644 | if it >= self.phase1_iters and (it + 1) % (self.num_iters // 50) == 0: 645 | self.save_value(step=it + 1) 646 | 647 | def save_trained_model(self, saver, step): 648 | checkpoint_path = os.path.join(model_dir(), self.model_name, "saved-model", 649 | "{}".format(self.current_time)) 650 | checkpoint_path = os.path.join(checkpoint_path, "mdaot_" + self.current_time + ".ckpt") 651 | 652 | directory = os.path.dirname(checkpoint_path) 653 | if not os.path.exists(directory): 654 | os.makedirs(directory) 655 | saver.save(self.tf_session, checkpoint_path, global_step=step) 656 | 657 | def save_value(self, step): 658 | ema_acc, summary = self.compute_value(x_full=self.data_loader.trg_test[0][1], 659 | y=self.data_loader.trg_test[0][2], labeler=None) 660 | self.tf_summary_writer.add_summary(summary, step) 661 | self.tf_summary_writer.flush() 662 | print_list = ['test_acc', round(ema_acc * 100, 2),] 663 | print(print_list) 664 | 665 | def compute_value(self, x_full, y, labeler, full=True): 666 | with tb.nputils.FixedSeed(0): 667 | shuffle = np.random.permutation(len(x_full)) 668 | 669 | xs = x_full[shuffle] 670 | ys = y[shuffle] if y is not None else None 671 | if not full: 672 | xs = xs[:1000] 673 | ys = ys[:1000] if ys is not None else None 674 | 675 | n = len(xs) 676 | bs = 200 677 | 678 | ema_acc_full = np.ones(n, dtype=float) 679 | for i in range(0, n, bs): 680 | x = xs[i:i + bs] 681 | y = ys[i:i + bs] if ys is not None else labeler(x) 682 | ema_acc_batch = self.fn_batch_ema_acc(x, y) 683 | ema_acc_full[i:i + bs] = ema_acc_batch 684 | 685 | ema_acc = np.mean(ema_acc_full) 686 | summary = tf.Summary.Value(tag='trg_test/ema_acc', simple_value=ema_acc) 687 | summary = tf.Summary(value=[summary]) 688 | return ema_acc, summary 689 | -------------------------------------------------------------------------------- /model/most_AlexNet_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Tuan Nguyen. 2 | # All rights reserved. 3 | 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import absolute_import 7 | 8 | import tensorflow as tf 9 | from tensorflow.contrib.framework import arg_scope 10 | from tensorflow.contrib.framework import add_arg_scope 11 | from tensorbayes.layers import dense, conv2d, batch_norm, instance_norm 12 | from tensorbayes.tfutils import softmax_cross_entropy_with_two_logits as softmax_x_entropy_two 13 | 14 | from keras import backend as K 15 | from keras.preprocessing.image import ImageDataGenerator 16 | 17 | from generic_utils import random_seed 18 | 19 | from layers import leaky_relu 20 | import os 21 | from generic_utils import model_dir 22 | import numpy as np 23 | import tensorbayes as tb 24 | from layers import batch_ema_acc 25 | from keras.utils.np_utils import to_categorical 26 | from alexnet.model import AlexNetModel 27 | 28 | 29 | def build_block(input_layer, layout, info=1): 30 | x = input_layer 31 | for i in range(0, len(layout)): 32 | with tf.variable_scope('l{:d}'.format(i)): 33 | f, f_args, f_kwargs = layout[i] 34 | x = f(x, *f_args, **f_kwargs) 35 | if info > 1: 36 | print(x) 37 | return x 38 | 39 | 40 | @add_arg_scope 41 | def normalize_perturbation(d, scope=None): 42 | with tf.name_scope(scope, 'norm_pert'): 43 | output = tf.nn.l2_normalize(d, axis=np.arange(1, len(d.shape))) 44 | return output 45 | 46 | 47 | def build_encode_template( 48 | input_layer, training_phase, scope, encode_layout, 49 | reuse=None, internal_update=False, getter=None, inorm=True, cnn_size='large'): 50 | with tf.variable_scope(scope, reuse=reuse, custom_getter=getter): 51 | with arg_scope([leaky_relu], a=0.1), \ 52 | arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \ 53 | arg_scope([batch_norm], internal_update=internal_update): 54 | 55 | preprocess = instance_norm if inorm else tf.identity 56 | 57 | layout = encode_layout(preprocess=preprocess, training_phase=training_phase, cnn_size=cnn_size) 58 | output_layer = build_block(input_layer, layout) 59 | 60 | return output_layer 61 | 62 | 63 | def build_class_discriminator_template( 64 | input_layer, training_phase, scope, num_classes, class_discriminator_layout, 65 | reuse=None, internal_update=False, getter=None, cnn_size='large'): 66 | with tf.variable_scope(scope, reuse=reuse, custom_getter=getter): 67 | with arg_scope([leaky_relu], a=0.1), \ 68 | arg_scope([conv2d, dense], activation=leaky_relu, bn=True, phase=training_phase), \ 69 | arg_scope([batch_norm], internal_update=internal_update): 70 | layout = class_discriminator_layout(num_classes=num_classes, global_pool=True, activation=None, 71 | cnn_size=cnn_size) 72 | output_layer = build_block(input_layer, layout) 73 | 74 | return output_layer 75 | 76 | 77 | def build_domain_discriminator_template(x, domain_layout, c=1, reuse=None, scope='domain_disc'): 78 | with tf.variable_scope(scope, reuse=reuse): 79 | with arg_scope([dense], activation=tf.nn.relu): 80 | layout = domain_layout(c=c) 81 | output_layer = build_block(x, layout) 82 | 83 | return output_layer 84 | 85 | 86 | def build_phi_network_template(x, domain_layout, c=1, reuse=None): 87 | with tf.variable_scope('phi_net', reuse=reuse): 88 | with arg_scope([dense], activation=tf.nn.relu): 89 | layout = domain_layout(c=c) 90 | output_layer = build_block(x, layout) 91 | 92 | return output_layer 93 | 94 | 95 | def get_default_config(): 96 | tf_config = tf.ConfigProto() 97 | tf_config.gpu_options.allow_growth = True 98 | tf_config.log_device_placement = False 99 | tf_config.allow_soft_placement = True 100 | return tf_config 101 | 102 | 103 | class MOST(): 104 | def __init__(self, 105 | model_name="MOST-results", 106 | learning_rate=0.001, 107 | batch_size=128, 108 | num_iters=80000, 109 | summary_freq=400, 110 | src_class_trade_off=1.0, 111 | src_domain_trade_off='1.0,1.0', 112 | src_vat_trade_off=1.0, 113 | trg_vat_troff=0.1, 114 | trg_ent_troff=0.1, 115 | ot_trade_off=0.1, 116 | domain_trade_off=0.1, 117 | mimic_trade_off=0.1, 118 | encode_layout=None, 119 | classify_layout=None, 120 | domain_layout=None, 121 | phi_layout=None, 122 | current_time='', 123 | inorm=True, 124 | theta=0.1, 125 | g_network_trade_off=1.0, 126 | mdaot_model_id='', 127 | only_save_final_model=True, 128 | cnn_size='large', 129 | sample_size=50, 130 | data_shift_troff=10.0, 131 | lbl_shift_troff=1.0, 132 | num_classes=10, 133 | multi_scale='', 134 | resnet_depth=101, 135 | train_layers='fc8,fc7,fc6', 136 | **kwargs): 137 | self.model_name = model_name 138 | self.batch_size = batch_size 139 | self.learning_rate = learning_rate 140 | self.num_iters = num_iters 141 | self.summary_freq = summary_freq 142 | self.src_class_trade_off = src_class_trade_off 143 | self.src_domain_trade_off = [float(item) for item in src_domain_trade_off.split(',')] 144 | self.src_vat_trade_off = src_vat_trade_off 145 | self.trg_vat_troff = trg_vat_troff 146 | self.trg_ent_troff = trg_ent_troff 147 | self.ot_trade_off = ot_trade_off 148 | self.domain_trade_off = domain_trade_off 149 | self.mimic_trade_off = mimic_trade_off 150 | 151 | self.encode_layout = encode_layout 152 | self.classify_layout = classify_layout 153 | self.domain_layout = domain_layout 154 | self.phi_layout = phi_layout 155 | 156 | self.current_time = current_time 157 | self.inorm = inorm 158 | 159 | self.theta = theta 160 | self.g_network_trade_off = g_network_trade_off 161 | 162 | self.mdaot_model_id = mdaot_model_id 163 | self.only_save_final_model = only_save_final_model 164 | 165 | self.cnn_size = cnn_size 166 | self.sample_size = sample_size 167 | self.data_shift_troff = data_shift_troff 168 | self.lbl_shift_troff = lbl_shift_troff 169 | 170 | self.num_classes = num_classes 171 | self.multi_scale = multi_scale 172 | self.resnet_depth = resnet_depth 173 | self.train_layers = train_layers.split(',') 174 | 175 | def _init(self, src_preprocessors, trg_train_preprocessor, trg_test_preprocessor, num_src_domain): 176 | np.random.seed(random_seed()) 177 | tf.set_random_seed(random_seed()) 178 | tf.reset_default_graph() 179 | 180 | self.tf_graph = tf.get_default_graph() 181 | self.tf_config = get_default_config() 182 | self.tf_session = tf.Session(config=self.tf_config, graph=self.tf_graph) 183 | 184 | self.src_preprocessors = src_preprocessors 185 | self.trg_train_preprocessor = trg_train_preprocessor 186 | self.trg_test_preprocessor = trg_test_preprocessor 187 | self.num_src_domain = num_src_domain 188 | self.batch_size_src = self.sample_size*self.num_classes 189 | 190 | assert len(self.src_domain_trade_off) == self.num_src_domain 191 | assert self.sample_size*self.num_classes*self.num_src_domain == self.batch_size 192 | 193 | def _get_variables(self, list_scopes): 194 | variables = [] 195 | for scope_name in list_scopes: 196 | variables.append(tf.get_collection('trainable_variables', scope_name)) 197 | return variables 198 | 199 | def convert_one_hot(self, y): 200 | y_idx = y.reshape(-1).astype(int) if y is not None else None 201 | y = np.eye(self.num_classes)[y_idx] if y is not None else None 202 | return y 203 | 204 | def _get_scope(self, part_name, side_name, same_network=True): 205 | suffix = '' 206 | if not same_network: 207 | suffix = '/' + side_name 208 | return part_name + suffix 209 | 210 | def _get_all_g_h(self): 211 | return ['generator', 'classifier'] 212 | 213 | def _get_teacher_scopes(self): 214 | return ['generator', 'classifier', 'domain_disc'] 215 | 216 | def _get_student_primary_scopes(self): 217 | return ['generator', 'c-trg'] 218 | 219 | def _get_student_secondary_scopes(self): 220 | return ['phi_net'] 221 | 222 | def _build_source_middle(self, x_src, is_reused): 223 | scope_name = self._get_scope('generator', 'src') 224 | if is_reused == 0: 225 | generator_model = build_encode_template(x_src, encode_layout=self.encode_layout, 226 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm, cnn_size=self.cnn_size) 227 | else: 228 | generator_model = build_encode_template(x_src, encode_layout=self.encode_layout, 229 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm, 230 | reuse=True, internal_update=True, 231 | cnn_size=self.cnn_size) 232 | return generator_model 233 | 234 | def _build_target_middle(self, x_trg, reuse=None): 235 | scope_name = 'generator' 236 | return build_encode_template( 237 | x_trg, encode_layout=self.encode_layout, 238 | scope=scope_name, training_phase=self.is_training, inorm=self.inorm, 239 | reuse=reuse, internal_update=True, cnn_size=self.cnn_size 240 | ) 241 | 242 | def _build_classifier(self, x, num_classes, ema=None, is_teacher=False): 243 | g_teacher_scope = self._get_scope('generator', 'teacher', same_network=False) 244 | g_x = build_encode_template( 245 | x, encode_layout=self.encode_layout, 246 | scope=g_teacher_scope if is_teacher else 'generator', training_phase=False, inorm=self.inorm, 247 | reuse=False if is_teacher else True, getter=None if is_teacher else tb.tfutils.get_getter(ema), 248 | cnn_size=self.cnn_size 249 | ) 250 | 251 | h_teacher_scope = self._get_scope('c-trg', 'teacher', same_network=False) 252 | h_g_x = build_class_discriminator_template( 253 | g_x, training_phase=False, scope=h_teacher_scope if is_teacher else 'c-trg', num_classes=num_classes, 254 | reuse=False if is_teacher else True, class_discriminator_layout=self.classify_layout, 255 | getter=None if is_teacher else tb.tfutils.get_getter(ema), cnn_size=self.cnn_size 256 | ) 257 | return h_g_x 258 | 259 | def _build_domain_discriminator(self, x_mid, reuse=None, scope='domain_disc'): 260 | return build_domain_discriminator_template(x_mid, domain_layout=self.domain_layout, c=self.num_src_domain, reuse=reuse, scope=scope) 261 | 262 | def _build_phi_network(self, x_mid, reuse=None): 263 | return build_phi_network_template(x_mid, domain_layout=self.phi_layout, c=1, reuse=reuse) 264 | 265 | def _build_class_src_discriminator(self, x_src, num_src_classes, i, reuse=None): 266 | classifier_model = build_class_discriminator_template( 267 | x_src, training_phase=self.is_training, scope='classifier/{}'.format(i), num_classes=num_src_classes, 268 | reuse=reuse, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size 269 | ) 270 | return classifier_model 271 | 272 | def _build_class_trg_discriminator(self, x_trg, num_trg_classes): 273 | return build_class_discriminator_template( 274 | x_trg, training_phase=self.is_training, scope='c-trg', num_classes=num_trg_classes, 275 | reuse=False, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size 276 | ) 277 | 278 | def perturb_image(self, x, p, num_classes, class_discriminator_layout, encode_layout, 279 | pert='vat', scope=None, radius=3.5, scope_classify=None, scope_encode=None, training_phase=None): 280 | with tf.name_scope(scope, 'perturb_image'): 281 | eps = 1e-6 * normalize_perturbation(tf.random_normal(shape=tf.shape(x))) 282 | 283 | # Predict on randomly perturbed image 284 | x_eps_mid = build_encode_template( 285 | x + eps, encode_layout=encode_layout, scope=scope_encode, training_phase=training_phase, reuse=True, 286 | inorm=self.inorm, cnn_size=self.cnn_size) 287 | x_eps_pred = build_class_discriminator_template( 288 | x_eps_mid, class_discriminator_layout=class_discriminator_layout, 289 | training_phase=training_phase, scope=scope_classify, reuse=True, num_classes=num_classes, 290 | cnn_size=self.cnn_size 291 | ) 292 | # eps_p = classifier(x + eps, phase=True, reuse=True) 293 | loss = softmax_x_entropy_two(labels=p, logits=x_eps_pred) 294 | 295 | # Based on perturbed image, get direction of greatest error 296 | eps_adv = tf.gradients(loss, [eps], aggregation_method=2)[0] 297 | 298 | # Use that direction as adversarial perturbation 299 | eps_adv = normalize_perturbation(eps_adv) 300 | x_adv = tf.stop_gradient(x + radius * eps_adv) 301 | 302 | return x_adv 303 | 304 | def vat_loss(self, x, p, num_classes, class_discriminator_layout, encode_layout, 305 | scope=None, scope_classify=None, scope_encode=None, training_phase=None): 306 | 307 | with tf.name_scope(scope, 'smoothing_loss'): 308 | x_adv = self.perturb_image( 309 | x, p, num_classes, class_discriminator_layout=class_discriminator_layout, encode_layout=encode_layout, 310 | scope_classify=scope_classify, scope_encode=scope_encode, training_phase=training_phase) 311 | 312 | x_adv_mid = build_encode_template( 313 | x_adv, encode_layout=encode_layout, scope=scope_encode, training_phase=training_phase, inorm=self.inorm, 314 | reuse=True, cnn_size=self.cnn_size) 315 | x_adv_pred = build_class_discriminator_template( 316 | x_adv_mid, training_phase=training_phase, scope=scope_classify, reuse=True, num_classes=num_classes, 317 | class_discriminator_layout=class_discriminator_layout, cnn_size=self.cnn_size 318 | ) 319 | # p_adv = classifier(x_adv, phase=True, reuse=True) 320 | loss = tf.reduce_mean(softmax_x_entropy_two(labels=tf.stop_gradient(p), logits=x_adv_pred)) 321 | 322 | return loss 323 | 324 | def _build_vat_loss(self, x, p, num_classes, scope=None, scope_classify=None, scope_encode=None): 325 | return self.vat_loss( 326 | x, p, num_classes, 327 | class_discriminator_layout=self.classify_layout, 328 | encode_layout=self.encode_layout, 329 | scope=scope, scope_classify=scope_classify, scope_encode=scope_encode, 330 | training_phase=self.is_training 331 | ) 332 | 333 | def _compute_cosine_similarity(self, x_trg_mid, x_src_mid_all): 334 | x_trg_mid_flatten = tf.layers.Flatten()(x_trg_mid) 335 | x_src_mid_all_flatten = tf.layers.Flatten()(x_src_mid_all) 336 | similarity = tf.reduce_sum(x_trg_mid_flatten[:, tf.newaxis] * x_src_mid_all_flatten, axis=-1) 337 | similarity /= tf.norm(x_trg_mid_flatten[:, tf.newaxis], axis=-1) * tf.norm(x_src_mid_all_flatten, axis=-1) 338 | distance = 1.0 - similarity 339 | return distance 340 | 341 | def _compute_data_shift_loss(self, x_src_mid, x_trg_mid): 342 | x_src_mid_flatten = tf.layers.Flatten()(x_src_mid) 343 | x_trg_mid_flatten = tf.layers.Flatten()(x_trg_mid) 344 | 345 | data_shift_loss = tf.norm(tf.subtract(x_src_mid_flatten, tf.expand_dims(x_trg_mid_flatten, 1)), axis=2) 346 | return data_shift_loss 347 | 348 | def _compute_label_shift_loss(self, y_trg_logit, y_src_teacher): 349 | y_trg_logit_rep = K.repeat_elements(tf.expand_dims(y_trg_logit, axis=0), rep=self.batch_size, axis=0) 350 | y_trg_logit_rep = tf.reshape(y_trg_logit_rep, [-1, y_trg_logit_rep.get_shape()[-1]]) 351 | y_src_teacher_rep = K.repeat_elements(y_src_teacher, rep=self.batch_size, axis=0) 352 | label_shift_loss = softmax_x_entropy_two(labels=y_src_teacher_rep, 353 | logits=y_trg_logit_rep) 354 | 355 | label_shift_loss = tf.reshape(label_shift_loss, [self.batch_size, self.batch_size]) 356 | return label_shift_loss 357 | 358 | def _compute_teacher_hs(self, y_label_trg_output_each_h, y_d_trg_sofmax_output): 359 | y_label_trg_output_each_h = tf.transpose(tf.stack(y_label_trg_output_each_h), perm=[1, 0, 2]) 360 | y_d_trg_sofmax_output_multi_y = y_d_trg_sofmax_output 361 | y_d_trg_sofmax_output_multi_y = tf.expand_dims(y_d_trg_sofmax_output_multi_y, axis=-1) 362 | y_d_trg_sofmax_output_multi_y = tf.tile(y_d_trg_sofmax_output_multi_y, [1, 1, self.num_classes]) 363 | y_label_trg_output = y_d_trg_sofmax_output_multi_y * y_label_trg_output_each_h 364 | y_label_trg_output = tf.reduce_sum(y_label_trg_output, axis=1) 365 | return y_label_trg_output 366 | 367 | def _build_model(self): 368 | self.x_src_lst = [] 369 | self.y_src_lst = [] 370 | for i in range(self.num_src_domain): 371 | x_src = tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_src, 372 | name='x_src_{}_input'.format(i)) 373 | y_src = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes), 374 | name='y_src_{}_input'.format(i)) 375 | 376 | self.x_src_lst.append(x_src) 377 | self.y_src_lst.append(y_src) 378 | 379 | self.x_trg = tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_trg, name='x_trg_input') 380 | self.y_trg = tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes), 381 | name='y_trg_input') 382 | self.y_src_domain = tf.placeholder(dtype=tf.float32, shape=(None, self.num_src_domain), 383 | name='y_src_domain_input') 384 | 385 | T = tb.utils.TensorDict(dict( 386 | x_tmp=tf.placeholder(dtype=tf.float32, shape=tuple([None]) + self.dim_src), 387 | y_tmp=tf.placeholder(dtype=tf.float32, shape=(None, self.num_classes)) 388 | )) 389 | 390 | self.is_training = tf.placeholder(tf.bool, shape=(), name='is_training') 391 | self.alexNet = AlexNetModel(num_classes=self.num_classes, is_training=self.is_training) 392 | 393 | self.x_src_mid_lst = [] 394 | for i in range(self.num_src_domain): 395 | if i == 0: 396 | x_src_mid_feat = tf.reshape(self.alexNet.inference(self.x_src_lst[i], extract_feat=True), 397 | (-1, 8, 8, 64)) 398 | else: 399 | x_src_mid_feat = tf.reshape(self.alexNet.inference(self.x_src_lst[i], reuse=True, extract_feat=True), 400 | (-1, 8, 8, 64)) 401 | 402 | x_src_mid = self._build_source_middle(x_src_mid_feat, is_reused=i) 403 | self.x_src_mid_lst.append(x_src_mid) 404 | 405 | self.x_trg_mid_feat = tf.reshape(self.alexNet.inference(self.x_trg, reuse=True, extract_feat=True), (-1, 8, 8, 64)) 406 | self.x_trg_mid = self._build_target_middle(self.x_trg_mid_feat, reuse=True) 407 | 408 | # 409 | self.y_src_logit_lst = [] 410 | for i in range(self.num_src_domain): 411 | y_src_logit = self._build_class_src_discriminator(self.x_src_mid_lst[i], self.num_classes, i) 412 | self.y_src_logit_lst.append(y_src_logit) 413 | 414 | self.y_trg_logit = self._build_class_trg_discriminator(self.x_trg_mid, 415 | self.num_classes) 416 | # 417 | 418 | # 419 | self.src_loss_class_lst = [] 420 | self.src_loss_class_sum = tf.constant(0.0) 421 | for i in range(self.num_src_domain): 422 | src_loss_class_detail = tf.nn.softmax_cross_entropy_with_logits_v2( 423 | logits=self.y_src_logit_lst[i], labels=self.y_src_lst[i]) 424 | src_loss_class = tf.reduce_mean(src_loss_class_detail) 425 | self.src_loss_class_lst.append(self.src_domain_trade_off[i]*src_loss_class) 426 | self.src_loss_class_sum += self.src_domain_trade_off[i]*src_loss_class 427 | # 428 | 429 | # 430 | self.x_src_mid_all = tf.concat(self.x_src_mid_lst, axis=0) 431 | self.y_src_discriminator_logit = self._build_domain_discriminator(self.x_src_mid_all) 432 | self.src_loss_discriminator_detail = tf.nn.softmax_cross_entropy_with_logits_v2( 433 | logits=self.y_src_discriminator_logit, labels=self.y_src_domain) 434 | self.src_loss_discriminator = tf.reduce_mean(self.src_loss_discriminator_detail) 435 | # 436 | 437 | # 438 | self.y_src_teacher_all = [] 439 | for i, bs in zip(range(self.num_src_domain), 440 | range(0, self.batch_size_src * self.num_src_domain, self.batch_size_src)): 441 | y_src_logit_each_h_lst = [] 442 | for j in range(self.num_src_domain): 443 | y_src_logit_each_h = self._build_class_src_discriminator(self.x_src_mid_lst[i], self.num_classes, 444 | j, reuse=True) 445 | y_src_logit_each_h_lst.append(y_src_logit_each_h) 446 | y_src_logit_each_h_lst = tf.nn.softmax(tf.convert_to_tensor(y_src_logit_each_h_lst)) 447 | y_src_discriminator_prob = tf.nn.softmax(tf.gather(self.y_src_discriminator_logit, 448 | tf.range(bs, bs + self.batch_size_src, 449 | dtype=tf.int32), axis=0)) 450 | y_src_teacher = self._compute_teacher_hs(y_src_logit_each_h_lst, y_src_discriminator_prob) 451 | self.y_src_teacher_all.append(y_src_teacher) 452 | self.y_src_teacher_all = tf.concat(self.y_src_teacher_all, axis=0) 453 | # 454 | 455 | # 456 | y_trg_logit_each_h_lst = [] 457 | for j in range(self.num_src_domain): 458 | y_trg_logit_each_h = self._build_class_src_discriminator(self.x_trg_mid, self.num_classes, 459 | j, reuse=True) 460 | y_trg_logit_each_h_lst.append(y_trg_logit_each_h) 461 | y_trg_logit_each_h_lst = tf.nn.softmax(tf.convert_to_tensor(y_trg_logit_each_h_lst)) 462 | self.y_trg_src_domains_logit = self._build_domain_discriminator(self.x_trg_mid, reuse=True) 463 | y_trg_discriminator_prob = tf.nn.softmax(self.y_trg_src_domains_logit) 464 | self.y_trg_teacher = self._compute_teacher_hs(y_trg_logit_each_h_lst, y_trg_discriminator_prob) 465 | # 466 | 467 | # 468 | self.ht_g_xs = build_class_discriminator_template( 469 | self.x_src_mid_all, training_phase=self.is_training, scope='c-trg', num_classes=self.num_classes, 470 | reuse=True, internal_update=True, class_discriminator_layout=self.classify_layout, cnn_size=self.cnn_size 471 | ) 472 | self.mimic_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2( 473 | logits=self.ht_g_xs, labels=self.y_src_teacher_all)) + \ 474 | tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2( 475 | logits=self.y_trg_logit, labels=self.y_trg_teacher)) 476 | # 477 | 478 | # 479 | self.data_shift_loss = self._compute_cosine_similarity(self.x_trg_mid, self.x_src_mid_all) 480 | self.label_shift_loss = self._compute_label_shift_loss(self.y_trg_logit, self.ht_g_xs) 481 | self.data_label_shift_loss = self.data_shift_troff*self.data_shift_loss + self.label_shift_loss 482 | self.g_network = tf.reshape(self._build_phi_network(self.x_trg_mid), [-1]) 483 | self.exp_term = (- self.data_label_shift_loss + self.g_network) / self.theta 484 | self.g_network_loss = tf.reduce_mean(self.g_network) 485 | self.OT_loss = tf.reduce_mean( 486 | - self.theta * \ 487 | ( 488 | tf.log(1.0 / self.batch_size) + 489 | tf.reduce_logsumexp(self.exp_term, axis=1) 490 | ) 491 | ) + self.g_network_trade_off * self.g_network_loss 492 | # 493 | 494 | # 495 | self.trg_loss_cond_entropy = tf.reduce_mean(softmax_x_entropy_two(labels=self.y_trg_logit, 496 | logits=self.y_trg_logit)) 497 | # 498 | 499 | # 500 | self.src_accuracy_lst = [] 501 | for i in range(self.num_src_domain): 502 | y_src_pred = tf.argmax(self.y_src_logit_lst[i], 1, output_type=tf.int32) 503 | y_src_sparse = tf.argmax(self.y_src_lst[i], 1, output_type=tf.int32) 504 | src_accuracy = tf.reduce_mean(tf.cast(tf.equal(y_src_sparse, y_src_pred), 'float32')) 505 | self.src_accuracy_lst.append(src_accuracy) 506 | # compute acc for target domain 507 | self.y_trg_pred = tf.argmax(self.y_trg_logit, 1, output_type=tf.int32) 508 | self.y_trg_sparse = tf.argmax(self.y_trg, 1, output_type=tf.int32) 509 | self.trg_accuracy = tf.reduce_mean(tf.cast(tf.equal(self.y_trg_sparse, self.y_trg_pred), 'float32')) 510 | # compute acc for src domain disc 511 | self.y_src_domain_pred = tf.argmax(self.y_src_discriminator_logit, 1, output_type=tf.int32) 512 | self.y_src_domain_sparse = tf.argmax(self.y_src_domain, 1, output_type=tf.int32) 513 | self.src_domain_acc = tf.reduce_mean(tf.cast(tf.equal(self.y_src_domain_sparse, self.y_src_domain_pred), 'float32')) 514 | # 515 | 516 | # 517 | lst_losses = [ 518 | (self.src_class_trade_off, self.src_loss_class_sum), 519 | (self.ot_trade_off, self.OT_loss), 520 | (self.domain_trade_off, self.src_loss_discriminator), 521 | (self.trg_ent_troff, self.trg_loss_cond_entropy), 522 | (self.mimic_trade_off, self.mimic_loss) 523 | ] 524 | self.total_loss = tf.constant(0.0) 525 | for trade_off, loss in lst_losses: 526 | self.total_loss += trade_off * loss 527 | # 528 | 529 | # 530 | primary_student_variables = self._get_variables(self._get_student_primary_scopes()) 531 | 532 | self.batch_student_acc = batch_ema_acc(self.y_trg, self.y_trg_logit) 533 | self.fn_batch_student_acc = tb.function(self.tf_session, [self.x_trg, self.y_trg, self.is_training], self.batch_student_acc) 534 | 535 | teacher_variables = self._get_variables(self._get_teacher_scopes()) 536 | pretrained_var_list = [v for v in tf.trainable_variables() if v.name.split('/')[1] in self.train_layers] 537 | self.train_student_main = \ 538 | tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(self.total_loss, 539 | var_list=teacher_variables + [primary_student_variables[1]] + [pretrained_var_list]) 540 | self.primary_train_student_op = tf.group(self.train_student_main) 541 | 542 | secondary_variables = self._get_variables(self._get_student_secondary_scopes()) 543 | self.secondary_train_student_op = \ 544 | tf.train.AdamOptimizer(self.learning_rate, 0.5).minimize(-self.OT_loss, 545 | var_list=secondary_variables) 546 | # 547 | 548 | # 549 | g_h_variables = self._get_variables(self._get_all_g_h()) 550 | self.all_g_h_variables = [] 551 | for s in g_h_variables: 552 | self.all_g_h_variables += s 553 | # 554 | 555 | # 556 | tf.summary.scalar('loss/total_loss', self.total_loss) 557 | tf.summary.scalar('loss/W_distance', self.OT_loss) 558 | tf.summary.scalar('loss/src_loss_discriminator', self.src_loss_discriminator) 559 | tf.summary.scalar('loss/data_shift_loss', tf.reduce_mean(self.data_shift_loss)) 560 | tf.summary.scalar('loss/label_shift_loss', tf.reduce_mean(self.label_shift_loss)) 561 | tf.summary.scalar('loss/data_label_shift_loss', tf.reduce_mean(self.data_label_shift_loss)) 562 | tf.summary.scalar('loss/exp_term', tf.reduce_mean(self.exp_term)) 563 | tf.summary.histogram('loss/g_batch', self.g_network) 564 | tf.summary.scalar('loss/g_network_loss', self.g_network_loss) 565 | 566 | for i in range(self.num_src_domain): 567 | tf.summary.scalar('loss/src_loss_class_{}'.format(i), self.src_loss_class_lst[i]) 568 | tf.summary.scalar('acc/src_acc_{}'.format(i), self.src_accuracy_lst[i]) 569 | tf.summary.scalar('acc/src_domain_acc', self.src_domain_acc) 570 | tf.summary.scalar('acc/trg_acc', self.trg_accuracy) 571 | 572 | tf.summary.scalar('hyperparameters/learning_rate', self.learning_rate) 573 | tf.summary.scalar('hyperparameters/src_class_trade_off', self.src_class_trade_off) 574 | tf.summary.scalar('hyperparameters/g_network_trade_off', self.g_network_trade_off) 575 | tf.summary.scalar('hyperparameters/domain_trade_off', self.domain_trade_off) 576 | tf.summary.scalar('hyperparameters/src_vat_trade_off', self.src_vat_trade_off) 577 | tf.summary.scalar('hyperparameters/trg_vat_troff', self.trg_vat_troff) 578 | tf.summary.scalar('hyperparameters/trg_ent_troff', self.trg_ent_troff) 579 | self.tf_merged_summaries = tf.summary.merge_all() 580 | # 581 | 582 | def _fit_loop(self): 583 | print('Start training MOST at', os.path.abspath(__file__)) 584 | print('============ LOG-ID: %s ============' % self.current_time) 585 | 586 | src_batchsize = self.batch_size // self.num_src_domain 587 | self.tf_session.run(tf.global_variables_initializer()) 588 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=101) 589 | self.log_path = os.path.join(model_dir(), self.model_name, "logs", 590 | "{}".format(self.current_time)) 591 | self.tf_summary_writer = tf.summary.FileWriter(self.log_path, self.tf_session.graph) 592 | 593 | print("Load pretrained AlexNet") 594 | self.alexNet.load_original_weights(self.tf_session, skip_layers=self.train_layers) 595 | 596 | # 597 | print("Load pre-trained shallow network") 598 | graph = tf.Graph() 599 | with graph.as_default(): 600 | with tf.Session() as sess: 601 | checkpoint_path = os.path.join(model_dir(), self.model_name, "saved-model", 602 | "{}".format(self.mdaot_model_id)) 603 | check_point = tf.train.get_checkpoint_state(checkpoint_path) 604 | model_checkpoint_path = os.path.join(checkpoint_path, check_point.model_checkpoint_path.split('/')[-1]) 605 | saver_old = tf.train.import_meta_graph("{}.meta".format(model_checkpoint_path)) 606 | saver_old.restore(sess, model_checkpoint_path) 607 | print("Loaded model parameters from %s\n" % model_checkpoint_path) 608 | 609 | saved_variables = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 610 | saved_values = [sess.run(v) for v in saved_variables] 611 | 612 | g_mean_var = [] 613 | h_mean_var = [] 614 | g_mean_var_id = [0] 615 | h_mean_var_id = [0, 0] 616 | 617 | for i in g_mean_var_id: 618 | mean = graph.get_tensor_by_name("generator/l{}/conv2d/bn/mean:0".format(i)) 619 | var = graph.get_tensor_by_name("generator/l{}/conv2d/bn/var:0".format(i)) 620 | mean, var = sess.run([mean, var]) 621 | g_mean_var.append([mean, var]) 622 | 623 | for i in range(len(h_mean_var_id)): 624 | mean = graph.get_tensor_by_name("classifier/{}/l{}/dense/bn/mean:0".format(i, h_mean_var_id[i])) 625 | var = graph.get_tensor_by_name("classifier/{}/l{}/dense/bn/var:0".format(i, h_mean_var_id[i])) 626 | mean, var = sess.run([mean, var]) 627 | h_mean_var.append([mean, var]) 628 | 629 | for i in range(len(g_mean_var)): 630 | mean = self.tf_graph.get_tensor_by_name("generator/l{}/conv2d/bn/mean:0".format(g_mean_var_id[i])) 631 | var = self.tf_graph.get_tensor_by_name("generator/l{}/conv2d/bn/var:0".format(g_mean_var_id[i])) 632 | self.tf_session.run(tf.assign(mean, g_mean_var[i][0])) 633 | self.tf_session.run(tf.assign(var, g_mean_var[i][1])) 634 | 635 | for i in range(len(h_mean_var)): 636 | mean = self.tf_graph.get_tensor_by_name( 637 | "classifier/{}/l{}/dense/bn/mean:0".format(i, h_mean_var_id[i])) 638 | var = self.tf_graph.get_tensor_by_name( 639 | "classifier/{}/l{}/dense/bn/var:0".format(i, h_mean_var_id[i])) 640 | self.tf_session.run(tf.assign(mean, h_mean_var[i][0])) 641 | self.tf_session.run(tf.assign(var, h_mean_var[i][1])) 642 | 643 | domain_disc_w = self.tf_graph.get_tensor_by_name('domain_disc/l0/dense/weights:0') 644 | domain_disc_b = self.tf_graph.get_tensor_by_name('domain_disc/l0/dense/biases:0') 645 | phi_net_w = self.tf_graph.get_tensor_by_name('phi_net/l0/dense/weights:0') 646 | phi_net_b = self.tf_graph.get_tensor_by_name('phi_net/l0/dense/biases:0') 647 | self.tf_session.run(tf.assign(domain_disc_w, saved_values[16])) 648 | self.tf_session.run(tf.assign(domain_disc_b, saved_values[17])) 649 | self.tf_session.run(tf.assign(phi_net_w, saved_values[18])) 650 | self.tf_session.run(tf.assign(phi_net_b, saved_values[19])) 651 | for i in range(12): 652 | self.tf_session.run(tf.assign(self.all_g_h_variables[i], saved_values[i])) 653 | # 654 | 655 | feed_y_src_domain = to_categorical(np.repeat(np.arange(self.num_src_domain), 656 | repeats=self.sample_size * self.num_classes, axis=0)) 657 | 658 | for it in range(self.num_iters): 659 | feed_data = dict() 660 | for k in range(self.num_src_domain): 661 | feed_data[self.x_src_lst[k]], feed_data[self.y_src_lst[k]] = self.src_preprocessors[k].next_batch( 662 | src_batchsize) 663 | 664 | feed_data[self.x_trg], feed_data[self.y_trg] = self.trg_train_preprocessor.next_batch(self.batch_size) 665 | 666 | feed_data[self.y_src_domain] = feed_y_src_domain 667 | feed_data[self.is_training] = True 668 | 669 | for i in range(0, 5): 670 | g_feed_data = dict() 671 | for k in range(self.num_src_domain): 672 | g_feed_data[self.x_src_lst[k]], g_feed_data[self.y_src_lst[k]] = self.src_preprocessors[k].next_batch( 673 | src_batchsize) 674 | g_feed_data[self.x_trg], g_feed_data[self.y_trg] = self.trg_train_preprocessor.next_batch(self.batch_size) 675 | g_feed_data[self.is_training] = True 676 | 677 | _, W_dist = \ 678 | self.tf_session.run( 679 | [self.secondary_train_student_op, self.OT_loss], 680 | feed_dict=g_feed_data 681 | ) 682 | _, total_loss, src_loss_class_sum, src_loss_class_lst, src_loss_discriminator, src_acc_lst, trg_acc, src_domain_acc, mimic_loss = \ 683 | self.tf_session.run( 684 | [self.primary_train_student_op, self.total_loss, self.src_loss_class_sum, self.src_loss_class_lst, self.src_loss_discriminator, 685 | self.src_accuracy_lst, self.trg_accuracy, self.src_domain_acc, self.mimic_loss], 686 | feed_dict=feed_data 687 | ) 688 | 689 | if it == 0 or (it + 1) % self.summary_freq == 0: 690 | print( 691 | "iter %d/%d total_loss %.3f; src_loss_class_sum %.3f; W_dist %.3f;\n src_loss_discriminator %.3f, pseudo_lbl_loss %.3f" % ( 692 | it + 1, self.num_iters, total_loss, src_loss_class_sum, W_dist, 693 | src_loss_discriminator, mimic_loss)) 694 | for k in range(self.num_src_domain): 695 | print('src_loss_class_{}: {:.3f} acc {:.2f}'.format(k, src_loss_class_lst[k], src_acc_lst[k]*100)) 696 | print("src_domain_disc_acc: %.2f, trg_acc: %.2f;" % (src_domain_acc*100, trg_acc*100)) 697 | 698 | summary = self.tf_session.run(self.tf_merged_summaries, feed_dict=feed_data) 699 | self.tf_summary_writer.add_summary(summary, it + 1) 700 | self.tf_summary_writer.flush() 701 | 702 | if it == 0 or (it + 1) % self.summary_freq == 0: 703 | if not self.only_save_final_model: 704 | self.save_trained_model(saver, it + 1) 705 | elif it + 1 == self.num_iters: 706 | self.save_trained_model(saver, it + 1) 707 | if (it + 1) % (self.num_iters // 50) == 0: 708 | self.save_value(step=it + 1) 709 | 710 | def save_trained_model(self, saver, step): 711 | checkpoint_path = os.path.join(model_dir(), self.model_name, "saved-model", 712 | "{}".format(self.current_time)) 713 | checkpoint_path = os.path.join(checkpoint_path, "mdaot_" + self.current_time + ".ckpt") 714 | 715 | directory = os.path.dirname(checkpoint_path) 716 | if not os.path.exists(directory): 717 | os.makedirs(directory) 718 | saver.save(self.tf_session, checkpoint_path, global_step=step) 719 | 720 | def save_value(self, step): 721 | student_acc, summary = self.compute_value() 722 | 723 | self.tf_summary_writer.add_summary(summary, step) 724 | self.tf_summary_writer.flush() 725 | 726 | print_list = ['test_acc', round(student_acc * 100, 2)] 727 | print(print_list) 728 | 729 | def compute_value(self,): 730 | n = len(self.trg_test_preprocessor.labels) 731 | bs = 200 732 | student_acc_full = np.ones(n, dtype=float) 733 | 734 | for i in range(0, n, bs): 735 | x, y = self.trg_test_preprocessor.next_batch(bs) 736 | student_acc_batch = self.fn_batch_student_acc(x, y, False) 737 | student_acc_full[i:i + bs] = student_acc_batch 738 | 739 | student_acc = np.mean(student_acc_full) 740 | self.trg_test_preprocessor.reset_pointer() 741 | 742 | summary = tf.Summary.Value(tag='trg_test/student_acc', simple_value=student_acc) 743 | summary = tf.Summary(value=[summary]) 744 | 745 | return student_acc, summary 746 | --------------------------------------------------------------------------------