├── README.md ├── LICENSE ├── create_tf.py ├── .gitignore ├── test.py ├── get_data.py ├── net.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # facial-attribute-classification-with-graph 2 | Facial attribute classification based on graph attention (tensorflow) 3 | 4 | 1.Datasets 5 | 6 | CelebA(aligned & cropped version):http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html 7 | 8 | LFWA:http://vis-www.cs.umass.edu/lfw/ 9 | 10 | 2.Pretrain Model 11 | 12 | Alexnet with bn:converted from caffe model trained by Marcel Simon, Erik Rodner, Joachim Denzler. 13 | 14 | You can download the caffe model here: 15 | 16 | https://github.com/cvjena/cnn-models 17 | 18 | We also provide the converted tensorflow model in the following link, along with the converted resnet50 mdoel trained on VggFace2. 19 | 20 | 链接:https://pan.baidu.com/s/1KBSN0ZGuGAtXRRyP7fq_7Q 21 | 提取码:xhzu 22 | 23 | 3.Data Augmentation 24 | 25 | As the LFWA dataset is too small, we add some distortion to the training set. The distortion tool is from https://github.com/mdbloice/Augmentor 26 | 27 | 28 | 29 | The paper has been uploaded on https://arxiv.org/abs/1810.09162. 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ivy Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /create_tf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Mar 15 13:20:39 2018 5 | 6 | @author: crazydemo 7 | """ 8 | 9 | import tensorflow as tf 10 | from PIL import Image 11 | import numpy as np 12 | import time 13 | list_file ="path/to/val.txt" 14 | root = 'path/to/root' 15 | 16 | count = 0 17 | writer = tf.python_io.TFRecordWriter("vali.tfrecords") 18 | with open(list_file, 'r') as f: 19 | for line in f: 20 | line = line.strip() 21 | field = line.split(' ') 22 | temp = field[1:41] 23 | label=[np.int(i) for i in temp] 24 | img = Image.open(root+field[0]) 25 | if float(img.size[0])/float(img.size[1])>4 or float(img.size[1])/float(img.size[0])>4: 26 | continue 27 | img= img.resize((256,256)) 28 | img_raw = img.tobytes() 29 | example = tf.train.Example(features=tf.train.Features(feature={'label': tf.train.Feature(int64_list=tf.train.Int64List(value=label)), 30 | 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))})) 31 | writer.write(example.SerializeToString()) 32 | count = count + 1 33 | if count%500 ==0: 34 | print 'Time:{0},{1} images are processed.'.format(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),count) 35 | print "%d images are processed." %count 36 | print 'Done!' 37 | writer.close() 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Oct 17 20:31:40 2018 5 | 6 | @author: ivy 7 | """ 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | from PIL import Image 12 | from net import * 13 | from get_data import * 14 | 15 | TEST_BATCH_SIZE=1 16 | num_of_attri = 40 17 | total_num = 19962 18 | 19 | test_file = "path/to/your/test.tfrecords" 20 | 21 | phase_train = tf.placeholder(tf.bool, name='phase_train') 22 | config = tf.ConfigProto() 23 | config.gpu_options.allow_growth = True 24 | sess=tf.Session(config=config) 25 | 26 | x_image = tf.placeholder(tf.float32, [TEST_BATCH_SIZE, 227,227,3]) 27 | y = tf.placeholder(tf.int64, shape=[TEST_BATCH_SIZE, 40]) 28 | 29 | img_val,label_val= read_and_decode_test(test_file) 30 | img_batch_val,label_batch_val = tf.train.batch([img_val,label_val], batch_size=VAL_BATCH_SIZE, capacity=2000) 31 | 32 | logits, _, affinity_matrix= mynet(x_image, phase_train) 33 | cross_entropy = [0]*40 34 | with tf.name_scope("cross_ent"): 35 | for i in range(num_of_attri): 36 | cross_entropy[i] = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y[:,i],logits=logits[i])) 37 | cross_ent40 = tf.reduce_sum(cross_entropy) 38 | acc = [0.0]*40 39 | temp_y = [0]*40 40 | with tf.name_scope("accuracy"): 41 | for i in range(num_of_attri): 42 | temp_y[i] = tf.cast(tf.argmax(logits[i],1),tf.int64) 43 | acc[i] = tf.reduce_mean(tf.cast(tf.equal(temp_y[i], y[:,i]), tf.float32)) 44 | accuracy40 = tf.reduce_mean(acc) 45 | 46 | saver = tf.train.Saver(max_to_keep = None) 47 | 48 | saver.restore(sess,'path/to/your/model') 49 | 50 | mean_ce_ = 0.0 51 | mean_acc_ = 0.0 52 | acc_v = np.array([40]) 53 | acc_v_ = np.zeros([40]) 54 | visual_v = np.zeros([40, 40]) 55 | 56 | threads = tf.train.start_queue_runners(sess=sess) 57 | 58 | for i in range(total_num): 59 | x_,y_= sess.run([img_batch_val,label_batch_val]) 60 | vali_op = [cross_ent40, accuracy40, acc, logits, temp_y] 61 | vali_ce_v, vali_acc_v, acc_v, logits_v, y_predict_v, temp_y_v = sess.run(vali_op,feed_dict={x_image:x_, y:y_, phase_train:False}) 62 | 63 | print("batch:{}, mean_ce:{:.4f}, mean_acc:{:.4f}".format(i, vali_ce_v, vali_acc_v)) 64 | mean_ce_+=vali_ce_v 65 | mean_acc_+=vali_acc_v 66 | acc_v_ += acc_v 67 | mean_ce_ /= total_num 68 | mean_acc_ /= total_num 69 | acc_v_ /= total_num 70 | for i in range(40): 71 | print acc_v_[i] 72 | 73 | print("mean_ce_:{:.4f}, mean_acc_:{:.4f}".format(mean_ce_, mean_acc_)) 74 | 75 | sess.close() -------------------------------------------------------------------------------- /get_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Oct 17 20:30:09 2018 5 | 6 | @author: crazydemo 7 | """ 8 | 9 | import tensorflow as tf 10 | from PIL import Image 11 | 12 | image_height=227 13 | image_width=227 14 | 15 | def read_and_decode(filename): 16 | filename_queue = tf.train.string_input_producer([filename], shuffle=False) 17 | reader = tf.TFRecordReader() 18 | _, serialized_example = reader.read(filename_queue) 19 | features = tf.parse_single_example(serialized_example,features={'label': tf.FixedLenFeature([40], tf.int64), 20 | 'img_raw' : tf.FixedLenFeature([], tf.string)}) 21 | label = tf.cast(features['label'], tf.int32) 22 | img = tf.decode_raw(features['img_raw'], tf.uint8) 23 | img = tf.reshape(img, [256, 256, 3]) 24 | img = tf.image.random_flip_left_right(img) 25 | img = tf.image.resize_image_with_crop_or_pad(img, image_height, image_width) 26 | # img = tf.image.per_image_standardization(img) 27 | img = tf.cast(img, tf.float32) 28 | return img,label 29 | 30 | def read_and_decode_test(filename): 31 | filename_queue = tf.train.string_input_producer([filename], shuffle=False) 32 | reader = tf.TFRecordReader() 33 | _, serialized_example = reader.read(filename_queue) 34 | features = tf.parse_single_example(serialized_example,features={'label': tf.FixedLenFeature([40], tf.int64), 35 | 'img_raw' : tf.FixedLenFeature([], tf.string)}) 36 | label = tf.cast(features['label'], tf.int32) 37 | img = tf.decode_raw(features['img_raw'], tf.uint8) 38 | img = tf.reshape(img, [256, 256, 3]) 39 | img = tf.image.resize_image_with_crop_or_pad(img, image_height, image_width) 40 | # img = tf.image.per_image_standardization(img) 41 | img = tf.cast(img, tf.float32) 42 | return img,label 43 | 44 | def read_and_decode_test_ordinal(filename): 45 | filename_queue = tf.train.string_input_producer([filename], shuffle=False) 46 | reader = tf.TFRecordReader() 47 | _, serialized_example = reader.read(filename_queue) 48 | features = tf.parse_single_example(serialized_example,features={'label': tf.FixedLenFeature([40], tf.int64), 49 | 'img_raw' : tf.FixedLenFeature([], tf.string), 50 | 'name': tf.FixedLenFeature([], tf.int64)}) 51 | name = tf.cast(features['name'], tf.int32) 52 | label = tf.cast(features['label'], tf.int32) 53 | img = tf.decode_raw(features['img_raw'], tf.uint8) 54 | img = tf.reshape(img, [256, 256, 3]) 55 | img = tf.image.resize_image_with_crop_or_pad(img, image_height, image_width) 56 | # img = tf.image.per_image_standardization(img) 57 | img = tf.cast(img, tf.float32) 58 | return img,label,name -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Oct 17 20:21:08 2018 5 | 6 | @author: crazydemo 7 | """ 8 | 9 | import tensorflow as tf 10 | import tensorflow.python.layers.layers 11 | 12 | REGULARIZATION_RATE = 0.001 13 | regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) 14 | 15 | def max_pool(x, kheight, kwidth, stridex, stridey, padding): 16 | return tf.nn.max_pool(x, ksize=[1, kheight, kwidth, 1],strides=[1, stridex, stridey, 1], padding=padding) 17 | 18 | def bn(x, phase_train, name, activation=None): 19 | if activation=="relu": 20 | activation = tf.nn.relu 21 | return tf.contrib.layers.batch_norm(x, activation_fn=activation, center=True, scale=True, is_training=phase_train,scope=name) 22 | 23 | def conv(x, kHeight, kWidth, strideX, strideY, featureNum, name, padding = "SAME"): 24 | channel = int(x.get_shape()[-1]) 25 | shape = [kHeight, kWidth, channel, featureNum] 26 | with tf.variable_scope(name): 27 | w = tf.get_variable('w', shape=shape, initializer=tf.contrib.layers.xavier_initializer(), regularizer=regularizer) 28 | b = tf.get_variable('b', shape = [featureNum], initializer=tf.constant_initializer(0.0)) 29 | out = tf.nn.conv2d(x, w, strides=[1, strideX, strideY, 1], padding=padding)+b 30 | return out 31 | 32 | def positionFeature(x): 33 | pool = tf.reduce_mean(x, 3) 34 | out = tf.expand_dims(pool, 3) 35 | return out 36 | 37 | def sigmLayer(x): 38 | return tf.nn.sigmoid(x) 39 | 40 | def gapLayer(x, kHeight, kWidth, padding = "VALID"): 41 | return tf.nn.avg_pool(x, ksize = [1, kHeight, kWidth, 1], strides = [1, 1, 1, 1], padding = padding) 42 | 43 | def fc(x, outD, name): 44 | inD = int(x.get_shape()[-1]) 45 | layer_flat = tf.reshape(x, [-1, inD]) 46 | shape = [inD, outD] 47 | with tf.variable_scope(name): 48 | w = tf.get_variable('w', shape=shape, initializer=tf.contrib.layers.xavier_initializer(), regularizer=regularizer) 49 | b = tf.get_variable('b', shape = [outD], initializer=tf.constant_initializer(0.0)) 50 | out = tf.matmul(layer_flat, w) + b 51 | return out 52 | 53 | def fc_conv(x, outD, name): 54 | h = int(x.get_shape()[1]) 55 | w = int(x.get_shape()[2]) 56 | c = int(x.get_shape()[3]) 57 | inD = h*w*c 58 | layer_flat = tf.reshape(x, [-1, inD]) 59 | shape = [inD, outD] 60 | with tf.variable_scope(name): 61 | w = tf.get_variable('w', shape=shape, initializer=tf.contrib.layers.xavier_initializer(), regularizer=regularizer) 62 | b = tf.get_variable('b', shape = [outD], initializer=tf.constant_initializer(0.0)) 63 | out = tf.matmul(layer_flat, w) + b 64 | return out 65 | 66 | def convLayer(x, kHeight, kWidth, strideX, strideY, 67 | featureNum, name, padding = "SAME", groups = 1): 68 | """convolution""" 69 | channel = int(x.get_shape()[-1]) 70 | initializer = tf.contrib.layers.xavier_initializer() 71 | conv = lambda a, b: tf.nn.conv2d(a, b, strides = [1, strideY, strideX, 1], padding = padding) 72 | with tf.variable_scope(name) as scope: 73 | w = tf.get_variable("w", shape = [kHeight, kWidth, channel/groups, featureNum], initializer=initializer) 74 | b = tf.get_variable("b", shape = [featureNum], initializer=initializer) 75 | 76 | xNew = tf.split(value = x, num_or_size_splits = groups, axis = 3) 77 | wNew = tf.split(value = w, num_or_size_splits = groups, axis = 3) 78 | 79 | featureMap = [conv(t1, t2) for t1, t2 in zip(xNew, wNew)] 80 | mergeFeatureMap = tf.concat(axis = 3, values = featureMap) 81 | 82 | out = tf.nn.bias_add(mergeFeatureMap, b) 83 | return out 84 | 85 | 86 | def net(x, train_phase): 87 | y_predict_cln = [0]*40 88 | y_predict_fln = [0]*40 89 | 90 | data_bn = bn(x, train_phase, "data_bn") 91 | # print('bn0:{}'.format(bn0.get_shape())) 92 | conv1 = conv(data_bn, 11, 11, 4, 4, 96, "conv1", "VALID") 93 | conv1_bn = bn(conv1, train_phase, "conv1_bn", "relu") 94 | pool1 = max_pool(conv1_bn, 3, 3, 2, 2, "VALID") 95 | # print('pool1:{}'.format(pool1.get_shape())) 96 | conv2 = convLayer(pool1, 5, 5, 1, 1, 256, "conv2", "SAME", 2) 97 | conv2_bn = bn(conv2, train_phase, "conv2_bn", "relu") 98 | pool2 = max_pool(conv2_bn, 3, 3, 2, 2, "VALID") 99 | # print('pool2:{}'.format(pool2.get_shape())) 100 | conv3 = conv(pool2, 3, 3, 1, 1, 384, "conv3") 101 | conv3_bn = bn(conv3, train_phase, "conv3_bn", "relu") 102 | # print('bn3:{}'.format(bn3.get_shape())) 103 | conv4 = convLayer(conv3_bn, 3, 3, 1, 1, 384, "conv4", "SAME", 2) 104 | conv4_bn = bn(conv4, train_phase, "conv4_bn", "relu") 105 | # print('bn4:{}'.format(bn4.get_shape())) 106 | conv5 = convLayer(conv4_bn, 3, 3, 1, 1, 256, "conv5", "SAME", 2) 107 | conv5_bn = bn(conv5, train_phase, "conv5_bn", "relu") 108 | pool5 = max_pool(conv5_bn, 3, 3, 2, 2, "VALID") 109 | # print('pool5:{}'.format(pool5.get_shape())) 110 | 111 | with tf.variable_scope("block0") as scope: 112 | bconv1 = conv(pool5, 1, 1, 1, 1, 256, "bconv1") 113 | bbn1 = bn(bconv1, train_phase, "bbn1", "relu") 114 | # print('bbn1:{}'.format(bbn1.get_shape())) 115 | pf = positionFeature(bbn1) 116 | # print('pf:{}'.format(pf.get_shape())) 117 | bconv2 = conv(pf, 1, 1, 1, 1, 16, "bconv2") 118 | bbn2 = bn(bconv2, train_phase, "bbn2", "relu") 119 | # print('bbn2:{}'.format(bbn2.get_shape())) 120 | bconv3 = conv(bbn2, 1, 1, 1, 1, 1, "bconv3", "VALID") 121 | sigm3 = sigmLayer(bconv3) 122 | # print('sigm3:{}'.format(sigm3.get_shape())) 123 | ele_mul = tf.multiply(bbn1, sigm3) 124 | ###################### 40 tasks ######################################################## 125 | gap = gapLayer(ele_mul, 6, 6) 126 | fc_out = fc(gap, 2, scope.name) 127 | y_predict_fln[0] = fc_out 128 | ###################### 40 tasks ######################################################## 129 | 130 | psetemp = conv(ele_mul, 1, 1, 1, 1, 10, "cln_conv", "SAME") 131 | pseout = psetemp 132 | # print('pseout:{}'.format(pseout.get_shape())) 133 | b, h, w, c = psetemp.get_shape().as_list() 134 | affinity_matrix = tf.reshape(psetemp, [b, h*w*c, 1]) 135 | # print('affinity_matrix:{}'.format(affinity_matrix.get_shape())) 136 | 137 | for i in range(1,40): 138 | with tf.variable_scope("block"+str(i)) as scope: 139 | bconv1 = conv(pool5, 1, 1, 1, 1, 256, "bconv1") 140 | bbn1 = bn(bconv1, train_phase, "bbn1", "relu") 141 | 142 | pf = positionFeature(bbn1) 143 | # print('pf:{}'.format(pf.get_shape())) 144 | bconv2 = conv(pf, 1, 1, 1, 1, 16, "bconv2") 145 | bbn2 = bn(bconv2, train_phase, "bbn2", "relu") 146 | # print('bbn2:{}'.format(bbn2.get_shape())) 147 | bconv3 = conv(bbn2, 1, 1, 1, 1, 1, "bconv3", "VALID") 148 | sigm3 = sigmLayer(bconv3) 149 | # print('sigm3:{}'.format(sigm3.get_shape())) 150 | ele_mul = tf.multiply(bbn1, sigm3) 151 | ###################### 40 tasks ######################################################## 152 | gap = gapLayer(ele_mul, 6, 6) 153 | fc_out = fc(gap, 2, scope.name) 154 | y_predict_fln[i] = fc_out 155 | ###################### 40 tasks ######################################################## 156 | psetemp = conv(ele_mul, 1, 1, 1, 1, 10, "cln_conv", "SAME") 157 | pseout = tf.concat([pseout, psetemp], 3) 158 | b, h, w, c = psetemp.get_shape().as_list() 159 | affinity_matrix = tf.concat([affinity_matrix, tf.reshape(psetemp, [b, h*w*c, 1])], 2) 160 | 161 | affinity_matrix_ = tf.matmul(tf.transpose(affinity_matrix, [0, 2, 1]), affinity_matrix) 162 | merge_mat = tf.expand_dims(affinity_matrix_, 3) 163 | with tf.variable_scope("affinity_mat") as scope: 164 | tf.summary.image('merge_mat', merge_mat) 165 | # print('affinity_matrix_:{}'.format(affinity_matrix_.get_shape())) 166 | affinity_weight = [0]*40 167 | affinity_weight = tf.split(affinity_matrix_,40, 1) 168 | # print('affinity_weight[0]:{}'.format(affinity_weight[0].get_shape())) 169 | pseout_ = tf.reshape(pseout, [b, h*w*c, 40]) 170 | # print('pseout_:{}'.format(pseout_.get_shape())) 171 | pseout_ = tf.transpose(pseout_, [0,2,1]) 172 | for i in range(40): 173 | with tf.variable_scope('cln_fc'+str(i)) as scope: 174 | aw = tf.nn.softmax(affinity_weight[i]) 175 | tf.summary.histogram("aw", aw) 176 | weighted_pse = tf.matmul(aw,pseout_) 177 | fc_out = fc(weighted_pse, 2, scope.name) 178 | y_predict_cln[i] = fc_out 179 | merged_summary = tf.summary.merge([tf.get_collection(tf.GraphKeys.SUMMARIES,'affinity_mat')]) 180 | return y_predict_cln, y_predict_fln, merged_summary 181 | 182 | 183 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Oct 17 19:35:19 2018 5 | 6 | @author: crazydemo 7 | """ 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | import os 12 | from PIL import Image 13 | from net import * 14 | from get_data import * 15 | 16 | BATCH_SIZE=128 17 | VAL_BATCH_SIZE=128 18 | num_of_attri = 40 19 | total_step = 25600 20 | pretrain = False 21 | continue_train = False 22 | 23 | last_model_path = 'path/to/your/restore/model' 24 | train_file = 'path/to/your/train.tfrecords' 25 | vali_file = 'path/to/your/vali.tfrecords' 26 | tensorboard_path = 'path/logs/tensorboard' 27 | model_path = 'path/logs/model/' 28 | pretrain_path = 'path/to/your/pretrain/model'#Here we use the alexbn.npy model, please refer to readme for detail information. 29 | 30 | '''The learning rate generator may be used, especially when training the lfwa dataset''' 31 | def cyclical_learning_rate(global_step, step_size, max_bound, min_bound, decay, name=None):#cyclical learning rate generator 32 | if global_step is None: 33 | raise ValueError("global_step is required for exponential_decay.") 34 | global_step = tf.cast(global_step, tf.float32) 35 | step_size = tf.convert_to_tensor(step_size, tf.float32) 36 | max_bound = tf.convert_to_tensor(max_bound, tf.float32) 37 | min_bound = tf.convert_to_tensor(min_bound, tf.float32) 38 | decay = tf.convert_to_tensor(decay, tf.float32) 39 | inverse = tf.floordiv(global_step, step_size) 40 | max_bound = tf.multiply(max_bound, tf.pow(decay, tf.floordiv(inverse, 2))) 41 | 42 | x = tf.mod(global_step, step_size) 43 | 44 | p = tf.cond(tf.mod(inverse, 2)<1, lambda: tf.divide((max_bound-min_bound), step_size), 45 | lambda: tf.divide((min_bound-max_bound), step_size)) 46 | res = tf.multiply(x, p) 47 | return tf.cond(tf.mod(inverse, 2)<1, lambda: tf.add(res, min_bound, name), 48 | lambda: tf.add(res, max_bound, name)) 49 | 50 | def test_learning_rate(total_step, global_step, maximum_bound):#function for determining the hyparameters of cyclical learning rate 51 | if global_step is None: 52 | raise ValueError("global_step is required for exponential_decay.") 53 | global_step = tf.cast(global_step, tf.float32) 54 | max_bound = tf.convert_to_tensor(maximum_bound, tf.float32) 55 | k = tf.divide(max_bound, total_step) 56 | return tf.multiply(global_step, k) 57 | 58 | 59 | '''basic definition and configuration''' 60 | if not os.path.exists(tensorboard_path): 61 | os.makedirs(tensorboard_path) 62 | if not os.path.exists(model_path): 63 | os.makedirs(model_path) 64 | 65 | config = tf.ConfigProto() 66 | config.gpu_options.allow_growth = True 67 | sess=tf.Session(config=config) 68 | 69 | phase_train = tf.placeholder(tf.bool, name='phase_train') 70 | x_image = tf.placeholder(tf.float32, [BATCH_SIZE, 227,227,3]) 71 | y = tf.placeholder(tf.int64, shape=[BATCH_SIZE, 40]) 72 | global_step=tf.Variable(0,trainable=False) 73 | learning_rate = tf.train.polynomial_decay(0.001, global_step/3, 25600, 0) 74 | 75 | lr_summary = tf.summary.scalar('learning_rate',learning_rate) 76 | 77 | img,label = read_and_decode(train_file) 78 | img_batch,label_batch = tf.train.shuffle_batch([img,label], batch_size=BATCH_SIZE, capacity=2000, min_after_dequeue=1000) 79 | img_val,label_val = read_and_decode_test(vali_file) 80 | img_batch_val,label_batch_val = tf.train.shuffle_batch([img_val,label_val], batch_size=VAL_BATCH_SIZE, capacity=2000, min_after_dequeue=1000) 81 | 82 | 83 | '''net, loss, accuracy and summary''' 84 | logits_cln, logits_fln, out1_summary = net(x_image, phase_train) 85 | 86 | cross_entropy_cln = [0]*40 87 | cross_entropy_flc = [0]*40 88 | with tf.name_scope("cross_ent"): 89 | for i in range(num_of_attri): 90 | cross_entropy_cln[i] = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y[:,i],logits=logits_cln[i])) 91 | cross_entropy_flc[i] = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y[:,i],logits=logits_fln[i])) 92 | cross_ent_cln_40 = tf.reduce_sum(cross_entropy_cln) 93 | cross_ent_fln_40 = tf.reduce_sum(cross_entropy_flc) 94 | 95 | regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 96 | reg_variables = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 97 | reg_term = tf.contrib.layers.apply_regularization(regularizer, reg_variables) 98 | reg_term_summary = tf.summary.scalar('reg_term', reg_term) 99 | loss_cln = cross_ent_cln_40+reg_term 100 | loss_fln = cross_ent_fln_40+reg_term 101 | loss_cln_summary = tf.summary.scalar('loss_cln', loss_cln) 102 | loss_fln_summary = tf.summary.scalar('loss_fln', loss_fln) 103 | 104 | acc_cln = [0.0]*40 105 | acc_fln = [0.0]*40 106 | with tf.name_scope("accuracy"): 107 | for i in range(num_of_attri): 108 | temp_y_cln = tf.cast(tf.argmax(logits_cln[i],1),tf.int64) 109 | acc_cln[i] = tf.reduce_mean(tf.cast(tf.equal(temp_y_cln, y[:,i]), tf.float32)) 110 | temp_y_fln = tf.cast(tf.argmax(logits_fln[i],1),tf.int64) 111 | acc_fln[i] = tf.reduce_mean(tf.cast(tf.equal(temp_y_fln, y[:,i]), tf.float32)) 112 | accuracy40_cln = tf.reduce_mean(acc_cln) 113 | accuracy40_fln = tf.reduce_mean(acc_fln) 114 | acc_cln_summary = tf.summary.scalar('acc_cln',accuracy40_cln) 115 | acc_fln_summary = tf.summary.scalar('acc_fln',accuracy40_fln) 116 | 117 | 118 | merged_train_summary = tf.summary.merge([lr_summary, loss_cln_summary, loss_fln_summary, acc_cln_summary,acc_fln_summary, out1_summary]) 119 | merged_vali_summary = tf.summary.merge([loss_cln_summary, loss_fln_summary, acc_cln_summary, acc_fln_summary]) 120 | 121 | summary_writer_train = tf.summary.FileWriter(tensorboard_path+'/train', sess.graph) 122 | summary_writer_test = tf.summary.FileWriter(tensorboard_path+'/test') 123 | 124 | '''training method''' 125 | var = tf.trainable_variables() 126 | var_cln = [] 127 | var_fln = [] 128 | var_backbone = [] 129 | for v in var: 130 | if "block" not in v.name and "cln" not in v.name: 131 | var_backbone.append(v) 132 | elif "block" in v.name and "cln" not in v.name: 133 | var_fln.append(v) 134 | elif "cln" in v.name: 135 | var_cln.append(v) 136 | 137 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 138 | with tf.control_dependencies(update_ops): 139 | train_backbone = tf.train.MomentumOptimizer(learning_rate, 0.9).minimize(loss_fln, var_list=var_backbone, global_step=global_step) 140 | train_cln = tf.train.MomentumOptimizer(learning_rate*10, 0.9).minimize(loss_cln, var_list=var_cln, global_step=global_step) 141 | train_fln = tf.train.MomentumOptimizer(learning_rate*10, 0.9).minimize(loss_fln, var_list=var_fln, global_step=global_step) 142 | train_op = tf.group(train_backbone, train_cln, train_fln) 143 | 144 | sess.run(tf.global_variables_initializer()) 145 | 146 | 147 | '''main''' 148 | if pretrain: 149 | #############################alexbn################################################## 150 | print("pretrain...") 151 | weights_dict = np.load(pretrain_path, encoding='bytes').item() 152 | for op_name in weights_dict: 153 | if op_name not in ['fc6', 'fc7', 'fc8', 'fc6_bn', 'fc7_bn', 'fc8_bn']: 154 | with tf.variable_scope(op_name, reuse=True): 155 | data = weights_dict[op_name] 156 | if 'bn' in op_name: 157 | var = tf.get_variable('beta', trainable=True) 158 | sess.run(var.assign(data['mean'])) 159 | var = tf.get_variable('gamma', trainable=True) 160 | sess.run(var.assign(data['variance'])) 161 | else: 162 | var = tf.get_variable('b', trainable=True) 163 | sess.run(var.assign(data['biases'])) 164 | var = tf.get_variable('w', trainable=True) 165 | sess.run(var.assign(data['weights'])) 166 | print("restoring:"+op_name) 167 | #############################alexbn################################################## 168 | 169 | if continue_train: 170 | saver = tf.train.Saver(max_to_keep = None) 171 | saver.restore(sess, last_model_path) 172 | print("restoring...") 173 | 174 | print("training...") 175 | threads = tf.train.start_queue_runners(sess=sess) 176 | for i in range(total_step): 177 | x_,y_= sess.run([img_batch,label_batch]) 178 | op = [accuracy40_fln, accuracy40_cln, cross_ent_cln_40, cross_ent_fln_40, loss_cln, loss_fln, merged_train_summary, train_op, global_step] 179 | at_v, af_v, cet_v, cef_v, lt_v, lf_v, merged_train_summary_str, _, step=sess.run(op,feed_dict={x_image:x_, y: y_, phase_train:True}) 180 | step /= 3 181 | print("step:{}, cet:{:.4f}, lt:{:.4f}, at:{:.4f}, cef:{:.4f}, lf:{:.4f}, af:{:.4f}\r".format(step, cet_v, lt_v, at_v, cef_v, lf_v, af_v)) 182 | if step%25==0: 183 | summary_writer_train.add_summary(merged_train_summary_str,step) 184 | if step%100 == 0: 185 | x_,y_= sess.run([img_batch_val,label_batch_val]) 186 | vali_op = [accuracy40_cln, accuracy40_fln, cross_ent_cln_40, cross_ent_fln_40, loss_cln, loss_fln, merged_vali_summary] 187 | at_v, af_v, cet_v, cef_v, lt_v, lf_v, merged_vali_summary_str = sess.run(vali_op,feed_dict={x_image:x_, y: y_,phase_train:False}) 188 | summary_writer_test.add_summary(merged_vali_summary_str,step) 189 | print("validating...") 190 | print("vali step:{}, cet:{:.4f}, lt:{:.4f}, at:{:.4f}, cef:{:.4f}, lf:{:.4f}, af:{:.4f}\r".format(step, cet_v, lt_v, at_v, cef_v, lf_v, af_v)) 191 | print("training...") 192 | if step%100 == 0 and at_v>0.92: 193 | save_path = model_path + "model.ckpt" 194 | saver.save(sess,save_path,global_step=step) 195 | if step==total_step: 196 | save_path = model_path + "model.ckpt" 197 | saver.save(sess,save_path,global_step=step) 198 | summary_writer_train.close() 199 | summary_writer_test.close() 200 | sess.close() 201 | --------------------------------------------------------------------------------