├── README.md ├── dataprocessor.py ├── log ├── Readme.md ├── checkpoint ├── mstnmodel_amazo_to_webcam_final2967.ckpt.index └── tensorboardamazo_to_webcam_final │ └── events.out.tfevents.1573467756.smartdsp-PC ├── model.py ├── pseudo.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # PFAN 2 | Code for CVPR-2019 paper "Progressive Feature Alignment for Unsupervised Domain Adaptation",We will release a journal version code which further improves the reported results in our paper.We will keep updating this code. 3 | 4 | Prerequisites: 5 | ============= 6 | Python2/Python3 7 | Tensorflow 1.10 8 | Numpy 9 | 10 | Dataset: 11 | ======= 12 | You need to download the domain_adaptation_images dataset for test. 13 | 14 | Training: 15 | ======== 16 | 1.run 'train.py' to get the prototype vector 17 | 2.run 'pseudo.py' to get the new train dataset 18 | 3.execute 1&2 alternatively and iteratively 19 | 20 | 21 | Citation: 22 | ======== 23 | If you use this code for your research, please consider citing: 24 | @InProceedings{PFAN_2019_CVPR, 25 | author = {Chen, Chaoqi and Xie, Weiping and Huang, Wenbing and Rong, Yu and Ding, Xinghao and Huang, Yue and Xu, Tingyang and Huang, Junzhou}, 26 | title = {Progressive Feature Alignment for Unsupervised Domain Adaptation}, 27 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 28 | year = {2019} 29 | } 30 | 31 | Contact: 32 | ======== 33 | If you have any problem about our code, feel free to contact Xiewp67@stu.xmu.edu.cn. 34 | -------------------------------------------------------------------------------- /dataprocessor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | class BatchPreprocessor(object): 6 | 7 | def __init__(self, dataset_file_path, num_classes, output_size=[227, 227], horizontal_flip=False, shuffle=False, 8 | mean_color=[104.0069879317889,116.66876761696767,122.6789143406786], multi_scale=None,istraining=True): 9 | self.num_classes = num_classes 10 | self.output_size = output_size 11 | self.horizontal_flip = horizontal_flip 12 | self.shuffle = shuffle 13 | self.mean_color = mean_color 14 | self.multi_scale = multi_scale 15 | self.istraining=istraining 16 | self.pointer = 0 17 | self.images = [] 18 | self.labels = [] 19 | # Read the dataset file 20 | dataset_file = open(dataset_file_path) 21 | lines = dataset_file.readlines() 22 | self.classpaths=[] 23 | for i in xrange(31): 24 | self.classpaths.append([]) 25 | for line in lines: 26 | items = line.split() 27 | self.images.append(items[0]) 28 | self.labels.append(int(items[1])) 29 | for j in xrange(len(self.labels)): 30 | label=self.labels[j] 31 | self.classpaths[label].append(j) 32 | 33 | # Shuffle the data 34 | if self.shuffle: 35 | self.shuffle_data() 36 | 37 | def shuffle_data(self): 38 | images = self.images[:] 39 | labels = self.labels[:] 40 | self.images = [] 41 | self.labels = [] 42 | 43 | idx = np.random.permutation(len(labels)) 44 | for i in idx: 45 | self.images.append(images[i]) 46 | self.labels.append(labels[i]) 47 | 48 | def reset_pointer(self): 49 | self.pointer = 0 50 | if self.shuffle: 51 | self.shuffle_data() 52 | def class_next_batch(self,num_per_class): 53 | batch_size=31*num_per_class 54 | ids=[] 55 | for i in xrange(31): 56 | ids+=np.random.choice(self.classpaths[i],size=num_per_class,replace=False).tolist() 57 | selfimages=np.array(self.images) 58 | selflabels=np.array(self.labels) 59 | paths=selfimages[ids] 60 | labels=selflabels[ids] 61 | # Read images 62 | images = np.ndarray([num_per_class*31, self.output_size[0], self.output_size[1], 3]) 63 | for i in range(len(paths)): 64 | img = cv2.imread(paths[i]) 65 | # Flip image at random if flag is selected 66 | if self.horizontal_flip and np.random.random() < 0.5: 67 | img = cv2.flip(img, 1) 68 | if self.multi_scale is None: 69 | # Resize the image for output 70 | img = cv2.resize(img, (self.output_size[0], self.output_size[0])) 71 | img = img.astype(np.float32) 72 | elif isinstance(self.multi_scale, list): 73 | # Resize to random scale 74 | new_size = np.random.randint(self.multi_scale[0], self.multi_scale[1], 1)[0] 75 | img = cv2.resize(img, (new_size, new_size)) 76 | img = img.astype(np.float32) 77 | if new_size!=self.output_size[0]: 78 | if self.istraining: 79 | # random crop at output size 80 | diff_size = new_size - self.output_size[0] 81 | random_offset_x = np.random.randint(0, diff_size, 1)[0] 82 | random_offset_y = np.random.randint(0, diff_size, 1)[0] 83 | img = img[random_offset_x:(random_offset_x+self.output_size[0]), 84 | random_offset_y:(random_offset_y+self.output_size[0])] 85 | else: 86 | y,x,_=img.shape 87 | startx=x//2-self.output_size[0]//2 88 | starty=y//2-self.output_size[1]//2 89 | img=img[starty:starty+self.output_size[0],startx:startx+self.output_size[1]] 90 | # Subtract mean color 91 | img -= np.array(self.mean_color) 92 | 93 | images[i] = img 94 | # Expand labels to one hot encoding 95 | one_hot_labels = np.zeros((batch_size, self.num_classes)) 96 | for i in range(len(labels)): 97 | one_hot_labels[i][labels[i]] = 1 98 | # Return array of images and labels 99 | return images, one_hot_labels 100 | 101 | def next_batch(self, batch_size): 102 | # Get next batch of image (path) and labels 103 | paths = self.images[self.pointer:(self.pointer+batch_size)] 104 | labels = self.labels[self.pointer:(self.pointer+batch_size)] 105 | 106 | # Update pointer 107 | self.pointer += batch_size 108 | 109 | # Read images 110 | images = np.ndarray([batch_size, self.output_size[0], self.output_size[1], 3]) 111 | for i in range(len(paths)): 112 | img = cv2.imread(paths[i]) 113 | # Flip image at random if flag is selected 114 | if self.horizontal_flip and np.random.random() < 0.5: 115 | img = cv2.flip(img, 1) 116 | 117 | if self.multi_scale is None: 118 | # Resize the image for output 119 | img = cv2.resize(img, (self.output_size[0], self.output_size[0])) 120 | img = img.astype(np.float32) 121 | elif isinstance(self.multi_scale, list): 122 | # Resize to random scale 123 | new_size = np.random.randint(self.multi_scale[0], self.multi_scale[1], 1)[0] 124 | img = cv2.resize(img, (new_size, new_size)) 125 | img = img.astype(np.float32) 126 | if new_size!=self.output_size[0]: 127 | if self.istraining: 128 | # random crop at output size 129 | diff_size = new_size - self.output_size[0] 130 | random_offset_x = np.random.randint(0, diff_size, 1)[0] 131 | random_offset_y = np.random.randint(0, diff_size, 1)[0] 132 | img = img[random_offset_x:(random_offset_x+self.output_size[0]), 133 | random_offset_y:(random_offset_y+self.output_size[0])] 134 | else: 135 | y,x,_=img.shape 136 | startx=x//2-self.output_size[0]//2 137 | starty=y//2-self.output_size[1]//2 138 | img=img[starty:starty+self.output_size[0],startx:startx+self.output_size[1]] 139 | # Subtract mean color 140 | img -= np.array(self.mean_color) 141 | 142 | images[i] = img 143 | 144 | # Expand labels to one hot encoding 145 | one_hot_labels = np.zeros((batch_size, self.num_classes)) 146 | for i in range(len(labels)): 147 | one_hot_labels[i][labels[i]] = 1 148 | 149 | # Return array of images and labels 150 | return images, one_hot_labels 151 | -------------------------------------------------------------------------------- /log/Readme.md: -------------------------------------------------------------------------------- 1 | the model is too large, should be downloaded from : 2 | 链接:https://pan.baidu.com/s/1NqNpPIaabo7zx2n4Fbn2qw 3 | 提取码:6qlt 4 | -------------------------------------------------------------------------------- /log/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "mstnmodel_amazo_to_webcam_final2967.ckpt" 2 | all_model_checkpoint_paths: "mstnmodel_amazo_to_webcam_final2967.ckpt" 3 | -------------------------------------------------------------------------------- /log/mstnmodel_amazo_to_webcam_final2967.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiewp/PFAN/328c4c5b4e24a6d04b06136d9e83d8acb391a289/log/mstnmodel_amazo_to_webcam_final2967.ckpt.index -------------------------------------------------------------------------------- /log/tensorboardamazo_to_webcam_final/events.out.tfevents.1573467756.smartdsp-PC: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiewp/PFAN/328c4c5b4e24a6d04b06136d9e83d8acb391a289/log/tensorboardamazo_to_webcam_final/events.out.tfevents.1573467756.smartdsp-PC -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import math 4 | def protoloss(sc,tc): 5 | return tf.reduce_mean(tf.square(sc-tc)) 6 | 7 | class AlexNetModel(object): 8 | 9 | def __init__(self, num_classes=1000, dropout_keep_prob=0.5): 10 | self.num_classes = num_classes 11 | self.dropout_keep_prob = dropout_keep_prob 12 | self.featurelen=256 13 | self.source_moving_centroid=tf.get_variable(name='source_moving_centroid',shape=[num_classes,self.featurelen],initializer=tf.zeros_initializer(),trainable=False) 14 | self.target_moving_centroid=tf.get_variable(name='target_moving_centroid',shape=[num_classes,self.featurelen],initializer=tf.zeros_initializer(),trainable=False) 15 | 16 | tf.summary.histogram('source_moving_centroid',self.source_moving_centroid) 17 | tf.summary.histogram('target_moving_centroid',self.target_moving_centroid) 18 | 19 | 20 | 21 | def inference(self, x, training=False): 22 | # 1st Layer: Conv (w ReLu) -> Pool -> Lrn 23 | conv1 = conv(x, 11, 11, 96, 4, 4, padding='VALID', name='conv1') 24 | pool1 = max_pool(conv1, 3, 3, 2, 2, padding='VALID', name='pool1') 25 | norm1 = lrn(pool1, 1, 1e-5, 0.75, name='norm1') 26 | 27 | # 2nd Layer: Conv (w ReLu) -> Pool -> Lrn with 2 groups 28 | conv2 = conv(norm1, 5, 5, 256, 1, 1, groups=2, name='conv2') 29 | pool2 = max_pool(conv2, 3, 3, 2, 2, padding='VALID', name ='pool2') 30 | norm2 = lrn(pool2, 1, 1e-5, 0.75, name='norm2') 31 | 32 | # 3rd Layer: Conv (w ReLu) 33 | conv3 = conv(norm2, 3, 3, 384, 1, 1, name='conv3') 34 | # 4th Layer: Conv (w ReLu) splitted into two groups 35 | conv4 = conv(conv3, 3, 3, 384, 1, 1, groups=2, name='conv4') 36 | 37 | conv4_flattened=tf.contrib.layers.flatten(conv4) 38 | 39 | # 5th Layer: Conv (w ReLu) -> Pool splitted into two groups 40 | conv5 = conv(conv4, 3, 3, 256, 1, 1, groups=2, name='conv5') 41 | pool5 = max_pool(conv5, 3, 3, 2, 2, padding='VALID', name='pool5') 42 | 43 | # 6th Layer: Flatten -> FC (w ReLu) -> Dropout 44 | flattened = tf.reshape(pool5, [-1, 6*6*256]) 45 | self.flattened=flattened 46 | fc6 = fc(flattened, 6*6*256, 4096, name='fc6') 47 | if training: 48 | fc6 = dropout(fc6, self.dropout_keep_prob) 49 | self.fc6=fc6 50 | # 7th Layer: FC (w ReLu) -> Dropout 51 | fc7 = fc(fc6, 4096, 4096, name='fc7') 52 | if training: 53 | fc7 = dropout(fc7, self.dropout_keep_prob) 54 | self.fc7=fc7 55 | # 8th Layer: FC and return unscaled activations (for tf.nn.softmax_cross_entropy_with_logits) 56 | fc8=fc(fc7,4096,256,relu=False,name='fc8') 57 | self.vector=fc8 58 | self.fc8=fc8 59 | self.score = fc(fc8, 256, self.num_classes, relu=False, stddev=0.005,name='fc9') 60 | self.output=tf.nn.softmax(self.score/1.8) 61 | self.feature=self.fc8 62 | return self.score 63 | def adoptimize(self,learning_rate,train_layers=[]): 64 | var_list=[v for v in tf.trainable_variables() if 'D' in v.name] 65 | D_weights=[v for v in var_list if 'weights' in v.name] 66 | D_biases=[v for v in var_list if 'biases' in v.name] 67 | print '=================Discriminator_weights=====================' 68 | print D_weights 69 | print '=================Discriminator_biases=====================' 70 | print D_biases 71 | 72 | self.Dregloss=0.0005*tf.reduce_mean([tf.nn.l2_loss(v) for v in var_list if 'weights' in v.name]) 73 | D_op1 = tf.train.MomentumOptimizer(learning_rate,0.9).minimize(self.D_loss+self.Dregloss, var_list=D_weights) 74 | D_op2 = tf.train.MomentumOptimizer(learning_rate*2.0,0.9).minimize(self.D_loss+self.Dregloss, var_list=D_biases) 75 | D_op=tf.group(D_op1,D_op2) 76 | return D_op 77 | def wganloss(self,x,xt,batch_size,lam=10.0): 78 | with tf.variable_scope('reuse_inference') as scope: 79 | scope.reuse_variables() 80 | self.inference(x,training=True) 81 | source_fc6=self.fc6 82 | source_fc7=self.fc7 83 | source_fc8=self.fc8 84 | source_softmax=self.output 85 | source_output=outer(source_fc7,source_softmax) 86 | print 'SOURCE_OUTPUT: ',source_output.get_shape() 87 | scope.reuse_variables() 88 | self.inference(xt,training=True) 89 | target_fc6=self.fc6 90 | target_fc7=self.fc7 91 | target_fc8=self.fc8 92 | target_softmax=self.output 93 | target_output=outer(target_fc7,target_softmax) 94 | print 'TARGET_OUTPUT: ',target_output.get_shape() 95 | with tf.variable_scope('reuse') as scope: 96 | target_logits,_=D(target_fc8) 97 | scope.reuse_variables() 98 | source_logits,_=D(source_fc8) 99 | eps=tf.random_uniform([batch_size,1],minval=0.0,maxval=1.0) 100 | X_inter=eps*source_fc8+(1-eps)*target_fc8 101 | grad = tf.gradients(D(X_inter), [X_inter])[0] 102 | grad_norm = tf.sqrt(tf.reduce_sum((grad)**2, axis=1)) 103 | grad_pen = lam * tf.reduce_mean((grad_norm - 1)**2) 104 | D_loss=tf.reduce_mean(target_logits)-tf.reduce_mean(source_logits)+grad_pen 105 | G_loss=tf.reduce_mean(source_logits)-tf.reduce_mean(target_logits) 106 | self.G_loss=G_loss 107 | self.D_loss=D_loss 108 | self.D_loss=0.3*self.D_loss 109 | self.G_loss=0.3*self.G_loss 110 | return G_loss,D_loss 111 | def adloss(self,x,xt,y,global_step): 112 | with tf.variable_scope('reuse_inference') as scope: 113 | scope.reuse_variables() 114 | self.inference(x,training=True) 115 | source_feature=self.feature 116 | scope.reuse_variables() 117 | self.inference(xt,training=True) 118 | target_feature=self.feature 119 | target_pred=self.output 120 | with tf.variable_scope('reuse') as scope: 121 | source_logits,_=D(source_feature) 122 | scope.reuse_variables() 123 | target_logits,_=D(target_feature) 124 | self.source_feature=source_feature 125 | self.target_feature=target_feature 126 | self.concat_feature=tf.concat([source_feature,target_feature],0) 127 | source_result=tf.argmax(y,1) 128 | target_result=tf.argmax(target_pred,1) 129 | ones=tf.ones_like(source_feature) 130 | current_source_count=tf.unsorted_segment_sum(ones,source_result,self.num_classes) 131 | current_target_count=tf.unsorted_segment_sum(ones,target_result,self.num_classes) 132 | 133 | current_positive_source_count=tf.maximum(current_source_count,tf.ones_like(current_source_count)) 134 | current_positive_target_count=tf.maximum(current_target_count,tf.ones_like(current_target_count)) 135 | 136 | current_source_centroid=tf.divide(tf.unsorted_segment_sum(data=source_feature,segment_ids=source_result,num_segments=self.num_classes),current_positive_source_count) 137 | current_target_centroid=tf.divide(tf.unsorted_segment_sum(data=target_feature,segment_ids=target_result,num_segments=self.num_classes),current_positive_target_count) 138 | 139 | decay=tf.constant(0.3) 140 | self.decay=decay 141 | 142 | # target_centroid=(decay)*current_target_centroid+(1.-decay)*self.target_moving_centroid 143 | target_centroid=(decay)*current_target_centroid+(1.-decay)*self.source_moving_centroid 144 | source_centroid=(decay)*current_source_centroid+(1.-decay)*self.source_moving_centroid 145 | 146 | self.Semanticloss=protoloss(source_centroid,target_centroid) 147 | tf.summary.scalar('semanticloss',self.Semanticloss) 148 | 149 | D_real_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=target_logits,labels=tf.ones_like(target_logits))) 150 | D_fake_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=source_logits,labels=tf.zeros_like(source_logits))) 151 | self.D_loss=D_real_loss+D_fake_loss 152 | self.G_loss=-self.D_loss 153 | tf.summary.scalar('G_loss',self.G_loss) 154 | tf.summary.scalar('JSD',self.G_loss/2+math.log(2)) 155 | 156 | self.G_loss=0.1*self.G_loss 157 | self.D_loss=0.1*self.D_loss 158 | return self.G_loss,self.D_loss,source_centroid,target_centroid 159 | 160 | def loss(self, batch_x, batch_y=None): 161 | with tf.variable_scope('reuse_inference') as scope: 162 | y_predict = self.inference(batch_x, training=True) 163 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_predict, labels=batch_y)) 164 | tf.summary.scalar('Closs',self.loss) 165 | return self.loss 166 | 167 | def optimize(self, learning_rate, train_layers,global_step,source_centroid,target_centroid): 168 | print '+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++' 169 | print train_layers 170 | var_list = [v for v in tf.trainable_variables() if v.name.split('/')[1] in train_layers+['fc9']] 171 | finetune_list=[v for v in var_list if v.name.split('/')[1] in ['conv1','conv2','conv3','conv4','conv5','fc6','fc7']] 172 | new_list=[v for v in var_list if v.name.split('/')[1] in ['fc8','fc9']] 173 | self.Gregloss=0.0005*tf.reduce_mean([tf.nn.l2_loss(x) for x in var_list if 'weights' in x.name]) 174 | finetune_weights=[v for v in finetune_list if 'weights' in v.name] 175 | finetune_biases=[v for v in finetune_list if 'biases' in v.name] 176 | new_weights=[v for v in new_list if 'weights' in v.name] 177 | new_biases=[v for v in new_list if 'biases' in v.name] 178 | print '==============finetune_weights=======================' 179 | print finetune_weights 180 | print '==============finetune_biases=======================' 181 | print finetune_biases 182 | print '==============new_weights=======================' 183 | print new_weights 184 | print '==============new_biases=======================' 185 | print new_biases 186 | 187 | self.F_loss=self.loss+self.Gregloss+global_step*self.G_loss+global_step*self.Semanticloss 188 | train_op1=tf.train.MomentumOptimizer(learning_rate*0.1,0.9).minimize(self.F_loss, var_list=finetune_weights) 189 | train_op2=tf.train.MomentumOptimizer(learning_rate*0.2,0.9).minimize(self.F_loss, var_list=finetune_biases) 190 | train_op3=tf.train.MomentumOptimizer(learning_rate*1.0,0.9).minimize(self.F_loss, var_list=new_weights) 191 | train_op4=tf.train.MomentumOptimizer(learning_rate*2.0,0.9).minimize(self.F_loss, var_list=new_biases) 192 | train_op=tf.group(train_op1,train_op2,train_op3,train_op4) 193 | with tf.control_dependencies([train_op1,train_op2,train_op3,train_op4]): 194 | update_sc=self.source_moving_centroid.assign(source_centroid) 195 | update_tc=self.target_moving_centroid.assign(target_centroid) 196 | return tf.group(update_sc,update_tc) 197 | def load_original_weights(self, session, skip_layers=[]): 198 | weights_dict = np.load('bvlc_alexnet.npy', encoding='bytes').item() 199 | for op_name in weights_dict: 200 | # if op_name in skip_layers: 201 | # continue 202 | if op_name == 'fc8' and self.num_classes != 1000: 203 | continue 204 | with tf.variable_scope('reuse_inference/'+op_name, reuse=True): 205 | print '=============================OP_NAME ========================================' 206 | for data in weights_dict[op_name]: 207 | if len(data.shape) == 1: 208 | var = tf.get_variable('biases') 209 | print op_name,var 210 | session.run(var.assign(data)) 211 | else: 212 | var = tf.get_variable('weights') 213 | print op_name,var 214 | session.run(var.assign(data)) 215 | 216 | 217 | """ 218 | Helper methods 219 | """ 220 | def conv(x, filter_height, filter_width, num_filters, stride_y, stride_x, name, padding='SAME', groups=1): 221 | input_channels = int(x.get_shape()[-1]) 222 | convolve = lambda i, k: tf.nn.conv2d(i, k, strides=[1, stride_y, stride_x, 1], padding=padding) 223 | 224 | with tf.variable_scope(name) as scope: 225 | weights = tf.get_variable('weights', shape=[filter_height, filter_width, input_channels/groups, num_filters]) 226 | biases = tf.get_variable('biases', shape=[num_filters]) 227 | 228 | if groups == 1: 229 | conv = convolve(x, weights) 230 | else: 231 | input_groups = tf.split(axis=3, num_or_size_splits=groups, value=x) 232 | weight_groups = tf.split(axis=3, num_or_size_splits=groups, value=weights) 233 | output_groups = [convolve(i, k) for i,k in zip(input_groups, weight_groups)] 234 | conv = tf.concat(axis=3, values=output_groups) 235 | 236 | bias = tf.reshape(tf.nn.bias_add(conv, biases), [-1]+conv.get_shape().as_list()[1:]) 237 | relu = tf.nn.relu(bias, name=scope.name) 238 | return relu 239 | def D(x): 240 | with tf.variable_scope('D'): 241 | num_units_in=int(x.get_shape()[-1]) 242 | num_units_out=1 243 | weights = tf.get_variable('weights',initializer=tf.truncated_normal([num_units_in,1024],stddev=0.01)) 244 | biases = tf.get_variable('biases', shape=[1024], initializer=tf.zeros_initializer()) 245 | hx=(tf.matmul(x,weights)+biases) 246 | ax=tf.nn.dropout(tf.nn.relu(hx),0.5) 247 | weights2 = tf.get_variable('weights2',initializer=tf.truncated_normal([1024,1024],stddev=0.01)) 248 | biases2 = tf.get_variable('biases2', shape=[1024], initializer=tf.zeros_initializer()) 249 | hx2=(tf.matmul(ax,weights2)+biases2) 250 | ax2=tf.nn.dropout(tf.nn.relu(hx2),0.5) 251 | weights3 = tf.get_variable('weights3', initializer=tf.truncated_normal([1024,num_units_out],stddev=0.3)) 252 | biases3 = tf.get_variable('biases3', shape=[num_units_out], initializer=tf.zeros_initializer()) 253 | hx3=tf.matmul(ax2,weights3)+biases3 254 | return hx3,tf.nn.sigmoid(hx3) 255 | 256 | def fc(x, num_in, num_out, name, relu=True,stddev=0.01): 257 | with tf.variable_scope(name) as scope: 258 | weights = tf.get_variable('weights', initializer=tf.truncated_normal([num_in,num_out],stddev=stddev)) 259 | biases = tf.get_variable('biases',initializer=tf.constant(0.1,shape=[num_out])) 260 | act = tf.nn.xw_plus_b(x, weights, biases, name=scope.name) 261 | if relu == True: 262 | relu = tf.nn.relu(act) 263 | return relu 264 | else: 265 | return act 266 | def leaky_relu(x, alpha=0.2): 267 | return tf.maximum(tf.minimum(0.0, alpha * x), x) 268 | 269 | def outer(a,b): 270 | a=tf.reshape(a,[-1,a.get_shape()[-1],1]) 271 | b=tf.reshape(b,[-1,1,b.get_shape()[-1]]) 272 | c=a*b 273 | return tf.contrib.layers.flatten(c) 274 | 275 | def max_pool(x, filter_height, filter_width, stride_y, stride_x, name, padding='SAME'): 276 | return tf.nn.max_pool(x, ksize=[1, filter_height, filter_width, 1], strides = [1, stride_y, stride_x, 1], 277 | padding = padding, name=name) 278 | 279 | def lrn(x, radius, alpha, beta, name, bias=1.0): 280 | return tf.nn.local_response_normalization(x, depth_radius=radius, alpha=alpha, beta=beta, bias=bias, name=name) 281 | 282 | def dropout(x, keep_prob): 283 | return tf.nn.dropout(x, keep_prob) 284 | -------------------------------------------------------------------------------- /pseudo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Oct 31 17:05:04 2019 5 | 6 | @author: smartdsp 7 | """ 8 | 9 | with open('webcam_list.txt','r') as f1: 10 | f_c = f1.readlines() 11 | filename = [] 12 | for j in f_c: 13 | filename.append(j[:-3]) 14 | 15 | dic_target = {} 16 | with open('pre_and_sim.txt','r') as f: 17 | f_context = f.readlines() 18 | count = 0 19 | for i in f_context: 20 | dic_target.setdefault(eval(i.split(' ')[0]),[]).append([eval(i.split(' ')[1]),filename[count]]) 21 | count += 1 22 | class_c = 0 23 | result = [] 24 | for k in range(31): 25 | if k not in dic_target: 26 | continue 27 | ans = sorted(dic_target[k],reverse = True) 28 | if 0 < len(ans) <= 3: 29 | for t in ans: 30 | result.append(t[1] + ' ' + str(k)) 31 | else: 32 | for t in range(3): 33 | result.append(ans[t][1] + ' ' + str(k)) 34 | with open('amazon_list.txt','r') as fs: 35 | source = fs.readlines() 36 | with open('a_with_pseudo.txt','w') as fn: 37 | for i in source: 38 | fn.write(i) 39 | for j in result: 40 | fn.write(j+'\n') 41 | print('ok') -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | import datetime 5 | from model import AlexNetModel 6 | from dataprocessor import BatchPreprocessor 7 | import math 8 | from sklearn import manifold 9 | import matplotlib.pyplot as plt 10 | import matplotlib 11 | matplotlib.rcParams['pdf.fonttype'] = 42 12 | matplotlib.rcParams['ps.fonttype'] = 42 13 | tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Learning rate for adam optimizer') 14 | tf.app.flags.DEFINE_float('dropout_keep_prob', 0.5, 'Dropout keep probability') 15 | tf.app.flags.DEFINE_integer('num_epochs', 1, 'Number of epochs for training') 16 | tf.app.flags.DEFINE_integer('batch_size', 100, 'Batch size') 17 | tf.app.flags.DEFINE_string('train_layers', 'fc8,fc7,fc6,conv5,conv4,conv3,conv2,conv1', 'Finetuning layers, seperated by commas') 18 | tf.app.flags.DEFINE_string('multi_scale', '256,257', 'As preprocessing; scale the image randomly between 2 numbers and crop randomly at networs input size') 19 | tf.app.flags.DEFINE_string('train_root_dir', '../training', 'Root directory to put the training data') 20 | tf.app.flags.DEFINE_integer('log_step', 10000, 'Logging period in terms of iteration') 21 | 22 | NUM_CLASSES = 31 23 | TRAINING_FILE = 'amazon_list.txt' 24 | VAL_FILE = 'webcam_list.txt' 25 | FLAGS = tf.app.flags.FLAGS 26 | MAX_STEP=10000 27 | 28 | def decay(start_rate,epoch,num_epochs): 29 | return start_rate/pow(1+0.001*epoch,0.75) 30 | 31 | def adaptation_factor(x): 32 | if x>=1.0: 33 | return 1.0 34 | den=1.0+math.exp(-10*x) 35 | lamb=2.0/den-1.0 36 | return lamb 37 | def main(_): 38 | # Create training directories 39 | now = datetime.datetime.now() 40 | train_dir_name = now.strftime('alexnet_%Y%m%d_%H%M%S') 41 | train_dir = os.path.join(FLAGS.train_root_dir, train_dir_name) 42 | checkpoint_dir = os.path.join(train_dir, 'checkpoint') 43 | tensorboard_dir = os.path.join(train_dir, 'tensorboard') 44 | tensorboard_train_dir = os.path.join(tensorboard_dir, 'train') 45 | tensorboard_val_dir = os.path.join(tensorboard_dir, 'val') 46 | 47 | if not os.path.isdir(FLAGS.train_root_dir): os.mkdir(FLAGS.train_root_dir) 48 | if not os.path.isdir(train_dir): os.mkdir(train_dir) 49 | if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir) 50 | if not os.path.isdir(tensorboard_dir): os.mkdir(tensorboard_dir) 51 | if not os.path.isdir(tensorboard_train_dir): os.mkdir(tensorboard_train_dir) 52 | if not os.path.isdir(tensorboard_val_dir): os.mkdir(tensorboard_val_dir) 53 | 54 | # Write flags to txt 55 | flags_file_path = os.path.join(train_dir, 'flags.txt') 56 | flags_file = open(flags_file_path, 'w') 57 | flags_file.write('learning_rate={}\n'.format(FLAGS.learning_rate)) 58 | flags_file.write('dropout_keep_prob={}\n'.format(FLAGS.dropout_keep_prob)) 59 | flags_file.write('num_epochs={}\n'.format(FLAGS.num_epochs)) 60 | flags_file.write('batch_size={}\n'.format(FLAGS.batch_size)) 61 | flags_file.write('train_layers={}\n'.format(FLAGS.train_layers)) 62 | flags_file.write('multi_scale={}\n'.format(FLAGS.multi_scale)) 63 | flags_file.write('train_root_dir={}\n'.format(FLAGS.train_root_dir)) 64 | flags_file.write('log_step={}\n'.format(FLAGS.log_step)) 65 | flags_file.close() 66 | # Placeholders 67 | x = tf.placeholder(tf.float32, [None, 227, 227, 3],'x') 68 | xt = tf.placeholder(tf.float32, [None, 227, 227, 3],'xt') 69 | y = tf.placeholder(tf.float32, [None, NUM_CLASSES],'y') 70 | yt = tf.placeholder(tf.float32, [None, NUM_CLASSES],'yt') 71 | adlamb=tf.placeholder(tf.float32) 72 | decay_learning_rate=tf.placeholder(tf.float32) 73 | dropout_keep_prob = tf.placeholder(tf.float32) 74 | 75 | # Model 76 | train_layers = FLAGS.train_layers.split(',') 77 | model = AlexNetModel(num_classes=NUM_CLASSES, dropout_keep_prob=dropout_keep_prob) 78 | loss = model.loss(x, y) 79 | # Training accuracy of the model 80 | correct_pred = tf.equal(tf.argmax(model.score, 1), tf.argmax(y, 1)) 81 | correct=tf.reduce_sum(tf.cast(correct_pred,tf.float32)) 82 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 83 | #G_loss,D_loss=model.wganloss(x,xt,FLAGS.batch_size,10.0) 84 | G_loss,D_loss,sc,tc=model.adloss(x,xt,y,10) 85 | target_correct_pred = tf.equal(tf.argmax(model.score, 1), tf.argmax(yt, 1)) 86 | target_correct=tf.reduce_sum(tf.cast(target_correct_pred,tf.float32)) 87 | target_accuracy = tf.reduce_mean(tf.cast(target_correct_pred, tf.float32)) 88 | train_op = model.optimize(decay_learning_rate, train_layers,adlamb,sc,tc) 89 | 90 | D_op=model.adoptimize(decay_learning_rate,train_layers) 91 | optimizer=tf.group(train_op,D_op) 92 | 93 | train_writer=tf.summary.FileWriter('./log/tensorboard_restore') 94 | train_writer.add_graph(tf.get_default_graph()) 95 | tf.summary.scalar('Testing Accuracy',target_accuracy) 96 | merged=tf.summary.merge_all() 97 | 98 | print '============================GLOBAL TRAINABLE VARIABLES ============================' 99 | print tf.trainable_variables() 100 | #print '============================GLOBAL VARIABLES ======================================' 101 | #print tf.global_variables() 102 | # Batch preprocessors 103 | multi_scale = FLAGS.multi_scale.split(',') 104 | if len(multi_scale) == 2: 105 | multi_scale = [int(multi_scale[0]), int(multi_scale[1])] 106 | else: 107 | multi_scale = None 108 | print '==================== MULTI SCALE===================================================' 109 | print multi_scale 110 | train_preprocessor = BatchPreprocessor(dataset_file_path=TRAINING_FILE, num_classes=NUM_CLASSES, 111 | output_size=[227, 227], horizontal_flip=True, shuffle=True, multi_scale=multi_scale) 112 | Ttrain_preprocessor = BatchPreprocessor(dataset_file_path=VAL_FILE, num_classes=NUM_CLASSES, 113 | output_size=[227, 227], horizontal_flip=True, shuffle=True, multi_scale=multi_scale) 114 | val_preprocessor = BatchPreprocessor(dataset_file_path=VAL_FILE, num_classes=NUM_CLASSES, output_size=[227, 227],multi_scale=multi_scale,istraining=False) 115 | 116 | # Get the number of training/validation steps per epoch 117 | train_batches_per_epoch = np.floor(len(train_preprocessor.labels) / FLAGS.batch_size).astype(np.int16) 118 | Ttrain_batches_per_epoch = np.floor(len(Ttrain_preprocessor.labels) / FLAGS.batch_size).astype(np.int16) 119 | val_batches_per_epoch = np.floor(len(val_preprocessor.labels) / FLAGS.batch_size).astype(np.int16) 120 | 121 | 122 | with tf.Session() as sess: 123 | sess.run(tf.global_variables_initializer()) 124 | saver=tf.train.Saver() 125 | train_writer.add_graph(sess.graph) 126 | # Load the pretrained weights 127 | #model.load_original_weights(sess, skip_layers=train_layers) 128 | 129 | # Directly restore (your model should be exactly the same with checkpoint) 130 | saver.restore(sess, "./log/mstnmodel_amazo_to_webcam_final2967.ckpt") 131 | 132 | print("{} Start training...".format(datetime.datetime.now())) 133 | print("{} Open Tensorboard at --logdir {}".format(datetime.datetime.now(), tensorboard_dir)) 134 | gs=0 135 | gd=0 136 | for epoch in range(FLAGS.num_epochs): 137 | #print("{} Epoch number: {}".format(datetime.datetime.now(), epoch+1)) 138 | step = 1 139 | # Start training 140 | while step < train_batches_per_epoch: 141 | gd+=1 142 | lamb=adaptation_factor(gd*1.0/MAX_STEP) 143 | rate=decay(FLAGS.learning_rate,gd,MAX_STEP) 144 | if gd%1==0: 145 | print("{} Start validation".format(datetime.datetime.now())) 146 | test_acc = 0. 147 | test_count = 0 148 | 149 | for _ in range((len(val_preprocessor.labels))): 150 | batch_tx, batch_ty = val_preprocessor.next_batch(1) 151 | acc = sess.run(correct, feed_dict={x: batch_tx, y: batch_ty, dropout_keep_prob: 1.}) 152 | test_acc += acc 153 | test_count += 1 154 | print test_acc,test_count 155 | test_acc /= test_count 156 | print("{} Validation Accuracy = {:.4f}".format(datetime.datetime.now(), test_acc)) 157 | # Reset the dataset pointers 158 | val_preprocessor.reset_pointer() 159 | if __name__ == '__main__': 160 | tf.app.run() 161 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | import datetime 5 | from model import AlexNetModel 6 | from dataprocessor import BatchPreprocessor 7 | import math 8 | tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Learning rate for adam optimizer') 9 | tf.app.flags.DEFINE_float('dropout_keep_prob', 0.5, 'Dropout keep probability') 10 | tf.app.flags.DEFINE_integer('num_epochs', 10000000, 'Number of epochs for training') 11 | tf.app.flags.DEFINE_integer('batch_size',100, 'Batch size') 12 | tf.app.flags.DEFINE_string('train_layers', 'fc8,fc7,fc6,conv5,conv4,conv3,conv2,conv1', 'Finetuning layers, seperated by commas') 13 | tf.app.flags.DEFINE_string('multi_scale', '256,257', 'As preprocessing; scale the image randomly between 2 numbers and crop randomly at networs input size') 14 | tf.app.flags.DEFINE_string('train_root_dir', '../training', 'Root directory to put the training data') 15 | tf.app.flags.DEFINE_integer('log_step', 10000, 'Logging period in terms of iteration') 16 | 17 | NUM_CLASSES = 31 18 | TRAINING_FILE = 'amazon_list.txt' 19 | VAL_FILE = 'webcam_list.txt' 20 | FLAGS = tf.app.flags.FLAGS 21 | MAX_STEP=10000 22 | MODEL_NAME='amazo_to_webcam_final' 23 | def decay(start_rate,epoch,num_epochs): 24 | return start_rate/pow(1+0.001*epoch,0.75) 25 | 26 | def adaptation_factor(x): 27 | if x>=1.0: 28 | return 1.0 29 | den=1.0+math.exp(-10*x) 30 | lamb=2.0/den-1.0 31 | return lamb 32 | 33 | def cos_distance(vector1,vector2): 34 | dot_product = 0.0; 35 | normA = 0.0; 36 | normB = 0.0; 37 | for a,b in zip(vector1,vector2): 38 | dot_product += a*b 39 | normA += a**2 40 | normB += b**2 41 | if normA == 0.0 or normB==0.0: 42 | return None 43 | else: 44 | return dot_product / ((normA*normB)**0.5) 45 | 46 | 47 | def main(_): 48 | now = datetime.datetime.now() 49 | train_dir_name = now.strftime('alexnet_%Y%m%d_%H%M%S') 50 | train_dir = os.path.join(FLAGS.train_root_dir, train_dir_name) 51 | checkpoint_dir = os.path.join(train_dir, 'checkpoint') 52 | tensorboard_dir = os.path.join(train_dir, 'tensorboard') 53 | tensorboard_train_dir = os.path.join(tensorboard_dir, 'train') 54 | tensorboard_val_dir = os.path.join(tensorboard_dir, 'val') 55 | 56 | if not os.path.isdir(FLAGS.train_root_dir): os.mkdir(FLAGS.train_root_dir) 57 | if not os.path.isdir(train_dir): os.mkdir(train_dir) 58 | if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir) 59 | if not os.path.isdir(tensorboard_dir): os.mkdir(tensorboard_dir) 60 | if not os.path.isdir(tensorboard_train_dir): os.mkdir(tensorboard_train_dir) 61 | if not os.path.isdir(tensorboard_val_dir): os.mkdir(tensorboard_val_dir) 62 | 63 | # Write flags to txt 64 | flags_file_path = os.path.join(train_dir, 'flags.txt') 65 | flags_file = open(flags_file_path, 'w') 66 | flags_file.write('learning_rate={}\n'.format(FLAGS.learning_rate)) 67 | flags_file.write('dropout_keep_prob={}\n'.format(FLAGS.dropout_keep_prob)) 68 | flags_file.write('num_epochs={}\n'.format(FLAGS.num_epochs)) 69 | flags_file.write('batch_size={}\n'.format(FLAGS.batch_size)) 70 | flags_file.write('train_layers={}\n'.format(FLAGS.train_layers)) 71 | flags_file.write('multi_scale={}\n'.format(FLAGS.multi_scale)) 72 | flags_file.write('train_root_dir={}\n'.format(FLAGS.train_root_dir)) 73 | flags_file.write('log_step={}\n'.format(FLAGS.log_step)) 74 | flags_file.close() 75 | # Placeholders 76 | x = tf.placeholder(tf.float32, [None, 227, 227, 3],'x') 77 | xt = tf.placeholder(tf.float32, [None, 227, 227, 3],'xt') 78 | y = tf.placeholder(tf.float32, [None, NUM_CLASSES],'y') 79 | yt = tf.placeholder(tf.float32, [None, NUM_CLASSES],'yt') 80 | adlamb=tf.placeholder(tf.float32) 81 | decay_learning_rate=tf.placeholder(tf.float32) 82 | dropout_keep_prob = tf.placeholder(tf.float32) 83 | 84 | # Model 85 | train_layers = FLAGS.train_layers.split(',') 86 | model = AlexNetModel(num_classes=NUM_CLASSES, dropout_keep_prob=dropout_keep_prob) 87 | loss = model.loss(x, y) 88 | # Training accuracy of the model 89 | correct_pred = tf.equal(tf.argmax(model.score, 1), tf.argmax(y, 1)) 90 | correct=tf.reduce_sum(tf.cast(correct_pred,tf.float32)) 91 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 92 | G_loss,D_loss,sc,tc=model.adloss(x,xt,y,adlamb) 93 | target_correct_pred = tf.equal(tf.argmax(model.score, 1), tf.argmax(yt, 1)) 94 | target_correct=tf.reduce_sum(tf.cast(target_correct_pred,tf.float32)) 95 | target_accuracy = tf.reduce_mean(tf.cast(target_correct_pred, tf.float32)) 96 | train_op = model.optimize(decay_learning_rate, train_layers,adlamb,sc,tc) 97 | 98 | 99 | # Testing accuracy of the model 100 | source_vector = model.fc8 101 | target_vector = model.vector 102 | target_pre = tf.argmax(model.score, 1) 103 | 104 | 105 | 106 | 107 | D_op=model.adoptimize(decay_learning_rate,train_layers) 108 | optimizer=tf.group(train_op,D_op) 109 | 110 | 111 | train_writer=tf.summary.FileWriter('./log/tensorboard'+MODEL_NAME) 112 | train_writer.add_graph(tf.get_default_graph()) 113 | tf.summary.scalar('Testing Accuracy',target_accuracy) 114 | merged=tf.summary.merge_all() 115 | 116 | print '============================GLOBAL TRAINABLE VARIABLES ============================' 117 | print tf.trainable_variables() 118 | multi_scale = FLAGS.multi_scale.split(',') 119 | if len(multi_scale) == 2: 120 | multi_scale = [int(multi_scale[0]), int(multi_scale[1])] 121 | else: 122 | multi_scale = None 123 | print '==================== MULTI SCALE===================================================' 124 | print multi_scale 125 | train_preprocessor = BatchPreprocessor(dataset_file_path=TRAINING_FILE, num_classes=NUM_CLASSES, 126 | output_size=[227, 227], horizontal_flip=True, shuffle=True, multi_scale=multi_scale) 127 | Ttrain_preprocessor = BatchPreprocessor(dataset_file_path=VAL_FILE, num_classes=NUM_CLASSES, 128 | output_size=[227, 227], horizontal_flip=True, shuffle=True, multi_scale=multi_scale) 129 | val_preprocessor = BatchPreprocessor(dataset_file_path=VAL_FILE, num_classes=NUM_CLASSES, output_size=[227, 227],multi_scale=multi_scale,istraining=False) 130 | train_batches_per_epoch = np.floor(len(train_preprocessor.labels) / FLAGS.batch_size).astype(np.int16) 131 | Ttrain_batches_per_epoch = np.floor(len(Ttrain_preprocessor.labels) / FLAGS.batch_size).astype(np.int16) 132 | val_batches_per_epoch = np.floor(len(val_preprocessor.labels) / FLAGS.batch_size).astype(np.int16) 133 | 134 | 135 | dic_s = {} 136 | dic_temp = {} 137 | dic_temp1 = {} 138 | dic_t = {} 139 | 140 | 141 | 142 | with tf.Session() as sess: 143 | sess.run(tf.global_variables_initializer()) 144 | saver=tf.train.Saver() 145 | train_writer.add_graph(sess.graph) 146 | model.load_original_weights(sess, skip_layers=train_layers) 147 | print("{} Start training...".format(datetime.datetime.now())) 148 | print("{} Open Tensorboard at --logdir {}".format(datetime.datetime.now(), tensorboard_dir)) 149 | gs=0 150 | gd=0 151 | best_acc = 0.0 152 | flag = 1 153 | flag1 = 1 154 | first_s = 50 155 | for epoch in range(FLAGS.num_epochs): 156 | step = 1 157 | while step < train_batches_per_epoch: 158 | gd+=1 159 | lamb=adaptation_factor(gd*1.0/MAX_STEP) 160 | rate=decay(FLAGS.learning_rate,gd,MAX_STEP) 161 | for it in xrange(1): 162 | gs+=1 163 | if gs%Ttrain_batches_per_epoch==0: 164 | Ttrain_preprocessor.reset_pointer() 165 | if gs%train_batches_per_epoch==0: 166 | train_preprocessor.reset_pointer() 167 | batch_xs, batch_ys = train_preprocessor.next_batch(FLAGS.batch_size) 168 | Tbatch_xs, Tbatch_ys = Ttrain_preprocessor.next_batch(FLAGS.batch_size) 169 | summary,_=sess.run([merged,optimizer], feed_dict={x: batch_xs,xt: Tbatch_xs,yt:Tbatch_ys,adlamb:lamb, decay_learning_rate:rate,y: batch_ys,dropout_keep_prob:0.5}) 170 | train_writer.add_summary(summary,gd) 171 | closs,gloss,dloss,gregloss,dregloss,floss,smloss=sess.run([model.loss,model.G_loss,model.D_loss,model.Gregloss,model.Dregloss,model.F_loss,model.Semanticloss], 172 | feed_dict={x: batch_xs,xt: Tbatch_xs,adlamb:lamb, decay_learning_rate:rate,y: batch_ys,dropout_keep_prob:0.5}) 173 | step += 1 174 | 175 | 176 | if epoch == first_s: 177 | source_v = sess.run(source_vector, feed_dict={x: batch_xs, y: batch_ys, xt: Tbatch_xs,dropout_keep_prob: 1.}) 178 | for i in range(FLAGS.batch_size): 179 | dic_temp.setdefault(np.argmax(batch_ys[i]),[]).append(source_v[i]) 180 | if epoch == first_s+1 and flag == 1: 181 | for i in dic_temp.keys(): 182 | dic_s[i] = np.mean(dic_temp[i],axis=0) 183 | with open('dic_s.txt','w') as f: 184 | f.write(str(dic_s)) 185 | flag = 0 186 | 187 | 188 | 189 | # if gd%50==0: 190 | if epoch%5 == 0 and step == train_batches_per_epoch - 1: 191 | print '=================== Step {0:<10} ================='.format(gs) 192 | print 'Epoch {0:<5} Step {1:<5} Closs {2:<10} Gloss {3:<10} Dloss {4:<10} Total_Loss {7:<10} Gregloss {5:<10} Dregloss {6:<10} Semloss {7:<10}'.format(epoch,step,closs,gloss,dloss,gregloss,dregloss,floss,smloss) 193 | print 'lambda: ',lamb 194 | print 'rate: ',rate 195 | # Epoch completed, start validation 196 | print("{} Start validation".format(datetime.datetime.now())) 197 | test_acc = 0. 198 | test_count = 0 199 | fp = open('pre_and_sim.txt','w') 200 | for _ in range((len(val_preprocessor.labels))): 201 | batch_tx, batch_ty = val_preprocessor.next_batch(1) 202 | 203 | if flag == 0 and flag1 == 1: 204 | target_v = sess.run(target_vector, feed_dict={xt: batch_tx, dropout_keep_prob: 1.}) 205 | sim_list = [] 206 | for j in range(NUM_CLASSES): 207 | # print(target_v[0]) 208 | # print('okkk') 209 | # print(dic_s[j]) 210 | sim_value = cos_distance(target_v[0], dic_s[j]) 211 | sim_list.append(sim_value) 212 | max_sim = max(sim_list) 213 | max_idx = sim_list.index(max_sim) 214 | fp.write(str(max_idx) + ' ' + str(max_sim) + '\n') 215 | dic_temp1.setdefault(np.argmax(batch_ty[0]),[]).append(target_v[0]) 216 | if epoch > first_s and flag1 == 1: 217 | for i in dic_temp1.keys(): 218 | dic_t[i] = np.mean(dic_temp1[i],axis=0) 219 | with open('dic_t.txt','w') as f: 220 | f.write(str(dic_t)) 221 | flag1 = 0 222 | 223 | 224 | acc = sess.run(correct, feed_dict={x: batch_tx, y: batch_ty, dropout_keep_prob: 1.}) 225 | test_acc += acc 226 | test_count += 1 227 | fp.close() 228 | print test_acc,test_count 229 | test_acc /= test_count 230 | if test_acc > best_acc: 231 | best_acc = test_acc 232 | print('best acc is: %f'%best_acc) 233 | print("{} Validation Accuracy = {:.4f}".format(datetime.datetime.now(), test_acc)) 234 | # Reset the dataset pointers 235 | val_preprocessor.reset_pointer() 236 | #train_preprocessor.reset_pointer() 237 | if gd%4000==0 and gd>0: 238 | saver.save(sess,'./log/mstnmodel_'+MODEL_NAME+str(gd)+'.ckpt') 239 | print("{} Saving checkpoint of model...".format(datetime.datetime.now())) 240 | # while(1): 241 | # print("1") 242 | 243 | 244 | if __name__ == '__main__': 245 | tf.app.run() 246 | --------------------------------------------------------------------------------