├── README.md ├── datagenerator.py ├── images.py ├── images ├── teaser1.png └── teaser2.png └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Joint-Learning-of-NN 2 | The full paper can be found in here: https://arxiv.org/abs/1905.06526 3 | 4 | Results on Generative Model| Data Social Network 5 | :-------------------------:|:-------------------------: 6 | ![](images/teaser1.png) | ![](images/teaser2.png) 7 | 8 | ## Requirements 9 | 10 | - Python 2.7 11 | - TensorFlow >= 1.2rc0 12 | - Numpy 13 | - Tensorflow slim library [https://github.com/tensorflow/models/tree/master/research/slim] (for inception v3 architecture, you can use custom built network architecture if you have one. Since the five datasets shown in the paper have different number of classes, I added 4 more last layers for inception v3. Please make changes based on your need.) 14 | 15 | ## Content 16 | 17 | - `images.py`: Script to run the training process. 18 | - `utils.py`: Utility functions. 19 | - `datagenerator.py`: Contains a wrapper class for the new input pipeline. 20 | - `images/*`: contains some teaser images from the paper. 21 | 22 | ## Usage 23 | 24 | All you need to touch is the `images.py`. You can configure different parameters in there. You have to provide `.txt` files to the script (`exp_1_train.txt`, `exp_2_train.txt`, .... and `exp_1_test.txt`, `exp_2_test.txt`, .... for different datasets. In the paper, I used five datasets.) Each of them list the complete path to your train/val images together with the class number in the following structure. 25 | 26 | ``` 27 | Example train.txt: 28 | /path/to/train/image1.png 0 29 | /path/to/train/image2.png 1 30 | /path/to/train/image3.png 2 31 | /path/to/train/image4.png 0 32 | . 33 | . 34 | ``` 35 | were the first column is the path and the second the class label. 36 | 37 | In the paper and in the current training script, I used five datasets: 38 | 1. [Caltech-UCSD Birds 200](http://www.vision.caltech.edu/visipedia/CUB-200.html) 39 | 2. [Stanford Dogs Dataset](http://vision.stanford.edu/aditya86/ImageNetDogs/) 40 | 3. [Flower Datasets](http://www.robots.ox.ac.uk/~vgg/data/flowers/) 41 | 4. [Cars Dataset](https://ai.stanford.edu/~jkrause/cars/car_dataset.html) 42 | 5. [FGVC-Aircraft Benchmark](http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/) 43 | -------------------------------------------------------------------------------- /datagenerator.py: -------------------------------------------------------------------------------- 1 | # Created on Wed May 31 14:48:46 2017 2 | # 3 | # @author: Frederik Kratzert 4 | 5 | """Containes a helper class for image input pipelines in tensorflow.""" 6 | ### Borrowed from the repo: https://github.com/kratzert/finetune_alexnet_with_tensorflow 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | from tensorflow.contrib.data import Dataset 12 | from tensorflow.python.framework import dtypes 13 | from tensorflow.python.framework.ops import convert_to_tensor 14 | 15 | VGG_MEAN = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32) 16 | 17 | 18 | class ImageDataGenerator(object): 19 | """Wrapper class around the new Tensorflows dataset pipeline. 20 | 21 | Requires Tensorflow >= version 1.12rc0 22 | """ 23 | 24 | def __init__(self, txt_file, mode, batch_size, num_classes, shuffle=True, 25 | buffer_size=1000): 26 | """Create a new ImageDataGenerator. 27 | 28 | Recieves a path string to a text file, which consists of many lines, 29 | where each line has first a path string to an image and seperated by 30 | a space an integer, referring to the class number. Using this data, 31 | this class will create TensrFlow datasets, that can be used to train 32 | e.g. a convolutional neural network. 33 | 34 | Args: 35 | txt_file: Path to the text file. 36 | mode: Either 'training' or 'validation'. Depending on this value, 37 | different parsing functions will be used. 38 | batch_size: Number of images per batch. 39 | num_classes: Number of classes in the dataset. 40 | shuffle: Wether or not to shuffle the data in the dataset and the 41 | initial file list. 42 | buffer_size: Number of images used as buffer for TensorFlows 43 | shuffling of the dataset. 44 | 45 | Raises: 46 | ValueError: If an invalid mode is passed. 47 | 48 | """ 49 | self.txt_file = txt_file 50 | self.num_classes = num_classes 51 | 52 | # retrieve the data from the text file 53 | self._read_txt_file() 54 | 55 | # number of samples in the dataset 56 | self.data_size = len(self.labels) 57 | 58 | # initial shuffling of the file and label lists (together!) 59 | if shuffle: 60 | self._shuffle_lists() 61 | 62 | # convert lists to TF tensor 63 | self.img_paths = convert_to_tensor(self.img_paths, dtype=dtypes.string) 64 | self.labels = convert_to_tensor(self.labels, dtype=dtypes.int32) 65 | 66 | # create dataset 67 | data = Dataset.from_tensor_slices((self.img_paths, self.labels)) 68 | 69 | # distinguish between train/infer. when calling the parsing functions 70 | if mode == 'training': 71 | data = data.map(self._parse_function_train, num_threads=8, 72 | output_buffer_size=100*batch_size) 73 | 74 | elif mode == 'inference': 75 | data = data.map(self._parse_function_inference, num_threads=8, 76 | output_buffer_size=100*batch_size) 77 | 78 | else: 79 | raise ValueError("Invalid mode '%s'." % (mode)) 80 | 81 | # shuffle the first `buffer_size` elements of the dataset 82 | if shuffle: 83 | data = data.shuffle(buffer_size=buffer_size) 84 | 85 | # create a new dataset with batches of images 86 | data = data.batch(batch_size) 87 | 88 | self.data = data 89 | 90 | def _read_txt_file(self): 91 | """Read the content of the text file and store it into lists.""" 92 | self.img_paths = [] 93 | self.labels = [] 94 | with open(self.txt_file, 'r') as f: 95 | lines = f.readlines() 96 | for line in lines: 97 | items = line.split(' ') 98 | self.img_paths.append(items[0]) 99 | self.labels.append(int(items[1])) 100 | 101 | def _shuffle_lists(self): 102 | """Conjoined shuffling of the list of paths and labels.""" 103 | path = self.img_paths 104 | labels = self.labels 105 | permutation = np.random.permutation(self.data_size) 106 | self.img_paths = [] 107 | self.labels = [] 108 | for i in permutation: 109 | self.img_paths.append(path[i]) 110 | self.labels.append(labels[i]) 111 | 112 | def _parse_function_train(self, filename, label): 113 | """Input parser for samples of the training set.""" 114 | # convert label number into one-hot-encoding 115 | one_hot = tf.one_hot(label, self.num_classes) 116 | 117 | # load and preprocess the image 118 | img_string = tf.read_file(filename) 119 | #img_decoded = tf.image.decode_png(img_string, channels=3) 120 | img_decoded = tf.image.decode_jpeg(img_string, channels=3) 121 | img_resized = tf.image.resize_images(img_decoded, [227, 227]) 122 | #img_resized = tf.image.resize_images(img_decoded, [128, 128]) 123 | """ 124 | Dataaugmentation comes here. 125 | """ 126 | img_centered = tf.subtract(img_resized, VGG_MEAN) 127 | 128 | # RGB -> BGR 129 | img_bgr = img_centered[:, :, ::-1] 130 | 131 | return img_bgr, one_hot 132 | 133 | def _parse_function_inference(self, filename, label): 134 | """Input parser for samples of the validation/test set.""" 135 | # convert label number into one-hot-encoding 136 | one_hot = tf.one_hot(label, self.num_classes) 137 | 138 | # load and preprocess the image 139 | img_string = tf.read_file(filename) 140 | #img_decoded = tf.image.decode_png(img_string, channels=3) 141 | img_decoded = tf.image.decode_jpeg(img_string, channels=3) 142 | img_resized = tf.image.resize_images(img_decoded, [227, 227]) 143 | #img_resized = tf.image.resize_images(img_decoded, [128, 128]) 144 | img_centered = tf.subtract(img_resized, VGG_MEAN) 145 | 146 | # RGB -> BGR 147 | img_bgr = img_centered[:, :, ::-1] 148 | 149 | return img_bgr, one_hot 150 | -------------------------------------------------------------------------------- /images.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import sys 7 | import tempfile 8 | 9 | import tensorflow as tf 10 | 11 | from random import shuffle 12 | 13 | import time 14 | import utils 15 | 16 | 17 | import numpy as np 18 | from datagenerator import ImageDataGenerator 19 | 20 | from nets import inception_v3 21 | from tensorflow.contrib.data import Iterator 22 | 23 | from tensorflow.python import pywrap_tensorflow 24 | 25 | import os 26 | import subprocess 27 | 28 | slim = tf.contrib.slim 29 | 30 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 31 | nvidia_output = subprocess.check_output('nvidia-smi') 32 | if nvidia_output.split("\n")[-3].split(' ')[4] == '1': 33 | print('using GPU 2') 34 | os.environ["CUDA_VISIBLE_DEVICES"]="2" 35 | else: 36 | print('using GPU 1') 37 | os.environ["CUDA_VISIBLE_DEVICES"]="1" 38 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 39 | 40 | FLAGS = None 41 | 42 | def reinit_variables(sess, variable_list): 43 | for i in range(len(variable_list)): 44 | if 'Logits' in variable_list[i].op.name: 45 | sess.run(variable_list[i].initializer) 46 | 47 | def extract_vars(cur, post, checkpoint): 48 | reader = pywrap_tensorflow.NewCheckpointReader(checkpoint) 49 | var_to_shape_map = reader.get_variable_to_shape_map() 50 | var_list = [] 51 | for var in var_to_shape_map: 52 | var_list.append(var.replace(cur, post)) 53 | return var_list 54 | 55 | def main(_): 56 | ## Some hyper parameters 57 | mu = 0.01/2.0 #for l2 58 | t_mu = 10 59 | #mu = 1 #for l1 60 | batch_size = 64#32 61 | display_step = 40 62 | global_step = 30000#15000 63 | save_model = 5000 64 | optimize_w = 5000 65 | w_lambda_update = mu 66 | initw_step = 1500 67 | num_classes = [200, 120, 102, 196, 100] 68 | last_layer_name = "Logits" 69 | var_in_checkpoint = extract_vars("InceptionV3", "model0", FLAGS.pre_train) 70 | 71 | # Import data 72 | # Multiple dataset 73 | data = [] 74 | testdata = [] 75 | train_init_op = [] 76 | test_init_op = [] 77 | iterator = [] 78 | next_batch = [] 79 | #XXX 80 | ### Create your own data files 81 | ### You can use the naming stand as exp_1_train.txt for the first dataset 82 | ### and exp_1_test.txt for the testing data used for the first dataset 83 | with tf.device('/cpu:0'): 84 | for i in range(int(FLAGS.num_data)): 85 | data.append(ImageDataGenerator(FLAGS.data_dir+"exp_"+str(i+1)+"_train.txt", 86 | mode='training', 87 | batch_size=batch_size, 88 | num_classes=num_classes[i], 89 | shuffle=True)) 90 | testdata.append(ImageDataGenerator(FLAGS.data_dir+"exp_"+str(i+1)+"_test.txt", 91 | mode='inference', 92 | batch_size=batch_size, 93 | num_classes=num_classes[i], 94 | shuffle=False)) 95 | 96 | iterator.append(Iterator.from_structure(data[i].data.output_types, data[i].data.output_shapes)) 97 | next_batch.append(iterator[i].get_next()) 98 | train_init_op.append(iterator[i].make_initializer(data[i].data)) 99 | test_init_op.append(iterator[i].make_initializer(testdata[i].data)) 100 | 101 | # Create the model 102 | x_list = [] 103 | for i in range(1): 104 | x_list.append(tf.placeholder(tf.float32, [None, 299, 299, 3], name="data"+str(i))) 105 | 106 | # Define loss and optimizer 107 | y_losses = [] 108 | for i in range(int(FLAGS.num_data)): 109 | y_losses.append(tf.placeholder(tf.float32, [None, num_classes[i]], name="loss"+str(i))) 110 | 111 | # Build the graph for the deep net 112 | # Multiple networks 113 | variable_list = [] 114 | layer_list = [] 115 | test_layer_list = [] 116 | paired_loss = [] 117 | data_loss = [] 118 | optimizer = [] 119 | paired_optimizer = [] 120 | accuracy_list = [] 121 | test_accuracy_list = [] 122 | joint_optimizer = [] 123 | joint_loss = [] 124 | isolated_optimizer = [] 125 | naive_joint_optimizer = [] 126 | naive_joint_loss = [] 127 | winitial_optimizer = [] 128 | model_list = [] 129 | shared_variable_list = [] 130 | stored_vars = [] 131 | update_ops = [] 132 | ### Create the model only once 133 | for i in range(1): 134 | with slim.arg_scope(inception_v3.inception_v3_arg_scope()): 135 | net,_ = inception_v3.inception_v3(x_list[i], num_classes=num_classes, scope='model'+str(i), create_aux_logits=False) 136 | testnet,_ = inception_v3.inception_v3(x_list[i], num_classes=num_classes, scope='model'+str(i), is_training=False, reuse=True, create_aux_logits=False) 137 | layer_list = net 138 | test_layer_list = testnet 139 | 140 | ## Special for recnet to remove the bias layers 141 | templist = [] 142 | org_vars = slim.get_trainable_variables(scope='model'+str(i)) 143 | update_ops.append(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='model'+str(i))) 144 | for var in org_vars: 145 | if var.op.name in var_in_checkpoint: 146 | templist.append(var) 147 | variable_list.append(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='model'+str(i))) 148 | stored_vars.append(templist) 149 | 150 | temp = [] 151 | for var in variable_list[i]: 152 | if last_layer_name not in var.op.name and "BatchNorm" not in var.op.name: 153 | temp.append(var) 154 | shared_variable_list.append(temp) 155 | 156 | #XXX 157 | # claim just variables 158 | copy_to_model0_op = [] 159 | copy_from_model0_op = [] 160 | model0_vars = variable_list[0]#tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='model0') 161 | ### Implement a trick here to create multiple copies of training parameters 162 | for i in range(1, int(FLAGS.num_data)+1): 163 | variable_list_i = [] 164 | shared_variable_list_i = [] 165 | copy_from_model0_op_i = [] 166 | copy_to_model0_op_i = [] 167 | for var in model0_vars: 168 | new_var_name = var.name.replace('model0', 'model%d' % i) 169 | if var in tf.trainable_variables(): 170 | trainable = True 171 | else: 172 | trainable = False 173 | new_var = tf.get_variable(new_var_name.split(':')[0], shape=var.shape, 174 | dtype=var.dtype, trainable=trainable) 175 | if last_layer_name not in new_var.op.name and "BatchNorm" not in var.op.name: 176 | shared_variable_list_i.append(new_var) 177 | variable_list_i.append(new_var) 178 | copy_from_model0_op_i.append(new_var.assign(var)) 179 | copy_to_model0_op_i.append(var.assign(new_var)) 180 | shared_variable_list.append(shared_variable_list_i) 181 | variable_list.append(variable_list_i) 182 | copy_to_model0_op.append(copy_to_model0_op_i) 183 | copy_from_model0_op.append(copy_from_model0_op_i) 184 | 185 | var_loss = [] 186 | for k in range(len(shared_variable_list[0])): 187 | temp1 = [] 188 | for i in range(int(FLAGS.num_data)): 189 | temp2 = [] 190 | for j in range(int(FLAGS.num_data)): 191 | temp2.append(0) 192 | temp1.append(temp2) 193 | var_loss.append(temp1) 194 | 195 | ## Savers 196 | saver = tf.train.Saver() 197 | ## For model1 198 | pre_saver = [] 199 | pretrain = {} 200 | for i in range(len(stored_vars[0])): 201 | if last_layer_name not in stored_vars[0][i].op.name: 202 | org_name = stored_vars[0][i].op.name.replace("model0", "InceptionV3") 203 | pretrain[org_name] = stored_vars[0][i] 204 | pre_saver = tf.train.Saver(pretrain) 205 | 206 | weight_graph = tf.placeholder(tf.float32, [len(shared_variable_list[0]), int(FLAGS.num_data)]) 207 | weight_scale = tf.placeholder(tf.float32, [len(shared_variable_list[0])]) 208 | w_lambda = tf.placeholder(tf.float32) 209 | 210 | for i in range(int(FLAGS.num_data)): 211 | with tf.variable_scope('data_loss'+str(i)): 212 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_losses[i], logits=layer_list[i]) 213 | data_loss.append(tf.reduce_mean(cross_entropy)) 214 | 215 | for i in range(int(FLAGS.num_data)): 216 | ## Add the pairwise training loss and optimizer 217 | for j in range(i+1, int(FLAGS.num_data)): 218 | with tf.name_scope('paired_weight_loss'+str(i)+str(j)): 219 | w_loss = 0 220 | for n_var in range(len(shared_variable_list[0])):## Same model 221 | cur_var_loss = tf.nn.l2_loss(shared_variable_list[i+1][n_var] - shared_variable_list[j+1][n_var]) 222 | var_loss[n_var][i][j] = cur_var_loss 223 | var_loss[n_var][j][i] = cur_var_loss 224 | 225 | ## Add the joint loss here 226 | winit_losses = [] 227 | for i in range(1): 228 | with tf.name_scope('joint'+str(i)): 229 | w_loss = 0 230 | naive_joint = 0 231 | for j in range(1, int(FLAGS.num_data)+1): 232 | if i != j: 233 | for n_var in range(len(shared_variable_list[0])):## Same model 234 | if FLAGS.norm == "l1": 235 | w_loss += weight_graph[n_var][j-1]*tf.reduce_mean(tf.abs(shared_variable_list[i][n_var] - shared_variable_list[j][n_var])) 236 | naive_joint += tf.reduce_mean(tf.abs(shared_variable_list[i][n_var] - shared_variable_list[j][n_var])) 237 | else: 238 | w_loss += weight_graph[n_var][j-1]*tf.nn.l2_loss(shared_variable_list[i][n_var] - shared_variable_list[j][n_var]) * (1.0 / weight_scale[n_var]) 239 | naive_joint += tf.nn.l2_loss(shared_variable_list[i][n_var] - shared_variable_list[j][n_var]) * (1.0 / weight_scale[n_var]) 240 | 241 | w_loss *= w_lambda 242 | winit = naive_joint * w_lambda 243 | naive_joint *= w_lambda 244 | winit_losses.append(winit) 245 | for k in range(int(FLAGS.num_data)): 246 | joint_loss.append(data_loss[k]+w_loss) 247 | naive_joint_loss.append(data_loss[k]+naive_joint) 248 | 249 | if FLAGS.stage == 0: 250 | ### Stage for pre pairwise training 251 | for k in range(int(FLAGS.num_data)): 252 | with tf.variable_scope('Moment_winitial%d' % k): 253 | with tf.control_dependencies(update_ops[0]): 254 | winitial_optimizer.append(tf.train.MomentumOptimizer(1e-2, momentum=0.9).minimize(data_loss[k]+winit, var_list=variable_list[i])) 255 | if FLAGS.stage == 1: 256 | ### Stage for joint training 257 | for k in range(int(FLAGS.num_data)): 258 | num_train = int(np.floor(data[k].data_size / batch_size)) 259 | with tf.variable_scope('Moment_joint%d' % k): 260 | global_step_iso = tf.Variable(0, trainable=False) 261 | starter_learning_rate = 0.01 262 | learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step_iso, num_train*60, 0.1, staircase=True) 263 | with tf.control_dependencies(update_ops[0]): 264 | joint_optimizer.append(tf.train.MomentumOptimizer(learning_rate, momentum=0.9).minimize(joint_loss[k], var_list=variable_list[i], global_step=global_step_iso)) 265 | if FLAGS.stage == 2: 266 | ### Stage for isolated training 267 | for k in range(int(FLAGS.num_data)): 268 | num_train = int(np.floor(data[k].data_size / batch_size)) 269 | with tf.variable_scope('Moment_isolated%d' % k): 270 | global_step_iso = tf.Variable(0, trainable=False) 271 | starter_learning_rate = 0.01 272 | learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step_iso, num_train*60, 0.1, staircase=True) 273 | with tf.control_dependencies(update_ops[0]): 274 | isolated_optimizer.append(tf.train.MomentumOptimizer(learning_rate, momentum=0.9).minimize(data_loss[k], var_list=variable_list[i], global_step=global_step_iso)) 275 | if FLAGS.stage == 3: 276 | ### Stage for naive joint 277 | for k in range(int(FLAGS.num_data)): 278 | num_train = int(np.floor(data[k].data_size / batch_size)) 279 | with tf.variable_scope('Moment_naive%d' % k): 280 | global_step_iso = tf.Variable(0, trainable=False) 281 | starter_learning_rate = 0.01 282 | learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step_iso, num_train*60, 0.1, staircase=True) 283 | with tf.control_dependencies(update_ops[0]): 284 | naive_joint_optimizer.append(tf.train.MomentumOptimizer(learning_rate, momentum=0.9).minimize(naive_joint_loss[k], var_list=variable_list[i], global_step=global_step_iso)) 285 | 286 | for i in range(int(FLAGS.num_data)): 287 | with tf.name_scope('accuracy'+str(i)): 288 | correct_prediction = tf.equal(tf.argmax(layer_list[i], 1), tf.argmax(y_losses[i], 1)) 289 | correct_prediction = tf.cast(correct_prediction, tf.float32) 290 | accuracy = tf.reduce_mean(correct_prediction) 291 | accuracy_list.append(accuracy) 292 | 293 | test_correct_prediction = tf.equal(tf.argmax(test_layer_list[i], 1), tf.argmax(y_losses[i], 1)) 294 | test_correct_prediction = tf.cast(test_correct_prediction, tf.float32) 295 | test_accuracy = tf.reduce_mean(test_correct_prediction) 296 | test_accuracy_list.append(test_accuracy) 297 | 298 | graph_location = FLAGS.log_location + "log_graph" 299 | print('Saving graph to: %s' % graph_location) 300 | train_writer = tf.summary.FileWriter(graph_location) 301 | train_writer.add_graph(tf.get_default_graph()) 302 | print('done') 303 | """ 304 | training 305 | """ 306 | with tf.Session() as sess: 307 | sess.run(tf.global_variables_initializer()) 308 | w_scale = [] 309 | for i in range(len(shared_variable_list[0])): #get the weight scale 310 | w_scale.append(1.0) 311 | 312 | distMat = [] 313 | for k in range(len(shared_variable_list[0])): 314 | temp1 = [] 315 | for i in range(int(FLAGS.num_data)): 316 | temp2 = [] 317 | for j in range(int(FLAGS.num_data)): 318 | temp2.append(0) 319 | temp1.append(temp2) 320 | distMat.append(temp1) 321 | distMat_test = [] 322 | for i in range(int(FLAGS.num_data)): 323 | temp2 = [] 324 | for j in range(int(FLAGS.num_data)): 325 | temp2.append(0) 326 | distMat_test.append(temp2) 327 | 328 | ##First we initialize the bias 329 | ##Each train 1000 epochs 330 | data_sum = 0 331 | w_sum = 0 332 | w_lambda_update = mu 333 | """ ==== """ 334 | if FLAGS.stage == 0: 335 | print ('Initializeing w') 336 | for i in range(initw_step): 337 | for j in range(int(FLAGS.num_data)): 338 | """ get data""" 339 | #XXX 340 | x = data[j] 341 | numtrain = int(np.floor(x.data_size / batch_size)) 342 | sess.run(copy_to_model0_op[j]) 343 | #numtrain = 1 344 | if i % numtrain == 0: 345 | sess.run(train_init_op[j]) 346 | 347 | x_batch_images, x_batch_labels = sess.run(next_batch[j]) 348 | 349 | """ get weight """ 350 | winitial_optimizer[j].run(feed_dict={x_list[0]: x_batch_images, 351 | y_losses[j]: x_batch_labels, w_lambda: w_lambda_update, weight_scale: w_scale}) 352 | if i % display_step == 0: 353 | train_accuracy = accuracy_list[j].eval(feed_dict={x_list[0]: x_batch_images, y_losses[j]: x_batch_labels}) 354 | dataloss = data_loss[j].eval(feed_dict={x_list[0]: x_batch_images, y_losses[j]: x_batch_labels}) 355 | ylabel = layer_list[j].eval(feed_dict={x_list[0]: x_batch_images}) 356 | wloss = winit_losses[0].eval(feed_dict={w_lambda: w_lambda_update, weight_scale: w_scale}) 357 | print('Epoch %g, dataset %g, training accuracy %g, data loss %g, wloss %g' % (i, j+1, train_accuracy, dataloss, wloss)) 358 | sys.stdout.flush() 359 | if (i+1) % numtrain == 0: 360 | test_accuracy = 0 361 | sess.run(test_init_op[j]) 362 | numtest = int(np.floor(testdata[j].data_size / batch_size)) 363 | for iter1 in range(numtest): 364 | x_batch_images, x_batch_labels = sess.run(next_batch[j]) 365 | 366 | test_accuracy += test_accuracy_list[j].eval(feed_dict={ 367 | x_list[0]: x_batch_images, y_losses[j]: x_batch_labels}) 368 | test_accuracy /= numtest 369 | print('Epoch %g, dataset %g, test accuracy %g' % (i, j+1, test_accuracy)) 370 | sys.stdout.flush() 371 | sess.run(copy_from_model0_op[j]) 372 | 373 | saver.save(sess, FLAGS.log_location+"pre_joint_model"+str(i+1)+".ckpt") 374 | 375 | ## Fill in the data into distMat 376 | for i in range(int(FLAGS.num_data)): 377 | ## Add the pairwise training loss and optimizer 378 | for j in range(i+1, int(FLAGS.num_data)): 379 | for n_var in range(len(shared_variable_list[i])): 380 | print ("for var "+shared_variable_list[i][n_var].op.name) 381 | sys.stdout.flush() 382 | distMat[n_var][i][j] = var_loss[n_var][i][j].eval() 383 | print ("we get first "+str(var_loss[n_var][j][i].eval())) 384 | sys.stdout.flush() 385 | distMat[n_var][j][i] = var_loss[n_var][j][i].eval() 386 | print ("we get second "+str(var_loss[n_var][j][i].eval())) 387 | sys.stdout.flush() 388 | 389 | ##Optimize the dist matrix here 390 | print (distMat) 391 | np.save("tempdist", distMat) 392 | w_opt = utils.optimizeW(distMat, int(FLAGS.num_data)) 393 | np.save("tempw", w_opt) 394 | else: 395 | """ Simply load bias and scale here""" 396 | w_opt = np.load("tempw.npy") 397 | 398 | ##Alternating minization here 399 | if FLAGS.stage == 1: 400 | ## Clear the dataset 401 | for i in range(1, int(FLAGS.num_data)+1): 402 | pre_saver.restore(sess, FLAGS.pre_train) 403 | sess.run(copy_from_model0_op[i-1]) 404 | reinit_variables(sess, variable_list[i]) 405 | 406 | print('Begin joint training!') 407 | data_sum = 0 408 | w_sum = 0 409 | w_lambda_update = mu 410 | for i in range(global_step): 411 | for j in range(int(FLAGS.num_data)): 412 | x = data[j] 413 | numtrain = int(np.floor(x.data_size / batch_size)) 414 | sess.run(copy_to_model0_op[j]) 415 | if i % numtrain == 0: 416 | sess.run(train_init_op[j]) 417 | #XXX 418 | x_batch_images, x_batch_labels = sess.run(next_batch[j]) 419 | if i % display_step == 0: 420 | train_accuracy = accuracy_list[j].eval(feed_dict={x_list[0]: x_batch_images, y_losses[j]: x_batch_labels}) 421 | jloss = joint_loss[j].eval(feed_dict={x_list[0]: x_batch_images, 422 | y_losses[j]: x_batch_labels, weight_graph: w_opt[:, j, :], w_lambda: w_lambda_update, weight_scale: w_scale}) 423 | dataloss = data_loss[j].eval(feed_dict={x_list[0]: x_batch_images, y_losses[j]: x_batch_labels}) 424 | w_sum += jloss - dataloss 425 | data_sum += dataloss 426 | print('Epoch %g, dataset %g, training accuracy %g, joint_loss %g, data loss %g' % (i, j+1, train_accuracy, jloss, dataloss)) 427 | sys.stdout.flush() 428 | joint_optimizer[j].run(feed_dict={x_list[0]: x_batch_images, y_losses[j]: x_batch_labels, weight_graph: w_opt[:, j, :], w_lambda: w_lambda_update, weight_scale: w_scale}) 429 | if (i+1) % numtrain == 0: 430 | test_accuracy = 0 431 | sess.run(test_init_op[j]) 432 | numtest = int(np.floor(testdata[j].data_size / batch_size)) 433 | for iter1 in range(numtest): 434 | x_batch_images, x_batch_labels = sess.run(next_batch[j]) 435 | test_accuracy += test_accuracy_list[j].eval(feed_dict={ 436 | x_list[0]: x_batch_images, y_losses[j]: x_batch_labels}) 437 | test_accuracy /= numtest 438 | print('Epoch %g, dataset %g, test accuracy %g' % (i, j+1, test_accuracy)) 439 | sys.stdout.flush() 440 | sess.run(copy_from_model0_op[j]) 441 | if (i+1) % save_model == 0: 442 | saver.save(sess, FLAGS.log_location+"joint_model"+str(i+1)+".ckpt") 443 | 444 | ## Begin isolated training 445 | """ ==== """ 446 | if FLAGS.stage == 2: 447 | print ('Isolated training begin') 448 | data_sum = 0 449 | w_sum = 0 450 | w_lambda_update = mu 451 | 452 | for i in range(1, int(FLAGS.num_data)+1): 453 | pre_saver.restore(sess, FLAGS.pre_train) 454 | sess.run(copy_from_model0_op[i-1]) 455 | reinit_variables(sess, variable_lit[i]) 456 | train_accuracy = [0]*int(FLAGS.num_data) 457 | for i in range(global_step): 458 | for j in range(int(FLAGS.num_data)): 459 | x = data[j] 460 | numtrain = int(np.floor(x.data_size / batch_size)) 461 | if i % numtrain == 0: 462 | sess.run(train_init_op[j]) 463 | sess.run(copy_to_model0_op[j]) 464 | x_batch_images, x_batch_labels = sess.run(next_batch[j]) 465 | train_accuracy[j] += accuracy_list[j].eval(feed_dict={x_list[0]: x_batch_images, y_losses[j]: x_batch_labels}) 466 | dataloss = data_loss[j].eval(feed_dict={x_list[0]: x_batch_images, y_losses[j]: x_batch_labels}) 467 | isolated_optimizer[j].run(feed_dict={x_list[0]: x_batch_images, y_losses[j]: x_batch_labels}) 468 | if (i+1) % numtrain == 0: 469 | test_accuracy = 0 470 | print('Epoch %g, dataset %g, training accuracy %g, data loss %g' % (i, j+1, train_accuracy[j]/numtrain, dataloss)) 471 | sys.stdout.flush() 472 | train_accuracy[j] = 0 473 | sess.run(test_init_op[j]) 474 | numtest = int(np.floor(testdata[j].data_size / batch_size)) 475 | for iter1 in range(numtest): 476 | x_batch_images, x_batch_labels = sess.run(next_batch[j]) 477 | test_accuracy += test_accuracy_list[j].eval(feed_dict={ 478 | x_list[0]: x_batch_images, y_losses[j]: x_batch_labels}) 479 | test_accuracy /= numtest 480 | print('Epoch %g, dataset %g, test accuracy %g' % (i, j+1, test_accuracy)) 481 | sys.stdout.flush() 482 | sess.run(copy_from_model0_op[j]) 483 | 484 | if (i+1) % save_model == 0: 485 | saver.save(sess, FLAGS.log_location+"isolate_model"+str(i+1)+".ckpt") 486 | 487 | ### naive joint training 488 | """ ==== """ 489 | if FLAGS.stage == 3: 490 | print ('Naive joint training begin') 491 | data_sum = 0 492 | w_sum = 0 493 | w_lambda_update = mu 494 | for i in range(1, int(FLAGS.num_data)+1): 495 | pre_saver.restore(sess, FLAGS.pre_train) 496 | sess.run(copy_from_model0_op[i-1]) 497 | reinit_variables(sess, variable_list[i]) 498 | 499 | for i in range(global_step): 500 | for j in range(int(FLAGS.num_data)): 501 | x = data[j] 502 | numtrain = int(np.floor(x.data_size / batch_size)) 503 | if i % numtrain == 0: 504 | sess.run(train_init_op[j]) 505 | x_batch_images, x_batch_labels = sess.run(next_batch[j]) 506 | naive_joint_optimizer[j].run(feed_dict={x_list[0]: x_batch_images, y_losses[j]: x_batch_labels, w_lambda: w_lambda_update, weight_scale: w_scale}) 507 | if i % display_step == 0: 508 | train_accuracy = accuracy_list[j].eval(feed_dict={x_list[0]: x_batch_images, y_losses[j]: x_batch_labels}) 509 | dataloss = data_loss[j].eval(feed_dict={x_list[0]: x_batch_images, y_losses[j]: x_batch_labels}) 510 | print('Epoch %g, dataset %g, training accuracy %g, data loss %g' % (i, j+1, train_accuracy, dataloss)) 511 | sys.stdout.flush() 512 | if (i+1) % numtrain == 0: 513 | test_accuracy = 0 514 | sess.run(test_init_op[j]) 515 | numtest = int(np.floor(testdata[j].data_size / batch_size)) 516 | for iter1 in range(numtest): 517 | x_batch_images, x_batch_labels = sess.run(next_batch[j]) 518 | test_accuracy += test_accuracy_list[j].eval(feed_dict={ 519 | x_list[0]: x_batch_images, y_losses[j]: x_batch_labels}) 520 | test_accuracy /= numtest 521 | print('Epoch %g, dataset %g, test accuracy %g' % (i, j+1, test_accuracy)) 522 | sys.stdout.flush() 523 | if (i+1) % save_model == 0: 524 | saver.save(sess, FLAGS.log_location+"naive_joint_model"+str(i+1)+".ckpt") 525 | 526 | 527 | if __name__ == '__main__': 528 | parser = argparse.ArgumentParser() 529 | parser.add_argument('--data_dir', type=str, 530 | default='/tmp/tensorflow/mnist/input_data', 531 | help='Directory for storing input data') 532 | parser.add_argument('--num_data', type=str, 533 | default='5', 534 | help='Number of datasets') 535 | parser.add_argument('--log_location', type=str, 536 | default='logs', 537 | help='Directory for storing the log files') 538 | parser.add_argument('--norm', type=str, 539 | default='l2', 540 | help='Type of norm between variables') 541 | parser.add_argument('--pre_train', type=str, 542 | default='logs/checkpoint', 543 | help='checkpoint path') 544 | parser.add_argument('--stage', type=int, 545 | default='0', 546 | help='0:winitial, 1:joint, 2:isolated, 3:naive') 547 | FLAGS, unparsed = parser.parse_known_args() 548 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 549 | -------------------------------------------------------------------------------- /images/teaser1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zaiweizhang/Joint-Learning-of-NN/9b1bceb986670a1532b6ee217e9b3add25ea22b4/images/teaser1.png -------------------------------------------------------------------------------- /images/teaser2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zaiweizhang/Joint-Learning-of-NN/9b1bceb986670a1532b6ee217e9b3add25ea22b4/images/teaser2.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from cvxopt import matrix, solvers 3 | 4 | def getImage(listofDict): 5 | images = [] 6 | for data in listofDict: 7 | images.append(data['image']) 8 | return images 9 | 10 | def getLabel(listofDict): 11 | labels = [] 12 | for data in listofDict: 13 | labels.append(data['label']) 14 | return np.squeeze(labels) 15 | 16 | def digit_data(data_path): 17 | inputdata = sio.loadmat(data_path) 18 | train = inputdata['digitdata'] 19 | test = inputdata['testdata'] 20 | training_data = [] 21 | for i in range(len(train)): 22 | dataset = [] 23 | for j in range(25): 24 | for k in range(len(train[0])): 25 | data = {} 26 | data['image'] = train[i,k,j,:,:]/255.0 27 | temp = np.zeros([1, 10]) 28 | temp[0][k] = 1 29 | data['label'] = temp 30 | dataset.append(data) 31 | training_data.append(dataset) 32 | ## Split the test dataset 33 | testing_data = [] 34 | for k in range(0,5,2): 35 | dataset = [] 36 | for i in range(100*(k+1), (k+3)*100): 37 | for j in range(len(test)): 38 | data = {} 39 | data['image'] = test[j,i,:,:]/255.0 40 | temp = np.zeros([1, 10]) 41 | temp[0][j] = 1 42 | data['label'] = temp 43 | dataset.append(data) 44 | testing_data.append(dataset) 45 | 46 | for k in range(2): 47 | dataset = [] 48 | for i in range(0, 100): 49 | for j in range(len(test)): 50 | data = {} 51 | data['image'] = test[j,i,:,:]/255.0 52 | temp = np.zeros([1, 10]) 53 | temp[0][j] = 1 54 | data['label'] = temp 55 | dataset.append(data) 56 | # Additional test data 57 | for i in range(400, 500): 58 | for j in range(len(test)): 59 | data = {} 60 | data['image'] = train[4,j,i,:,:]/255.0 61 | temp = np.zeros([1, 10]) 62 | temp[0][j] = 1 63 | data['label'] = temp 64 | dataset.append(data) 65 | testing_data.append(dataset) 66 | 67 | print (len(training_data[0])) 68 | print (len(testing_data)) 69 | print (len(testing_data[4])) 70 | return training_data, testing_data 71 | 72 | def optimizePSD(distMat, numdata): 73 | G0 = matrix(np.concatenate([-np.eye(numdata*numdata), np.eye(numdata*numdata)])) 74 | h0 = matrix(np.concatenate([np.zeros([numdata*numdata, 1]), np.ones([numdata*numdata, 1])])) 75 | wopt = [] 76 | for nvar in range(len(distMat)): 77 | c = np.reshape(distMat[nvar], [numdata*numdata, 1]) 78 | stat = np.sort(np.reshape(c, [numdata*numdata])) 79 | c = matrix(c) - np.median(c) 80 | sol = solvers.sdp(c, Gl=G0, hl=h0) 81 | #print (np.round((np.reshape(sol['x'], [numdata, numdata])), 3)) 82 | wopt.append(np.round(np.reshape(sol['x'], [numdata, numdata]), 3)) 83 | return wopt 84 | --------------------------------------------------------------------------------