├── .gitignore ├── .idea ├── inspectionProfiles │ └── Project_Default.xml ├── misc.xml ├── modules.xml ├── sphereface_tensorflow.iml ├── vcs.xml └── workspace.xml ├── README.md ├── data └── __init__.py ├── requirements.txt ├── src ├── __init__.py ├── align │ └── __init__.py ├── data │ └── learning_rate_classifier_casia.txt ├── loss │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── sphere.cpython-35.pyc │ │ └── sphere.cpython-36.pyc │ └── sphere.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── inception_resnet.cpython-35.pyc │ ├── inception_resnet.py │ ├── inception_resnet_no_train.py │ ├── small_inceptin_resnet.py │ └── squeezenet.py ├── printPretrainedTensor.py ├── train_demo.py ├── train_softmax_demo.py ├── train_softmax_no_train_inception.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── faceUtil.cpython-35.pyc │ └── faceUtil.cpython-36.pyc │ └── faceUtil.py └── test ├── Loss_ASoftmax.ipynb ├── .ipynb_checkpoints ├── Loss_ASoftmax-checkpoint.ipynb └── sphereloss-checkpoint.ipynb ├── __init__.py └── sphereloss.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | src/align/casia_maxpy_mtcnnpy_182 2 | src/align 3 | src/modeltrained/20170512 4 | src/models/inception_resnet_v1.py 5 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 67 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/sphereface_tensorflow.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 14 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | This is a simple re-implemention of a-softmax loss which proposed in paper: [SphereFace: Deep Hypersphere Embedding for Face Recognition](https://arxiv.org/abs/1704.08063). Please cite it if it helps in your paper.
4 | This just contain a sphereloss and train demo.
5 | Thanks to [sphereface](https://github.com/wy1iu/sphereface),It gives me a lot of inspiration. 6 | 7 | ## Requirements 8 | Tensorflow1.2+
9 | scipy
10 | scikit-learn
11 | opencv-python
12 | h5py
13 | matplotlib
14 | Pillow
15 | requests
16 | psutil 17 | 18 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/data/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.2 2 | scipy 3 | scikit-learn 4 | opencv-python 5 | h5py 6 | matplotlib 7 | Pillow 8 | requests 9 | psutil -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/__init__.py -------------------------------------------------------------------------------- /src/align/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/align/__init__.py -------------------------------------------------------------------------------- /src/data/learning_rate_classifier_casia.txt: -------------------------------------------------------------------------------- 1 | # Learning rate schedule 2 | # Maps an epoch number to a learning rate 3 | 0: 0.001 4 | 65: 0.0001 5 | 77: 0.0001 6 | 1000: 0.0001 -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/loss/__init__.py -------------------------------------------------------------------------------- /src/loss/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/loss/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/sphere.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/loss/__pycache__/sphere.cpython-35.pyc -------------------------------------------------------------------------------- /src/loss/__pycache__/sphere.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/loss/__pycache__/sphere.cpython-36.pyc -------------------------------------------------------------------------------- /src/loss/sphere.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def sphereloss(inputs,label,classes,batch_size,fraction = 1, scope='Logits',reuse=None,m =4,eplion = 1e-8): 4 | """ 5 | inputs tensor shape=[batch,features_num] 6 | labels tensor shape=[batch] each unit belong num_outputs 7 | 8 | """ 9 | inputs_shape = inputs.get_shape().as_list() 10 | with tf.variable_scope(name_or_scope=scope): 11 | weight = tf.Variable(initial_value=tf.random_normal((classes,inputs_shape[1])) * tf.sqrt(2 / inputs_shape[1]),dtype=tf.float32,name='weights') # shaep =classes, features, 12 | print("weight shape = ",weight.get_shape().as_list()) 13 | 14 | weight_unit = tf.nn.l2_normalize(weight,dim=1) 15 | print("weight_unit shape = ",weight_unit.get_shape().as_list()) 16 | 17 | inputs_mo = tf.sqrt(tf.reduce_sum(tf.square(inputs),axis=1)+eplion) #shape=[batch 18 | print("inputs_mo shape = ",inputs_mo.get_shape().as_list()) 19 | 20 | inputs_unit = tf.nn.l2_normalize(inputs,dim=1) #shape = [batch,features_num] 21 | print("inputs_unit shape = ",inputs_unit.get_shape().as_list()) 22 | 23 | logits = tf.matmul(inputs,tf.transpose(weight_unit)) #shape = [batch,classes] x * w_unit 24 | print("logits shape = ",logits.get_shape().as_list()) 25 | 26 | weight_unit_batch = tf.gather(weight_unit,label) # shaep =batch,features_num, 27 | print("weight_unit_batch shape = ",weight_unit_batch.get_shape().as_list()) 28 | 29 | logits_inputs = tf.reduce_sum(tf.multiply(inputs,weight_unit_batch),axis=1) # shaep =batch, 30 | 31 | print("logits_inputs shape = ",logits_inputs.get_shape().as_list()) 32 | 33 | cos_theta = tf.reduce_sum(tf.multiply(inputs_unit,weight_unit_batch),axis=1) # shaep =batch, 34 | print("cos_theta shape = ",cos_theta.get_shape().as_list()) 35 | 36 | cos_theta_square = tf.square(cos_theta) 37 | cos_theta_biq = tf.pow(cos_theta,4) 38 | sign0 = tf.sign(cos_theta) 39 | sign2 = tf.sign(2 * cos_theta_square-1) 40 | sign3 = tf.multiply(sign2,sign0) 41 | sign4 = 2 * sign0 +sign3 -3 42 | cos_far_theta = sign3 * (8 * cos_theta_biq - 8 * cos_theta_square + 1) + sign4 43 | print("cos_far_theta = ",cos_far_theta.get_shape().as_list()) 44 | 45 | logit_ii = tf.multiply(cos_far_theta,inputs_mo)#shape = batch 46 | print("logit_ii shape = ",logit_ii.get_shape().as_list()) 47 | 48 | index_range = tf.range(start=0,limit= tf.shape(inputs,out_type=tf.int64)[0],delta=1,dtype=tf.int64) 49 | index_labels = tf.stack([index_range, label], axis = 1) 50 | index_logits = tf.scatter_nd(index_labels,tf.subtract(logit_ii,logits_inputs), tf.shape(logits,out_type=tf.int64)) 51 | print("index_logits shape = ",logit_ii.get_shape().as_list()) 52 | 53 | logits_final = tf.add(logits,index_logits) 54 | logits_final = fraction * logits_final + (1 - fraction) * logits 55 | 56 | 57 | loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label,logits=logits_final)) 58 | 59 | return logits_final,loss 60 | 61 | 62 | def soft_loss(inputs,label,classes,scope='Logits'): 63 | """ 64 | inputs tensor shape=[batch,features_num] 65 | labels tensor shape=[batch] each unit belong num_outputs 66 | 67 | """ 68 | inputs_shape = inputs.get_shape().as_list() 69 | with tf.variable_scope(name_or_scope=scope): 70 | weight = tf.Variable(initial_value=tf.random_normal((classes,inputs_shape[1])) * tf.sqrt(2 / inputs_shape[1]), 71 | dtype=tf.float32,name='weights') # shaep =classes, features, 72 | bias = tf.Variable(initial_value=tf.zeros(classes),dtype=tf.float32,name='bias') 73 | print("weight shape = ",weight.get_shape().as_list()) 74 | print("bias shape = ", bias.get_shape().as_list()) 75 | 76 | weight = tf.Print(weight, [tf.shape(weight)], message='logits weights shape = ',summarize=4, first_n=1) 77 | logits = tf.nn.bias_add(tf.matmul(inputs,tf.transpose(weight)),bias,name='logits') 78 | loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label,logits=logits, 79 | name='cross_entropy_per_example'), 80 | name='cross_entropy') 81 | 82 | return logits,loss 83 | 84 | def soft_loss_nobias(inputs,label,classes,scope='Logits'): 85 | """ 86 | inputs tensor shape=[batch,features_num] 87 | labels tensor shape=[batch] each unit belong num_outputs 88 | 89 | """ 90 | inputs_shape = inputs.get_shape().as_list() 91 | with tf.variable_scope(name_or_scope=scope): 92 | weight = tf.Variable(initial_value=tf.random_normal((classes,inputs_shape[1])) * tf.sqrt(2 / inputs_shape[1]), 93 | dtype=tf.float32,name='weights') # shaep =classes, features, 94 | print("weight shape = ",weight.get_shape().as_list()) 95 | 96 | weight = tf.Print(weight, [tf.shape(weight)], message='logits weights shape = ',summarize=4, first_n=1) 97 | logits =tf.matmul(inputs,tf.transpose(weight),name='logits') 98 | loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label,logits=logits, 99 | name='cross_entropy_per_example'), 100 | name='cross_entropy') 101 | 102 | return logits,loss 103 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | -------------------------------------------------------------------------------- /src/models/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/models/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/inception_resnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/models/__pycache__/inception_resnet.cpython-35.pyc -------------------------------------------------------------------------------- /src/models/inception_resnet.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import tensorflow as tf 7 | import tensorflow.contrib.slim as slim 8 | 9 | # Inception-Renset-A 10 | def incep_res_a(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): 11 | """Builds the 35x35 resnet block.""" 12 | with tf.variable_scope(scope, 'Block35', [net], reuse=reuse): 13 | with tf.variable_scope('Branch_0'): 14 | tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1') 15 | with tf.variable_scope('Branch_1'): 16 | tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1') 17 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3') 18 | with tf.variable_scope('Branch_2'): 19 | tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1') 20 | tower_conv2_1 = slim.conv2d(tower_conv2_0, 32, 3, scope='Conv2d_0b_3x3') 21 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 32, 3, scope='Conv2d_0c_3x3') 22 | mixed = tf.concat([tower_conv, tower_conv1_1, tower_conv2_2], 3) 23 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, 24 | activation_fn=None, scope='Conv2d_1x1') 25 | net += scale * up 26 | if activation_fn: 27 | net = activation_fn(net) 28 | return net 29 | 30 | # Inception-Renset-B 31 | def incep_res_b(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): 32 | """Builds the 17x17 resnet block.""" 33 | with tf.variable_scope(scope, 'Block17', [net], reuse=reuse): 34 | with tf.variable_scope('Branch_0'): 35 | tower_conv = slim.conv2d(net, 128, 1, scope='Conv2d_1x1') 36 | with tf.variable_scope('Branch_1'): 37 | tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1') 38 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 128, [1, 7], 39 | scope='Conv2d_0b_1x7') 40 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 128, [7, 1], 41 | scope='Conv2d_0c_7x1') 42 | mixed = tf.concat([tower_conv, tower_conv1_2], 3) 43 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, 44 | activation_fn=None, scope='Conv2d_1x1') 45 | net += scale * up 46 | if activation_fn: 47 | net = activation_fn(net) 48 | return net 49 | 50 | 51 | # Inception-Resnet-C 52 | def incep_res_c(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): 53 | """Builds the 8x8 resnet block.""" 54 | with tf.variable_scope(scope, 'Block8', [net], reuse=reuse): 55 | with tf.variable_scope('Branch_0'): 56 | tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1') 57 | with tf.variable_scope('Branch_1'): 58 | tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1') 59 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 192, [1, 3], 60 | scope='Conv2d_0b_1x3') 61 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [3, 1], 62 | scope='Conv2d_0c_3x1') 63 | mixed = tf.concat([tower_conv, tower_conv1_2], 3) 64 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, 65 | activation_fn=None, scope='Conv2d_1x1') 66 | net += scale * up 67 | if activation_fn: 68 | net = activation_fn(net) 69 | return net 70 | 71 | def reduction_a(net, k, l, m, n): 72 | with tf.variable_scope('Branch_0'): 73 | tower_conv = slim.conv2d(net, n, 3, stride=2, padding='VALID', 74 | scope='Conv2d_1a_3x3') 75 | with tf.variable_scope('Branch_1'): 76 | tower_conv1_0 = slim.conv2d(net, k, 1, scope='Conv2d_0a_1x1') 77 | tower_conv1_1 = slim.conv2d(tower_conv1_0, l, 3, 78 | scope='Conv2d_0b_3x3') 79 | tower_conv1_2 = slim.conv2d(tower_conv1_1, m, 3, 80 | stride=2, padding='VALID', 81 | scope='Conv2d_1a_3x3') 82 | with tf.variable_scope('Branch_2'): 83 | tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID', 84 | scope='MaxPool_1a_3x3') 85 | net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3) 86 | return net 87 | 88 | def reduction_b(net): 89 | with tf.variable_scope('Branch_0'): 90 | tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 91 | tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2, 92 | padding='VALID', scope='Conv2d_1a_3x3') 93 | with tf.variable_scope('Branch_1'): 94 | tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 95 | tower_conv1_1 = slim.conv2d(tower_conv1, 256, 3, stride=2, 96 | padding='VALID', scope='Conv2d_1a_3x3') 97 | with tf.variable_scope('Branch_2'): 98 | tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 99 | tower_conv2_1 = slim.conv2d(tower_conv2, 256, 3, 100 | scope='Conv2d_0b_3x3') 101 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 256, 3, stride=2, 102 | padding='VALID', scope='Conv2d_1a_3x3') 103 | with tf.variable_scope('Branch_3'): 104 | tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID', 105 | scope='MaxPool_1a_3x3') 106 | net = tf.concat([tower_conv_1, tower_conv1_1, 107 | tower_conv2_2, tower_pool], 3) 108 | return net 109 | 110 | def inference(images, keep_probability, phase_train=True, 111 | bottleneck_layer_size=128, weight_decay=0.0, reuse=None): 112 | batch_norm_params = { 113 | # Decay for the moving averages. 114 | 'decay': 0.995, 115 | # epsilon to prevent 0s in variance. 116 | 'epsilon': 0.001, 117 | # force in-place updates of mean and variance estimates 118 | 'updates_collections': None, 119 | # Moving averages ends up in the trainable variables collection 120 | 'variables_collections': [ tf.GraphKeys.TRAINABLE_VARIABLES ], 121 | } 122 | 123 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 124 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1), 125 | weights_regularizer=slim.l2_regularizer(weight_decay), 126 | normalizer_fn=slim.batch_norm, 127 | normalizer_params=batch_norm_params): 128 | return inception_resnet_v1(images, is_training=phase_train, 129 | dropout_keep_prob=keep_probability, bottleneck_layer_size=bottleneck_layer_size, reuse=reuse) 130 | 131 | 132 | def inception_resnet_v1(inputs, is_training=True, 133 | dropout_keep_prob=0.8, 134 | bottleneck_layer_size=128, 135 | reuse=None, 136 | scope='InceptionResnetV1'): 137 | """Creates the Inception Resnet V1 model. 138 | Args: 139 | inputs: a 4-D tensor of size [batch_size, height, width, 3]. 140 | num_classes: number of predicted classes. 141 | is_training: whether is training or not. 142 | dropout_keep_prob: float, the fraction to keep before final layer. 143 | reuse: whether or not the network and its variables should be reused. To be 144 | able to reuse 'scope' must be given. 145 | scope: Optional variable_scope. 146 | Returns: 147 | logits: the logits outputs of the model. 148 | end_points: the set of end_points from the inception model. 149 | """ 150 | end_points = {} 151 | 152 | with tf.variable_scope(scope, 'InceptionResnetV1', [inputs], reuse=reuse): 153 | with slim.arg_scope([slim.batch_norm, slim.dropout], 154 | is_training=is_training): 155 | with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d], 156 | stride=1, padding='SAME'): 157 | 158 | # 149 x 149 x 32 159 | net = slim.conv2d(inputs, 32, 3, stride=2, padding='VALID', 160 | scope='Conv2d_1a_3x3') 161 | end_points['Conv2d_1a_3x3'] = net 162 | # 147 x 147 x 32 163 | net = slim.conv2d(net, 32, 3, padding='VALID', 164 | scope='Conv2d_2a_3x3') 165 | end_points['Conv2d_2a_3x3'] = net 166 | # 147 x 147 x 64 167 | net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3') 168 | end_points['Conv2d_2b_3x3'] = net 169 | # 73 x 73 x 64 170 | net = slim.max_pool2d(net, 3, stride=2, padding='VALID', 171 | scope='MaxPool_3a_3x3') 172 | end_points['MaxPool_3a_3x3'] = net 173 | # 73 x 73 x 80 174 | net = slim.conv2d(net, 80, 1, padding='VALID', 175 | scope='Conv2d_3b_1x1') 176 | end_points['Conv2d_3b_1x1'] = net 177 | # 71 x 71 x 192 178 | net = slim.conv2d(net, 192, 3, padding='VALID', 179 | scope='Conv2d_4a_3x3') 180 | end_points['Conv2d_4a_3x3'] = net 181 | # 35 x 35 x 256 182 | net = slim.conv2d(net, 256, 3, stride=2, padding='VALID', 183 | scope='Conv2d_4b_3x3') 184 | end_points['Conv2d_4b_3x3'] = net 185 | 186 | # 5 x Inception-resnet-A 35 x 35 x 256 187 | net = slim.repeat(net, 5, incep_res_a, scale=0.17) 188 | end_points['Mixed_5a'] = net 189 | 190 | # Reduction-A 191 | with tf.variable_scope('Mixed_6a'): 192 | net = reduction_a(net, 192, 192, 256, 384) 193 | end_points['Mixed_6a'] = net 194 | 195 | # 10 x Inception-Resnet-B 196 | net = slim.repeat(net, 10, incep_res_b, scale=0.10) 197 | end_points['Mixed_6b'] = net 198 | 199 | # Reduction-B 200 | with tf.variable_scope('Mixed_7a'): 201 | net = reduction_b(net) 202 | end_points['Mixed_7a'] = net 203 | 204 | # 5 x Inception-Resnet-C 205 | net = slim.repeat(net, 5, incep_res_c, scale=0.20) 206 | end_points['Mixed_8a'] = net 207 | 208 | net = incep_res_c(net, activation_fn=None) 209 | end_points['Mixed_8b'] = net 210 | 211 | with tf.variable_scope('Logits'): 212 | end_points['PrePool'] = net 213 | #pylint: disable=no-member 214 | net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID', 215 | scope='AvgPool_1a_8x8') 216 | net = slim.flatten(net) 217 | 218 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 219 | scope='Dropout') 220 | 221 | end_points['PreLogitsFlatten'] = net 222 | 223 | net = slim.fully_connected(net, bottleneck_layer_size, activation_fn=None, 224 | scope='Bottleneck', reuse=False) 225 | 226 | return net, end_points 227 | -------------------------------------------------------------------------------- /src/models/inception_resnet_no_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Contains the definition of the Inception Resnet V1 architecture. 17 | As described in http://arxiv.org/abs/1602.07261. 18 | Inception-v4, Inception-ResNet and the Impact of Residual Connections 19 | on Learning 20 | Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi 21 | """ 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import tensorflow as tf 27 | import tensorflow.contrib.slim as slim 28 | 29 | # Inception-Renset-A 30 | def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): 31 | """Builds the 35x35 resnet block.""" 32 | with tf.variable_scope(scope, 'Block35', [net], reuse=reuse): 33 | with tf.variable_scope('Branch_0'): 34 | tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1') 35 | with tf.variable_scope('Branch_1'): 36 | tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1') 37 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3') 38 | with tf.variable_scope('Branch_2'): 39 | tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1') 40 | tower_conv2_1 = slim.conv2d(tower_conv2_0, 32, 3, scope='Conv2d_0b_3x3') 41 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 32, 3, scope='Conv2d_0c_3x3') 42 | mixed = tf.concat([tower_conv, tower_conv1_1, tower_conv2_2], 3) 43 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, 44 | activation_fn=None, scope='Conv2d_1x1') 45 | net += scale * up 46 | if activation_fn: 47 | net = activation_fn(net) 48 | return net 49 | 50 | # Inception-Renset-B 51 | def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): 52 | """Builds the 17x17 resnet block.""" 53 | with tf.variable_scope(scope, 'Block17', [net], reuse=reuse): 54 | with tf.variable_scope('Branch_0'): 55 | tower_conv = slim.conv2d(net, 128, 1, scope='Conv2d_1x1') 56 | with tf.variable_scope('Branch_1'): 57 | tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1') 58 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 128, [1, 7], 59 | scope='Conv2d_0b_1x7') 60 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 128, [7, 1], 61 | scope='Conv2d_0c_7x1') 62 | mixed = tf.concat([tower_conv, tower_conv1_2], 3) 63 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, 64 | activation_fn=None, scope='Conv2d_1x1') 65 | net += scale * up 66 | if activation_fn: 67 | net = activation_fn(net) 68 | return net 69 | 70 | 71 | # Inception-Resnet-C 72 | def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): 73 | """Builds the 8x8 resnet block.""" 74 | with tf.variable_scope(scope, 'Block8', [net], reuse=reuse): 75 | with tf.variable_scope('Branch_0'): 76 | tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1') 77 | with tf.variable_scope('Branch_1'): 78 | tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1') 79 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 192, [1, 3], 80 | scope='Conv2d_0b_1x3') 81 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [3, 1], 82 | scope='Conv2d_0c_3x1') 83 | mixed = tf.concat([tower_conv, tower_conv1_2], 3) 84 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, 85 | activation_fn=None, scope='Conv2d_1x1') 86 | net += scale * up 87 | if activation_fn: 88 | net = activation_fn(net) 89 | return net 90 | 91 | def reduction_a(net, k, l, m, n): 92 | with tf.variable_scope('Branch_0'): 93 | tower_conv = slim.conv2d(net, n, 3, stride=2, padding='VALID', 94 | scope='Conv2d_1a_3x3') 95 | with tf.variable_scope('Branch_1'): 96 | tower_conv1_0 = slim.conv2d(net, k, 1, scope='Conv2d_0a_1x1') 97 | tower_conv1_1 = slim.conv2d(tower_conv1_0, l, 3, 98 | scope='Conv2d_0b_3x3') 99 | tower_conv1_2 = slim.conv2d(tower_conv1_1, m, 3, 100 | stride=2, padding='VALID', 101 | scope='Conv2d_1a_3x3') 102 | with tf.variable_scope('Branch_2'): 103 | tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID', 104 | scope='MaxPool_1a_3x3') 105 | net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3) 106 | return net 107 | 108 | def reduction_b(net): 109 | with tf.variable_scope('Branch_0'): 110 | tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 111 | tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2, 112 | padding='VALID', scope='Conv2d_1a_3x3') 113 | with tf.variable_scope('Branch_1'): 114 | tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 115 | tower_conv1_1 = slim.conv2d(tower_conv1, 256, 3, stride=2, 116 | padding='VALID', scope='Conv2d_1a_3x3') 117 | with tf.variable_scope('Branch_2'): 118 | tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 119 | tower_conv2_1 = slim.conv2d(tower_conv2, 256, 3, 120 | scope='Conv2d_0b_3x3') 121 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 256, 3, stride=2, 122 | padding='VALID', scope='Conv2d_1a_3x3') 123 | with tf.variable_scope('Branch_3'): 124 | tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID', 125 | scope='MaxPool_1a_3x3') 126 | net = tf.concat([tower_conv_1, tower_conv1_1, 127 | tower_conv2_2, tower_pool], 3) 128 | return net 129 | 130 | def inference(images, keep_probability, phase_train=True, 131 | bottleneck_layer_size=128, weight_decay=0.0, reuse=None): 132 | batch_norm_params = { 133 | # Decay for the moving averages. 134 | 'decay': 0.995, 135 | # epsilon to prevent 0s in variance. 136 | 'epsilon': 0.001, 137 | # force in-place updates of mean and variance estimates 138 | 'updates_collections': None, 139 | # Moving averages ends up in the trainable variables collection 140 | 'variables_collections': [ tf.GraphKeys.TRAINABLE_VARIABLES ], 141 | } 142 | 143 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 144 | weights_initializer=tf.contrib.layers.xavier_initializer(), 145 | # tf.truncated_normal_initializer(stddev=0.01), 146 | weights_regularizer=slim.l2_regularizer(weight_decay), 147 | normalizer_fn=slim.batch_norm, 148 | normalizer_params=batch_norm_params): 149 | return inception_resnet_v1(images, is_training=phase_train, 150 | dropout_keep_prob=keep_probability, bottleneck_layer_size=bottleneck_layer_size, reuse=reuse) 151 | 152 | 153 | def inception_resnet_v1(inputs, is_training=True, 154 | dropout_keep_prob=0.8, 155 | bottleneck_layer_size=128, 156 | reuse=None, 157 | scope='InceptionResnetV1'): 158 | """Creates the Inception Resnet V1 model. 159 | Args: 160 | inputs: a 4-D tensor of size [batch_size, height, width, 3]. 161 | num_classes: number of predicted classes. 162 | is_training: whether is training or not. 163 | dropout_keep_prob: float, the fraction to keep before final layer. 164 | reuse: whether or not the network and its variables should be reused. To be 165 | able to reuse 'scope' must be given. 166 | scope: Optional variable_scope. 167 | Returns: 168 | logits: the logits outputs of the model. 169 | end_points: the set of end_points from the inception model. 170 | """ 171 | end_points = {} 172 | 173 | with tf.variable_scope(scope, 'InceptionResnetV1', [inputs], reuse=reuse): 174 | with slim.arg_scope([slim.batch_norm, slim.dropout], 175 | is_training=is_training): 176 | with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d], 177 | stride=1, padding='SAME'): 178 | with slim.arg_scope([slim.conv2d,slim.fully_connected],trainable=False): 179 | 180 | # 149 x 149 x 32 181 | net = slim.conv2d(inputs, 32, 3, stride=2, padding='VALID', 182 | scope='Conv2d_1a_3x3') 183 | end_points['Conv2d_1a_3x3'] = net 184 | # 147 x 147 x 32 185 | net = slim.conv2d(net, 32, 3, padding='VALID', 186 | scope='Conv2d_2a_3x3') 187 | end_points['Conv2d_2a_3x3'] = net 188 | # 147 x 147 x 64 189 | net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3') 190 | end_points['Conv2d_2b_3x3'] = net 191 | # 73 x 73 x 64 192 | net = slim.max_pool2d(net, 3, stride=2, padding='VALID', 193 | scope='MaxPool_3a_3x3') 194 | end_points['MaxPool_3a_3x3'] = net 195 | # 73 x 73 x 80 196 | net = slim.conv2d(net, 80, 1, padding='VALID', 197 | scope='Conv2d_3b_1x1') 198 | end_points['Conv2d_3b_1x1'] = net 199 | # 71 x 71 x 192 200 | net = slim.conv2d(net, 192, 3, padding='VALID', 201 | scope='Conv2d_4a_3x3') 202 | end_points['Conv2d_4a_3x3'] = net 203 | # 35 x 35 x 256 204 | net = slim.conv2d(net, 256, 3, stride=2, padding='VALID', 205 | scope='Conv2d_4b_3x3') 206 | end_points['Conv2d_4b_3x3'] = net 207 | 208 | # 5 x Inception-resnet-A 35 x 35 x 256 209 | net = slim.repeat(net, 5, block35, scale=0.17) 210 | end_points['Mixed_5a'] = net 211 | 212 | # Reduction-A 213 | with tf.variable_scope('Mixed_6a'): 214 | net = reduction_a(net, 192, 192, 256, 384) 215 | end_points['Mixed_6a'] = net 216 | 217 | # 10 x Inception-Resnet-B 218 | net = slim.repeat(net, 10, block17, scale=0.10) 219 | end_points['Mixed_6b'] = net 220 | 221 | # Reduction-B 222 | with tf.variable_scope('Mixed_7a'): 223 | net = reduction_b(net) 224 | end_points['Mixed_7a'] = net 225 | 226 | # 5 x Inception-Resnet-C 227 | net = slim.repeat(net, 5, block8, scale=0.20) 228 | end_points['Mixed_8a'] = net 229 | 230 | net = block8(net, activation_fn=None) 231 | end_points['Mixed_8b'] = net 232 | 233 | with tf.variable_scope('Logits'): 234 | end_points['PrePool'] = net 235 | #pylint: disable=no-member 236 | net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID', 237 | scope='AvgPool_1a_8x8') 238 | net = slim.flatten(net) 239 | 240 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 241 | scope='Dropout') 242 | 243 | end_points['PreLogitsFlatten'] = net 244 | 245 | net = slim.fully_connected(net, bottleneck_layer_size, activation_fn=None, 246 | scope='Bottleneck', reuse=False) 247 | 248 | return net, end_points 249 | -------------------------------------------------------------------------------- /src/models/small_inceptin_resnet.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import tensorflow as tf 7 | import tensorflow.contrib.slim as slim 8 | 9 | # Inception-Renset-A 10 | def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): 11 | """Builds the 35x35 resnet block.""" 12 | with tf.variable_scope(scope, 'Block35', [net], reuse=reuse): 13 | with tf.variable_scope('Branch_0'): 14 | tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1') 15 | with tf.variable_scope('Branch_1'): 16 | tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1') 17 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3') 18 | with tf.variable_scope('Branch_2'): 19 | tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1') 20 | tower_conv2_1 = slim.conv2d(tower_conv2_0, 32, 3, scope='Conv2d_0b_3x3') 21 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 32, 3, scope='Conv2d_0c_3x3') 22 | mixed = tf.concat([tower_conv, tower_conv1_1, tower_conv2_2], 3) 23 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, 24 | activation_fn=None, scope='Conv2d_1x1') 25 | net += scale * up 26 | if activation_fn: 27 | net = activation_fn(net) 28 | return net 29 | 30 | # Inception-Renset-B 31 | def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): 32 | """Builds the 17x17 resnet block.""" 33 | with tf.variable_scope(scope, 'Block17', [net], reuse=reuse): 34 | with tf.variable_scope('Branch_0'): 35 | tower_conv = slim.conv2d(net, 128, 1, scope='Conv2d_1x1') 36 | with tf.variable_scope('Branch_1'): 37 | tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1') 38 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 128, [1, 7], 39 | scope='Conv2d_0b_1x7') 40 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 128, [7, 1], 41 | scope='Conv2d_0c_7x1') 42 | mixed = tf.concat([tower_conv, tower_conv1_2], 3) 43 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, 44 | activation_fn=None, scope='Conv2d_1x1') 45 | net += scale * up 46 | if activation_fn: 47 | net = activation_fn(net) 48 | return net 49 | 50 | 51 | # Inception-Resnet-C 52 | def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): 53 | """Builds the 8x8 resnet block.""" 54 | with tf.variable_scope(scope, 'Block8', [net], reuse=reuse): 55 | with tf.variable_scope('Branch_0'): 56 | tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1') 57 | with tf.variable_scope('Branch_1'): 58 | tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1') 59 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 192, [1, 3], 60 | scope='Conv2d_0b_1x3') 61 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [3, 1], 62 | scope='Conv2d_0c_3x1') 63 | mixed = tf.concat([tower_conv, tower_conv1_2], 3) 64 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, 65 | activation_fn=None, scope='Conv2d_1x1') 66 | net += scale * up 67 | if activation_fn: 68 | net = activation_fn(net) 69 | return net 70 | 71 | def reduction_a(net, k, l, m, n): 72 | with tf.variable_scope('Branch_0'): 73 | tower_conv = slim.conv2d(net, n, 3, stride=2, padding='VALID', 74 | scope='Conv2d_1a_3x3') 75 | with tf.variable_scope('Branch_1'): 76 | tower_conv1_0 = slim.conv2d(net, k, 1, scope='Conv2d_0a_1x1') 77 | tower_conv1_1 = slim.conv2d(tower_conv1_0, l, 3, 78 | scope='Conv2d_0b_3x3') 79 | tower_conv1_2 = slim.conv2d(tower_conv1_1, m, 3, 80 | stride=2, padding='VALID', 81 | scope='Conv2d_1a_3x3') 82 | with tf.variable_scope('Branch_2'): 83 | tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID', 84 | scope='MaxPool_1a_3x3') 85 | net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3) 86 | return net 87 | 88 | def reduction_b(net): 89 | with tf.variable_scope('Branch_0'): 90 | tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 91 | tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=1, 92 | padding='VALID', scope='Conv2d_1a_3x3') 93 | with tf.variable_scope('Branch_1'): 94 | tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 95 | tower_conv1_1 = slim.conv2d(tower_conv1, 256, 3, stride=1, 96 | padding='VALID', scope='Conv2d_1a_3x3') 97 | with tf.variable_scope('Branch_2'): 98 | tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 99 | tower_conv2_1 = slim.conv2d(tower_conv2, 256, 3, 100 | scope='Conv2d_0b_3x3') 101 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 256, 3, stride=1, 102 | padding='VALID', scope='Conv2d_1a_3x3') 103 | with tf.variable_scope('Branch_3'): 104 | tower_pool = slim.max_pool2d(net, 3, stride=1, padding='VALID', 105 | scope='MaxPool_1a_3x3') 106 | net = tf.concat([tower_conv_1, tower_conv1_1, 107 | tower_conv2_2, tower_pool], 3) 108 | return net 109 | 110 | def inference(images, keep_probability, phase_train=True, 111 | bottleneck_layer_size=128, weight_decay=0.0, reuse=None): 112 | batch_norm_params = { 113 | # Decay for the moving averages. 114 | 'decay': 0.995, 115 | # epsilon to prevent 0s in variance. 116 | 'epsilon': 0.001, 117 | # force in-place updates of mean and variance estimates 118 | 'updates_collections': None, 119 | # Moving averages ends up in the trainable variables collection 120 | 'variables_collections': [ tf.GraphKeys.TRAINABLE_VARIABLES ], 121 | } 122 | 123 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 124 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1), 125 | weights_regularizer=slim.l2_regularizer(weight_decay), 126 | normalizer_fn=slim.batch_norm, 127 | normalizer_params=batch_norm_params): 128 | return inception_resnet_v1(images, is_training=phase_train, 129 | dropout_keep_prob=keep_probability, bottleneck_layer_size=bottleneck_layer_size, reuse=reuse) 130 | 131 | 132 | def inception_resnet_v1(inputs, is_training=True, 133 | dropout_keep_prob=0.8, 134 | bottleneck_layer_size=128, 135 | reuse=None, 136 | scope='InceptionResnetV1'): 137 | """Creates model 138 | Args: 139 | inputs: a 4-D tensor of size [batch_size, 32, 32, 3]. 140 | num_classes: number of predicted classes. 141 | is_training: whether is training or not. 142 | dropout_keep_prob: float, the fraction to keep before final layer. 143 | reuse: whether or not the network and its variables should be reused. 144 | scope: Optional variable_scope. 145 | Returns: 146 | logits: the logits outputs of the model. 147 | end_points: the set of end_points from the inception model. 148 | """ 149 | end_points = {} 150 | 151 | with tf.variable_scope(scope, 'InceptionResnetV1', [inputs], reuse=reuse): 152 | with slim.arg_scope([slim.batch_norm, slim.dropout], 153 | is_training=is_training): 154 | with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d], 155 | stride=1, padding='SAME'): 156 | 157 | 158 | #31 x 31 x 32 159 | net = slim.conv2d(inputs,32,3,stride=1,padding='VALID',scope='conv_1_3x3') 160 | 161 | #15 * 15 * 64 162 | net = slim.conv2d(net,64,3,stride=2,padding='VALID',scope='conv_2_3x3') 163 | 164 | #7 * 7 * 96 165 | net = slim.conv2d(net,96,3,stride=2,padding='VALID',scope='conv_3_3x3') 166 | 167 | #7 * 7 * 96 168 | net = slim.repeat(net, 4, block35, scale=0.17) 169 | 170 | #4 * 4 * 224 171 | with tf.variable_scope('Mixed_6a'): 172 | net = reduction_a(net, 32, 32, 64, 96) 173 | 174 | net = slim.repeat(net, 4, block8, scale=0.20) 175 | 176 | with tf.variable_scope('Mixed_7a'): 177 | net = reduction_b(net) 178 | 179 | 180 | with tf.variable_scope('Logits'): 181 | 182 | #pylint: disable=no-member 183 | net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID', 184 | scope='AvgPool_1a_8x8') 185 | net = slim.flatten(net) 186 | 187 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 188 | scope='Dropout') 189 | 190 | end_points['PreLogitsFlatten'] = net 191 | 192 | net = slim.fully_connected(net, bottleneck_layer_size, activation_fn=None, 193 | scope='Bottleneck', reuse=False) 194 | 195 | return net, end_points 196 | 197 | -------------------------------------------------------------------------------- /src/models/squeezenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import tensorflow.contrib.slim as slim 7 | 8 | def fire_module(inputs, 9 | squeeze_depth, 10 | expand_depth, 11 | reuse=None, 12 | scope=None, 13 | outputs_collections=None): 14 | with tf.variable_scope(scope, 'fire', [inputs], reuse=reuse): 15 | with slim.arg_scope([slim.conv2d, slim.max_pool2d], 16 | outputs_collections=None): 17 | net = squeeze(inputs, squeeze_depth) 18 | outputs = expand(net, expand_depth) 19 | return outputs 20 | 21 | def squeeze(inputs, num_outputs): 22 | return slim.conv2d(inputs, num_outputs, [1, 1], stride=1, scope='squeeze') 23 | 24 | def expand(inputs, num_outputs): 25 | with tf.variable_scope('expand'): 26 | e1x1 = slim.conv2d(inputs, num_outputs, [1, 1], stride=1, scope='1x1') 27 | e3x3 = slim.conv2d(inputs, num_outputs, [3, 3], scope='3x3') 28 | return tf.concat([e1x1, e3x3], 3) 29 | 30 | def inference(images, keep_probability, phase_train=True, bottleneck_layer_size=128, weight_decay=0.0, reuse=None): 31 | batch_norm_params = { 32 | # Decay for the moving averages. 33 | 'decay': 0.995, 34 | # epsilon to prevent 0s in variance. 35 | 'epsilon': 0.001, 36 | # force in-place updates of mean and variance estimates 37 | 'updates_collections': None, 38 | # Moving averages ends up in the trainable variables collection 39 | 'variables_collections': [ tf.GraphKeys.TRAINABLE_VARIABLES ], 40 | } 41 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 42 | weights_initializer=slim.xavier_initializer_conv2d(uniform=True), 43 | weights_regularizer=slim.l2_regularizer(weight_decay), 44 | normalizer_fn=slim.batch_norm, 45 | normalizer_params=batch_norm_params): 46 | with tf.variable_scope('squeezenet', [images], reuse=reuse): 47 | with slim.arg_scope([slim.batch_norm, slim.dropout], 48 | is_training=phase_train): 49 | net = slim.conv2d(images, 96, [7, 7], stride=2, scope='conv1') 50 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='maxpool1') 51 | net = fire_module(net, 16, 64, scope='fire2') 52 | net = fire_module(net, 16, 64, scope='fire3') 53 | net = fire_module(net, 32, 128, scope='fire4') 54 | net = slim.max_pool2d(net, [2, 2], stride=2, scope='maxpool4') 55 | net = fire_module(net, 32, 128, scope='fire5') 56 | net = fire_module(net, 48, 192, scope='fire6') 57 | net = fire_module(net, 48, 192, scope='fire7') 58 | net = fire_module(net, 64, 256, scope='fire8') 59 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='maxpool8') 60 | net = fire_module(net, 64, 256, scope='fire9') 61 | net = slim.dropout(net, keep_probability) 62 | net = slim.conv2d(net, 1000, [1, 1], activation_fn=None, normalizer_fn=None, scope='conv10') 63 | net = slim.avg_pool2d(net, net.get_shape()[1:3], scope='avgpool10') 64 | net = tf.squeeze(net, [1, 2], name='logits') 65 | net = slim.fully_connected(net, bottleneck_layer_size, activation_fn=None, 66 | scope='Bottleneck', reuse=False) 67 | return net, None 68 | -------------------------------------------------------------------------------- /src/printPretrainedTensor.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tensorflow.python import pywrap_tensorflow 3 | checkpoint_path = 'modeltrained/20170512/model-20170512-110547.ckpt-250000' 4 | 5 | reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) 6 | var_to_shape_map = reader.get_variable_to_shape_map() 7 | print(len(var_to_shape_map)) 8 | # for key in var_to_shape_map: 9 | # print("tensor_name: ", key) 10 | # print(reader.get_tensor(key).shape) 11 | -------------------------------------------------------------------------------- /src/train_demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import os.path 7 | import sys 8 | import time 9 | from datetime import datetime 10 | 11 | import numpy as np 12 | from tensorflow.python.framework import ops 13 | from tensorflow.python.ops import array_ops 14 | from tensorflow.python.ops import data_flow_ops 15 | 16 | from loss.sphere import * 17 | from models import inception_resnet_v1 as network 18 | from utils import faceUtil 19 | 20 | 21 | def main(args): 22 | print("main start") 23 | np.random.seed(seed=args.seed) 24 | #train_set = ImageClass list 25 | train_set = faceUtil.get_dataset(args.data_dir) 26 | 27 | #总类别 28 | nrof_classes = len(train_set) 29 | print(nrof_classes) 30 | 31 | #subdir =20171122-112109 32 | subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S') 33 | 34 | #log_dir = c:\User\logs\facenet\20171122- 35 | log_dir = os.path.join(os.path.expanduser(args.logs_base_dir),subdir) 36 | if not os.path.isdir(log_dir): 37 | os.makedirs(log_dir) 38 | print("log_dir =",log_dir) 39 | 40 | # model_dir =c:\User/models/facenet/2017;;; 41 | model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir) 42 | if not os.path.isdir(model_dir): # Create the model directory if it doesn't exist 43 | os.makedirs(model_dir) 44 | 45 | print("model_dir =", model_dir) 46 | pretrained_model = None 47 | if args.pretrained_model: 48 | # pretrained_model = os.path.expanduser(args.pretrained_model) 49 | # pretrained_model = tf.train.get_checkpoint_state(args.pretrained_model) 50 | pretrained_model = args.pretrained_model 51 | print('Pre-trained model: %s' % pretrained_model) 52 | 53 | 54 | # Write arguments to a text file 55 | faceUtil.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt')) 56 | print("write_arguments_to_file") 57 | with tf.Graph().as_default(): 58 | tf.set_random_seed(args.seed) 59 | global_step = tf.Variable(0,trainable=False) 60 | 61 | #两个列表 image_list= 图片地址列表, label_list = 对应label列表,两个大小相同 62 | image_list, label_list = faceUtil.get_image_paths_and_labels(train_set) 63 | assert len(image_list) > 0 , 'dataset is empty' 64 | print("len(image_list) = ",len(image_list)) 65 | 66 | # Create a queue that produces indices into the image_list and label_list 67 | labels = ops.convert_to_tensor(label_list,dtype=tf.int64) 68 | range_size = array_ops.shape(labels)[0] 69 | range_size = tf.Print(range_size, [tf.shape(range_size)],message='Shape of range_input_producer range_size : ',summarize=4, first_n=1) 70 | 71 | #产生一个队列,队列包含0到range_size-1的元素,打乱 72 | index_queue = tf.train.range_input_producer(range_size,num_epochs=None,shuffle=True,seed=None,capacity=32) 73 | 74 | #从index_queue中取出 args.batch_size*args.epoch_size 个元素,用来从image_list, label_list中取出一部分feed给网络 75 | index_dequeue_op = index_queue.dequeue_many(args.batch_size * args.epoch_size,'index_dequeue') 76 | 77 | #学习率 78 | learning_rate_placeholder = tf.placeholder(tf.float32,name='learning_rate') 79 | #批大小 arg.batch_size 80 | batch_size_placeholder = tf.placeholder(tf.int32,name='batch_size') 81 | #是否训练中 82 | phase_train_placeholder = tf.placeholder(tf.bool,name='phase_train') 83 | #图像路径 大小 arg.batch_size * arg.epoch_size 84 | image_paths_placeholder = tf.placeholder(tf.string,shape=[None,1],name='image_paths') 85 | #图像标签 大小:arg.batch_size * arg.epoch_size 86 | labels_placeholder = tf.placeholder(tf.int64,shape=[None,1],name='labels') 87 | 88 | #新建一个队列,数据流操作,fifo,先入先出 89 | input_queue = data_flow_ops.FIFOQueue(capacity=100000,dtypes=[tf.string,tf.int64],shapes=[(1,),(1,)],shared_name=None,name=None) 90 | 91 | # enqueue_many返回的是一个操作 ,入站的数量是 len(image_paths_placeholder) = 从index_queue中取出 args.batch_size*args.epoch_size个元素 92 | enqueue_op = input_queue.enqueue_many([image_paths_placeholder,labels_placeholder],name='enqueue_op') 93 | 94 | nrof_preprocess_threads = 4 95 | images_and_labels = [] 96 | 97 | for _ in range(nrof_preprocess_threads): 98 | filenames , label = input_queue.dequeue() 99 | # label = tf.Print(label,[tf.shape(label)],message='Shape of one thread input_queue.dequeue label : ', 100 | # summarize=4,first_n=1) 101 | # filenames = tf.Print(filenames, [tf.shape(filenames)], message='Shape of one thread input_queue.dequeue filenames : ', 102 | # summarize=4, first_n=1) 103 | print("one thread input_queue.dequeue len = ",tf.shape(label)) 104 | images =[] 105 | for filenames in tf.unstack(filenames): 106 | file_contents = tf.read_file(filenames) 107 | image = tf.image.decode_image(file_contents,channels=3) 108 | 109 | if args.random_rotate: 110 | image = tf.py_func(faceUtil.random_rotate_image, [image], tf.uint8) 111 | 112 | if args.random_crop: 113 | image = tf.random_crop(image,[args.image_size,args.image_size,3]) 114 | 115 | else: 116 | image = tf.image.resize_image_with_crop_or_pad(image,args.image_size,args.image_size) 117 | 118 | if args.random_flip: 119 | image = tf.image.random_flip_left_right(image) 120 | 121 | image.set_shape((args.image_size,args.image_size,3)) 122 | images.append(tf.image.per_image_standardization(image)) 123 | 124 | #从队列中取出名字 解析为image 然后加进images_and_labels 可能长度 = 4 * 125 | images_and_labels.append([images,label]) 126 | 127 | #最终一次进入网络的数据: 长应该度 = batch_size_placeholder 128 | image_batch, label_batch = tf.train.batch_join(images_and_labels,batch_size=batch_size_placeholder, 129 | shapes=[(args.image_size,args.image_size,3),()], 130 | enqueue_many = True, 131 | capacity = 4 * nrof_preprocess_threads * args.batch_size, 132 | allow_smaller_final_batch=True) 133 | print('final input net image_batch len = ',tf.shape(image_batch)) 134 | 135 | image_batch = tf.Print(image_batch, [tf.shape(image_batch)], message='final input net image_batch shape = ', 136 | summarize=4, first_n=1) 137 | image_batch = tf.identity(image_batch, 'image_batch') 138 | image_batch = tf.identity(image_batch, 'input') 139 | label_batch = tf.identity(label_batch, 'label_batch') 140 | 141 | print('Total number of classes: %d' % nrof_classes) 142 | print('Total number of examples: %d' % len(image_list)) 143 | 144 | print('Building training graph') 145 | 146 | # 将指数衰减应用到学习率上 147 | learning_rate = tf.train.exponential_decay(learning_rate= learning_rate_placeholder, 148 | global_step = global_step, 149 | decay_steps=args.learning_rate_decay_epochs * args.epoch_size, 150 | decay_rate=args.learning_rate_decay_factor, 151 | staircase = True) 152 | #decay_steps=args.learning_rate_decay_epochs * args.epoch_size, 153 | 154 | tf.summary.scalar('learning_rate', learning_rate) 155 | 156 | # Build the inference graph 157 | prelogits, _ = network.inference(image_batch,args.keep_probability,phase_train=phase_train_placeholder, 158 | bottleneck_layer_size=args.embedding_size,weight_decay=args.weight_decay) 159 | print("prelogits.shape = ",prelogits.get_shape().as_list()) 160 | 161 | logits,sphere_loss = sphereloss(prelogits,label_batch,len(train_set),batch_size=args.batch_size) 162 | 163 | tf.add_to_collection('losses',sphere_loss) 164 | regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 165 | total_loss = tf.add_n([sphere_loss] + regularization_losses,name='total_loss') 166 | 167 | train_op = faceUtil.train(total_loss, global_step, args.optimizer, learning_rate, 168 | args.moving_average_decay, tf.global_variables(), args.log_histograms) 169 | # print("global_variables len = {}".format(len(tf.global_variables()))) 170 | # print("local_variables len = {}".format(len(tf.local_variables()))) 171 | # print("trainable_variables len = {}".format(len(tf.trainable_variables()))) 172 | # for v in tf.trainable_variables() : 173 | # print("trainable_variables :{}".format(v.name)) 174 | # train_op = faceUtil.train(sphere_loss,global_step,args.optimizer,learning_rate, 175 | # args.moving_average_decay, tf.global_variables(), args.log_histograms) 176 | 177 | #创建saver 178 | variables = tf.trainable_variables() 179 | print("variables_trainable len = ", len(variables)) 180 | for v in variables: 181 | print('variables_trainable : {}'.format(v.name)) 182 | saver = tf.train.Saver(var_list=variables, max_to_keep=2) 183 | 184 | # variables_to_restore = [v for v in variables if v.name.split('/')[0] != 'Logits'] 185 | # print("variables_trainable len = ",len(variables)) 186 | # print("variables_to_restore len = ",len(variables_to_restore)) 187 | # # for v in variables_to_restore : 188 | # # print("variables_to_restore : ",v.name) 189 | # saver = tf.train.Saver(var_list=variables_to_restore,max_to_keep=3) 190 | 191 | 192 | # variables_trainable = tf.trainable_variables() 193 | # print("variables_trainable len = ",len(variables_trainable)) 194 | # # for v in variables_trainable : 195 | # # print('variables_trainable : {}'.format(v.name)) 196 | # variables_to_restore = slim.get_variables_to_restore(include=['InceptionResnetV1']) 197 | # print("variables_to_restore len = ",len(variables_to_restore)) 198 | # saver = tf.train.Saver(var_list=variables_to_restore,max_to_keep=3) 199 | 200 | 201 | 202 | # Build the summary operation based on the TF collection of Summaries. 203 | summary_op = tf.summary.merge_all() 204 | 205 | # 能够在gpu上分配的最大内存 206 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_memory_fraction) 207 | sess = tf.Session(config=tf.ConfigProto(gpu_options = gpu_options,log_device_placement = False)) 208 | 209 | # Initialize variables 210 | sess.run(tf.global_variables_initializer()) 211 | sess.run(tf.local_variables_initializer()) 212 | summary_writer = tf.summary.FileWriter(log_dir, sess.graph) 213 | 214 | # 获取线程坐标 215 | coord = tf.train.Coordinator() 216 | 217 | # 将队列中的所有Runner开始执行 218 | tf.train.start_queue_runners(coord=coord,sess=sess) 219 | 220 | with sess.as_default(): 221 | print('Running training') 222 | if pretrained_model : 223 | print('Restoring pretrained model_checkpoint_path: %s' % pretrained_model) 224 | saver.restore(sess,pretrained_model) 225 | 226 | # Training and validation loop 227 | print('Running training really') 228 | epoch = 0 229 | # 将所有数据过一遍的次数 230 | while epoch < args.max_nrof_epochs: 231 | 232 | #这里是返回当前的global_step值吗,step可以看做是全局的批处理个数 233 | step = sess.run(global_step,feed_dict=None) 234 | 235 | #epoch_size是一个epoch中批的个数 236 | # 这个epoch是全局的批处理个数除以一个epoch中批的个数得到epoch,这个epoch将用于求学习率 237 | epoch = step // args.epoch_size 238 | # Train for one epoch 239 | train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder, 240 | learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, global_step, 241 | total_loss, train_op, summary_op, summary_writer, regularization_losses, args.learning_rate_schedule_file) 242 | 243 | # Save variables and the metagraph if it doesn't exist already 244 | save_variables_and_metagraph(sess, saver, summary_writer, model_dir, subdir, step) 245 | 246 | return model_dir 247 | 248 | 249 | 250 | def train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder, 251 | learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, global_step, 252 | loss, train_op, summary_op, summary_writer, regularization_losses, learning_rate_schedule_file): 253 | 254 | batch_number = 0 255 | if args.learning_rate>0.0: 256 | 257 | lr = args.learning_rate 258 | else: 259 | lr = faceUtil.get_learning_rate_from_file(learning_rate_schedule_file, epoch) 260 | 261 | index_epoch = sess.run(index_dequeue_op) 262 | label_epoch = np.array(label_list)[index_epoch] 263 | image_epoch = np.array(image_list)[index_epoch] 264 | 265 | # Enqueue one epoch of image paths and labels 266 | labels_array = np.expand_dims(np.array(label_epoch),1) 267 | image_paths_array = np.expand_dims(np.array(image_epoch),1) 268 | sess.run(enqueue_op,{image_paths_placeholder:image_paths_array,labels_placeholder:labels_array}) 269 | 270 | # Training loop 271 | train_time = 0 272 | while batch_number < args.epoch_size: 273 | start_time = time.time() 274 | feed_dict = {learning_rate_placeholder: lr, phase_train_placeholder:True, batch_size_placeholder:args.batch_size} 275 | 276 | if (batch_number % 100 == 0) : 277 | err, _, step, reg_loss, summary_str = sess.run([loss,train_op,global_step,regularization_losses,summary_op],feed_dict=feed_dict) 278 | summary_writer.add_summary(summary_str, global_step=step) 279 | else : 280 | err, _, step, reg_loss = sess.run([loss, train_op, global_step, regularization_losses], feed_dict=feed_dict) 281 | 282 | duration = time.time() - start_time 283 | print('global_step[%d],Epoch: [%d][%d/%d]\tTime %.3f\tLoss %2.3f\tRegLoss %2.3f' % 284 | (step,epoch, batch_number+1, args.epoch_size, duration, err, np.sum(reg_loss))) 285 | batch_number += 1 286 | train_time += duration 287 | 288 | # Add validation loss and accuracy to summary 289 | summary = tf.Summary() 290 | 291 | #pylint: disable=maybe-no-member 292 | summary.value.add(tag='time/total', simple_value=train_time) 293 | summary_writer.add_summary(summary, step) 294 | return step 295 | 296 | 297 | 298 | def save_variables_and_metagraph(sess, saver, summary_writer, model_dir, model_name, step): 299 | 300 | # Save the model checkpoint 301 | print('Saving variables') 302 | start_time = time.time() 303 | checkpoint_path = os.path.join(model_dir,'model-%s.ckpt' % model_name) 304 | saver.save(sess,checkpoint_path,global_step=step,write_meta_graph=False) 305 | save_time_variables = time.time() - start_time 306 | print('Variables saved in %.2f seconds' % save_time_variables) 307 | metagraph_filename = os.path.join(model_dir,'model-%s.meta' % model_name) 308 | save_time_metagraph = 0 309 | if not os.path.exists(metagraph_filename): 310 | print('Saving metagraph') 311 | start_time = time.time() 312 | saver.export_meta_graph(metagraph_filename) 313 | save_time_metagraph = time.time() - start_time 314 | print('Metagraph saved in %.2f seconds' % save_time_metagraph) 315 | 316 | summary = tf.Summary() 317 | #pylint: disable=maybe-no-member 318 | summary.value.add(tag='time/save_variables', simple_value=save_time_variables) 319 | summary.value.add(tag='time/save_metagraph', simple_value=save_time_metagraph) 320 | summary_writer.add_summary(summary, step) 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | def parse_arguments(argv): 343 | parser = argparse.ArgumentParser() 344 | print("parser = argparse.ArgumentParser()") 345 | parser.add_argument('--data_dir',type=str,default='align/casia_maxpy_mtcnnpy_182') 346 | parser.add_argument('--gpu_memory_fraction',type=float,default=0.8) 347 | parser.add_argument('--pretrained_model', type=str,default = '/home/huyu/models/facenet/20180211-122649/model-20180211-122649.ckpt-9000', 348 | help='Load a pretrained model before training starts.') 349 | #default = '/home/huyu/models/facenet/20180209-114624/model-20180209-114624.ckpt-0', 350 | #default='/home/huyu/models/facenet/20180207-203905/model-20180207-203905.ckpt-21000', 351 | # default='modeltrained/20170512/model-20170512-110547.ckpt-250000', 352 | 353 | parser.add_argument('--max_nrof_epochs', type=int, default=100) 354 | parser.add_argument('--batch_size', type=int, default=64) 355 | parser.add_argument('--image_size',type=int,default=160) 356 | parser.add_argument('--epoch_size', type=int, default=300) 357 | parser.add_argument('--embedding_size', type=int, default=128) 358 | parser.add_argument('--random_crop', 359 | help='Performs random cropping of training images. If false, the center image_size pixels from the training images are used. ' + 360 | 'If the size of the images in the data directory is equal to image_size no cropping is performed', action='store_true') 361 | parser.add_argument('--random_flip', 362 | help='Performs random horizontal flipping of training images.', action='store_true') 363 | parser.add_argument('--random_rotate', 364 | help='Performs random rotations of training images.', action='store_true') 365 | parser.add_argument('--keep_probability', type=float, 366 | help='Keep probability of dropout for the fully connected layer(s).', default=1) 367 | parser.add_argument('--weight_decay', type=float, 368 | help='L2 weight regularization.', default=0.0) 369 | parser.add_argument('--learning_rate', type=float, 370 | help='Initial learning rate. If set to a negative value a learning rate ' + 371 | 'schedule can be specified in the file "learning_rate_schedule.txt"', default=0.0001) 372 | parser.add_argument('--optimizer', type=str, choices=['ADAGRAD', 'ADADELTA', 'ADAM', 'RMSPROP', 'MOM'], 373 | help='The optimization algorithm to use', default='ADAM') 374 | parser.add_argument('--learning_rate_decay_epochs', type=int, 375 | help='Number of epochs between learning rate decay.', default=100) 376 | parser.add_argument('--learning_rate_decay_factor', type=float, 377 | help='Learning rate decay factor.', default=1) 378 | parser.add_argument('--moving_average_decay', type=float, 379 | help='Exponential decay for tracking of training parameters.', default=0.9999) 380 | parser.add_argument('--seed', type=int, 381 | help='Random seed.', default=666) 382 | parser.add_argument('--nrof_preprocess_threads', type=int, 383 | help='Number of preprocessing (data loading and augmentation) threads.', default=4) 384 | parser.add_argument('--log_histograms', 385 | help='Enables logging of weight/bias histograms in tensorboard.', action='store_true') 386 | parser.add_argument('--learning_rate_schedule_file', type=str, 387 | help='File containing the learning rate schedule that is used when learning_rate is set to to -1.', default='data/learning_rate_classifier_casia.txt') 388 | parser.add_argument('--filter_filename', type=str, 389 | help='File containing image data used for dataset filtering', default='') 390 | parser.add_argument('--filter_percentile', type=float, 391 | help='Keep only the percentile images closed to its class center', default=100.0) 392 | parser.add_argument('--filter_min_nrof_images_per_class', type=int, 393 | help='Keep only the classes with this number of examples or more', default=0) 394 | parser.add_argument('--logs_base_dir', type=str, 395 | help='Directory where to write event logs.', default='~/logs/facenet') 396 | parser.add_argument('--models_base_dir', type=str, 397 | help='Directory where to write trained models and checkpoints.', default='~/models/facenet') 398 | 399 | 400 | return parser.parse_args(argv) 401 | 402 | 403 | if __name__ == '__main__': 404 | main(parse_arguments(sys.argv[1:])) 405 | -------------------------------------------------------------------------------- /src/train_softmax_demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import os.path 7 | import sys 8 | import time 9 | from datetime import datetime 10 | 11 | import numpy as np 12 | from tensorflow.python.framework import ops 13 | from tensorflow.python.ops import array_ops 14 | from tensorflow.python.ops import data_flow_ops 15 | import tensorflow.contrib.slim as slim 16 | 17 | from loss.sphere import * 18 | from models import inception_resnet_v1 as network 19 | from utils import faceUtil 20 | 21 | 22 | def main(args): 23 | print("main start") 24 | np.random.seed(seed=args.seed) 25 | #train_set = ImageClass list 26 | train_set = faceUtil.get_dataset(args.data_dir) 27 | 28 | #总类别 29 | nrof_classes = len(train_set) 30 | print(nrof_classes) 31 | 32 | #subdir =20171122-112109 33 | subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S') 34 | 35 | #log_dir = c:\User\logs\facenet\20171122- 36 | log_dir = os.path.join(os.path.expanduser(args.logs_base_dir),subdir) 37 | if not os.path.isdir(log_dir): 38 | os.makedirs(log_dir) 39 | print("log_dir =",log_dir) 40 | 41 | # model_dir =c:\User/models/facenet/2017;;; 42 | model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir) 43 | if not os.path.isdir(model_dir): # Create the model directory if it doesn't exist 44 | os.makedirs(model_dir) 45 | 46 | print("model_dir =", model_dir) 47 | pretrained_model = None 48 | if args.pretrained_model: 49 | # pretrained_model = os.path.expanduser(args.pretrained_model) 50 | # pretrained_model = tf.train.get_checkpoint_state(args.pretrained_model) 51 | pretrained_model = args.pretrained_model 52 | print('Pre-trained model: %s' % pretrained_model) 53 | 54 | 55 | # Write arguments to a text file 56 | faceUtil.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt')) 57 | print("write_arguments_to_file") 58 | with tf.Graph().as_default(): 59 | tf.set_random_seed(args.seed) 60 | global_step = tf.Variable(0,trainable=False) 61 | 62 | #两个列表 image_list= 图片地址列表, label_list = 对应label列表,两个大小相同 63 | image_list, label_list = faceUtil.get_image_paths_and_labels(train_set) 64 | assert len(image_list) > 0 , 'dataset is empty' 65 | print("len(image_list) = ",len(image_list)) 66 | 67 | # Create a queue that produces indices into the image_list and label_list 68 | labels = ops.convert_to_tensor(label_list,dtype=tf.int64) 69 | range_size = array_ops.shape(labels)[0] 70 | range_size = tf.Print(range_size, [tf.shape(range_size)],message='Shape of range_input_producer range_size : ',summarize=4, first_n=1) 71 | 72 | #产生一个队列,队列包含0到range_size-1的元素,打乱 73 | index_queue = tf.train.range_input_producer(range_size,num_epochs=None,shuffle=True,seed=None,capacity=32) 74 | 75 | #从index_queue中取出 args.batch_size*args.epoch_size 个元素,用来从image_list, label_list中取出一部分feed给网络 76 | index_dequeue_op = index_queue.dequeue_many(args.batch_size * args.epoch_size,'index_dequeue') 77 | 78 | #学习率 79 | learning_rate_placeholder = tf.placeholder(tf.float32,name='learning_rate') 80 | #批大小 arg.batch_size 81 | batch_size_placeholder = tf.placeholder(tf.int32,name='batch_size') 82 | #是否训练中 83 | phase_train_placeholder = tf.placeholder(tf.bool,name='phase_train') 84 | #图像路径 大小 arg.batch_size * arg.epoch_size 85 | image_paths_placeholder = tf.placeholder(tf.string,shape=[None,1],name='image_paths') 86 | #图像标签 大小:arg.batch_size * arg.epoch_size 87 | labels_placeholder = tf.placeholder(tf.int64,shape=[None,1],name='labels') 88 | 89 | #新建一个队列,数据流操作,fifo,先入先出 90 | input_queue = data_flow_ops.FIFOQueue(capacity=100000,dtypes=[tf.string,tf.int64],shapes=[(1,),(1,)],shared_name=None,name=None) 91 | 92 | # enqueue_many返回的是一个操作 ,入站的数量是 len(image_paths_placeholder) = 从index_queue中取出 args.batch_size*args.epoch_size个元素 93 | enqueue_op = input_queue.enqueue_many([image_paths_placeholder,labels_placeholder],name='enqueue_op') 94 | 95 | nrof_preprocess_threads = 4 96 | images_and_labels = [] 97 | 98 | for _ in range(nrof_preprocess_threads): 99 | filenames , label = input_queue.dequeue() 100 | # label = tf.Print(label,[tf.shape(label)],message='Shape of one thread input_queue.dequeue label : ', 101 | # summarize=4,first_n=1) 102 | # filenames = tf.Print(filenames, [tf.shape(filenames)], message='Shape of one thread input_queue.dequeue filenames : ', 103 | # summarize=4, first_n=1) 104 | print("one thread input_queue.dequeue len = ",tf.shape(label)) 105 | images =[] 106 | for filenames in tf.unstack(filenames): 107 | file_contents = tf.read_file(filenames) 108 | image = tf.image.decode_image(file_contents,channels=3) 109 | 110 | if args.random_rotate: 111 | image = tf.py_func(faceUtil.random_rotate_image, [image], tf.uint8) 112 | 113 | if args.random_crop: 114 | image = tf.random_crop(image,[args.image_size,args.image_size,3]) 115 | 116 | else: 117 | image = tf.image.resize_image_with_crop_or_pad(image,args.image_size,args.image_size) 118 | 119 | if args.random_flip: 120 | image = tf.image.random_flip_left_right(image) 121 | 122 | image.set_shape((args.image_size,args.image_size,3)) 123 | images.append(tf.image.per_image_standardization(image)) 124 | 125 | #从队列中取出名字 解析为image 然后加进images_and_labels 可能长度 = 4 * 126 | images_and_labels.append([images,label]) 127 | 128 | #最终一次进入网络的数据: 长应该度 = batch_size_placeholder 129 | image_batch, label_batch = tf.train.batch_join(images_and_labels,batch_size=batch_size_placeholder, 130 | shapes=[(args.image_size,args.image_size,3),()], 131 | enqueue_many = True, 132 | capacity = 4 * nrof_preprocess_threads * args.batch_size, 133 | allow_smaller_final_batch=True) 134 | print('final input net image_batch len = ',tf.shape(image_batch)) 135 | 136 | image_batch = tf.Print(image_batch, [tf.shape(image_batch)], message='final input net image_batch shape = ', 137 | summarize=4, first_n=1) 138 | image_batch = tf.identity(image_batch, 'image_batch') 139 | image_batch = tf.identity(image_batch, 'input') 140 | label_batch = tf.identity(label_batch, 'label_batch') 141 | 142 | print('Total number of classes: %d' % nrof_classes) 143 | print('Total number of examples: %d' % len(image_list)) 144 | 145 | print('Building training graph') 146 | 147 | # 将指数衰减应用到学习率上 148 | learning_rate = tf.train.exponential_decay(learning_rate= learning_rate_placeholder, 149 | global_step = global_step, 150 | decay_steps=args.learning_rate_decay_epochs * args.epoch_size, 151 | decay_rate=args.learning_rate_decay_factor, 152 | staircase = True) 153 | #decay_steps=args.learning_rate_decay_epochs * args.epoch_size, 154 | 155 | tf.summary.scalar('learning_rate', learning_rate) 156 | 157 | # Build the inference graph 158 | prelogits, _ = network.inference(image_batch,args.keep_probability,phase_train=phase_train_placeholder, 159 | bottleneck_layer_size=args.embedding_size,weight_decay=args.weight_decay) 160 | 161 | prelogits = tf.Print(prelogits, [tf.shape(prelogits)], message='prelogits shape = ', 162 | summarize=4, first_n=1) 163 | print("prelogits.shape = ",prelogits.get_shape().as_list()) 164 | 165 | # logits =slim.fully_connected(prelogits, len(train_set), activation_fn=None, 166 | # weights_initializer=tf.contrib.layers.xavier_initializer(), 167 | # weights_regularizer=slim.l2_regularizer(args.weight_decay), 168 | # scope='Logits', reuse=False) 169 | # 170 | # # Calculate the average cross entropy loss across the batch 171 | # cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( 172 | # labels=label_batch, logits=logits, name='cross_entropy_per_example') 173 | # tf.reduce_mean(cross_entropy, name='cross_entropy') 174 | _,cross_entropy_mean = soft_loss_nobias(prelogits,label_batch,len(train_set)) 175 | tf.add_to_collection('losses', cross_entropy_mean) 176 | 177 | regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 178 | total_loss = tf.add_n([cross_entropy_mean] + regularization_losses,name='total_loss') 179 | 180 | train_op = faceUtil.train(total_loss, global_step, args.optimizer, learning_rate, 181 | args.moving_average_decay, tf.global_variables(), args.log_histograms) 182 | # print("global_variables len = {}".format(len(tf.global_variables()))) 183 | # print("local_variables len = {}".format(len(tf.local_variables()))) 184 | # print("trainable_variables len = {}".format(len(tf.trainable_variables()))) 185 | # for v in tf.trainable_variables() : 186 | # print("trainable_variables :{}".format(v.name)) 187 | # train_op = faceUtil.train(sphere_loss,global_step,args.optimizer,learning_rate, 188 | # args.moving_average_decay, tf.global_variables(), args.log_histograms) 189 | 190 | #创建saver 191 | variables = tf.trainable_variables() 192 | print("variables_trainable len = ", len(variables)) 193 | for v in variables: 194 | print('variables_trainable : {}'.format(v.name)) 195 | saver = tf.train.Saver(var_list=variables, max_to_keep=2) 196 | 197 | # variables_to_restore = [v for v in variables if v.name.split('/')[0] != 'Logits'] 198 | # print("variables_trainable len = ",len(variables)) 199 | # print("variables_to_restore len = ",len(variables_to_restore)) 200 | # # for v in variables_to_restore : 201 | # # print("variables_to_restore : ",v.name) 202 | # saver = tf.train.Saver(var_list=variables_to_restore,max_to_keep=3) 203 | 204 | 205 | # variables_trainable = tf.trainable_variables() 206 | # print("variables_trainable len = ",len(variables_trainable)) 207 | # # for v in variables_trainable : 208 | # # print('variables_trainable : {}'.format(v.name)) 209 | # variables_to_restore = slim.get_variables_to_restore(include=['InceptionResnetV1']) 210 | # print("variables_to_restore len = ",len(variables_to_restore)) 211 | # saver = tf.train.Saver(var_list=variables_to_restore,max_to_keep=3) 212 | 213 | 214 | 215 | # Build the summary operation based on the TF collection of Summaries. 216 | summary_op = tf.summary.merge_all() 217 | 218 | # 能够在gpu上分配的最大内存 219 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_memory_fraction) 220 | sess = tf.Session(config=tf.ConfigProto(gpu_options = gpu_options,log_device_placement = False)) 221 | 222 | # Initialize variables 223 | sess.run(tf.global_variables_initializer()) 224 | sess.run(tf.local_variables_initializer()) 225 | summary_writer = tf.summary.FileWriter(log_dir, sess.graph) 226 | 227 | # 获取线程坐标 228 | coord = tf.train.Coordinator() 229 | 230 | # 将队列中的所有Runner开始执行 231 | tf.train.start_queue_runners(coord=coord,sess=sess) 232 | 233 | with sess.as_default(): 234 | print('Running training') 235 | if pretrained_model : 236 | print('Restoring pretrained model_checkpoint_path: %s' % pretrained_model) 237 | saver.restore(sess,pretrained_model) 238 | 239 | # Training and validation loop 240 | print('Running training really') 241 | epoch = 0 242 | # 将所有数据过一遍的次数 243 | while epoch < args.max_nrof_epochs: 244 | 245 | #这里是返回当前的global_step值吗,step可以看做是全局的批处理个数 246 | step = sess.run(global_step,feed_dict=None) 247 | 248 | #epoch_size是一个epoch中批的个数 249 | # 这个epoch是全局的批处理个数除以一个epoch中批的个数得到epoch,这个epoch将用于求学习率 250 | epoch = step // args.epoch_size 251 | # Train for one epoch 252 | train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder, 253 | learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, global_step, 254 | total_loss, train_op, summary_op, summary_writer, regularization_losses, args.learning_rate_schedule_file) 255 | 256 | # Save variables and the metagraph if it doesn't exist already 257 | save_variables_and_metagraph(sess, saver, summary_writer, model_dir, subdir, step) 258 | 259 | return model_dir 260 | 261 | 262 | 263 | def train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder, 264 | learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, global_step, 265 | loss, train_op, summary_op, summary_writer, regularization_losses, learning_rate_schedule_file): 266 | 267 | batch_number = 0 268 | if args.learning_rate>0.0: 269 | 270 | lr = args.learning_rate 271 | else: 272 | lr = faceUtil.get_learning_rate_from_file(learning_rate_schedule_file, epoch) 273 | 274 | index_epoch = sess.run(index_dequeue_op) 275 | label_epoch = np.array(label_list)[index_epoch] 276 | image_epoch = np.array(image_list)[index_epoch] 277 | 278 | # Enqueue one epoch of image paths and labels 279 | labels_array = np.expand_dims(np.array(label_epoch),1) 280 | image_paths_array = np.expand_dims(np.array(image_epoch),1) 281 | sess.run(enqueue_op,{image_paths_placeholder:image_paths_array,labels_placeholder:labels_array}) 282 | 283 | # Training loop 284 | train_time = 0 285 | while batch_number < args.epoch_size: 286 | start_time = time.time() 287 | feed_dict = {learning_rate_placeholder: lr, phase_train_placeholder:True, batch_size_placeholder:args.batch_size} 288 | 289 | if (batch_number % 100 == 0) : 290 | err, _, step, reg_loss, summary_str = sess.run([loss,train_op,global_step,regularization_losses,summary_op],feed_dict=feed_dict) 291 | summary_writer.add_summary(summary_str, global_step=step) 292 | else : 293 | err, _, step, reg_loss = sess.run([loss, train_op, global_step, regularization_losses], feed_dict=feed_dict) 294 | 295 | duration = time.time() - start_time 296 | print('global_step[%d],Epoch: [%d][%d/%d]\tTime %.3f\tLoss %2.3f\tRegLoss %2.3f' % 297 | (step,epoch, batch_number+1, args.epoch_size, duration, err, np.sum(reg_loss))) 298 | batch_number += 1 299 | train_time += duration 300 | 301 | # Add validation loss and accuracy to summary 302 | summary = tf.Summary() 303 | 304 | #pylint: disable=maybe-no-member 305 | summary.value.add(tag='time/total', simple_value=train_time) 306 | summary_writer.add_summary(summary, step) 307 | return step 308 | 309 | 310 | 311 | def save_variables_and_metagraph(sess, saver, summary_writer, model_dir, model_name, step): 312 | 313 | # Save the model checkpoint 314 | print('Saving variables') 315 | start_time = time.time() 316 | checkpoint_path = os.path.join(model_dir,'model-%s.ckpt' % model_name) 317 | saver.save(sess,checkpoint_path,global_step=step,write_meta_graph=False) 318 | save_time_variables = time.time() - start_time 319 | print('Variables saved in %.2f seconds' % save_time_variables) 320 | metagraph_filename = os.path.join(model_dir,'model-%s.meta' % model_name) 321 | save_time_metagraph = 0 322 | if not os.path.exists(metagraph_filename): 323 | print('Saving metagraph') 324 | start_time = time.time() 325 | saver.export_meta_graph(metagraph_filename) 326 | save_time_metagraph = time.time() - start_time 327 | print('Metagraph saved in %.2f seconds' % save_time_metagraph) 328 | 329 | summary = tf.Summary() 330 | #pylint: disable=maybe-no-member 331 | summary.value.add(tag='time/save_variables', simple_value=save_time_variables) 332 | summary.value.add(tag='time/save_metagraph', simple_value=save_time_metagraph) 333 | summary_writer.add_summary(summary, step) 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | def parse_arguments(argv): 356 | parser = argparse.ArgumentParser() 357 | print("parser = argparse.ArgumentParser()") 358 | parser.add_argument('--data_dir',type=str,default='align/casia_maxpy_mtcnnpy_182') 359 | parser.add_argument('--gpu_memory_fraction',type=float,default=0.8) 360 | parser.add_argument('--pretrained_model', type=str,default = '/home/huyu/models/facenet/20180209-115727/model-20180209-115727.ckpt-14400', 361 | help='Load a pretrained model before training starts.') 362 | default = '/home/huyu/models/facenet/20180209-114624/model-20180209-114624.ckpt-0', 363 | #default='/home/huyu/models/facenet/20180208-210946/model-20180208-210946.ckpt-1200', 364 | # default='modeltrained/20170512/model-20170512-110547.ckpt-250000', 365 | 366 | parser.add_argument('--max_nrof_epochs', type=int, default=200) 367 | parser.add_argument('--batch_size', type=int, default=64) 368 | parser.add_argument('--image_size',type=int,default=160) 369 | parser.add_argument('--epoch_size', type=int, default=300) 370 | parser.add_argument('--embedding_size', type=int, default=128) 371 | parser.add_argument('--random_crop', 372 | help='Performs random cropping of training images. If false, the center image_size pixels from the training images are used. ' + 373 | 'If the size of the images in the data directory is equal to image_size no cropping is performed', action='store_true') 374 | parser.add_argument('--random_flip', 375 | help='Performs random horizontal flipping of training images.', action='store_true') 376 | parser.add_argument('--random_rotate', 377 | help='Performs random rotations of training images.', action='store_true') 378 | parser.add_argument('--keep_probability', type=float, 379 | help='Keep probability of dropout for the fully connected layer(s).', default=1) 380 | parser.add_argument('--weight_decay', type=float, 381 | help='L2 weight regularization.', default=0.0) 382 | parser.add_argument('--learning_rate', type=float, 383 | help='Initial learning rate. If set to a negative value a learning rate ' + 384 | 'schedule can be specified in the file "learning_rate_schedule.txt"', default=0.005) 385 | parser.add_argument('--optimizer', type=str, choices=['ADAGRAD', 'ADADELTA', 'ADAM', 'RMSPROP', 'MOM'], 386 | help='The optimization algorithm to use', default='ADAM') 387 | parser.add_argument('--learning_rate_decay_epochs', type=int, 388 | help='Number of epochs between learning rate decay.', default=5) 389 | parser.add_argument('--learning_rate_decay_factor', type=float, 390 | help='Learning rate decay factor.', default=0.8) 391 | parser.add_argument('--moving_average_decay', type=float, 392 | help='Exponential decay for tracking of training parameters.', default=0.9999) 393 | parser.add_argument('--seed', type=int, 394 | help='Random seed.', default=666) 395 | parser.add_argument('--nrof_preprocess_threads', type=int, 396 | help='Number of preprocessing (data loading and augmentation) threads.', default=4) 397 | parser.add_argument('--log_histograms', 398 | help='Enables logging of weight/bias histograms in tensorboard.', action='store_true') 399 | parser.add_argument('--learning_rate_schedule_file', type=str, 400 | help='File containing the learning rate schedule that is used when learning_rate is set to to -1.', default='data/learning_rate_classifier_casia.txt') 401 | parser.add_argument('--filter_filename', type=str, 402 | help='File containing image data used for dataset filtering', default='') 403 | parser.add_argument('--filter_percentile', type=float, 404 | help='Keep only the percentile images closed to its class center', default=100.0) 405 | parser.add_argument('--filter_min_nrof_images_per_class', type=int, 406 | help='Keep only the classes with this number of examples or more', default=0) 407 | parser.add_argument('--logs_base_dir', type=str, 408 | help='Directory where to write event logs.', default='~/logs/facenet') 409 | parser.add_argument('--models_base_dir', type=str, 410 | help='Directory where to write trained models and checkpoints.', default='~/models/facenet') 411 | 412 | 413 | return parser.parse_args(argv) 414 | 415 | 416 | if __name__ == '__main__': 417 | main(parse_arguments(sys.argv[1:])) 418 | -------------------------------------------------------------------------------- /src/train_softmax_no_train_inception.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import os.path 7 | import sys 8 | import time 9 | from datetime import datetime 10 | 11 | import numpy as np 12 | from tensorflow.python.framework import ops 13 | from tensorflow.python.ops import array_ops 14 | from tensorflow.python.ops import data_flow_ops 15 | import tensorflow.contrib.slim as slim 16 | 17 | from loss.sphere import * 18 | from models import inception_resnet_no_train as network 19 | from utils import faceUtil 20 | 21 | 22 | def main(args): 23 | print("main start") 24 | np.random.seed(seed=args.seed) 25 | #train_set = ImageClass list 26 | train_set = faceUtil.get_dataset(args.data_dir) 27 | 28 | #总类别 29 | nrof_classes = len(train_set) 30 | print(nrof_classes) 31 | 32 | #subdir =20171122-112109 33 | subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S') 34 | 35 | #log_dir = c:\User\logs\facenet\20171122- 36 | log_dir = os.path.join(os.path.expanduser(args.logs_base_dir),subdir) 37 | if not os.path.isdir(log_dir): 38 | os.makedirs(log_dir) 39 | print("log_dir =",log_dir) 40 | 41 | # model_dir =c:\User/models/facenet/2017;;; 42 | model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir) 43 | if not os.path.isdir(model_dir): # Create the model directory if it doesn't exist 44 | os.makedirs(model_dir) 45 | 46 | print("model_dir =", model_dir) 47 | pretrained_model = None 48 | if args.pretrained_model: 49 | # pretrained_model = os.path.expanduser(args.pretrained_model) 50 | # pretrained_model = tf.train.get_checkpoint_state(args.pretrained_model) 51 | pretrained_model = args.pretrained_model 52 | print('Pre-trained model: %s' % pretrained_model) 53 | 54 | 55 | # Write arguments to a text file 56 | faceUtil.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt')) 57 | print("write_arguments_to_file") 58 | with tf.Graph().as_default(): 59 | tf.set_random_seed(args.seed) 60 | global_step = tf.Variable(0,trainable=False) 61 | 62 | #两个列表 image_list= 图片地址列表, label_list = 对应label列表,两个大小相同 63 | image_list, label_list = faceUtil.get_image_paths_and_labels(train_set) 64 | assert len(image_list) > 0 , 'dataset is empty' 65 | print("len(image_list) = ",len(image_list)) 66 | 67 | # Create a queue that produces indices into the image_list and label_list 68 | labels = ops.convert_to_tensor(label_list,dtype=tf.int64) 69 | range_size = array_ops.shape(labels)[0] 70 | range_size = tf.Print(range_size, [tf.shape(range_size)],message='Shape of range_input_producer range_size : ',summarize=4, first_n=1) 71 | 72 | #产生一个队列,队列包含0到range_size-1的元素,打乱 73 | index_queue = tf.train.range_input_producer(range_size,num_epochs=None,shuffle=True,seed=None,capacity=32) 74 | 75 | #从index_queue中取出 args.batch_size*args.epoch_size 个元素,用来从image_list, label_list中取出一部分feed给网络 76 | index_dequeue_op = index_queue.dequeue_many(args.batch_size * args.epoch_size,'index_dequeue') 77 | 78 | #学习率 79 | learning_rate_placeholder = tf.placeholder(tf.float32,name='learning_rate') 80 | #批大小 arg.batch_size 81 | batch_size_placeholder = tf.placeholder(tf.int32,name='batch_size') 82 | #是否训练中 83 | phase_train_placeholder = tf.placeholder(tf.bool,name='phase_train') 84 | #图像路径 大小 arg.batch_size * arg.epoch_size 85 | image_paths_placeholder = tf.placeholder(tf.string,shape=[None,1],name='image_paths') 86 | #图像标签 大小:arg.batch_size * arg.epoch_size 87 | labels_placeholder = tf.placeholder(tf.int64,shape=[None,1],name='labels') 88 | 89 | #新建一个队列,数据流操作,fifo,先入先出 90 | input_queue = data_flow_ops.FIFOQueue(capacity=100000,dtypes=[tf.string,tf.int64],shapes=[(1,),(1,)],shared_name=None,name=None) 91 | 92 | # enqueue_many返回的是一个操作 ,入站的数量是 len(image_paths_placeholder) = 从index_queue中取出 args.batch_size*args.epoch_size个元素 93 | enqueue_op = input_queue.enqueue_many([image_paths_placeholder,labels_placeholder],name='enqueue_op') 94 | 95 | nrof_preprocess_threads = 4 96 | images_and_labels = [] 97 | 98 | for _ in range(nrof_preprocess_threads): 99 | filenames , label = input_queue.dequeue() 100 | # label = tf.Print(label,[tf.shape(label)],message='Shape of one thread input_queue.dequeue label : ', 101 | # summarize=4,first_n=1) 102 | # filenames = tf.Print(filenames, [tf.shape(filenames)], message='Shape of one thread input_queue.dequeue filenames : ', 103 | # summarize=4, first_n=1) 104 | print("one thread input_queue.dequeue len = ",tf.shape(label)) 105 | images =[] 106 | for filenames in tf.unstack(filenames): 107 | file_contents = tf.read_file(filenames) 108 | image = tf.image.decode_image(file_contents,channels=3) 109 | 110 | if args.random_rotate: 111 | image = tf.py_func(faceUtil.random_rotate_image, [image], tf.uint8) 112 | 113 | if args.random_crop: 114 | image = tf.random_crop(image,[args.image_size,args.image_size,3]) 115 | 116 | else: 117 | image = tf.image.resize_image_with_crop_or_pad(image,args.image_size,args.image_size) 118 | 119 | if args.random_flip: 120 | image = tf.image.random_flip_left_right(image) 121 | 122 | image.set_shape((args.image_size,args.image_size,3)) 123 | images.append(tf.image.per_image_standardization(image)) 124 | 125 | #从队列中取出名字 解析为image 然后加进images_and_labels 可能长度 = 4 * 126 | images_and_labels.append([images,label]) 127 | 128 | #最终一次进入网络的数据: 长应该度 = batch_size_placeholder 129 | image_batch, label_batch = tf.train.batch_join(images_and_labels,batch_size=batch_size_placeholder, 130 | shapes=[(args.image_size,args.image_size,3),()], 131 | enqueue_many = True, 132 | capacity = 4 * nrof_preprocess_threads * args.batch_size, 133 | allow_smaller_final_batch=True) 134 | print('final input net image_batch len = ',tf.shape(image_batch)) 135 | 136 | image_batch = tf.Print(image_batch, [tf.shape(image_batch)], message='final input net image_batch shape = ', 137 | summarize=4, first_n=1) 138 | image_batch = tf.identity(image_batch, 'image_batch') 139 | image_batch = tf.identity(image_batch, 'input') 140 | label_batch = tf.identity(label_batch, 'label_batch') 141 | 142 | print('Total number of classes: %d' % nrof_classes) 143 | print('Total number of examples: %d' % len(image_list)) 144 | 145 | print('Building training graph') 146 | 147 | # 将指数衰减应用到学习率上 148 | learning_rate = tf.train.exponential_decay(learning_rate= learning_rate_placeholder, 149 | global_step = global_step, 150 | decay_steps=args.learning_rate_decay_epochs * args.epoch_size, 151 | decay_rate=args.learning_rate_decay_factor, 152 | staircase = True) 153 | #decay_steps=args.learning_rate_decay_epochs * args.epoch_size, 154 | 155 | tf.summary.scalar('learning_rate', learning_rate) 156 | 157 | # Build the inference graph 158 | prelogits, _ = network.inference(image_batch,args.keep_probability,phase_train=phase_train_placeholder, 159 | bottleneck_layer_size=args.embedding_size,weight_decay=args.weight_decay) 160 | 161 | prelogits = tf.Print(prelogits, [tf.shape(prelogits)], message='prelogits shape = ', 162 | summarize=4, first_n=1) 163 | print("prelogits.shape = ",prelogits.get_shape().as_list()) 164 | 165 | # logits =slim.fully_connected(prelogits, len(train_set), activation_fn=None, 166 | # weights_initializer=tf.contrib.layers.xavier_initializer(), 167 | # weights_regularizer=slim.l2_regularizer(args.weight_decay), 168 | # scope='Logits', reuse=False) 169 | # 170 | # # Calculate the average cross entropy loss across the batch 171 | # cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( 172 | # labels=label_batch, logits=logits, name='cross_entropy_per_example') 173 | # tf.reduce_mean(cross_entropy, name='cross_entropy') 174 | _,cross_entropy_mean = soft_loss(prelogits,label_batch,len(train_set)) 175 | tf.add_to_collection('losses', cross_entropy_mean) 176 | 177 | regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 178 | total_loss = tf.add_n([cross_entropy_mean] + regularization_losses,name='total_loss') 179 | 180 | train_op = faceUtil.train(total_loss, global_step, args.optimizer, learning_rate, 181 | args.moving_average_decay, tf.trainable_variables(), args.log_histograms) 182 | # print("global_variables len = {}".format(len(tf.global_variables()))) 183 | # print("local_variables len = {}".format(len(tf.local_variables()))) 184 | # print("trainable_variables len = {}".format(len(tf.trainable_variables()))) 185 | # for v in tf.trainable_variables() : 186 | # print("trainable_variables :{}".format(v.name)) 187 | # train_op = faceUtil.train(sphere_loss,global_step,args.optimizer,learning_rate, 188 | # args.moving_average_decay, tf.global_variables(), args.log_histograms) 189 | 190 | #创建saver 191 | variables = tf.trainable_variables() 192 | print("variables_trainable len = ", len(variables)) 193 | # for v in variables: 194 | # print('variables_trainable : {}'.format(v.name)) 195 | saver = tf.train.Saver(var_list=variables, max_to_keep=2) 196 | 197 | variables_to_restore = slim.get_variables_to_restore(include=['InceptionResnetV1']) 198 | 199 | print("variables_to_restore len = ", len(variables_to_restore)) 200 | saver_restore = tf.train.Saver(var_list=variables_to_restore) 201 | 202 | # variables_to_restore = [v for v in variables if v.name.split('/')[0] != 'Logits'] 203 | # print("variables_trainable len = ",len(variables)) 204 | # print("variables_to_restore len = ",len(variables_to_restore)) 205 | # # for v in variables_to_restore : 206 | # # print("variables_to_restore : ",v.name) 207 | # saver = tf.train.Saver(var_list=variables_to_restore,max_to_keep=3) 208 | 209 | 210 | # variables_trainable = tf.trainable_variables() 211 | # print("variables_trainable len = ",len(variables_trainable)) 212 | # # for v in variables_trainable : 213 | # # print('variables_trainable : {}'.format(v.name)) 214 | # variables_to_restore = slim.get_variables_to_restore(include=['InceptionResnetV1']) 215 | # print("variables_to_restore len = ",len(variables_to_restore)) 216 | # saver = tf.train.Saver(var_list=variables_to_restore,max_to_keep=3) 217 | 218 | 219 | 220 | # Build the summary operation based on the TF collection of Summaries. 221 | summary_op = tf.summary.merge_all() 222 | 223 | # 能够在gpu上分配的最大内存 224 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_memory_fraction) 225 | sess = tf.Session(config=tf.ConfigProto(gpu_options = gpu_options,log_device_placement = False)) 226 | 227 | # Initialize variables 228 | sess.run(tf.global_variables_initializer()) 229 | sess.run(tf.local_variables_initializer()) 230 | summary_writer = tf.summary.FileWriter(log_dir, sess.graph) 231 | 232 | # 获取线程坐标 233 | coord = tf.train.Coordinator() 234 | 235 | # 将队列中的所有Runner开始执行 236 | tf.train.start_queue_runners(coord=coord,sess=sess) 237 | 238 | with sess.as_default(): 239 | print('Running training') 240 | if pretrained_model : 241 | print('Restoring pretrained model_checkpoint_path: %s' % pretrained_model) 242 | saver_restore.restore(sess,pretrained_model) 243 | 244 | # Training and validation loop 245 | print('Running training really') 246 | epoch = 0 247 | # 将所有数据过一遍的次数 248 | while epoch < args.max_nrof_epochs: 249 | 250 | #这里是返回当前的global_step值吗,step可以看做是全局的批处理个数 251 | step = sess.run(global_step,feed_dict=None) 252 | 253 | #epoch_size是一个epoch中批的个数 254 | # 这个epoch是全局的批处理个数除以一个epoch中批的个数得到epoch,这个epoch将用于求学习率 255 | epoch = step // args.epoch_size 256 | # Train for one epoch 257 | train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder, 258 | learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, global_step, 259 | total_loss, train_op, summary_op, summary_writer, regularization_losses, args.learning_rate_schedule_file) 260 | 261 | # Save variables and the metagraph if it doesn't exist already 262 | save_variables_and_metagraph(sess, saver, summary_writer, model_dir, subdir, step) 263 | 264 | return model_dir 265 | 266 | 267 | 268 | def train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder, 269 | learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, global_step, 270 | loss, train_op, summary_op, summary_writer, regularization_losses, learning_rate_schedule_file): 271 | 272 | batch_number = 0 273 | if args.learning_rate>0.0: 274 | 275 | lr = args.learning_rate 276 | else: 277 | lr = faceUtil.get_learning_rate_from_file(learning_rate_schedule_file, epoch) 278 | 279 | index_epoch = sess.run(index_dequeue_op) 280 | label_epoch = np.array(label_list)[index_epoch] 281 | image_epoch = np.array(image_list)[index_epoch] 282 | 283 | # Enqueue one epoch of image paths and labels 284 | labels_array = np.expand_dims(np.array(label_epoch),1) 285 | image_paths_array = np.expand_dims(np.array(image_epoch),1) 286 | sess.run(enqueue_op,{image_paths_placeholder:image_paths_array,labels_placeholder:labels_array}) 287 | 288 | # Training loop 289 | train_time = 0 290 | while batch_number < args.epoch_size: 291 | 292 | start_time = time.time() 293 | feed_dict = {learning_rate_placeholder: lr, phase_train_placeholder:True, batch_size_placeholder:args.batch_size} 294 | 295 | if (batch_number % 100 == 0) : 296 | err, _, step, reg_loss, summary_str = sess.run([loss,train_op,global_step,regularization_losses,summary_op],feed_dict=feed_dict) 297 | summary_writer.add_summary(summary_str, global_step=step) 298 | else : 299 | err, _, step, reg_loss = sess.run([loss, train_op, global_step, regularization_losses], feed_dict=feed_dict) 300 | 301 | duration = time.time() - start_time 302 | print('global_step[%d],Epoch: [%d][%d/%d]\tTime %.3f\tLoss %2.3f\tRegLoss %2.3f' % 303 | (step,epoch, batch_number+1, args.epoch_size, duration, err, np.sum(reg_loss))) 304 | batch_number += 1 305 | train_time += duration 306 | 307 | # Add validation loss and accuracy to summary 308 | summary = tf.Summary() 309 | 310 | #pylint: disable=maybe-no-member 311 | summary.value.add(tag='time/total', simple_value=train_time) 312 | summary_writer.add_summary(summary, step) 313 | return step 314 | 315 | 316 | 317 | def save_variables_and_metagraph(sess, saver, summary_writer, model_dir, model_name, step): 318 | 319 | # Save the model checkpoint 320 | print('Saving variables') 321 | start_time = time.time() 322 | checkpoint_path = os.path.join(model_dir,'model-%s.ckpt' % model_name) 323 | saver.save(sess,checkpoint_path,global_step=step,write_meta_graph=False) 324 | save_time_variables = time.time() - start_time 325 | print('Variables saved in %.2f seconds' % save_time_variables) 326 | metagraph_filename = os.path.join(model_dir,'model-%s.meta' % model_name) 327 | save_time_metagraph = 0 328 | if not os.path.exists(metagraph_filename): 329 | print('Saving metagraph') 330 | start_time = time.time() 331 | saver.export_meta_graph(metagraph_filename) 332 | save_time_metagraph = time.time() - start_time 333 | print('Metagraph saved in %.2f seconds' % save_time_metagraph) 334 | 335 | summary = tf.Summary() 336 | #pylint: disable=maybe-no-member 337 | summary.value.add(tag='time/save_variables', simple_value=save_time_variables) 338 | summary.value.add(tag='time/save_metagraph', simple_value=save_time_metagraph) 339 | summary_writer.add_summary(summary, step) 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | def parse_arguments(argv): 362 | parser = argparse.ArgumentParser() 363 | print("parser = argparse.ArgumentParser()") 364 | parser.add_argument('--data_dir',type=str,default='align/casia_maxpy_mtcnnpy_182') 365 | parser.add_argument('--gpu_memory_fraction',type=float,default=0.8) 366 | parser.add_argument('--pretrained_model', type=str, 367 | help='Load a pretrained model before training starts.') 368 | default = '/home/huyu/models/facenet/20180209-114624/model-20180209-114624.ckpt-0', 369 | #default='/home/huyu/models/facenet/20180208-210946/model-20180208-210946.ckpt-1200', 370 | # default='modeltrained/20170512/model-20170512-110547.ckpt-250000', 371 | 372 | parser.add_argument('--max_nrof_epochs', type=int, default=200) 373 | parser.add_argument('--batch_size', type=int, default=64) 374 | parser.add_argument('--image_size',type=int,default=160) 375 | parser.add_argument('--epoch_size', type=int, default=300) 376 | parser.add_argument('--embedding_size', type=int, default=128) 377 | parser.add_argument('--random_crop', 378 | help='Performs random cropping of training images. If false, the center image_size pixels from the training images are used. ' + 379 | 'If the size of the images in the data directory is equal to image_size no cropping is performed', action='store_true') 380 | parser.add_argument('--random_flip', 381 | help='Performs random horizontal flipping of training images.', action='store_true') 382 | parser.add_argument('--random_rotate', 383 | help='Performs random rotations of training images.', action='store_true') 384 | parser.add_argument('--keep_probability', type=float, 385 | help='Keep probability of dropout for the fully connected layer(s).', default=0.8) 386 | parser.add_argument('--weight_decay', type=float, 387 | help='L2 weight regularization.', default=1e-8) 388 | parser.add_argument('--learning_rate', type=float, 389 | help='Initial learning rate. If set to a negative value a learning rate ' + 390 | 'schedule can be specified in the file "learning_rate_schedule.txt"', default=0.05) 391 | parser.add_argument('--optimizer', type=str, choices=['ADAGRAD', 'ADADELTA', 'ADAM', 'RMSPROP', 'MOM'], 392 | help='The optimization algorithm to use', default='ADAM') 393 | parser.add_argument('--learning_rate_decay_epochs', type=int, 394 | help='Number of epochs between learning rate decay.', default=1) 395 | parser.add_argument('--learning_rate_decay_factor', type=float, 396 | help='Learning rate decay factor.', default=0.9) 397 | parser.add_argument('--moving_average_decay', type=float, 398 | help='Exponential decay for tracking of training parameters.', default=0.9999) 399 | parser.add_argument('--seed', type=int, 400 | help='Random seed.', default=666) 401 | parser.add_argument('--nrof_preprocess_threads', type=int, 402 | help='Number of preprocessing (data loading and augmentation) threads.', default=4) 403 | parser.add_argument('--log_histograms', 404 | help='Enables logging of weight/bias histograms in tensorboard.', action='store_true') 405 | parser.add_argument('--learning_rate_schedule_file', type=str, 406 | help='File containing the learning rate schedule that is used when learning_rate is set to to -1.', default='data/learning_rate_classifier_casia.txt') 407 | parser.add_argument('--filter_filename', type=str, 408 | help='File containing image data used for dataset filtering', default='') 409 | parser.add_argument('--filter_percentile', type=float, 410 | help='Keep only the percentile images closed to its class center', default=100.0) 411 | parser.add_argument('--filter_min_nrof_images_per_class', type=int, 412 | help='Keep only the classes with this number of examples or more', default=0) 413 | parser.add_argument('--logs_base_dir', type=str, 414 | help='Directory where to write event logs.', default='~/logs/facenet') 415 | parser.add_argument('--models_base_dir', type=str, 416 | help='Directory where to write trained models and checkpoints.', default='~/models/facenet') 417 | 418 | 419 | return parser.parse_args(argv) 420 | 421 | 422 | if __name__ == '__main__': 423 | main(parse_arguments(sys.argv[1:])) 424 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/faceUtil.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/utils/__pycache__/faceUtil.cpython-35.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/faceUtil.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/utils/__pycache__/faceUtil.cpython-36.pyc -------------------------------------------------------------------------------- /src/utils/faceUtil.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | from subprocess import Popen, PIPE 8 | import tensorflow as tf 9 | from tensorflow.python.framework import ops 10 | import numpy as np 11 | from scipy import misc 12 | from sklearn.model_selection import KFold 13 | from scipy import interpolate 14 | from tensorflow.python.training import training 15 | import random 16 | import re 17 | from tensorflow.python.platform import gfile 18 | 19 | 20 | def center_loss_angel(features, label, alfa, nrof_classes): 21 | 22 | nrof_features = features.get_shape()[1] 23 | centers = tf.get_variable('centers', [nrof_classes, nrof_features], dtype=tf.float32, 24 | initializer=tf.constant_initializer(1), trainable=False) 25 | 26 | centers_unit = tf.nn.l2_normalize(centers,dim=1) 27 | features_unit = tf.nn.l2_normalize(centers,dim=1) 28 | 29 | label = tf.reshape(label, [-1]) 30 | centers_batch = tf.gather(centers_unit, label) 31 | 32 | cos_theta = tf.reduce_sum( tf.multiply(features_unit , centers_batch),axis=1) 33 | 34 | diff = (1 - alfa) * (centers_batch - features) 35 | centers = tf.scatter_sub(centers, label, diff) 36 | 37 | loss = tf.reduce_mean(tf.square( cos_theta)) 38 | return loss, centers 39 | 40 | 41 | def center_loss(features, label, alfa, nrof_classes): 42 | """Center loss based on the paper "A Discriminative Feature Learning Approach for Deep Face Recognition" 43 | (http://ydwen.github.io/papers/WenECCV16.pdf) 44 | """ 45 | nrof_features = features.get_shape()[1] 46 | centers = tf.get_variable('centers', [nrof_classes, nrof_features], dtype=tf.float32, 47 | initializer=tf.constant_initializer(0), trainable=False) 48 | label = tf.reshape(label, [-1]) 49 | centers_batch = tf.gather(centers, label) 50 | diff = (1 - alfa) * (centers_batch - features) 51 | centers = tf.scatter_sub(centers, label, diff) 52 | loss = tf.reduce_mean(tf.square(features - centers_batch)) 53 | return loss, centers 54 | 55 | 56 | def get_image_paths_and_labels(dataset): 57 | image_paths_flat = [] 58 | labels_flat = [] 59 | for i in range(len(dataset)): 60 | image_paths_flat += dataset[i].image_paths 61 | labels_flat += [i] * len(dataset[i].image_paths) 62 | return image_paths_flat, labels_flat 63 | 64 | 65 | #excise 66 | def get_image_andlabes(dataset): 67 | image_path =[] 68 | label = [] 69 | 70 | for i in range(len(dataset)): 71 | image_path += dataset[i].image_paths 72 | label += [i] * len(dataset[i].image_paths) 73 | return image_path,label 74 | 75 | def shuffle_examples(image_paths, labels): 76 | shuffle_list = list(zip(image_paths, labels)) 77 | random.shuffle(shuffle_list) 78 | image_paths_shuff, labels_shuff = zip(*shuffle_list) 79 | return image_paths_shuff, labels_shuff 80 | 81 | def read_images_from_disk(input_queue): 82 | """Consumes a single filename and label as a ' '-delimited string. 83 | Args: 84 | filename_and_label_tensor: A scalar string tensor. 85 | Returns: 86 | Two tensors: the decoded image, and the string label. 87 | """ 88 | label = input_queue[1] 89 | file_contents = tf.read_file(input_queue[0]) 90 | example = tf.image.decode_image(file_contents, channels=3) 91 | return example, label 92 | 93 | def random_rotate_image(image): 94 | angle = np.random.uniform(low=-10.0, high=10.0) 95 | return misc.imrotate(image, angle, 'bicubic') 96 | 97 | def read_and_augment_data(image_list, label_list, image_size, batch_size, max_nrof_epochs, 98 | random_crop, random_flip, random_rotate, nrof_preprocess_threads, shuffle=True): 99 | 100 | images = ops.convert_to_tensor(image_list, dtype=tf.string) 101 | labels = ops.convert_to_tensor(label_list, dtype=tf.int32) 102 | 103 | # Makes an input queue 104 | input_queue = tf.train.slice_input_producer([images, labels], 105 | num_epochs=max_nrof_epochs, shuffle=shuffle) 106 | 107 | images_and_labels = [] 108 | for _ in range(nrof_preprocess_threads): 109 | image, label = read_images_from_disk(input_queue) 110 | if random_rotate: 111 | image = tf.py_func(random_rotate_image, [image], tf.uint8) 112 | if random_crop: 113 | image = tf.random_crop(image, [image_size, image_size, 3]) 114 | else: 115 | image = tf.image.resize_image_with_crop_or_pad(image, image_size, image_size) 116 | if random_flip: 117 | image = tf.image.random_flip_left_right(image) 118 | #pylint: disable=no-member 119 | image.set_shape((image_size, image_size, 3)) 120 | image = tf.image.per_image_standardization(image) 121 | images_and_labels.append([image, label]) 122 | 123 | image_batch, label_batch = tf.train.batch_join( 124 | images_and_labels, batch_size=batch_size, 125 | capacity=4 * nrof_preprocess_threads * batch_size, 126 | allow_smaller_final_batch=True) 127 | 128 | return image_batch, label_batch 129 | 130 | def _add_loss_summaries(total_loss): 131 | """Add summaries for losses. 132 | 133 | Generates moving average for all losses and associated summaries for 134 | visualizing the performance of the network. 135 | 136 | Args: 137 | total_loss: Total loss from loss(). 138 | Returns: 139 | loss_averages_op: op for generating moving averages of losses. 140 | """ 141 | # Compute the moving average of all individual losses and the total loss. 142 | loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') 143 | losses = tf.get_collection('losses') 144 | loss_averages_op = loss_averages.apply(losses + [total_loss]) 145 | 146 | # Attach a scalar summmary to all individual losses and the total loss; do the 147 | # same for the averaged version of the losses. 148 | for l in losses + [total_loss]: 149 | # Name each loss as '(raw)' and name the moving average version of the loss 150 | # as the original loss name. 151 | tf.summary.scalar(l.op.name +' (raw)', l) 152 | tf.summary.scalar(l.op.name, loss_averages.average(l)) 153 | 154 | return loss_averages_op 155 | 156 | def train(total_loss, global_step, optimizer, learning_rate, moving_average_decay, update_gradient_vars, log_histograms=True): 157 | # Generate moving averages of all losses and associated summaries. 158 | loss_averages_op = _add_loss_summaries(total_loss) 159 | 160 | # Compute gradients. 161 | with tf.control_dependencies([loss_averages_op]): 162 | if optimizer=='ADAGRAD': 163 | opt = tf.train.AdagradOptimizer(learning_rate) 164 | elif optimizer=='ADADELTA': 165 | opt = tf.train.AdadeltaOptimizer(learning_rate, rho=0.9, epsilon=1e-6) 166 | elif optimizer=='ADAM': 167 | opt = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999, epsilon=0.1) 168 | elif optimizer=='RMSPROP': 169 | opt = tf.train.RMSPropOptimizer(learning_rate, decay=0.9, momentum=0.9, epsilon=1.0) 170 | elif optimizer=='MOM': 171 | opt = tf.train.MomentumOptimizer(learning_rate, 0.9, use_nesterov=True) 172 | else: 173 | raise ValueError('Invalid optimization algorithm') 174 | 175 | grads = opt.compute_gradients(total_loss, update_gradient_vars) 176 | 177 | # Apply gradients. 178 | apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) 179 | 180 | # Add histograms for trainable variables. 181 | if log_histograms: 182 | for var in tf.trainable_variables(): 183 | tf.summary.histogram(var.op.name, var) 184 | 185 | # Add histograms for gradients. 186 | if log_histograms: 187 | for grad, var in grads: 188 | if grad is not None: 189 | tf.summary.histogram(var.op.name + '/gradients', grad) 190 | 191 | # Track the moving averages of all trainable variables. 192 | variable_averages = tf.train.ExponentialMovingAverage( 193 | moving_average_decay, global_step) 194 | variables_averages_op = variable_averages.apply(tf.trainable_variables()) 195 | 196 | with tf.control_dependencies([apply_gradient_op, variables_averages_op]): 197 | train_op = tf.no_op(name='train') 198 | 199 | return train_op 200 | 201 | def prewhiten(x): 202 | mean = np.mean(x) 203 | std = np.std(x) 204 | std_adj = np.maximum(std, 1.0/np.sqrt(x.size)) 205 | y = np.multiply(np.subtract(x, mean), 1/std_adj) 206 | return y 207 | 208 | def crop(image, random_crop, image_size): 209 | if image.shape[1]>image_size: 210 | sz1 = int(image.shape[1]//2) 211 | sz2 = int(image_size//2) 212 | if random_crop: 213 | diff = sz1-sz2 214 | (h, v) = (np.random.randint(-diff, diff+1), np.random.randint(-diff, diff+1)) 215 | else: 216 | (h, v) = (0,0) 217 | image = image[(sz1-sz2+v):(sz1+sz2+v),(sz1-sz2+h):(sz1+sz2+h),:] 218 | return image 219 | 220 | def flip(image, random_flip): 221 | if random_flip and np.random.choice([True, False]): 222 | image = np.fliplr(image) 223 | return image 224 | 225 | def to_rgb(img): 226 | w, h = img.shape 227 | ret = np.empty((w, h, 3), dtype=np.uint8) 228 | ret[:, :, 0] = ret[:, :, 1] = ret[:, :, 2] = img 229 | return ret 230 | 231 | def load_data(image_paths, do_random_crop, do_random_flip, image_size, do_prewhiten=True): 232 | nrof_samples = len(image_paths) 233 | images = np.zeros((nrof_samples, image_size, image_size, 3)) 234 | for i in range(nrof_samples): 235 | img = misc.imread(image_paths[i]) 236 | if img.ndim == 2: 237 | img = to_rgb(img) 238 | if do_prewhiten: 239 | img = prewhiten(img) 240 | img = crop(img, do_random_crop, image_size) 241 | img = flip(img, do_random_flip) 242 | images[i,:,:,:] = img 243 | return images 244 | 245 | def get_label_batch(label_data, batch_size, batch_index): 246 | nrof_examples = np.size(label_data, 0) 247 | j = batch_index*batch_size % nrof_examples 248 | if j+batch_size<=nrof_examples: 249 | batch = label_data[j:j+batch_size] 250 | else: 251 | x1 = label_data[j:nrof_examples] 252 | x2 = label_data[0:nrof_examples-j] 253 | batch = np.vstack([x1,x2]) 254 | batch_int = batch.astype(np.int64) 255 | return batch_int 256 | 257 | def get_batch(image_data, batch_size, batch_index): 258 | nrof_examples = np.size(image_data, 0) 259 | j = batch_index*batch_size % nrof_examples 260 | if j+batch_size<=nrof_examples: 261 | batch = image_data[j:j+batch_size,:,:,:] 262 | else: 263 | x1 = image_data[j:nrof_examples,:,:,:] 264 | x2 = image_data[0:nrof_examples-j,:,:,:] 265 | batch = np.vstack([x1,x2]) 266 | batch_float = batch.astype(np.float32) 267 | return batch_float 268 | 269 | def get_triplet_batch(triplets, batch_index, batch_size): 270 | ax, px, nx = triplets 271 | a = get_batch(ax, int(batch_size/3), batch_index) 272 | p = get_batch(px, int(batch_size/3), batch_index) 273 | n = get_batch(nx, int(batch_size/3), batch_index) 274 | batch = np.vstack([a, p, n]) 275 | return batch 276 | 277 | def get_learning_rate_from_file(filename, epoch): 278 | with open(filename, 'r') as f: 279 | for line in f.readlines(): 280 | line = line.split('#', 1)[0] 281 | if line: 282 | par = line.strip().split(':') 283 | e = int(par[0]) 284 | lr = float(par[1]) 285 | if e <= epoch: 286 | learning_rate = lr 287 | else: 288 | return learning_rate 289 | 290 | #name -String, image_paths : list 291 | class ImageClass(): 292 | "Stores the paths to images for a given class" 293 | def __init__(self, name, image_paths): 294 | self.name = name 295 | self.image_paths = image_paths 296 | 297 | def __str__(self): 298 | return self.name + ', ' + str(len(self.image_paths)) + ' images' 299 | 300 | def __len__(self): 301 | return len(self.image_paths) 302 | 303 | #return ImageClass list 304 | def get_dataset(paths, has_class_directories=True): 305 | dataset = [] 306 | for path in paths.split(':'): 307 | path_exp = os.path.expanduser(path) 308 | classes = os.listdir(path_exp) 309 | classes.sort() 310 | nrof_classes = len(classes) 311 | for i in range(nrof_classes): 312 | 313 | #/casia/casia_maxpy_mtcnnalign_182_160/000001...000000089 ... 314 | class_name = classes[i] 315 | #facedir=c:User/User/datasets/casia/casia_maxpy_mtcnnalign_182_160/000001.... 316 | facedir = os.path.join(path_exp, class_name) 317 | #image_paths=c:User/User/datasets/casia/casia_maxpy_mtcnnalign_182_160/0000.../....img list 318 | image_paths = get_image_paths(facedir) 319 | #class_name = 0000..., image_paths = ....img list 320 | dataset.append(ImageClass(class_name, image_paths)) 321 | 322 | return dataset 323 | 324 | def get_image_paths(facedir): 325 | image_paths = [] 326 | if os.path.isdir(facedir): 327 | images = os.listdir(facedir) 328 | image_paths = [os.path.join(facedir,img) for img in images] 329 | return image_paths 330 | 331 | def split_dataset(dataset, split_ratio, mode): 332 | if mode=='SPLIT_CLASSES': 333 | nrof_classes = len(dataset) 334 | class_indices = np.arange(nrof_classes) 335 | np.random.shuffle(class_indices) 336 | split = int(round(nrof_classes*split_ratio)) 337 | train_set = [dataset[i] for i in class_indices[0:split]] 338 | test_set = [dataset[i] for i in class_indices[split:-1]] 339 | elif mode=='SPLIT_IMAGES': 340 | train_set = [] 341 | test_set = [] 342 | min_nrof_images = 2 343 | for cls in dataset: 344 | paths = cls.image_paths 345 | np.random.shuffle(paths) 346 | split = int(round(len(paths)*split_ratio)) 347 | if split1: 381 | raise ValueError('There should not be more than one meta file in the model directory (%s)' % model_dir) 382 | meta_file = meta_files[0] 383 | meta_files = [s for s in files if '.ckpt' in s] 384 | max_step = -1 385 | for f in files: 386 | step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f) 387 | if step_str is not None and len(step_str.groups())>=2: 388 | step = int(step_str.groups()[1]) 389 | if step > max_step: 390 | max_step = step 391 | ckpt_file = step_str.groups()[0] 392 | return meta_file, ckpt_file 393 | 394 | def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds=10): 395 | assert(embeddings1.shape[0] == embeddings2.shape[0]) 396 | assert(embeddings1.shape[1] == embeddings2.shape[1]) 397 | nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) 398 | nrof_thresholds = len(thresholds) 399 | k_fold = KFold(n_splits=nrof_folds, shuffle=False) 400 | 401 | tprs = np.zeros((nrof_folds,nrof_thresholds)) 402 | fprs = np.zeros((nrof_folds,nrof_thresholds)) 403 | accuracy = np.zeros((nrof_folds)) 404 | 405 | diff = np.subtract(embeddings1, embeddings2) 406 | dist = np.sum(np.square(diff),1) 407 | indices = np.arange(nrof_pairs) 408 | 409 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): 410 | 411 | # Find the best threshold for the fold 412 | acc_train = np.zeros((nrof_thresholds)) 413 | for threshold_idx, threshold in enumerate(thresholds): 414 | _, _, acc_train[threshold_idx] = calculate_accuracy(threshold, dist[train_set], actual_issame[train_set]) 415 | best_threshold_index = np.argmax(acc_train) 416 | for threshold_idx, threshold in enumerate(thresholds): 417 | tprs[fold_idx,threshold_idx], fprs[fold_idx,threshold_idx], _ = calculate_accuracy(threshold, dist[test_set], actual_issame[test_set]) 418 | _, _, accuracy[fold_idx] = calculate_accuracy(thresholds[best_threshold_index], dist[test_set], actual_issame[test_set]) 419 | 420 | tpr = np.mean(tprs,0) 421 | fpr = np.mean(fprs,0) 422 | return tpr, fpr, accuracy 423 | 424 | def calculate_accuracy(threshold, dist, actual_issame): 425 | predict_issame = np.less(dist, threshold) 426 | tp = np.sum(np.logical_and(predict_issame, actual_issame)) 427 | fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) 428 | tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame))) 429 | fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) 430 | 431 | tpr = 0 if (tp+fn==0) else float(tp) / float(tp+fn) 432 | fpr = 0 if (fp+tn==0) else float(fp) / float(fp+tn) 433 | acc = float(tp+tn)/dist.size 434 | return tpr, fpr, acc 435 | 436 | 437 | 438 | def calculate_val(thresholds, embeddings1, embeddings2, actual_issame, far_target, nrof_folds=10): 439 | assert(embeddings1.shape[0] == embeddings2.shape[0]) 440 | assert(embeddings1.shape[1] == embeddings2.shape[1]) 441 | nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) 442 | nrof_thresholds = len(thresholds) 443 | k_fold = KFold(n_splits=nrof_folds, shuffle=False) 444 | 445 | val = np.zeros(nrof_folds) 446 | far = np.zeros(nrof_folds) 447 | 448 | diff = np.subtract(embeddings1, embeddings2) 449 | dist = np.sum(np.square(diff),1) 450 | indices = np.arange(nrof_pairs) 451 | 452 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): 453 | 454 | # Find the threshold that gives FAR = far_target 455 | far_train = np.zeros(nrof_thresholds) 456 | for threshold_idx, threshold in enumerate(thresholds): 457 | _, far_train[threshold_idx] = calculate_val_far(threshold, dist[train_set], actual_issame[train_set]) 458 | if np.max(far_train)>=far_target: 459 | f = interpolate.interp1d(far_train, thresholds, kind='slinear') 460 | threshold = f(far_target) 461 | else: 462 | threshold = 0.0 463 | 464 | val[fold_idx], far[fold_idx] = calculate_val_far(threshold, dist[test_set], actual_issame[test_set]) 465 | 466 | val_mean = np.mean(val) 467 | far_mean = np.mean(far) 468 | val_std = np.std(val) 469 | return val_mean, val_std, far_mean 470 | 471 | 472 | def calculate_val_far(threshold, dist, actual_issame): 473 | predict_issame = np.less(dist, threshold) 474 | true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) 475 | false_accept = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) 476 | n_same = np.sum(actual_issame) 477 | n_diff = np.sum(np.logical_not(actual_issame)) 478 | val = float(true_accept) / float(n_same) 479 | far = float(false_accept) / float(n_diff) 480 | return val, far 481 | 482 | def store_revision_info(src_path, output_dir, arg_string): 483 | 484 | # Get git hash 485 | gitproc = Popen(['git', 'rev-parse', 'HEAD'], stdout = PIPE, cwd=src_path) 486 | (stdout, _) = gitproc.communicate() 487 | git_hash = stdout.strip() 488 | 489 | # Get local changes 490 | gitproc = Popen(['git', 'diff', 'HEAD'], stdout = PIPE, cwd=src_path) 491 | (stdout, _) = gitproc.communicate() 492 | git_diff = stdout.strip() 493 | 494 | # Store a text file in the log directory 495 | rev_info_filename = os.path.join(output_dir, 'revision_info.txt') 496 | with open(rev_info_filename, "w") as text_file: 497 | text_file.write('arguments: %s\n--------------------\n' % arg_string) 498 | text_file.write('git hash: %s\n--------------------\n' % git_hash) 499 | text_file.write('%s' % git_diff) 500 | 501 | def list_variables(filename): 502 | reader = training.NewCheckpointReader(filename) 503 | variable_map = reader.get_variable_to_shape_map() 504 | names = sorted(variable_map.keys()) 505 | return names 506 | 507 | def put_images_on_grid(images, shape=(16,8)): 508 | nrof_images = images.shape[0] 509 | img_size = images.shape[1] 510 | bw = 3 511 | img = np.zeros((shape[1]*(img_size+bw)+bw, shape[0]*(img_size+bw)+bw, 3), np.float32) 512 | for i in range(shape[1]): 513 | x_start = i*(img_size+bw)+bw 514 | for j in range(shape[0]): 515 | img_index = i*shape[0]+j 516 | if img_index>=nrof_images: 517 | break 518 | y_start = j*(img_size+bw)+bw 519 | img[x_start:x_start+img_size, y_start:y_start+img_size, :] = images[img_index, :, :, :] 520 | if img_index>=nrof_images: 521 | break 522 | return img 523 | 524 | def write_arguments_to_file(args, filename): 525 | with open(filename, 'w') as f: 526 | for key, value in vars(args).items(): 527 | f.write('%s: %s\n' % (key, str(value))) 528 | -------------------------------------------------------------------------------- /test/.ipynb_checkpoints/sphereloss-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import tensorflow as tf\n", 12 | "import tensorflow.contrib.slim as slim\n", 13 | "import numpy as np\n", 14 | "from math import pi\n", 15 | "from builtins import range\n", 16 | "import numpy as np" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": { 23 | "collapsed": true 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "def sphereloss(inputs,label,classes,fraction = 1, scope='Logits',reuse=None,m =4,eplion = 1e-8):\n", 28 | " \"\"\"\n", 29 | " inputs tensor shape=[batch,features_num]\n", 30 | " labels tensor shape=[batch] each unit belong num_outputs\n", 31 | " \n", 32 | " \"\"\"\n", 33 | " inputs_shape = inputs.get_shape().as_list()\n", 34 | " with tf.variable_scope(name_or_scope=scope):\n", 35 | " weight = tf.Variable(initial_value=tf.random_normal((classes,inputs_shape[1])) * tf.sqrt(2 / inputs_shape[1]),dtype=tf.float32,name='weight') # shaep =classes, features,\n", 36 | " print(\"weight shape = \",weight.get_shape().as_list())\n", 37 | " \n", 38 | " weight_unit = tf.nn.l2_normalize(weight,dim=1)\n", 39 | " print(\"weight_unit shape = \",weight_unit.get_shape().as_list())\n", 40 | " \n", 41 | " inputs_mo = tf.sqrt(tf.reduce_sum(tf.square(inputs),axis=1)+eplion) #shape=[batch\n", 42 | " print(\"inputs_mo shape = \",inputs_mo.get_shape().as_list())\n", 43 | " \n", 44 | " inputs_unit = tf.nn.l2_normalize(inputs,dim=1) #shape = [batch,features_num]\n", 45 | " print(\"inputs_unit shape = \",inputs_unit.get_shape().as_list())\n", 46 | " \n", 47 | " logits = tf.matmul(inputs,tf.transpose(weight_unit)) #shape = [batch,classes] x * w_unit\n", 48 | " print(\"logits shape = \",logits.get_shape().as_list())\n", 49 | " \n", 50 | " weight_unit_batch = tf.gather(weight_unit,label) # shaep =batch,features_num,\n", 51 | " print(\"weight_unit_batch shape = \",weight_unit_batch.get_shape().as_list())\n", 52 | " \n", 53 | " logits_inputs = tf.reduce_sum(tf.multiply(inputs,weight_unit_batch),axis=1) # shaep =batch,\n", 54 | " \n", 55 | " print(\"logits_inputs shape = \",logits_inputs.get_shape().as_list())\n", 56 | " \n", 57 | " \n", 58 | " cos_theta = tf.reduce_sum(tf.multiply(inputs_unit,weight_unit_batch),axis=1) # shaep =batch,\n", 59 | " print(\"cos_theta shape = \",cos_theta.get_shape().as_list())\n", 60 | " \n", 61 | " cos_theta_square = tf.square(cos_theta)\n", 62 | " cos_theta_biq = tf.pow(cos_theta,4)\n", 63 | " sign0 = tf.sign(cos_theta)\n", 64 | " sign2 = tf.sign(2 * cos_theta_square-1)\n", 65 | " sign3 = tf.multiply(sign2,sign0)\n", 66 | " sign4 = 2 * sign0 +sign3 -3\n", 67 | " cos_far_theta = sign3 * (8 * cos_theta_biq - 8 * cos_theta_square + 1) + sign4\n", 68 | " print(\"cos_far_theta = \",cos_far_theta.get_shape().as_list())\n", 69 | " \n", 70 | " logit_ii = tf.multiply(cos_far_theta,inputs_mo)#shape = batch \n", 71 | " print(\"logit_ii shape = \",logit_ii.get_shape().as_list())\n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " index = tf.constant(list(range(0, inputs_shape[0])), tf.int32)\n", 76 | " index_labels = tf.stack([index, label], axis = 1)\n", 77 | " index_logits = tf.scatter_nd(index_labels,tf.subtract(logit_ii,logits_inputs),logits.get_shape())\n", 78 | " print(\"index_logits shape = \",logit_ii.get_shape().as_list())\n", 79 | " \n", 80 | " logits_final = tf.add(logits,index_logits)\n", 81 | " logits_final = fraction * logits_final + (1 - fraction) * logits\n", 82 | " \n", 83 | " \n", 84 | " loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label,logits=logits_final))\n", 85 | " \n", 86 | " return logits_final,loss\n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " " 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "[[ 0.54340494 0.27836939 0.42451759 ..., 0.49216697 0.40288033\n", 103 | " 0.3542983 ]\n", 104 | " [ 0.50061432 0.44517663 0.09043279 ..., 0.91975536 0.84960733\n", 105 | " 0.25446654]\n", 106 | " [ 0.87755554 0.43513019 0.72949434 ..., 0.34675413 0.1095646 0.378327 ]\n", 107 | " ..., \n", 108 | " [ 0.37716757 0.75750263 0.29912515 ..., 0.53643677 0.63122505\n", 109 | " 0.17644686]\n", 110 | " [ 0.5396841 0.55346366 0.30105263 ..., 0.97211104 0.51628182\n", 111 | " 0.44451879]\n", 112 | " [ 0.29864351 0.8312721 0.28519946 ..., 0.09984372 0.71015638\n", 113 | " 0.37341943]]\n", 114 | "[ 282 9194 3716 2977 9688 3074 487 6080 467 4974 1458 4028 9966 5131 583\n", 115 | " 2955]\n" 116 | ] 117 | } 118 | ], 119 | "source": [ 120 | "np.random.seed(100)\n", 121 | "batch = 16\n", 122 | "feating = 128\n", 123 | "classess = 10000\n", 124 | "# inputsinputs = [[-1.0,1.0],[1.0,1.0],[1.0,1.0]]\n", 125 | "\n", 126 | "inputsinputs = np.random.rand(batch,feating) \n", 127 | "print(inputsinputs)\n", 128 | "inputslables = np.random.randint(0,classess,batch)\n", 129 | "print(inputslables)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 4, 135 | "metadata": { 136 | "collapsed": true 137 | }, 138 | "outputs": [], 139 | "source": [ 140 | "inputs_place = tf.placeholder(dtype=tf.float32,shape=(batch,feating),name='inputs')\n", 141 | "labels_place = tf.placeholder(dtype=tf.int32,shape=(batch),name='labels')" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 5, 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "name": "stdout", 151 | "output_type": "stream", 152 | "text": [ 153 | "weight shape = [10000, 128]\n", 154 | "weight_unit shape = [10000, 128]\n", 155 | "inputs_mo shape = [16]\n", 156 | "inputs_unit shape = [16, 128]\n", 157 | "logits shape = [16, 10000]\n", 158 | "weight_unit_batch shape = [16, 128]\n", 159 | "logits_inputs shape = [16]\n", 160 | "cos_theta shape = [16]\n", 161 | "cos_far_theta = [16]\n", 162 | "logit_ii shape = [16]\n", 163 | "index_logits shape = [16]\n" 164 | ] 165 | } 166 | ], 167 | "source": [ 168 | "_,loss = sphereloss(inputs_place,labels_place,classess)\n", 169 | "optimizer = tf.train.AdamOptimizer(learning_rate=0.05)\n", 170 | "train_op = optimizer.minimize(loss)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 6, 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "name": "stdout", 180 | "output_type": "stream", 181 | "text": [ 182 | "[28.798058, None]\n", 183 | "[22.013681, None]\n", 184 | "[15.798929, None]\n", 185 | "[13.435317, None]\n", 186 | "[12.685493, None]\n", 187 | "[11.346285, None]\n", 188 | "[9.4595337, None]\n", 189 | "[7.3966951, None]\n", 190 | "[5.5232167, None]\n", 191 | "[4.1032286, None]\n", 192 | "[3.2184033, None]\n", 193 | "[2.7606359, None]\n", 194 | "[2.5476418, None]\n", 195 | "[2.4361296, None]\n", 196 | "[2.3488777, None]\n", 197 | "[2.2575321, None]\n", 198 | "[2.1591675, None]\n", 199 | "[2.0600286, None]\n", 200 | "[1.9670352, None]\n", 201 | "[1.884537, None]\n", 202 | "[1.8141112, None]\n", 203 | "[1.7553788, None]\n", 204 | "[1.7070088, None]\n", 205 | "[1.6673062, None]\n", 206 | "[1.6345801, None]\n", 207 | "[1.6072876, None]\n", 208 | "[1.5841055, None]\n", 209 | "[1.5639385, None]\n", 210 | "[1.5459807, None]\n", 211 | "[1.5295949, None]\n", 212 | "[1.5144105, None]\n", 213 | "[1.5001791, None]\n", 214 | "[1.4868002, None]\n", 215 | "[1.474241, None]\n", 216 | "[1.462525, None]\n", 217 | "[1.4517064, None]\n", 218 | "[1.4418349, None]\n", 219 | "[1.432952, None]\n", 220 | "[1.42506, None]\n", 221 | "[1.418117, None]\n", 222 | "[1.4120786, None]\n", 223 | "[1.4068146, None]\n", 224 | "[1.4022279, None]\n", 225 | "[1.3981946, None]\n", 226 | "[1.394609, None]\n", 227 | "[1.3913933, None]\n", 228 | "[1.3884966, None]\n", 229 | "[1.3858668, None]\n", 230 | "[1.3834978, None]\n", 231 | "[1.3813753, None]\n", 232 | "[1.3794972, None]\n", 233 | "[1.377805, None]\n", 234 | "[1.3763015, None]\n", 235 | "[1.3749441, None]\n", 236 | "[1.3737118, None]\n", 237 | "[1.3725736, None]\n", 238 | "[1.3715333, None]\n", 239 | "[1.3705664, None]\n", 240 | "[1.3696976, None]\n", 241 | "[1.3689098, None]\n", 242 | "[1.3682067, None]\n", 243 | "[1.3675909, None]\n", 244 | "[1.3670415, None]\n", 245 | "[1.3665429, None]\n", 246 | "[1.3661082, None]\n", 247 | "[1.3656967, None]\n", 248 | "[1.3653129, None]\n", 249 | "[1.3649571, None]\n", 250 | "[1.364624, None]\n", 251 | "[1.3643059, None]\n", 252 | "[1.3640261, None]\n", 253 | "[1.3637731, None]\n", 254 | "[1.3635414, None]\n", 255 | "[1.3633375, None]\n", 256 | "[1.3631665, None]\n", 257 | "[1.3629972, None]\n", 258 | "[1.3628507, None]\n", 259 | "[1.3627158, None]\n", 260 | "[1.3625877, None]\n", 261 | "[1.3624818, None]\n", 262 | "[1.3623769, None]\n", 263 | "[1.3622849, None]\n", 264 | "[1.3622081, None]\n", 265 | "[1.3621382, None]\n", 266 | "[1.3620652, None]\n", 267 | "[1.362, None]\n", 268 | "[1.3619499, None]\n", 269 | "[1.3618976, None]\n", 270 | "[1.3618547, None]\n", 271 | "[1.3618151, None]\n", 272 | "[1.3617785, None]\n", 273 | "[1.3617359, None]\n", 274 | "[1.3617026, None]\n", 275 | "[1.3616755, None]\n", 276 | "[1.3616362, None]\n", 277 | "[1.3616104, None]\n", 278 | "[1.3615811, None]\n", 279 | "[1.361555, None]\n", 280 | "[1.361526, None]\n", 281 | "[1.3615122, None]\n", 282 | "[1.3614895, None]\n", 283 | "[1.36147, None]\n", 284 | "[1.3614599, None]\n", 285 | "[1.3614435, None]\n", 286 | "[1.3614256, None]\n", 287 | "[1.3614086, None]\n", 288 | "[1.361397, None]\n", 289 | "[1.361378, None]\n", 290 | "[1.3613749, None]\n", 291 | "[1.361356, None]\n", 292 | "[1.3613384, None]\n", 293 | "[1.3613338, None]\n", 294 | "[1.3613162, None]\n", 295 | "[1.3613122, None]\n", 296 | "[1.3613, None]\n", 297 | "[1.361288, None]\n", 298 | "[1.3612826, None]\n", 299 | "[1.3612705, None]\n", 300 | "[1.3612605, None]\n", 301 | "[1.3612491, None]\n", 302 | "[1.3612472, None]\n", 303 | "[1.3612354, None]\n", 304 | "[1.3612268, None]\n", 305 | "[1.3612187, None]\n", 306 | "[1.3612064, None]\n", 307 | "[1.3612019, None]\n", 308 | "[1.3611991, None]\n", 309 | "[1.3611846, None]\n", 310 | "[1.3611834, None]\n", 311 | "[1.3611796, None]\n", 312 | "[1.3611608, None]\n", 313 | "[1.3611588, None]\n", 314 | "[1.3611513, None]\n", 315 | "[1.3611465, None]\n", 316 | "[1.3611407, None]\n", 317 | "[1.3611329, None]\n", 318 | "[1.3611214, None]\n", 319 | "[1.3611224, None]\n", 320 | "[1.3611122, None]\n", 321 | "[1.3611006, None]\n", 322 | "[1.361099, None]\n", 323 | "[1.3610919, None]\n", 324 | "[1.3610821, None]\n", 325 | "[1.3610785, None]\n", 326 | "[1.3610735, None]\n", 327 | "[1.361064, None]\n", 328 | "[1.3610643, None]\n", 329 | "[1.3610544, None]\n", 330 | "[1.3610488, None]\n", 331 | "[1.3610449, None]\n", 332 | "[1.3610425, None]\n", 333 | "[1.3610251, None]\n", 334 | "[1.361027, None]\n", 335 | "[1.3610241, None]\n", 336 | "[1.3610187, None]\n", 337 | "[1.361007, None]\n", 338 | "[1.361002, None]\n", 339 | "[1.3610014, None]\n", 340 | "[1.3609916, None]\n", 341 | "[1.3609868, None]\n", 342 | "[1.3609872, None]\n", 343 | "[1.3609767, None]\n", 344 | "[1.3609693, None]\n", 345 | "[1.360967, None]\n", 346 | "[1.3609564, None]\n", 347 | "[1.3609531, None]\n", 348 | "[1.3609515, None]\n", 349 | "[1.3609464, None]\n", 350 | "[1.3609388, None]\n", 351 | "[1.360932, None]\n", 352 | "[1.3609326, None]\n", 353 | "[1.3609302, None]\n", 354 | "[1.3609217, None]\n", 355 | "[1.3609126, None]\n", 356 | "[1.3609116, None]\n", 357 | "[1.3609014, None]\n", 358 | "[1.3608999, None]\n", 359 | "[1.3609022, None]\n", 360 | "[1.360888, None]\n", 361 | "[1.3608881, None]\n", 362 | "[1.3608785, None]\n", 363 | "[1.3608744, None]\n", 364 | "[1.3608811, None]\n", 365 | "[1.3608711, None]\n", 366 | "[1.3608688, None]\n", 367 | "[1.3608608, None]\n", 368 | "[1.3608541, None]\n", 369 | "[1.3608596, None]\n", 370 | "[1.3608459, None]\n", 371 | "[1.3608503, None]\n", 372 | "[1.3608425, None]\n", 373 | "[1.3608347, None]\n", 374 | "[1.3608333, None]\n", 375 | "[1.3608232, None]\n", 376 | "[1.3608229, None]\n", 377 | "[1.360822, None]\n", 378 | "[1.3608146, None]\n", 379 | "[1.3608112, None]\n", 380 | "[1.3608062, None]\n", 381 | "[1.3608042, None]\n", 382 | "[1.3608019, None]\n", 383 | "[1.3607998, None]\n", 384 | "[1.3607969, None]\n", 385 | "[1.360793, None]\n", 386 | "[1.3607876, None]\n", 387 | "[1.3607824, None]\n", 388 | "[1.360777, None]\n", 389 | "[1.3607728, None]\n", 390 | "[1.3607719, None]\n", 391 | "[1.3607681, None]\n", 392 | "[1.3607724, None]\n", 393 | "[1.3607639, None]\n", 394 | "[1.3607619, None]\n", 395 | "[1.3607515, None]\n", 396 | "[1.3607494, None]\n", 397 | "[1.3607481, None]\n", 398 | "[1.3607519, None]\n", 399 | "[1.3607433, None]\n", 400 | "[1.3607445, None]\n", 401 | "[1.3607424, None]\n", 402 | "[1.3607278, None]\n", 403 | "[1.3607357, None]\n", 404 | "[1.3607287, None]\n", 405 | "[1.3607242, None]\n", 406 | "[1.3607204, None]\n", 407 | "[1.3607173, None]\n", 408 | "[1.360713, None]\n", 409 | "[1.3607132, None]\n", 410 | "[1.3607059, None]\n", 411 | "[1.3607048, None]\n", 412 | "[1.3607016, None]\n", 413 | "[1.3607044, None]\n", 414 | "[1.3607032, None]\n", 415 | "[1.3606958, None]\n", 416 | "[1.3606982, None]\n", 417 | "[1.3606921, None]\n", 418 | "[1.3606881, None]\n", 419 | "[1.3606858, None]\n", 420 | "[1.3606811, None]\n", 421 | "[1.3606766, None]\n", 422 | "[1.3606761, None]\n", 423 | "[1.3606787, None]\n", 424 | "[1.3606739, None]\n", 425 | "[1.360667, None]\n", 426 | "[1.3606656, None]\n", 427 | "[1.3606657, None]\n", 428 | "[1.3606608, None]\n", 429 | "[1.360661, None]\n", 430 | "[1.3606596, None]\n", 431 | "[1.36066, None]\n", 432 | "[1.3606548, None]\n", 433 | "[1.3606544, None]\n", 434 | "[1.3606534, None]\n", 435 | "[1.3606489, None]\n", 436 | "[1.3606441, None]\n", 437 | "[1.3606439, None]\n", 438 | "[1.3606385, None]\n", 439 | "[1.3606369, None]\n", 440 | "[1.3606359, None]\n", 441 | "[1.3606341, None]\n", 442 | "[1.3606316, None]\n", 443 | "[1.3606304, None]\n", 444 | "[1.3606286, None]\n", 445 | "[1.3606267, None]\n", 446 | "[1.3606237, None]\n", 447 | "[1.3606184, None]\n", 448 | "[1.3606191, None]\n", 449 | "[1.3606161, None]\n", 450 | "[1.3606155, None]\n", 451 | "[1.3606112, None]\n", 452 | "[1.3606111, None]\n", 453 | "[1.3606093, None]\n", 454 | "[1.3606052, None]\n", 455 | "[1.3606048, None]\n", 456 | "[1.3606048, None]\n", 457 | "[1.360605, None]\n", 458 | "[1.3606012, None]\n", 459 | "[1.3605998, None]\n", 460 | "[1.3605975, None]\n", 461 | "[1.3605914, None]\n", 462 | "[1.3605908, None]\n", 463 | "[1.3605893, None]\n", 464 | "[1.36059, None]\n", 465 | "[1.3605886, None]\n", 466 | "[1.3605853, None]\n", 467 | "[1.3605857, None]\n", 468 | "[1.3605833, None]\n", 469 | "[1.3605798, None]\n", 470 | "[1.3605828, None]\n", 471 | "[1.3605783, None]\n", 472 | "[1.360579, None]\n", 473 | "[1.3605747, None]\n", 474 | "[1.3605775, None]\n", 475 | "[1.3605781, None]\n", 476 | "[1.360574, None]\n", 477 | "[1.3605719, None]\n", 478 | "[1.3605723, None]\n", 479 | "[1.3605675, None]\n", 480 | "[1.3605683, None]\n", 481 | "[1.3605654, None]\n", 482 | "[1.3605647, None]\n", 483 | "[1.3605644, None]\n", 484 | "[1.3605622, None]\n", 485 | "[1.360563, None]\n", 486 | "[1.3605621, None]\n", 487 | "[1.3605603, None]\n", 488 | "[1.360559, None]\n", 489 | "[1.360559, None]\n", 490 | "[1.3605585, None]\n", 491 | "[1.3605564, None]\n", 492 | "[1.3605515, None]\n", 493 | "[1.3605492, None]\n", 494 | "[1.360549, None]\n", 495 | "[1.3605447, None]\n", 496 | "[1.3605459, None]\n", 497 | "[1.3605437, None]\n", 498 | "[1.3605459, None]\n", 499 | "[1.3605437, None]\n", 500 | "[1.3605419, None]\n", 501 | "[1.3605409, None]\n", 502 | "[1.3605397, None]\n", 503 | "[1.3605382, None]\n", 504 | "[1.3605388, None]\n", 505 | "[1.3605382, None]\n", 506 | "[1.3605369, None]\n", 507 | "[1.3605351, None]\n", 508 | "[1.3605323, None]\n", 509 | "[1.3605349, None]\n", 510 | "[1.3605349, None]\n", 511 | "[1.3605313, None]\n", 512 | "[1.3605309, None]\n", 513 | "[1.3605285, None]\n", 514 | "[1.360528, None]\n", 515 | "[1.3605281, None]\n", 516 | "[1.3605267, None]\n", 517 | "[1.3605261, None]\n", 518 | "[1.3605261, None]\n", 519 | "[1.3605255, None]\n", 520 | "[1.3605223, None]\n", 521 | "[1.3605224, None]\n", 522 | "[1.3605227, None]\n", 523 | "[1.3605223, None]\n", 524 | "[1.3605227, None]\n", 525 | "[1.3605201, None]\n", 526 | "[1.3605181, None]\n", 527 | "[1.3605177, None]\n", 528 | "[1.3605196, None]\n", 529 | "[1.3605171, None]\n", 530 | "[1.3605158, None]\n", 531 | "[1.3605144, None]\n", 532 | "[1.3605165, None]\n", 533 | "[1.3605158, None]\n", 534 | "[1.3605144, None]\n", 535 | "[1.3605145, None]\n", 536 | "[1.3605134, None]\n", 537 | "[1.3605131, None]\n", 538 | "[1.3605111, None]\n", 539 | "[1.3605098, None]\n", 540 | "[1.3605087, None]\n", 541 | "[1.3605084, None]\n", 542 | "[1.360508, None]\n", 543 | "[1.3605077, None]\n", 544 | "[1.360508, None]\n", 545 | "[1.3605061, None]\n", 546 | "[1.3605061, None]\n", 547 | "[1.3605049, None]\n", 548 | "[1.3605019, None]\n", 549 | "[1.3605043, None]\n", 550 | "[1.360503, None]\n", 551 | "[1.3605008, None]\n", 552 | "[1.3604989, None]\n", 553 | "[1.3604987, None]\n", 554 | "[1.360498, None]\n", 555 | "[1.360498, None]\n", 556 | "[1.3604956, None]\n", 557 | "[1.3604977, None]\n", 558 | "[1.3604954, None]\n", 559 | "[1.3604962, None]\n", 560 | "[1.360498, None]\n", 561 | "[1.3604977, None]\n", 562 | "[1.3604957, None]\n", 563 | "[1.3604968, None]\n", 564 | "[1.3604938, None]\n", 565 | "[1.3604937, None]\n", 566 | "[1.3604914, None]\n", 567 | "[1.360492, None]\n", 568 | "[1.3604913, None]\n", 569 | "[1.3604912, None]\n", 570 | "[1.3604898, None]\n", 571 | "[1.36049, None]\n", 572 | "[1.3604873, None]\n", 573 | "[1.3604846, None]\n", 574 | "[1.3604844, None]\n", 575 | "[1.3604836, None]\n", 576 | "[1.3604845, None]\n", 577 | "[1.3604846, None]\n", 578 | "[1.3604861, None]\n", 579 | "[1.3604875, None]\n", 580 | "[1.3604866, None]\n", 581 | "[1.3604854, None]\n", 582 | "[1.3604875, None]\n", 583 | "[1.3604851, None]\n", 584 | "[1.3604836, None]\n", 585 | "[1.3604836, None]\n", 586 | "[1.3604813, None]\n", 587 | "[1.360482, None]\n", 588 | "[1.360482, None]\n", 589 | "[1.3604822, None]\n", 590 | "[1.3604785, None]\n", 591 | "[1.3604805, None]\n", 592 | "[1.3604803, None]\n", 593 | "[1.3604807, None]\n", 594 | "[1.3604786, None]\n", 595 | "[1.3604782, None]\n", 596 | "[1.3604753, None]\n", 597 | "[1.3604755, None]\n", 598 | "[1.3604745, None]\n", 599 | "[1.3604743, None]\n", 600 | "[1.3604729, None]\n", 601 | "[1.3604747, None]\n", 602 | "[1.3604743, None]\n", 603 | "[1.3604729, None]\n", 604 | "[1.3604736, None]\n", 605 | "[1.3604716, None]\n", 606 | "[1.3604699, None]\n", 607 | "[1.360468, None]\n", 608 | "[1.3604715, None]\n", 609 | "[1.3604712, None]\n", 610 | "[1.3604712, None]\n", 611 | "[1.3604708, None]\n", 612 | "[1.3604695, None]\n", 613 | "[1.360469, None]\n", 614 | "[1.3604705, None]\n", 615 | "[1.3604693, None]\n", 616 | "[1.3604724, None]\n", 617 | "[1.3604687, None]\n", 618 | "[1.3604704, None]\n", 619 | "[1.3604703, None]\n", 620 | "[1.3604716, None]\n", 621 | "[1.3604703, None]\n", 622 | "[1.3604674, None]\n", 623 | "[1.3604668, None]\n", 624 | "[1.3604646, None]\n", 625 | "[1.3604633, None]\n", 626 | "[1.3604618, None]\n", 627 | "[1.36046, None]\n", 628 | "[1.3604622, None]\n", 629 | "[1.3604629, None]\n", 630 | "[1.3604609, None]\n", 631 | "[1.3604593, None]\n", 632 | "[1.3604608, None]\n", 633 | "[1.3604633, None]\n", 634 | "[1.3604617, None]\n", 635 | "[1.3604594, None]\n", 636 | "[1.3604612, None]\n", 637 | "[1.3604605, None]\n", 638 | "[1.3604623, None]\n", 639 | "[1.36046, None]\n", 640 | "[1.3604598, None]\n" 641 | ] 642 | }, 643 | { 644 | "name": "stdout", 645 | "output_type": "stream", 646 | "text": [ 647 | "[1.3604589, None]\n", 648 | "[1.3604591, None]\n", 649 | "[1.3604579, None]\n", 650 | "[1.3604562, None]\n", 651 | "[1.3604577, None]\n", 652 | "[1.360456, None]\n", 653 | "[1.360455, None]\n", 654 | "[1.3604559, None]\n", 655 | "[1.3604553, None]\n", 656 | "[1.3604565, None]\n", 657 | "[1.3604563, None]\n", 658 | "[1.360455, None]\n", 659 | "[1.360455, None]\n", 660 | "[1.3604553, None]\n", 661 | "[1.3604541, None]\n", 662 | "[1.3604548, None]\n", 663 | "[1.3604541, None]\n", 664 | "[1.3604541, None]\n", 665 | "[1.3604527, None]\n", 666 | "[1.3604522, None]\n", 667 | "[1.3604528, None]\n", 668 | "[1.3604532, None]\n", 669 | "[1.3604512, None]\n", 670 | "[1.3604516, None]\n", 671 | "[1.3604524, None]\n", 672 | "[1.3604527, None]\n", 673 | "[1.3604497, None]\n", 674 | "[1.3604493, None]\n", 675 | "[1.3604485, None]\n", 676 | "[1.3604513, None]\n", 677 | "[1.3604528, None]\n", 678 | "[1.3604487, None]\n", 679 | "[1.3604469, None]\n", 680 | "[1.3604493, None]\n", 681 | "[1.3604496, None]\n", 682 | "[1.3604503, None]\n", 683 | "[1.3604494, None]\n", 684 | "[1.3604501, None]\n", 685 | "[1.3604507, None]\n", 686 | "[1.3604501, None]\n", 687 | "[1.3604491, None]\n", 688 | "[1.3604496, None]\n", 689 | "[1.3604499, None]\n", 690 | "[1.3604475, None]\n", 691 | "[1.3604467, None]\n", 692 | "[1.3604462, None]\n", 693 | "[1.3604478, None]\n", 694 | "[1.3604475, None]\n", 695 | "[1.360445, None]\n", 696 | "[1.3604461, None]\n", 697 | "[1.3604487, None]\n", 698 | "[1.3604486, None]\n", 699 | "[1.360446, None]\n", 700 | "[1.3604457, None]\n", 701 | "[1.3604449, None]\n", 702 | "[1.3604445, None]\n", 703 | "[1.3604444, None]\n", 704 | "[1.3604429, None]\n", 705 | "[1.360445, None]\n", 706 | "[1.3604469, None]\n", 707 | "[1.3604497, None]\n", 708 | "[1.3604454, None]\n", 709 | "[1.3604436, None]\n", 710 | "[1.3604431, None]\n", 711 | "[1.3604397, None]\n", 712 | "[1.36044, None]\n", 713 | "[1.3604395, None]\n", 714 | "[1.3604413, None]\n", 715 | "[1.3604409, None]\n", 716 | "[1.3604414, None]\n", 717 | "[1.3604426, None]\n", 718 | "[1.360442, None]\n", 719 | "[1.3604419, None]\n", 720 | "[1.360441, None]\n", 721 | "[1.3604395, None]\n", 722 | "[1.3604409, None]\n", 723 | "[1.3604407, None]\n", 724 | "[1.3604411, None]\n", 725 | "[1.3604399, None]\n", 726 | "[1.3604391, None]\n", 727 | "[1.3604399, None]\n", 728 | "[1.3604398, None]\n", 729 | "[1.3604392, None]\n", 730 | "[1.3604378, None]\n", 731 | "[1.3604375, None]\n", 732 | "[1.3604369, None]\n", 733 | "[1.3604362, None]\n", 734 | "[1.3604376, None]\n", 735 | "[1.3604379, None]\n", 736 | "[1.3604394, None]\n", 737 | "[1.3604398, None]\n", 738 | "[1.3604391, None]\n", 739 | "[1.3604403, None]\n", 740 | "[1.3604391, None]\n", 741 | "[1.3604394, None]\n", 742 | "[1.3604375, None]\n", 743 | "[1.3604392, None]\n", 744 | "[1.3604352, None]\n", 745 | "[1.3604355, None]\n", 746 | "[1.3604343, None]\n", 747 | "[1.3604351, None]\n", 748 | "[1.360438, None]\n", 749 | "[1.3604381, None]\n", 750 | "[1.3604383, None]\n", 751 | "[1.3604379, None]\n", 752 | "[1.3604364, None]\n", 753 | "[1.3604376, None]\n", 754 | "[1.3604357, None]\n", 755 | "[1.3604338, None]\n", 756 | "[1.3604321, None]\n", 757 | "[1.3604355, None]\n", 758 | "[1.3604357, None]\n", 759 | "[1.3604352, None]\n", 760 | "[1.3604356, None]\n", 761 | "[1.3604376, None]\n", 762 | "[1.3604378, None]\n", 763 | "[1.36044, None]\n", 764 | "[1.360438, None]\n", 765 | "[1.3604378, None]\n", 766 | "[1.3604367, None]\n", 767 | "[1.3604358, None]\n", 768 | "[1.3604348, None]\n", 769 | "[1.3604372, None]\n", 770 | "[1.3604383, None]\n", 771 | "[1.3604379, None]\n", 772 | "[1.3604367, None]\n", 773 | "[1.360436, None]\n", 774 | "[1.360435, None]\n", 775 | "[1.3604355, None]\n", 776 | "[1.3604378, None]\n", 777 | "[1.3604331, None]\n", 778 | "[1.3604321, None]\n", 779 | "[1.3604342, None]\n", 780 | "[1.360435, None]\n", 781 | "[1.3604349, None]\n", 782 | "[1.3604363, None]\n", 783 | "[1.360435, None]\n", 784 | "[1.3604364, None]\n", 785 | "[1.3604354, None]\n", 786 | "[1.3604329, None]\n", 787 | "[1.3604338, None]\n", 788 | "[1.3604351, None]\n", 789 | "[1.3604313, None]\n", 790 | "[1.3604326, None]\n", 791 | "[1.3604326, None]\n", 792 | "[1.360433, None]\n", 793 | "[1.3604329, None]\n", 794 | "[1.3604329, None]\n", 795 | "[1.3604331, None]\n", 796 | "[1.3604341, None]\n", 797 | "[1.3604329, None]\n", 798 | "[1.3604306, None]\n", 799 | "[1.3604308, None]\n", 800 | "[1.360431, None]\n", 801 | "[1.3604319, None]\n", 802 | "[1.3604324, None]\n", 803 | "[1.3604325, None]\n", 804 | "[1.3604331, None]\n", 805 | "[1.3604326, None]\n", 806 | "[1.3604331, None]\n", 807 | "[1.3604305, None]\n", 808 | "[1.3604348, None]\n", 809 | "[1.3604363, None]\n", 810 | "[1.3604343, None]\n", 811 | "[1.3604345, None]\n", 812 | "[1.3604338, None]\n", 813 | "[1.3604319, None]\n", 814 | "[1.3604341, None]\n", 815 | "[1.3604336, None]\n", 816 | "[1.360433, None]\n", 817 | "[1.3604335, None]\n", 818 | "[1.3604336, None]\n", 819 | "[1.3604324, None]\n", 820 | "[1.3604321, None]\n", 821 | "[1.360431, None]\n", 822 | "[1.3604298, None]\n", 823 | "[1.3604317, None]\n", 824 | "[1.3604319, None]\n", 825 | "[1.3604339, None]\n", 826 | "[1.3604307, None]\n", 827 | "[1.36043, None]\n", 828 | "[1.3604279, None]\n", 829 | "[1.3604274, None]\n", 830 | "[1.360428, None]\n", 831 | "[1.3604295, None]\n", 832 | "[1.3604305, None]\n", 833 | "[1.360433, None]\n", 834 | "[1.3604312, None]\n", 835 | "[1.3604314, None]\n", 836 | "[1.3604312, None]\n", 837 | "[1.3604314, None]\n", 838 | "[1.3604307, None]\n", 839 | "[1.3604276, None]\n", 840 | "[1.36043, None]\n", 841 | "[1.3604295, None]\n", 842 | "[1.3604283, None]\n", 843 | "[1.3604296, None]\n", 844 | "[1.3604307, None]\n", 845 | "[1.3604293, None]\n", 846 | "[1.3604269, None]\n", 847 | "[1.3604298, None]\n", 848 | "[1.3604293, None]\n", 849 | "[1.360431, None]\n", 850 | "[1.3604283, None]\n", 851 | "[1.3604295, None]\n", 852 | "[1.3604281, None]\n", 853 | "[1.360429, None]\n", 854 | "[1.3604276, None]\n", 855 | "[1.3604293, None]\n", 856 | "[1.3604281, None]\n", 857 | "[1.36043, None]\n", 858 | "[1.3604287, None]\n", 859 | "[1.3604318, None]\n", 860 | "[1.3604293, None]\n", 861 | "[1.360429, None]\n", 862 | "[1.360431, None]\n", 863 | "[1.360431, None]\n", 864 | "[1.3604298, None]\n", 865 | "[1.3604295, None]\n", 866 | "[1.3604302, None]\n", 867 | "[1.3604301, None]\n", 868 | "[1.3604285, None]\n", 869 | "[1.3604292, None]\n", 870 | "[1.3604294, None]\n", 871 | "[1.3604269, None]\n", 872 | "[1.3604283, None]\n", 873 | "[1.3604267, None]\n", 874 | "[1.3604273, None]\n", 875 | "[1.3604298, None]\n", 876 | "[1.3604314, None]\n", 877 | "[1.3604314, None]\n", 878 | "[1.3604289, None]\n", 879 | "[1.3604271, None]\n", 880 | "[1.3604271, None]\n", 881 | "[1.3604275, None]\n", 882 | "[1.3604274, None]\n", 883 | "[1.3604264, None]\n", 884 | "[1.3604262, None]\n", 885 | "[1.3604267, None]\n", 886 | "[1.3604271, None]\n", 887 | "[1.3604273, None]\n", 888 | "[1.3604271, None]\n", 889 | "[1.3604261, None]\n", 890 | "[1.3604271, None]\n", 891 | "[1.3604295, None]\n", 892 | "[1.3604293, None]\n", 893 | "[1.3604302, None]\n", 894 | "[1.3604292, None]\n", 895 | "[1.360428, None]\n", 896 | "[1.3604281, None]\n", 897 | "[1.3604275, None]\n", 898 | "[1.36043, None]\n", 899 | "[1.3604298, None]\n", 900 | "[1.3604288, None]\n", 901 | "[1.3604269, None]\n", 902 | "[1.3604301, None]\n", 903 | "[1.3604273, None]\n", 904 | "[1.3604276, None]\n", 905 | "[1.3604263, None]\n", 906 | "[1.3604269, None]\n", 907 | "[1.3604264, None]\n", 908 | "[1.3604271, None]\n", 909 | "[1.3604283, None]\n", 910 | "[1.3604276, None]\n", 911 | "[1.3604282, None]\n", 912 | "[1.3604295, None]\n", 913 | "[1.360431, None]\n", 914 | "[1.3604285, None]\n", 915 | "[1.3604279, None]\n", 916 | "[1.360428, None]\n", 917 | "[1.3604285, None]\n", 918 | "[1.3604294, None]\n", 919 | "[1.3604258, None]\n", 920 | "[1.3604246, None]\n", 921 | "[1.3604274, None]\n", 922 | "[1.3604257, None]\n", 923 | "[1.3604282, None]\n", 924 | "[1.3604288, None]\n", 925 | "[1.360427, None]\n", 926 | "[1.36043, None]\n", 927 | "[1.360429, None]\n", 928 | "[1.3604261, None]\n", 929 | "[1.3604274, None]\n", 930 | "[1.360427, None]\n", 931 | "[1.3604282, None]\n", 932 | "[1.3604257, None]\n", 933 | "[1.360425, None]\n", 934 | "[1.3604255, None]\n", 935 | "[1.3604276, None]\n", 936 | "[1.3604242, None]\n", 937 | "[1.3604239, None]\n", 938 | "[1.360423, None]\n", 939 | "[1.3604244, None]\n", 940 | "[1.3604262, None]\n", 941 | "[1.3604252, None]\n", 942 | "[1.360425, None]\n", 943 | "[1.3604248, None]\n", 944 | "[1.3604254, None]\n", 945 | "[1.3604243, None]\n", 946 | "[1.3604245, None]\n", 947 | "[1.3604279, None]\n", 948 | "[1.3604275, None]\n", 949 | "[1.3604265, None]\n", 950 | "[1.3604273, None]\n", 951 | "[1.3604257, None]\n", 952 | "[1.3604243, None]\n", 953 | "[1.3604246, None]\n", 954 | "[1.3604252, None]\n", 955 | "[1.3604249, None]\n", 956 | "[1.3604231, None]\n", 957 | "[1.3604263, None]\n", 958 | "[1.3604259, None]\n", 959 | "[1.3604276, None]\n", 960 | "[1.3604275, None]\n", 961 | "[1.3604251, None]\n", 962 | "[1.3604236, None]\n", 963 | "[1.3604252, None]\n", 964 | "[1.3604267, None]\n", 965 | "[1.3604243, None]\n", 966 | "[1.3604249, None]\n", 967 | "[1.3604248, None]\n", 968 | "[1.3604262, None]\n", 969 | "[1.3604263, None]\n", 970 | "[1.3604248, None]\n", 971 | "[1.3604274, None]\n", 972 | "[1.3604275, None]\n", 973 | "[1.360425, None]\n", 974 | "[1.3604255, None]\n", 975 | "[1.3604246, None]\n", 976 | "[1.3604271, None]\n", 977 | "[1.3604283, None]\n", 978 | "[1.3604262, None]\n", 979 | "[1.3604264, None]\n", 980 | "[1.3604268, None]\n", 981 | "[1.3604243, None]\n", 982 | "[1.3604248, None]\n", 983 | "[1.3604245, None]\n", 984 | "[1.3604263, None]\n", 985 | "[1.3604263, None]\n", 986 | "[1.3604258, None]\n", 987 | "[1.3604271, None]\n", 988 | "[1.3604263, None]\n", 989 | "[1.3604257, None]\n", 990 | "[1.3604239, None]\n", 991 | "[1.3604259, None]\n", 992 | "[1.3604271, None]\n", 993 | "[1.3604268, None]\n", 994 | "[1.3604255, None]\n", 995 | "[1.3604261, None]\n", 996 | "[1.3604271, None]\n", 997 | "[1.3604279, None]\n", 998 | "[1.3604274, None]\n", 999 | "[1.3604281, None]\n", 1000 | "[1.3604271, None]\n", 1001 | "[1.3604259, None]\n", 1002 | "[1.3604265, None]\n", 1003 | "[1.3604267, None]\n", 1004 | "[1.3604261, None]\n", 1005 | "[1.3604276, None]\n", 1006 | "[1.3604269, None]\n", 1007 | "[1.3604259, None]\n", 1008 | "[1.3604256, None]\n", 1009 | "[1.3604279, None]\n", 1010 | "[1.3604287, None]\n", 1011 | "[1.3604264, None]\n", 1012 | "[1.3604261, None]\n", 1013 | "[1.3604269, None]\n", 1014 | "[1.3604252, None]\n", 1015 | "[1.3604276, None]\n", 1016 | "[1.3604264, None]\n", 1017 | "[1.3604277, None]\n", 1018 | "[1.3604286, None]\n", 1019 | "[1.3604273, None]\n", 1020 | "[1.360427, None]\n", 1021 | "[1.3604261, None]\n", 1022 | "[1.3604271, None]\n", 1023 | "[1.3604256, None]\n", 1024 | "[1.3604269, None]\n", 1025 | "[1.3604234, None]\n", 1026 | "[1.3604236, None]\n", 1027 | "[1.3604265, None]\n", 1028 | "[1.3604293, None]\n", 1029 | "[1.3604304, None]\n", 1030 | "[1.3604267, None]\n", 1031 | "[1.3604267, None]\n", 1032 | "[1.3604252, None]\n", 1033 | "[1.3604257, None]\n", 1034 | "[1.3604243, None]\n", 1035 | "[1.3604269, None]\n", 1036 | "[1.3604283, None]\n", 1037 | "[1.3604264, None]\n", 1038 | "[1.3604258, None]\n", 1039 | "[1.3604232, None]\n", 1040 | "[1.3604242, None]\n", 1041 | "[1.3604244, None]\n", 1042 | "[1.3604279, None]\n", 1043 | "[1.3604269, None]\n", 1044 | "[1.3604261, None]\n", 1045 | "[1.360427, None]\n", 1046 | "[1.3604258, None]\n", 1047 | "[1.3604244, None]\n", 1048 | "[1.3604242, None]\n", 1049 | "[1.3604244, None]\n", 1050 | "[1.3604263, None]\n", 1051 | "[1.3604267, None]\n", 1052 | "[1.3604273, None]\n", 1053 | "[1.3604286, None]\n", 1054 | "[1.3604256, None]\n", 1055 | "[1.3604264, None]\n", 1056 | "[1.3604249, None]\n", 1057 | "[1.360425, None]\n", 1058 | "[1.3604254, None]\n", 1059 | "[1.3604261, None]\n", 1060 | "[1.3604252, None]\n", 1061 | "[1.3604255, None]\n", 1062 | "[1.3604243, None]\n", 1063 | "[1.360425, None]\n", 1064 | "[1.3604245, None]\n", 1065 | "[1.3604258, None]\n", 1066 | "[1.3604255, None]\n", 1067 | "[1.3604252, None]\n", 1068 | "[1.3604255, None]\n", 1069 | "[1.360425, None]\n", 1070 | "[1.360428, None]\n", 1071 | "[1.3604273, None]\n", 1072 | "[1.3604264, None]\n", 1073 | "[1.3604252, None]\n", 1074 | "[1.3604249, None]\n", 1075 | "[1.3604248, None]\n", 1076 | "[1.3604248, None]\n", 1077 | "[1.3604236, None]\n", 1078 | "[1.3604268, None]\n", 1079 | "[1.360425, None]\n", 1080 | "[1.3604256, None]\n", 1081 | "[1.360424, None]\n", 1082 | "[1.3604226, None]\n", 1083 | "[1.3604225, None]\n", 1084 | "[1.3604236, None]\n", 1085 | "[1.3604262, None]\n", 1086 | "[1.3604259, None]\n", 1087 | "[1.3604248, None]\n", 1088 | "[1.3604267, None]\n", 1089 | "[1.3604249, None]\n", 1090 | "[1.3604257, None]\n", 1091 | "[1.3604252, None]\n", 1092 | "[1.3604257, None]\n", 1093 | "[1.3604264, None]\n", 1094 | "[1.3604224, None]\n", 1095 | "[1.3604233, None]\n", 1096 | "[1.360424, None]\n", 1097 | "[1.3604245, None]\n", 1098 | "[1.360425, None]\n", 1099 | "[1.3604236, None]\n", 1100 | "[1.3604234, None]\n", 1101 | "[1.3604234, None]\n", 1102 | "[1.3604246, None]\n", 1103 | "[1.360424, None]\n", 1104 | "[1.3604255, None]\n", 1105 | "[1.3604273, None]\n" 1106 | ] 1107 | }, 1108 | { 1109 | "name": "stdout", 1110 | "output_type": "stream", 1111 | "text": [ 1112 | "[1.3604263, None]\n", 1113 | "[1.3604246, None]\n", 1114 | "[1.3604249, None]\n", 1115 | "[1.3604234, None]\n", 1116 | "[1.360425, None]\n", 1117 | "[1.3604245, None]\n", 1118 | "[1.3604246, None]\n", 1119 | "[1.3604239, None]\n", 1120 | "[1.3604271, None]\n", 1121 | "[1.3604264, None]\n", 1122 | "[1.3604269, None]\n", 1123 | "[1.3604271, None]\n", 1124 | "[1.3604252, None]\n", 1125 | "[1.3604259, None]\n", 1126 | "[1.3604225, None]\n", 1127 | "[1.3604227, None]\n", 1128 | "[1.3604248, None]\n", 1129 | "[1.3604262, None]\n", 1130 | "[1.3604263, None]\n", 1131 | "[1.3604252, None]\n", 1132 | "[1.3604264, None]\n", 1133 | "[1.3604258, None]\n", 1134 | "[1.3604263, None]\n", 1135 | "[1.3604259, None]\n", 1136 | "[1.360425, None]\n", 1137 | "[1.3604273, None]\n", 1138 | "[1.3604286, None]\n", 1139 | "[1.360429, None]\n", 1140 | "[1.3604264, None]\n", 1141 | "[1.3604259, None]\n", 1142 | "[1.3604256, None]\n", 1143 | "[1.3604267, None]\n", 1144 | "[1.3604224, None]\n", 1145 | "[1.3604224, None]\n", 1146 | "[1.3604245, None]\n", 1147 | "[1.3604246, None]\n", 1148 | "[1.360423, None]\n", 1149 | "[1.3604246, None]\n", 1150 | "[1.3604249, None]\n", 1151 | "[1.3604242, None]\n", 1152 | "[1.3604236, None]\n", 1153 | "[1.3604255, None]\n", 1154 | "[1.360424, None]\n", 1155 | "[1.3604248, None]\n", 1156 | "[1.3604264, None]\n", 1157 | "[1.3604243, None]\n", 1158 | "[1.3604252, None]\n", 1159 | "[1.3604257, None]\n", 1160 | "[1.3604244, None]\n", 1161 | "[1.3604257, None]\n", 1162 | "[1.3604285, None]\n", 1163 | "[1.3604257, None]\n", 1164 | "[1.360424, None]\n", 1165 | "[1.360424, None]\n", 1166 | "[1.3604226, None]\n", 1167 | "[1.3604243, None]\n", 1168 | "[1.3604251, None]\n", 1169 | "[1.3604252, None]\n", 1170 | "[1.3604233, None]\n", 1171 | "[1.3604245, None]\n", 1172 | "[1.3604236, None]\n", 1173 | "[1.3604261, None]\n", 1174 | "[1.3604252, None]\n", 1175 | "[1.3604292, None]\n", 1176 | "[1.3604267, None]\n", 1177 | "[1.3604248, None]\n", 1178 | "[1.360425, None]\n", 1179 | "[1.3604238, None]\n", 1180 | "[1.3604261, None]\n", 1181 | "[1.3604243, None]\n", 1182 | "[1.3604236, None]\n", 1183 | "[1.3604259, None]\n", 1184 | "[1.3604236, None]\n", 1185 | "[1.3604255, None]\n", 1186 | "[1.360425, None]\n", 1187 | "[1.3604259, None]\n", 1188 | "[1.360425, None]\n", 1189 | "[1.360425, None]\n", 1190 | "[1.3604234, None]\n", 1191 | "[1.3604248, None]\n", 1192 | "[1.3604254, None]\n", 1193 | "[1.3604251, None]\n" 1194 | ] 1195 | } 1196 | ], 1197 | "source": [ 1198 | "with tf.Session() as sess:\n", 1199 | " sess.run(tf.global_variables_initializer())\n", 1200 | " \n", 1201 | " for eopch in range(1000): \n", 1202 | " print(sess.run([loss,train_op],feed_dict={inputs_place:inputsinputs,labels_place:inputslables}))" 1203 | ] 1204 | }, 1205 | { 1206 | "cell_type": "code", 1207 | "execution_count": null, 1208 | "metadata": { 1209 | "collapsed": true 1210 | }, 1211 | "outputs": [], 1212 | "source": [ 1213 | "def sphereloss_mine(inputs,label,classes,weights_decay = 0.05,scope='Logits',reuse=None,m =4,eplion = 1e-12):\n", 1214 | " \n", 1215 | " \"\"\"\n", 1216 | " inputs tensor shape=[batch,features_num]\n", 1217 | " labels tensor shape=[batch] each unit belong num_outputs\n", 1218 | " \n", 1219 | " \"\"\"\n", 1220 | " features_num = inputs.get_shape().as_list()[1]\n", 1221 | "\n", 1222 | " with tf.variable_scope(name_or_scope=scope):\n", 1223 | " weight = tf.Variable(initial_value=tf.random_normal((classes,features_num),stddev=0.01),dtype=tf.float32,name='weight') # shaep =classes, features,\n", 1224 | " print(\"weight shape = \",weight.get_shape().as_list())\n", 1225 | " \n", 1226 | " weight_unit = tf.nn.l2_normalize(weight,dim=1)\n", 1227 | " print(\"weight_unit shape = \",weight_unit.get_shape().as_list())\n", 1228 | " \n", 1229 | " inputs_mo = tf.sqrt(tf.reduce_sum(tf.square(inputs),axis=1)+eplion) #shape=[batch\n", 1230 | " print(\"inputs_mo shape = \",inputs_mo.get_shape().as_list())\n", 1231 | " \n", 1232 | " inputs_unit = tf.nn.l2_normalize(inputs,dim=1) #shape = [batch,features_num]\n", 1233 | " print(\"inputs_unit shape = \",inputs_unit.get_shape().as_list())\n", 1234 | " \n", 1235 | " logits = tf.matmul(inputs,tf.transpose(weight_unit)) #shape = [batch,classes] x * w_unit\n", 1236 | " print(\"logits shape = \",logits.get_shape().as_list())\n", 1237 | " \n", 1238 | " weight_unit_batch = tf.gather(weight_unit,label) # shaep =batch,features_num,\n", 1239 | " print(\"weight_unit_batch shape = \",weight_unit_batch.get_shape().as_list())\n", 1240 | " \n", 1241 | " logits_inputs = tf.reduce_sum(tf.multiply(inputs,weight_unit_batch),axis=1) # shaep =batch,\n", 1242 | " \n", 1243 | " print(\"logits_inputs shape = \",logits_inputs.get_shape().as_list())\n", 1244 | " \n", 1245 | " cos_theta = tf.reduce_sum(tf.multiply(inputs_unit,weight_unit_batch),axis=1) # shaep =batch,\n", 1246 | " print(\"print shape = \",cos_theta.get_shape().as_list())\n", 1247 | " \n", 1248 | " k = tf.Variable(initial_value=tf.zeros_like(cos_theta),trainable=False)\n", 1249 | " k = tf.assign(k,tf.floor_div(tf.acos(cos_theta),pi/m)+0.0) # shaep =batch,\n", 1250 | " print(\"k shape = \",k.get_shape().as_list())\n", 1251 | " \n", 1252 | " cos_four_theta =tf.multiply(tf.pow(cos_theta,4),8) - tf.multiply(tf.pow(cos_theta,2),8)-1 #shape = batch\n", 1253 | " print(\"cos_four_theta shape = \",cos_four_theta.get_shape().as_list())\n", 1254 | " \n", 1255 | " cos_far_theta =tf.add(tf.multiply(tf.pow(-1.0,k) , cos_four_theta), tf.multiply(-2.0 , k)) #shape = batch\n", 1256 | " print(\"cos_far_theta = \",cos_far_theta.get_shape().as_list())\n", 1257 | " \n", 1258 | " \n", 1259 | " \n", 1260 | " logit_ii = tf.multiply(cos_far_theta,inputs_mo)#shape = batch \n", 1261 | " print(\"logit_ii shape = \",logit_ii.get_shape().as_list())\n", 1262 | " \n", 1263 | " logit_ii_exp = tf.exp(logit_ii)\n", 1264 | " logits_exp = tf.reduce_sum(tf.exp(logits),axis=1) \n", 1265 | " logits_inputs_exp = tf.exp(logits_inputs)\n", 1266 | " loss_exp = tf.divide (logit_ii_exp,tf.subtract(tf.add(logit_ii_exp,logits_exp ),logits_inputs_exp)) \n", 1267 | " #shape = batch\n", 1268 | " print(\"loss_exp shape = \",loss_exp.get_shape().as_list())\n", 1269 | " loss = -tf.reduce_mean(tf.log(loss_exp))\n", 1270 | " print(\"loss shape = \",loss.get_shape().as_list())\n", 1271 | " \n", 1272 | "# return weight,weight_unit,weight_unit_batch,inputs,inputs_mo,inputs_unit, logits,logits_inputs,cos_theta,k,cos_four_theta,cos_far_theta,logit_ii,loss_exp,loss\n", 1273 | "# return weight,weight_unit,weight_unit_batch,inputs_mo,inputs_unit, logits,logits_inputs,cos_theta,k,cos_four_theta,cos_far_theta,logit_ii,loss_exp, loss\n", 1274 | " return k,loss" 1275 | ] 1276 | } 1277 | ], 1278 | "metadata": { 1279 | "kernelspec": { 1280 | "display_name": "Python 3", 1281 | "language": "python", 1282 | "name": "python3" 1283 | }, 1284 | "language_info": { 1285 | "codemirror_mode": { 1286 | "name": "ipython", 1287 | "version": 3 1288 | }, 1289 | "file_extension": ".py", 1290 | "mimetype": "text/x-python", 1291 | "name": "python", 1292 | "nbconvert_exporter": "python", 1293 | "pygments_lexer": "ipython3", 1294 | "version": "3.6.3" 1295 | } 1296 | }, 1297 | "nbformat": 4, 1298 | "nbformat_minor": 2 1299 | } 1300 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/test/__init__.py --------------------------------------------------------------------------------