├── .gitignore
├── .idea
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
├── sphereface_tensorflow.iml
├── vcs.xml
└── workspace.xml
├── README.md
├── data
└── __init__.py
├── requirements.txt
├── src
├── __init__.py
├── align
│ └── __init__.py
├── data
│ └── learning_rate_classifier_casia.txt
├── loss
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-35.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── sphere.cpython-35.pyc
│ │ └── sphere.cpython-36.pyc
│ └── sphere.py
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-35.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ └── inception_resnet.cpython-35.pyc
│ ├── inception_resnet.py
│ ├── inception_resnet_no_train.py
│ ├── small_inceptin_resnet.py
│ └── squeezenet.py
├── printPretrainedTensor.py
├── train_demo.py
├── train_softmax_demo.py
├── train_softmax_no_train_inception.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── __init__.cpython-36.pyc
│ ├── faceUtil.cpython-35.pyc
│ └── faceUtil.cpython-36.pyc
│ └── faceUtil.py
└── test
├── Loss_ASoftmax.ipynb
├── .ipynb_checkpoints
├── Loss_ASoftmax-checkpoint.ipynb
└── sphereloss-checkpoint.ipynb
├── __init__.py
└── sphereloss.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | src/align/casia_maxpy_mtcnnpy_182
2 | src/align
3 | src/modeltrained/20170512
4 | src/models/inception_resnet_v1.py
5 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
17 |
18 |
19 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/sphereface_tensorflow.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Introduction
2 |
3 | This is a simple re-implemention of a-softmax loss which proposed in paper: [SphereFace: Deep Hypersphere Embedding for Face Recognition](https://arxiv.org/abs/1704.08063). Please cite it if it helps in your paper.
4 | This just contain a sphereloss and train demo.
5 | Thanks to [sphereface](https://github.com/wy1iu/sphereface),It gives me a lot of inspiration.
6 |
7 | ## Requirements
8 | Tensorflow1.2+
9 | scipy
10 | scikit-learn
11 | opencv-python
12 | h5py
13 | matplotlib
14 | Pillow
15 | requests
16 | psutil
17 |
18 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/data/__init__.py
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow==1.2
2 | scipy
3 | scikit-learn
4 | opencv-python
5 | h5py
6 | matplotlib
7 | Pillow
8 | requests
9 | psutil
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/__init__.py
--------------------------------------------------------------------------------
/src/align/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/align/__init__.py
--------------------------------------------------------------------------------
/src/data/learning_rate_classifier_casia.txt:
--------------------------------------------------------------------------------
1 | # Learning rate schedule
2 | # Maps an epoch number to a learning rate
3 | 0: 0.001
4 | 65: 0.0001
5 | 77: 0.0001
6 | 1000: 0.0001
--------------------------------------------------------------------------------
/src/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/loss/__init__.py
--------------------------------------------------------------------------------
/src/loss/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/loss/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/src/loss/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/loss/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/src/loss/__pycache__/sphere.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/loss/__pycache__/sphere.cpython-35.pyc
--------------------------------------------------------------------------------
/src/loss/__pycache__/sphere.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/loss/__pycache__/sphere.cpython-36.pyc
--------------------------------------------------------------------------------
/src/loss/sphere.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | def sphereloss(inputs,label,classes,batch_size,fraction = 1, scope='Logits',reuse=None,m =4,eplion = 1e-8):
4 | """
5 | inputs tensor shape=[batch,features_num]
6 | labels tensor shape=[batch] each unit belong num_outputs
7 |
8 | """
9 | inputs_shape = inputs.get_shape().as_list()
10 | with tf.variable_scope(name_or_scope=scope):
11 | weight = tf.Variable(initial_value=tf.random_normal((classes,inputs_shape[1])) * tf.sqrt(2 / inputs_shape[1]),dtype=tf.float32,name='weights') # shaep =classes, features,
12 | print("weight shape = ",weight.get_shape().as_list())
13 |
14 | weight_unit = tf.nn.l2_normalize(weight,dim=1)
15 | print("weight_unit shape = ",weight_unit.get_shape().as_list())
16 |
17 | inputs_mo = tf.sqrt(tf.reduce_sum(tf.square(inputs),axis=1)+eplion) #shape=[batch
18 | print("inputs_mo shape = ",inputs_mo.get_shape().as_list())
19 |
20 | inputs_unit = tf.nn.l2_normalize(inputs,dim=1) #shape = [batch,features_num]
21 | print("inputs_unit shape = ",inputs_unit.get_shape().as_list())
22 |
23 | logits = tf.matmul(inputs,tf.transpose(weight_unit)) #shape = [batch,classes] x * w_unit
24 | print("logits shape = ",logits.get_shape().as_list())
25 |
26 | weight_unit_batch = tf.gather(weight_unit,label) # shaep =batch,features_num,
27 | print("weight_unit_batch shape = ",weight_unit_batch.get_shape().as_list())
28 |
29 | logits_inputs = tf.reduce_sum(tf.multiply(inputs,weight_unit_batch),axis=1) # shaep =batch,
30 |
31 | print("logits_inputs shape = ",logits_inputs.get_shape().as_list())
32 |
33 | cos_theta = tf.reduce_sum(tf.multiply(inputs_unit,weight_unit_batch),axis=1) # shaep =batch,
34 | print("cos_theta shape = ",cos_theta.get_shape().as_list())
35 |
36 | cos_theta_square = tf.square(cos_theta)
37 | cos_theta_biq = tf.pow(cos_theta,4)
38 | sign0 = tf.sign(cos_theta)
39 | sign2 = tf.sign(2 * cos_theta_square-1)
40 | sign3 = tf.multiply(sign2,sign0)
41 | sign4 = 2 * sign0 +sign3 -3
42 | cos_far_theta = sign3 * (8 * cos_theta_biq - 8 * cos_theta_square + 1) + sign4
43 | print("cos_far_theta = ",cos_far_theta.get_shape().as_list())
44 |
45 | logit_ii = tf.multiply(cos_far_theta,inputs_mo)#shape = batch
46 | print("logit_ii shape = ",logit_ii.get_shape().as_list())
47 |
48 | index_range = tf.range(start=0,limit= tf.shape(inputs,out_type=tf.int64)[0],delta=1,dtype=tf.int64)
49 | index_labels = tf.stack([index_range, label], axis = 1)
50 | index_logits = tf.scatter_nd(index_labels,tf.subtract(logit_ii,logits_inputs), tf.shape(logits,out_type=tf.int64))
51 | print("index_logits shape = ",logit_ii.get_shape().as_list())
52 |
53 | logits_final = tf.add(logits,index_logits)
54 | logits_final = fraction * logits_final + (1 - fraction) * logits
55 |
56 |
57 | loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label,logits=logits_final))
58 |
59 | return logits_final,loss
60 |
61 |
62 | def soft_loss(inputs,label,classes,scope='Logits'):
63 | """
64 | inputs tensor shape=[batch,features_num]
65 | labels tensor shape=[batch] each unit belong num_outputs
66 |
67 | """
68 | inputs_shape = inputs.get_shape().as_list()
69 | with tf.variable_scope(name_or_scope=scope):
70 | weight = tf.Variable(initial_value=tf.random_normal((classes,inputs_shape[1])) * tf.sqrt(2 / inputs_shape[1]),
71 | dtype=tf.float32,name='weights') # shaep =classes, features,
72 | bias = tf.Variable(initial_value=tf.zeros(classes),dtype=tf.float32,name='bias')
73 | print("weight shape = ",weight.get_shape().as_list())
74 | print("bias shape = ", bias.get_shape().as_list())
75 |
76 | weight = tf.Print(weight, [tf.shape(weight)], message='logits weights shape = ',summarize=4, first_n=1)
77 | logits = tf.nn.bias_add(tf.matmul(inputs,tf.transpose(weight)),bias,name='logits')
78 | loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label,logits=logits,
79 | name='cross_entropy_per_example'),
80 | name='cross_entropy')
81 |
82 | return logits,loss
83 |
84 | def soft_loss_nobias(inputs,label,classes,scope='Logits'):
85 | """
86 | inputs tensor shape=[batch,features_num]
87 | labels tensor shape=[batch] each unit belong num_outputs
88 |
89 | """
90 | inputs_shape = inputs.get_shape().as_list()
91 | with tf.variable_scope(name_or_scope=scope):
92 | weight = tf.Variable(initial_value=tf.random_normal((classes,inputs_shape[1])) * tf.sqrt(2 / inputs_shape[1]),
93 | dtype=tf.float32,name='weights') # shaep =classes, features,
94 | print("weight shape = ",weight.get_shape().as_list())
95 |
96 | weight = tf.Print(weight, [tf.shape(weight)], message='logits weights shape = ',summarize=4, first_n=1)
97 | logits =tf.matmul(inputs,tf.transpose(weight),name='logits')
98 | loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label,logits=logits,
99 | name='cross_entropy_per_example'),
100 | name='cross_entropy')
101 |
102 | return logits,loss
103 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 |
3 |
--------------------------------------------------------------------------------
/src/models/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/models/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/inception_resnet.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/models/__pycache__/inception_resnet.cpython-35.pyc
--------------------------------------------------------------------------------
/src/models/inception_resnet.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import tensorflow as tf
7 | import tensorflow.contrib.slim as slim
8 |
9 | # Inception-Renset-A
10 | def incep_res_a(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
11 | """Builds the 35x35 resnet block."""
12 | with tf.variable_scope(scope, 'Block35', [net], reuse=reuse):
13 | with tf.variable_scope('Branch_0'):
14 | tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1')
15 | with tf.variable_scope('Branch_1'):
16 | tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
17 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3')
18 | with tf.variable_scope('Branch_2'):
19 | tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
20 | tower_conv2_1 = slim.conv2d(tower_conv2_0, 32, 3, scope='Conv2d_0b_3x3')
21 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 32, 3, scope='Conv2d_0c_3x3')
22 | mixed = tf.concat([tower_conv, tower_conv1_1, tower_conv2_2], 3)
23 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
24 | activation_fn=None, scope='Conv2d_1x1')
25 | net += scale * up
26 | if activation_fn:
27 | net = activation_fn(net)
28 | return net
29 |
30 | # Inception-Renset-B
31 | def incep_res_b(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
32 | """Builds the 17x17 resnet block."""
33 | with tf.variable_scope(scope, 'Block17', [net], reuse=reuse):
34 | with tf.variable_scope('Branch_0'):
35 | tower_conv = slim.conv2d(net, 128, 1, scope='Conv2d_1x1')
36 | with tf.variable_scope('Branch_1'):
37 | tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1')
38 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 128, [1, 7],
39 | scope='Conv2d_0b_1x7')
40 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 128, [7, 1],
41 | scope='Conv2d_0c_7x1')
42 | mixed = tf.concat([tower_conv, tower_conv1_2], 3)
43 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
44 | activation_fn=None, scope='Conv2d_1x1')
45 | net += scale * up
46 | if activation_fn:
47 | net = activation_fn(net)
48 | return net
49 |
50 |
51 | # Inception-Resnet-C
52 | def incep_res_c(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
53 | """Builds the 8x8 resnet block."""
54 | with tf.variable_scope(scope, 'Block8', [net], reuse=reuse):
55 | with tf.variable_scope('Branch_0'):
56 | tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
57 | with tf.variable_scope('Branch_1'):
58 | tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1')
59 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 192, [1, 3],
60 | scope='Conv2d_0b_1x3')
61 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [3, 1],
62 | scope='Conv2d_0c_3x1')
63 | mixed = tf.concat([tower_conv, tower_conv1_2], 3)
64 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
65 | activation_fn=None, scope='Conv2d_1x1')
66 | net += scale * up
67 | if activation_fn:
68 | net = activation_fn(net)
69 | return net
70 |
71 | def reduction_a(net, k, l, m, n):
72 | with tf.variable_scope('Branch_0'):
73 | tower_conv = slim.conv2d(net, n, 3, stride=2, padding='VALID',
74 | scope='Conv2d_1a_3x3')
75 | with tf.variable_scope('Branch_1'):
76 | tower_conv1_0 = slim.conv2d(net, k, 1, scope='Conv2d_0a_1x1')
77 | tower_conv1_1 = slim.conv2d(tower_conv1_0, l, 3,
78 | scope='Conv2d_0b_3x3')
79 | tower_conv1_2 = slim.conv2d(tower_conv1_1, m, 3,
80 | stride=2, padding='VALID',
81 | scope='Conv2d_1a_3x3')
82 | with tf.variable_scope('Branch_2'):
83 | tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',
84 | scope='MaxPool_1a_3x3')
85 | net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3)
86 | return net
87 |
88 | def reduction_b(net):
89 | with tf.variable_scope('Branch_0'):
90 | tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
91 | tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2,
92 | padding='VALID', scope='Conv2d_1a_3x3')
93 | with tf.variable_scope('Branch_1'):
94 | tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
95 | tower_conv1_1 = slim.conv2d(tower_conv1, 256, 3, stride=2,
96 | padding='VALID', scope='Conv2d_1a_3x3')
97 | with tf.variable_scope('Branch_2'):
98 | tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
99 | tower_conv2_1 = slim.conv2d(tower_conv2, 256, 3,
100 | scope='Conv2d_0b_3x3')
101 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 256, 3, stride=2,
102 | padding='VALID', scope='Conv2d_1a_3x3')
103 | with tf.variable_scope('Branch_3'):
104 | tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',
105 | scope='MaxPool_1a_3x3')
106 | net = tf.concat([tower_conv_1, tower_conv1_1,
107 | tower_conv2_2, tower_pool], 3)
108 | return net
109 |
110 | def inference(images, keep_probability, phase_train=True,
111 | bottleneck_layer_size=128, weight_decay=0.0, reuse=None):
112 | batch_norm_params = {
113 | # Decay for the moving averages.
114 | 'decay': 0.995,
115 | # epsilon to prevent 0s in variance.
116 | 'epsilon': 0.001,
117 | # force in-place updates of mean and variance estimates
118 | 'updates_collections': None,
119 | # Moving averages ends up in the trainable variables collection
120 | 'variables_collections': [ tf.GraphKeys.TRAINABLE_VARIABLES ],
121 | }
122 |
123 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
124 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
125 | weights_regularizer=slim.l2_regularizer(weight_decay),
126 | normalizer_fn=slim.batch_norm,
127 | normalizer_params=batch_norm_params):
128 | return inception_resnet_v1(images, is_training=phase_train,
129 | dropout_keep_prob=keep_probability, bottleneck_layer_size=bottleneck_layer_size, reuse=reuse)
130 |
131 |
132 | def inception_resnet_v1(inputs, is_training=True,
133 | dropout_keep_prob=0.8,
134 | bottleneck_layer_size=128,
135 | reuse=None,
136 | scope='InceptionResnetV1'):
137 | """Creates the Inception Resnet V1 model.
138 | Args:
139 | inputs: a 4-D tensor of size [batch_size, height, width, 3].
140 | num_classes: number of predicted classes.
141 | is_training: whether is training or not.
142 | dropout_keep_prob: float, the fraction to keep before final layer.
143 | reuse: whether or not the network and its variables should be reused. To be
144 | able to reuse 'scope' must be given.
145 | scope: Optional variable_scope.
146 | Returns:
147 | logits: the logits outputs of the model.
148 | end_points: the set of end_points from the inception model.
149 | """
150 | end_points = {}
151 |
152 | with tf.variable_scope(scope, 'InceptionResnetV1', [inputs], reuse=reuse):
153 | with slim.arg_scope([slim.batch_norm, slim.dropout],
154 | is_training=is_training):
155 | with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
156 | stride=1, padding='SAME'):
157 |
158 | # 149 x 149 x 32
159 | net = slim.conv2d(inputs, 32, 3, stride=2, padding='VALID',
160 | scope='Conv2d_1a_3x3')
161 | end_points['Conv2d_1a_3x3'] = net
162 | # 147 x 147 x 32
163 | net = slim.conv2d(net, 32, 3, padding='VALID',
164 | scope='Conv2d_2a_3x3')
165 | end_points['Conv2d_2a_3x3'] = net
166 | # 147 x 147 x 64
167 | net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3')
168 | end_points['Conv2d_2b_3x3'] = net
169 | # 73 x 73 x 64
170 | net = slim.max_pool2d(net, 3, stride=2, padding='VALID',
171 | scope='MaxPool_3a_3x3')
172 | end_points['MaxPool_3a_3x3'] = net
173 | # 73 x 73 x 80
174 | net = slim.conv2d(net, 80, 1, padding='VALID',
175 | scope='Conv2d_3b_1x1')
176 | end_points['Conv2d_3b_1x1'] = net
177 | # 71 x 71 x 192
178 | net = slim.conv2d(net, 192, 3, padding='VALID',
179 | scope='Conv2d_4a_3x3')
180 | end_points['Conv2d_4a_3x3'] = net
181 | # 35 x 35 x 256
182 | net = slim.conv2d(net, 256, 3, stride=2, padding='VALID',
183 | scope='Conv2d_4b_3x3')
184 | end_points['Conv2d_4b_3x3'] = net
185 |
186 | # 5 x Inception-resnet-A 35 x 35 x 256
187 | net = slim.repeat(net, 5, incep_res_a, scale=0.17)
188 | end_points['Mixed_5a'] = net
189 |
190 | # Reduction-A
191 | with tf.variable_scope('Mixed_6a'):
192 | net = reduction_a(net, 192, 192, 256, 384)
193 | end_points['Mixed_6a'] = net
194 |
195 | # 10 x Inception-Resnet-B
196 | net = slim.repeat(net, 10, incep_res_b, scale=0.10)
197 | end_points['Mixed_6b'] = net
198 |
199 | # Reduction-B
200 | with tf.variable_scope('Mixed_7a'):
201 | net = reduction_b(net)
202 | end_points['Mixed_7a'] = net
203 |
204 | # 5 x Inception-Resnet-C
205 | net = slim.repeat(net, 5, incep_res_c, scale=0.20)
206 | end_points['Mixed_8a'] = net
207 |
208 | net = incep_res_c(net, activation_fn=None)
209 | end_points['Mixed_8b'] = net
210 |
211 | with tf.variable_scope('Logits'):
212 | end_points['PrePool'] = net
213 | #pylint: disable=no-member
214 | net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID',
215 | scope='AvgPool_1a_8x8')
216 | net = slim.flatten(net)
217 |
218 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
219 | scope='Dropout')
220 |
221 | end_points['PreLogitsFlatten'] = net
222 |
223 | net = slim.fully_connected(net, bottleneck_layer_size, activation_fn=None,
224 | scope='Bottleneck', reuse=False)
225 |
226 | return net, end_points
227 |
--------------------------------------------------------------------------------
/src/models/inception_resnet_no_train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Contains the definition of the Inception Resnet V1 architecture.
17 | As described in http://arxiv.org/abs/1602.07261.
18 | Inception-v4, Inception-ResNet and the Impact of Residual Connections
19 | on Learning
20 | Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi
21 | """
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | import tensorflow as tf
27 | import tensorflow.contrib.slim as slim
28 |
29 | # Inception-Renset-A
30 | def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
31 | """Builds the 35x35 resnet block."""
32 | with tf.variable_scope(scope, 'Block35', [net], reuse=reuse):
33 | with tf.variable_scope('Branch_0'):
34 | tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1')
35 | with tf.variable_scope('Branch_1'):
36 | tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
37 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3')
38 | with tf.variable_scope('Branch_2'):
39 | tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
40 | tower_conv2_1 = slim.conv2d(tower_conv2_0, 32, 3, scope='Conv2d_0b_3x3')
41 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 32, 3, scope='Conv2d_0c_3x3')
42 | mixed = tf.concat([tower_conv, tower_conv1_1, tower_conv2_2], 3)
43 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
44 | activation_fn=None, scope='Conv2d_1x1')
45 | net += scale * up
46 | if activation_fn:
47 | net = activation_fn(net)
48 | return net
49 |
50 | # Inception-Renset-B
51 | def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
52 | """Builds the 17x17 resnet block."""
53 | with tf.variable_scope(scope, 'Block17', [net], reuse=reuse):
54 | with tf.variable_scope('Branch_0'):
55 | tower_conv = slim.conv2d(net, 128, 1, scope='Conv2d_1x1')
56 | with tf.variable_scope('Branch_1'):
57 | tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1')
58 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 128, [1, 7],
59 | scope='Conv2d_0b_1x7')
60 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 128, [7, 1],
61 | scope='Conv2d_0c_7x1')
62 | mixed = tf.concat([tower_conv, tower_conv1_2], 3)
63 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
64 | activation_fn=None, scope='Conv2d_1x1')
65 | net += scale * up
66 | if activation_fn:
67 | net = activation_fn(net)
68 | return net
69 |
70 |
71 | # Inception-Resnet-C
72 | def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
73 | """Builds the 8x8 resnet block."""
74 | with tf.variable_scope(scope, 'Block8', [net], reuse=reuse):
75 | with tf.variable_scope('Branch_0'):
76 | tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
77 | with tf.variable_scope('Branch_1'):
78 | tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1')
79 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 192, [1, 3],
80 | scope='Conv2d_0b_1x3')
81 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [3, 1],
82 | scope='Conv2d_0c_3x1')
83 | mixed = tf.concat([tower_conv, tower_conv1_2], 3)
84 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
85 | activation_fn=None, scope='Conv2d_1x1')
86 | net += scale * up
87 | if activation_fn:
88 | net = activation_fn(net)
89 | return net
90 |
91 | def reduction_a(net, k, l, m, n):
92 | with tf.variable_scope('Branch_0'):
93 | tower_conv = slim.conv2d(net, n, 3, stride=2, padding='VALID',
94 | scope='Conv2d_1a_3x3')
95 | with tf.variable_scope('Branch_1'):
96 | tower_conv1_0 = slim.conv2d(net, k, 1, scope='Conv2d_0a_1x1')
97 | tower_conv1_1 = slim.conv2d(tower_conv1_0, l, 3,
98 | scope='Conv2d_0b_3x3')
99 | tower_conv1_2 = slim.conv2d(tower_conv1_1, m, 3,
100 | stride=2, padding='VALID',
101 | scope='Conv2d_1a_3x3')
102 | with tf.variable_scope('Branch_2'):
103 | tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',
104 | scope='MaxPool_1a_3x3')
105 | net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3)
106 | return net
107 |
108 | def reduction_b(net):
109 | with tf.variable_scope('Branch_0'):
110 | tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
111 | tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2,
112 | padding='VALID', scope='Conv2d_1a_3x3')
113 | with tf.variable_scope('Branch_1'):
114 | tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
115 | tower_conv1_1 = slim.conv2d(tower_conv1, 256, 3, stride=2,
116 | padding='VALID', scope='Conv2d_1a_3x3')
117 | with tf.variable_scope('Branch_2'):
118 | tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
119 | tower_conv2_1 = slim.conv2d(tower_conv2, 256, 3,
120 | scope='Conv2d_0b_3x3')
121 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 256, 3, stride=2,
122 | padding='VALID', scope='Conv2d_1a_3x3')
123 | with tf.variable_scope('Branch_3'):
124 | tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',
125 | scope='MaxPool_1a_3x3')
126 | net = tf.concat([tower_conv_1, tower_conv1_1,
127 | tower_conv2_2, tower_pool], 3)
128 | return net
129 |
130 | def inference(images, keep_probability, phase_train=True,
131 | bottleneck_layer_size=128, weight_decay=0.0, reuse=None):
132 | batch_norm_params = {
133 | # Decay for the moving averages.
134 | 'decay': 0.995,
135 | # epsilon to prevent 0s in variance.
136 | 'epsilon': 0.001,
137 | # force in-place updates of mean and variance estimates
138 | 'updates_collections': None,
139 | # Moving averages ends up in the trainable variables collection
140 | 'variables_collections': [ tf.GraphKeys.TRAINABLE_VARIABLES ],
141 | }
142 |
143 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
144 | weights_initializer=tf.contrib.layers.xavier_initializer(),
145 | # tf.truncated_normal_initializer(stddev=0.01),
146 | weights_regularizer=slim.l2_regularizer(weight_decay),
147 | normalizer_fn=slim.batch_norm,
148 | normalizer_params=batch_norm_params):
149 | return inception_resnet_v1(images, is_training=phase_train,
150 | dropout_keep_prob=keep_probability, bottleneck_layer_size=bottleneck_layer_size, reuse=reuse)
151 |
152 |
153 | def inception_resnet_v1(inputs, is_training=True,
154 | dropout_keep_prob=0.8,
155 | bottleneck_layer_size=128,
156 | reuse=None,
157 | scope='InceptionResnetV1'):
158 | """Creates the Inception Resnet V1 model.
159 | Args:
160 | inputs: a 4-D tensor of size [batch_size, height, width, 3].
161 | num_classes: number of predicted classes.
162 | is_training: whether is training or not.
163 | dropout_keep_prob: float, the fraction to keep before final layer.
164 | reuse: whether or not the network and its variables should be reused. To be
165 | able to reuse 'scope' must be given.
166 | scope: Optional variable_scope.
167 | Returns:
168 | logits: the logits outputs of the model.
169 | end_points: the set of end_points from the inception model.
170 | """
171 | end_points = {}
172 |
173 | with tf.variable_scope(scope, 'InceptionResnetV1', [inputs], reuse=reuse):
174 | with slim.arg_scope([slim.batch_norm, slim.dropout],
175 | is_training=is_training):
176 | with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
177 | stride=1, padding='SAME'):
178 | with slim.arg_scope([slim.conv2d,slim.fully_connected],trainable=False):
179 |
180 | # 149 x 149 x 32
181 | net = slim.conv2d(inputs, 32, 3, stride=2, padding='VALID',
182 | scope='Conv2d_1a_3x3')
183 | end_points['Conv2d_1a_3x3'] = net
184 | # 147 x 147 x 32
185 | net = slim.conv2d(net, 32, 3, padding='VALID',
186 | scope='Conv2d_2a_3x3')
187 | end_points['Conv2d_2a_3x3'] = net
188 | # 147 x 147 x 64
189 | net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3')
190 | end_points['Conv2d_2b_3x3'] = net
191 | # 73 x 73 x 64
192 | net = slim.max_pool2d(net, 3, stride=2, padding='VALID',
193 | scope='MaxPool_3a_3x3')
194 | end_points['MaxPool_3a_3x3'] = net
195 | # 73 x 73 x 80
196 | net = slim.conv2d(net, 80, 1, padding='VALID',
197 | scope='Conv2d_3b_1x1')
198 | end_points['Conv2d_3b_1x1'] = net
199 | # 71 x 71 x 192
200 | net = slim.conv2d(net, 192, 3, padding='VALID',
201 | scope='Conv2d_4a_3x3')
202 | end_points['Conv2d_4a_3x3'] = net
203 | # 35 x 35 x 256
204 | net = slim.conv2d(net, 256, 3, stride=2, padding='VALID',
205 | scope='Conv2d_4b_3x3')
206 | end_points['Conv2d_4b_3x3'] = net
207 |
208 | # 5 x Inception-resnet-A 35 x 35 x 256
209 | net = slim.repeat(net, 5, block35, scale=0.17)
210 | end_points['Mixed_5a'] = net
211 |
212 | # Reduction-A
213 | with tf.variable_scope('Mixed_6a'):
214 | net = reduction_a(net, 192, 192, 256, 384)
215 | end_points['Mixed_6a'] = net
216 |
217 | # 10 x Inception-Resnet-B
218 | net = slim.repeat(net, 10, block17, scale=0.10)
219 | end_points['Mixed_6b'] = net
220 |
221 | # Reduction-B
222 | with tf.variable_scope('Mixed_7a'):
223 | net = reduction_b(net)
224 | end_points['Mixed_7a'] = net
225 |
226 | # 5 x Inception-Resnet-C
227 | net = slim.repeat(net, 5, block8, scale=0.20)
228 | end_points['Mixed_8a'] = net
229 |
230 | net = block8(net, activation_fn=None)
231 | end_points['Mixed_8b'] = net
232 |
233 | with tf.variable_scope('Logits'):
234 | end_points['PrePool'] = net
235 | #pylint: disable=no-member
236 | net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID',
237 | scope='AvgPool_1a_8x8')
238 | net = slim.flatten(net)
239 |
240 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
241 | scope='Dropout')
242 |
243 | end_points['PreLogitsFlatten'] = net
244 |
245 | net = slim.fully_connected(net, bottleneck_layer_size, activation_fn=None,
246 | scope='Bottleneck', reuse=False)
247 |
248 | return net, end_points
249 |
--------------------------------------------------------------------------------
/src/models/small_inceptin_resnet.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import tensorflow as tf
7 | import tensorflow.contrib.slim as slim
8 |
9 | # Inception-Renset-A
10 | def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
11 | """Builds the 35x35 resnet block."""
12 | with tf.variable_scope(scope, 'Block35', [net], reuse=reuse):
13 | with tf.variable_scope('Branch_0'):
14 | tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1')
15 | with tf.variable_scope('Branch_1'):
16 | tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
17 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3')
18 | with tf.variable_scope('Branch_2'):
19 | tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
20 | tower_conv2_1 = slim.conv2d(tower_conv2_0, 32, 3, scope='Conv2d_0b_3x3')
21 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 32, 3, scope='Conv2d_0c_3x3')
22 | mixed = tf.concat([tower_conv, tower_conv1_1, tower_conv2_2], 3)
23 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
24 | activation_fn=None, scope='Conv2d_1x1')
25 | net += scale * up
26 | if activation_fn:
27 | net = activation_fn(net)
28 | return net
29 |
30 | # Inception-Renset-B
31 | def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
32 | """Builds the 17x17 resnet block."""
33 | with tf.variable_scope(scope, 'Block17', [net], reuse=reuse):
34 | with tf.variable_scope('Branch_0'):
35 | tower_conv = slim.conv2d(net, 128, 1, scope='Conv2d_1x1')
36 | with tf.variable_scope('Branch_1'):
37 | tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1')
38 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 128, [1, 7],
39 | scope='Conv2d_0b_1x7')
40 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 128, [7, 1],
41 | scope='Conv2d_0c_7x1')
42 | mixed = tf.concat([tower_conv, tower_conv1_2], 3)
43 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
44 | activation_fn=None, scope='Conv2d_1x1')
45 | net += scale * up
46 | if activation_fn:
47 | net = activation_fn(net)
48 | return net
49 |
50 |
51 | # Inception-Resnet-C
52 | def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
53 | """Builds the 8x8 resnet block."""
54 | with tf.variable_scope(scope, 'Block8', [net], reuse=reuse):
55 | with tf.variable_scope('Branch_0'):
56 | tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
57 | with tf.variable_scope('Branch_1'):
58 | tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1')
59 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 192, [1, 3],
60 | scope='Conv2d_0b_1x3')
61 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [3, 1],
62 | scope='Conv2d_0c_3x1')
63 | mixed = tf.concat([tower_conv, tower_conv1_2], 3)
64 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
65 | activation_fn=None, scope='Conv2d_1x1')
66 | net += scale * up
67 | if activation_fn:
68 | net = activation_fn(net)
69 | return net
70 |
71 | def reduction_a(net, k, l, m, n):
72 | with tf.variable_scope('Branch_0'):
73 | tower_conv = slim.conv2d(net, n, 3, stride=2, padding='VALID',
74 | scope='Conv2d_1a_3x3')
75 | with tf.variable_scope('Branch_1'):
76 | tower_conv1_0 = slim.conv2d(net, k, 1, scope='Conv2d_0a_1x1')
77 | tower_conv1_1 = slim.conv2d(tower_conv1_0, l, 3,
78 | scope='Conv2d_0b_3x3')
79 | tower_conv1_2 = slim.conv2d(tower_conv1_1, m, 3,
80 | stride=2, padding='VALID',
81 | scope='Conv2d_1a_3x3')
82 | with tf.variable_scope('Branch_2'):
83 | tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',
84 | scope='MaxPool_1a_3x3')
85 | net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3)
86 | return net
87 |
88 | def reduction_b(net):
89 | with tf.variable_scope('Branch_0'):
90 | tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
91 | tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=1,
92 | padding='VALID', scope='Conv2d_1a_3x3')
93 | with tf.variable_scope('Branch_1'):
94 | tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
95 | tower_conv1_1 = slim.conv2d(tower_conv1, 256, 3, stride=1,
96 | padding='VALID', scope='Conv2d_1a_3x3')
97 | with tf.variable_scope('Branch_2'):
98 | tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
99 | tower_conv2_1 = slim.conv2d(tower_conv2, 256, 3,
100 | scope='Conv2d_0b_3x3')
101 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 256, 3, stride=1,
102 | padding='VALID', scope='Conv2d_1a_3x3')
103 | with tf.variable_scope('Branch_3'):
104 | tower_pool = slim.max_pool2d(net, 3, stride=1, padding='VALID',
105 | scope='MaxPool_1a_3x3')
106 | net = tf.concat([tower_conv_1, tower_conv1_1,
107 | tower_conv2_2, tower_pool], 3)
108 | return net
109 |
110 | def inference(images, keep_probability, phase_train=True,
111 | bottleneck_layer_size=128, weight_decay=0.0, reuse=None):
112 | batch_norm_params = {
113 | # Decay for the moving averages.
114 | 'decay': 0.995,
115 | # epsilon to prevent 0s in variance.
116 | 'epsilon': 0.001,
117 | # force in-place updates of mean and variance estimates
118 | 'updates_collections': None,
119 | # Moving averages ends up in the trainable variables collection
120 | 'variables_collections': [ tf.GraphKeys.TRAINABLE_VARIABLES ],
121 | }
122 |
123 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
124 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
125 | weights_regularizer=slim.l2_regularizer(weight_decay),
126 | normalizer_fn=slim.batch_norm,
127 | normalizer_params=batch_norm_params):
128 | return inception_resnet_v1(images, is_training=phase_train,
129 | dropout_keep_prob=keep_probability, bottleneck_layer_size=bottleneck_layer_size, reuse=reuse)
130 |
131 |
132 | def inception_resnet_v1(inputs, is_training=True,
133 | dropout_keep_prob=0.8,
134 | bottleneck_layer_size=128,
135 | reuse=None,
136 | scope='InceptionResnetV1'):
137 | """Creates model
138 | Args:
139 | inputs: a 4-D tensor of size [batch_size, 32, 32, 3].
140 | num_classes: number of predicted classes.
141 | is_training: whether is training or not.
142 | dropout_keep_prob: float, the fraction to keep before final layer.
143 | reuse: whether or not the network and its variables should be reused.
144 | scope: Optional variable_scope.
145 | Returns:
146 | logits: the logits outputs of the model.
147 | end_points: the set of end_points from the inception model.
148 | """
149 | end_points = {}
150 |
151 | with tf.variable_scope(scope, 'InceptionResnetV1', [inputs], reuse=reuse):
152 | with slim.arg_scope([slim.batch_norm, slim.dropout],
153 | is_training=is_training):
154 | with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
155 | stride=1, padding='SAME'):
156 |
157 |
158 | #31 x 31 x 32
159 | net = slim.conv2d(inputs,32,3,stride=1,padding='VALID',scope='conv_1_3x3')
160 |
161 | #15 * 15 * 64
162 | net = slim.conv2d(net,64,3,stride=2,padding='VALID',scope='conv_2_3x3')
163 |
164 | #7 * 7 * 96
165 | net = slim.conv2d(net,96,3,stride=2,padding='VALID',scope='conv_3_3x3')
166 |
167 | #7 * 7 * 96
168 | net = slim.repeat(net, 4, block35, scale=0.17)
169 |
170 | #4 * 4 * 224
171 | with tf.variable_scope('Mixed_6a'):
172 | net = reduction_a(net, 32, 32, 64, 96)
173 |
174 | net = slim.repeat(net, 4, block8, scale=0.20)
175 |
176 | with tf.variable_scope('Mixed_7a'):
177 | net = reduction_b(net)
178 |
179 |
180 | with tf.variable_scope('Logits'):
181 |
182 | #pylint: disable=no-member
183 | net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID',
184 | scope='AvgPool_1a_8x8')
185 | net = slim.flatten(net)
186 |
187 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
188 | scope='Dropout')
189 |
190 | end_points['PreLogitsFlatten'] = net
191 |
192 | net = slim.fully_connected(net, bottleneck_layer_size, activation_fn=None,
193 | scope='Bottleneck', reuse=False)
194 |
195 | return net, end_points
196 |
197 |
--------------------------------------------------------------------------------
/src/models/squeezenet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 | import tensorflow.contrib.slim as slim
7 |
8 | def fire_module(inputs,
9 | squeeze_depth,
10 | expand_depth,
11 | reuse=None,
12 | scope=None,
13 | outputs_collections=None):
14 | with tf.variable_scope(scope, 'fire', [inputs], reuse=reuse):
15 | with slim.arg_scope([slim.conv2d, slim.max_pool2d],
16 | outputs_collections=None):
17 | net = squeeze(inputs, squeeze_depth)
18 | outputs = expand(net, expand_depth)
19 | return outputs
20 |
21 | def squeeze(inputs, num_outputs):
22 | return slim.conv2d(inputs, num_outputs, [1, 1], stride=1, scope='squeeze')
23 |
24 | def expand(inputs, num_outputs):
25 | with tf.variable_scope('expand'):
26 | e1x1 = slim.conv2d(inputs, num_outputs, [1, 1], stride=1, scope='1x1')
27 | e3x3 = slim.conv2d(inputs, num_outputs, [3, 3], scope='3x3')
28 | return tf.concat([e1x1, e3x3], 3)
29 |
30 | def inference(images, keep_probability, phase_train=True, bottleneck_layer_size=128, weight_decay=0.0, reuse=None):
31 | batch_norm_params = {
32 | # Decay for the moving averages.
33 | 'decay': 0.995,
34 | # epsilon to prevent 0s in variance.
35 | 'epsilon': 0.001,
36 | # force in-place updates of mean and variance estimates
37 | 'updates_collections': None,
38 | # Moving averages ends up in the trainable variables collection
39 | 'variables_collections': [ tf.GraphKeys.TRAINABLE_VARIABLES ],
40 | }
41 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
42 | weights_initializer=slim.xavier_initializer_conv2d(uniform=True),
43 | weights_regularizer=slim.l2_regularizer(weight_decay),
44 | normalizer_fn=slim.batch_norm,
45 | normalizer_params=batch_norm_params):
46 | with tf.variable_scope('squeezenet', [images], reuse=reuse):
47 | with slim.arg_scope([slim.batch_norm, slim.dropout],
48 | is_training=phase_train):
49 | net = slim.conv2d(images, 96, [7, 7], stride=2, scope='conv1')
50 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='maxpool1')
51 | net = fire_module(net, 16, 64, scope='fire2')
52 | net = fire_module(net, 16, 64, scope='fire3')
53 | net = fire_module(net, 32, 128, scope='fire4')
54 | net = slim.max_pool2d(net, [2, 2], stride=2, scope='maxpool4')
55 | net = fire_module(net, 32, 128, scope='fire5')
56 | net = fire_module(net, 48, 192, scope='fire6')
57 | net = fire_module(net, 48, 192, scope='fire7')
58 | net = fire_module(net, 64, 256, scope='fire8')
59 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='maxpool8')
60 | net = fire_module(net, 64, 256, scope='fire9')
61 | net = slim.dropout(net, keep_probability)
62 | net = slim.conv2d(net, 1000, [1, 1], activation_fn=None, normalizer_fn=None, scope='conv10')
63 | net = slim.avg_pool2d(net, net.get_shape()[1:3], scope='avgpool10')
64 | net = tf.squeeze(net, [1, 2], name='logits')
65 | net = slim.fully_connected(net, bottleneck_layer_size, activation_fn=None,
66 | scope='Bottleneck', reuse=False)
67 | return net, None
68 |
--------------------------------------------------------------------------------
/src/printPretrainedTensor.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tensorflow.python import pywrap_tensorflow
3 | checkpoint_path = 'modeltrained/20170512/model-20170512-110547.ckpt-250000'
4 |
5 | reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
6 | var_to_shape_map = reader.get_variable_to_shape_map()
7 | print(len(var_to_shape_map))
8 | # for key in var_to_shape_map:
9 | # print("tensor_name: ", key)
10 | # print(reader.get_tensor(key).shape)
11 |
--------------------------------------------------------------------------------
/src/train_demo.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | import os.path
7 | import sys
8 | import time
9 | from datetime import datetime
10 |
11 | import numpy as np
12 | from tensorflow.python.framework import ops
13 | from tensorflow.python.ops import array_ops
14 | from tensorflow.python.ops import data_flow_ops
15 |
16 | from loss.sphere import *
17 | from models import inception_resnet_v1 as network
18 | from utils import faceUtil
19 |
20 |
21 | def main(args):
22 | print("main start")
23 | np.random.seed(seed=args.seed)
24 | #train_set = ImageClass list
25 | train_set = faceUtil.get_dataset(args.data_dir)
26 |
27 | #总类别
28 | nrof_classes = len(train_set)
29 | print(nrof_classes)
30 |
31 | #subdir =20171122-112109
32 | subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
33 |
34 | #log_dir = c:\User\logs\facenet\20171122-
35 | log_dir = os.path.join(os.path.expanduser(args.logs_base_dir),subdir)
36 | if not os.path.isdir(log_dir):
37 | os.makedirs(log_dir)
38 | print("log_dir =",log_dir)
39 |
40 | # model_dir =c:\User/models/facenet/2017;;;
41 | model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
42 | if not os.path.isdir(model_dir): # Create the model directory if it doesn't exist
43 | os.makedirs(model_dir)
44 |
45 | print("model_dir =", model_dir)
46 | pretrained_model = None
47 | if args.pretrained_model:
48 | # pretrained_model = os.path.expanduser(args.pretrained_model)
49 | # pretrained_model = tf.train.get_checkpoint_state(args.pretrained_model)
50 | pretrained_model = args.pretrained_model
51 | print('Pre-trained model: %s' % pretrained_model)
52 |
53 |
54 | # Write arguments to a text file
55 | faceUtil.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt'))
56 | print("write_arguments_to_file")
57 | with tf.Graph().as_default():
58 | tf.set_random_seed(args.seed)
59 | global_step = tf.Variable(0,trainable=False)
60 |
61 | #两个列表 image_list= 图片地址列表, label_list = 对应label列表,两个大小相同
62 | image_list, label_list = faceUtil.get_image_paths_and_labels(train_set)
63 | assert len(image_list) > 0 , 'dataset is empty'
64 | print("len(image_list) = ",len(image_list))
65 |
66 | # Create a queue that produces indices into the image_list and label_list
67 | labels = ops.convert_to_tensor(label_list,dtype=tf.int64)
68 | range_size = array_ops.shape(labels)[0]
69 | range_size = tf.Print(range_size, [tf.shape(range_size)],message='Shape of range_input_producer range_size : ',summarize=4, first_n=1)
70 |
71 | #产生一个队列,队列包含0到range_size-1的元素,打乱
72 | index_queue = tf.train.range_input_producer(range_size,num_epochs=None,shuffle=True,seed=None,capacity=32)
73 |
74 | #从index_queue中取出 args.batch_size*args.epoch_size 个元素,用来从image_list, label_list中取出一部分feed给网络
75 | index_dequeue_op = index_queue.dequeue_many(args.batch_size * args.epoch_size,'index_dequeue')
76 |
77 | #学习率
78 | learning_rate_placeholder = tf.placeholder(tf.float32,name='learning_rate')
79 | #批大小 arg.batch_size
80 | batch_size_placeholder = tf.placeholder(tf.int32,name='batch_size')
81 | #是否训练中
82 | phase_train_placeholder = tf.placeholder(tf.bool,name='phase_train')
83 | #图像路径 大小 arg.batch_size * arg.epoch_size
84 | image_paths_placeholder = tf.placeholder(tf.string,shape=[None,1],name='image_paths')
85 | #图像标签 大小:arg.batch_size * arg.epoch_size
86 | labels_placeholder = tf.placeholder(tf.int64,shape=[None,1],name='labels')
87 |
88 | #新建一个队列,数据流操作,fifo,先入先出
89 | input_queue = data_flow_ops.FIFOQueue(capacity=100000,dtypes=[tf.string,tf.int64],shapes=[(1,),(1,)],shared_name=None,name=None)
90 |
91 | # enqueue_many返回的是一个操作 ,入站的数量是 len(image_paths_placeholder) = 从index_queue中取出 args.batch_size*args.epoch_size个元素
92 | enqueue_op = input_queue.enqueue_many([image_paths_placeholder,labels_placeholder],name='enqueue_op')
93 |
94 | nrof_preprocess_threads = 4
95 | images_and_labels = []
96 |
97 | for _ in range(nrof_preprocess_threads):
98 | filenames , label = input_queue.dequeue()
99 | # label = tf.Print(label,[tf.shape(label)],message='Shape of one thread input_queue.dequeue label : ',
100 | # summarize=4,first_n=1)
101 | # filenames = tf.Print(filenames, [tf.shape(filenames)], message='Shape of one thread input_queue.dequeue filenames : ',
102 | # summarize=4, first_n=1)
103 | print("one thread input_queue.dequeue len = ",tf.shape(label))
104 | images =[]
105 | for filenames in tf.unstack(filenames):
106 | file_contents = tf.read_file(filenames)
107 | image = tf.image.decode_image(file_contents,channels=3)
108 |
109 | if args.random_rotate:
110 | image = tf.py_func(faceUtil.random_rotate_image, [image], tf.uint8)
111 |
112 | if args.random_crop:
113 | image = tf.random_crop(image,[args.image_size,args.image_size,3])
114 |
115 | else:
116 | image = tf.image.resize_image_with_crop_or_pad(image,args.image_size,args.image_size)
117 |
118 | if args.random_flip:
119 | image = tf.image.random_flip_left_right(image)
120 |
121 | image.set_shape((args.image_size,args.image_size,3))
122 | images.append(tf.image.per_image_standardization(image))
123 |
124 | #从队列中取出名字 解析为image 然后加进images_and_labels 可能长度 = 4 *
125 | images_and_labels.append([images,label])
126 |
127 | #最终一次进入网络的数据: 长应该度 = batch_size_placeholder
128 | image_batch, label_batch = tf.train.batch_join(images_and_labels,batch_size=batch_size_placeholder,
129 | shapes=[(args.image_size,args.image_size,3),()],
130 | enqueue_many = True,
131 | capacity = 4 * nrof_preprocess_threads * args.batch_size,
132 | allow_smaller_final_batch=True)
133 | print('final input net image_batch len = ',tf.shape(image_batch))
134 |
135 | image_batch = tf.Print(image_batch, [tf.shape(image_batch)], message='final input net image_batch shape = ',
136 | summarize=4, first_n=1)
137 | image_batch = tf.identity(image_batch, 'image_batch')
138 | image_batch = tf.identity(image_batch, 'input')
139 | label_batch = tf.identity(label_batch, 'label_batch')
140 |
141 | print('Total number of classes: %d' % nrof_classes)
142 | print('Total number of examples: %d' % len(image_list))
143 |
144 | print('Building training graph')
145 |
146 | # 将指数衰减应用到学习率上
147 | learning_rate = tf.train.exponential_decay(learning_rate= learning_rate_placeholder,
148 | global_step = global_step,
149 | decay_steps=args.learning_rate_decay_epochs * args.epoch_size,
150 | decay_rate=args.learning_rate_decay_factor,
151 | staircase = True)
152 | #decay_steps=args.learning_rate_decay_epochs * args.epoch_size,
153 |
154 | tf.summary.scalar('learning_rate', learning_rate)
155 |
156 | # Build the inference graph
157 | prelogits, _ = network.inference(image_batch,args.keep_probability,phase_train=phase_train_placeholder,
158 | bottleneck_layer_size=args.embedding_size,weight_decay=args.weight_decay)
159 | print("prelogits.shape = ",prelogits.get_shape().as_list())
160 |
161 | logits,sphere_loss = sphereloss(prelogits,label_batch,len(train_set),batch_size=args.batch_size)
162 |
163 | tf.add_to_collection('losses',sphere_loss)
164 | regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
165 | total_loss = tf.add_n([sphere_loss] + regularization_losses,name='total_loss')
166 |
167 | train_op = faceUtil.train(total_loss, global_step, args.optimizer, learning_rate,
168 | args.moving_average_decay, tf.global_variables(), args.log_histograms)
169 | # print("global_variables len = {}".format(len(tf.global_variables())))
170 | # print("local_variables len = {}".format(len(tf.local_variables())))
171 | # print("trainable_variables len = {}".format(len(tf.trainable_variables())))
172 | # for v in tf.trainable_variables() :
173 | # print("trainable_variables :{}".format(v.name))
174 | # train_op = faceUtil.train(sphere_loss,global_step,args.optimizer,learning_rate,
175 | # args.moving_average_decay, tf.global_variables(), args.log_histograms)
176 |
177 | #创建saver
178 | variables = tf.trainable_variables()
179 | print("variables_trainable len = ", len(variables))
180 | for v in variables:
181 | print('variables_trainable : {}'.format(v.name))
182 | saver = tf.train.Saver(var_list=variables, max_to_keep=2)
183 |
184 | # variables_to_restore = [v for v in variables if v.name.split('/')[0] != 'Logits']
185 | # print("variables_trainable len = ",len(variables))
186 | # print("variables_to_restore len = ",len(variables_to_restore))
187 | # # for v in variables_to_restore :
188 | # # print("variables_to_restore : ",v.name)
189 | # saver = tf.train.Saver(var_list=variables_to_restore,max_to_keep=3)
190 |
191 |
192 | # variables_trainable = tf.trainable_variables()
193 | # print("variables_trainable len = ",len(variables_trainable))
194 | # # for v in variables_trainable :
195 | # # print('variables_trainable : {}'.format(v.name))
196 | # variables_to_restore = slim.get_variables_to_restore(include=['InceptionResnetV1'])
197 | # print("variables_to_restore len = ",len(variables_to_restore))
198 | # saver = tf.train.Saver(var_list=variables_to_restore,max_to_keep=3)
199 |
200 |
201 |
202 | # Build the summary operation based on the TF collection of Summaries.
203 | summary_op = tf.summary.merge_all()
204 |
205 | # 能够在gpu上分配的最大内存
206 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_memory_fraction)
207 | sess = tf.Session(config=tf.ConfigProto(gpu_options = gpu_options,log_device_placement = False))
208 |
209 | # Initialize variables
210 | sess.run(tf.global_variables_initializer())
211 | sess.run(tf.local_variables_initializer())
212 | summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
213 |
214 | # 获取线程坐标
215 | coord = tf.train.Coordinator()
216 |
217 | # 将队列中的所有Runner开始执行
218 | tf.train.start_queue_runners(coord=coord,sess=sess)
219 |
220 | with sess.as_default():
221 | print('Running training')
222 | if pretrained_model :
223 | print('Restoring pretrained model_checkpoint_path: %s' % pretrained_model)
224 | saver.restore(sess,pretrained_model)
225 |
226 | # Training and validation loop
227 | print('Running training really')
228 | epoch = 0
229 | # 将所有数据过一遍的次数
230 | while epoch < args.max_nrof_epochs:
231 |
232 | #这里是返回当前的global_step值吗,step可以看做是全局的批处理个数
233 | step = sess.run(global_step,feed_dict=None)
234 |
235 | #epoch_size是一个epoch中批的个数
236 | # 这个epoch是全局的批处理个数除以一个epoch中批的个数得到epoch,这个epoch将用于求学习率
237 | epoch = step // args.epoch_size
238 | # Train for one epoch
239 | train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder,
240 | learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, global_step,
241 | total_loss, train_op, summary_op, summary_writer, regularization_losses, args.learning_rate_schedule_file)
242 |
243 | # Save variables and the metagraph if it doesn't exist already
244 | save_variables_and_metagraph(sess, saver, summary_writer, model_dir, subdir, step)
245 |
246 | return model_dir
247 |
248 |
249 |
250 | def train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder,
251 | learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, global_step,
252 | loss, train_op, summary_op, summary_writer, regularization_losses, learning_rate_schedule_file):
253 |
254 | batch_number = 0
255 | if args.learning_rate>0.0:
256 |
257 | lr = args.learning_rate
258 | else:
259 | lr = faceUtil.get_learning_rate_from_file(learning_rate_schedule_file, epoch)
260 |
261 | index_epoch = sess.run(index_dequeue_op)
262 | label_epoch = np.array(label_list)[index_epoch]
263 | image_epoch = np.array(image_list)[index_epoch]
264 |
265 | # Enqueue one epoch of image paths and labels
266 | labels_array = np.expand_dims(np.array(label_epoch),1)
267 | image_paths_array = np.expand_dims(np.array(image_epoch),1)
268 | sess.run(enqueue_op,{image_paths_placeholder:image_paths_array,labels_placeholder:labels_array})
269 |
270 | # Training loop
271 | train_time = 0
272 | while batch_number < args.epoch_size:
273 | start_time = time.time()
274 | feed_dict = {learning_rate_placeholder: lr, phase_train_placeholder:True, batch_size_placeholder:args.batch_size}
275 |
276 | if (batch_number % 100 == 0) :
277 | err, _, step, reg_loss, summary_str = sess.run([loss,train_op,global_step,regularization_losses,summary_op],feed_dict=feed_dict)
278 | summary_writer.add_summary(summary_str, global_step=step)
279 | else :
280 | err, _, step, reg_loss = sess.run([loss, train_op, global_step, regularization_losses], feed_dict=feed_dict)
281 |
282 | duration = time.time() - start_time
283 | print('global_step[%d],Epoch: [%d][%d/%d]\tTime %.3f\tLoss %2.3f\tRegLoss %2.3f' %
284 | (step,epoch, batch_number+1, args.epoch_size, duration, err, np.sum(reg_loss)))
285 | batch_number += 1
286 | train_time += duration
287 |
288 | # Add validation loss and accuracy to summary
289 | summary = tf.Summary()
290 |
291 | #pylint: disable=maybe-no-member
292 | summary.value.add(tag='time/total', simple_value=train_time)
293 | summary_writer.add_summary(summary, step)
294 | return step
295 |
296 |
297 |
298 | def save_variables_and_metagraph(sess, saver, summary_writer, model_dir, model_name, step):
299 |
300 | # Save the model checkpoint
301 | print('Saving variables')
302 | start_time = time.time()
303 | checkpoint_path = os.path.join(model_dir,'model-%s.ckpt' % model_name)
304 | saver.save(sess,checkpoint_path,global_step=step,write_meta_graph=False)
305 | save_time_variables = time.time() - start_time
306 | print('Variables saved in %.2f seconds' % save_time_variables)
307 | metagraph_filename = os.path.join(model_dir,'model-%s.meta' % model_name)
308 | save_time_metagraph = 0
309 | if not os.path.exists(metagraph_filename):
310 | print('Saving metagraph')
311 | start_time = time.time()
312 | saver.export_meta_graph(metagraph_filename)
313 | save_time_metagraph = time.time() - start_time
314 | print('Metagraph saved in %.2f seconds' % save_time_metagraph)
315 |
316 | summary = tf.Summary()
317 | #pylint: disable=maybe-no-member
318 | summary.value.add(tag='time/save_variables', simple_value=save_time_variables)
319 | summary.value.add(tag='time/save_metagraph', simple_value=save_time_metagraph)
320 | summary_writer.add_summary(summary, step)
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 | def parse_arguments(argv):
343 | parser = argparse.ArgumentParser()
344 | print("parser = argparse.ArgumentParser()")
345 | parser.add_argument('--data_dir',type=str,default='align/casia_maxpy_mtcnnpy_182')
346 | parser.add_argument('--gpu_memory_fraction',type=float,default=0.8)
347 | parser.add_argument('--pretrained_model', type=str,default = '/home/huyu/models/facenet/20180211-122649/model-20180211-122649.ckpt-9000',
348 | help='Load a pretrained model before training starts.')
349 | #default = '/home/huyu/models/facenet/20180209-114624/model-20180209-114624.ckpt-0',
350 | #default='/home/huyu/models/facenet/20180207-203905/model-20180207-203905.ckpt-21000',
351 | # default='modeltrained/20170512/model-20170512-110547.ckpt-250000',
352 |
353 | parser.add_argument('--max_nrof_epochs', type=int, default=100)
354 | parser.add_argument('--batch_size', type=int, default=64)
355 | parser.add_argument('--image_size',type=int,default=160)
356 | parser.add_argument('--epoch_size', type=int, default=300)
357 | parser.add_argument('--embedding_size', type=int, default=128)
358 | parser.add_argument('--random_crop',
359 | help='Performs random cropping of training images. If false, the center image_size pixels from the training images are used. ' +
360 | 'If the size of the images in the data directory is equal to image_size no cropping is performed', action='store_true')
361 | parser.add_argument('--random_flip',
362 | help='Performs random horizontal flipping of training images.', action='store_true')
363 | parser.add_argument('--random_rotate',
364 | help='Performs random rotations of training images.', action='store_true')
365 | parser.add_argument('--keep_probability', type=float,
366 | help='Keep probability of dropout for the fully connected layer(s).', default=1)
367 | parser.add_argument('--weight_decay', type=float,
368 | help='L2 weight regularization.', default=0.0)
369 | parser.add_argument('--learning_rate', type=float,
370 | help='Initial learning rate. If set to a negative value a learning rate ' +
371 | 'schedule can be specified in the file "learning_rate_schedule.txt"', default=0.0001)
372 | parser.add_argument('--optimizer', type=str, choices=['ADAGRAD', 'ADADELTA', 'ADAM', 'RMSPROP', 'MOM'],
373 | help='The optimization algorithm to use', default='ADAM')
374 | parser.add_argument('--learning_rate_decay_epochs', type=int,
375 | help='Number of epochs between learning rate decay.', default=100)
376 | parser.add_argument('--learning_rate_decay_factor', type=float,
377 | help='Learning rate decay factor.', default=1)
378 | parser.add_argument('--moving_average_decay', type=float,
379 | help='Exponential decay for tracking of training parameters.', default=0.9999)
380 | parser.add_argument('--seed', type=int,
381 | help='Random seed.', default=666)
382 | parser.add_argument('--nrof_preprocess_threads', type=int,
383 | help='Number of preprocessing (data loading and augmentation) threads.', default=4)
384 | parser.add_argument('--log_histograms',
385 | help='Enables logging of weight/bias histograms in tensorboard.', action='store_true')
386 | parser.add_argument('--learning_rate_schedule_file', type=str,
387 | help='File containing the learning rate schedule that is used when learning_rate is set to to -1.', default='data/learning_rate_classifier_casia.txt')
388 | parser.add_argument('--filter_filename', type=str,
389 | help='File containing image data used for dataset filtering', default='')
390 | parser.add_argument('--filter_percentile', type=float,
391 | help='Keep only the percentile images closed to its class center', default=100.0)
392 | parser.add_argument('--filter_min_nrof_images_per_class', type=int,
393 | help='Keep only the classes with this number of examples or more', default=0)
394 | parser.add_argument('--logs_base_dir', type=str,
395 | help='Directory where to write event logs.', default='~/logs/facenet')
396 | parser.add_argument('--models_base_dir', type=str,
397 | help='Directory where to write trained models and checkpoints.', default='~/models/facenet')
398 |
399 |
400 | return parser.parse_args(argv)
401 |
402 |
403 | if __name__ == '__main__':
404 | main(parse_arguments(sys.argv[1:]))
405 |
--------------------------------------------------------------------------------
/src/train_softmax_demo.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | import os.path
7 | import sys
8 | import time
9 | from datetime import datetime
10 |
11 | import numpy as np
12 | from tensorflow.python.framework import ops
13 | from tensorflow.python.ops import array_ops
14 | from tensorflow.python.ops import data_flow_ops
15 | import tensorflow.contrib.slim as slim
16 |
17 | from loss.sphere import *
18 | from models import inception_resnet_v1 as network
19 | from utils import faceUtil
20 |
21 |
22 | def main(args):
23 | print("main start")
24 | np.random.seed(seed=args.seed)
25 | #train_set = ImageClass list
26 | train_set = faceUtil.get_dataset(args.data_dir)
27 |
28 | #总类别
29 | nrof_classes = len(train_set)
30 | print(nrof_classes)
31 |
32 | #subdir =20171122-112109
33 | subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
34 |
35 | #log_dir = c:\User\logs\facenet\20171122-
36 | log_dir = os.path.join(os.path.expanduser(args.logs_base_dir),subdir)
37 | if not os.path.isdir(log_dir):
38 | os.makedirs(log_dir)
39 | print("log_dir =",log_dir)
40 |
41 | # model_dir =c:\User/models/facenet/2017;;;
42 | model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
43 | if not os.path.isdir(model_dir): # Create the model directory if it doesn't exist
44 | os.makedirs(model_dir)
45 |
46 | print("model_dir =", model_dir)
47 | pretrained_model = None
48 | if args.pretrained_model:
49 | # pretrained_model = os.path.expanduser(args.pretrained_model)
50 | # pretrained_model = tf.train.get_checkpoint_state(args.pretrained_model)
51 | pretrained_model = args.pretrained_model
52 | print('Pre-trained model: %s' % pretrained_model)
53 |
54 |
55 | # Write arguments to a text file
56 | faceUtil.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt'))
57 | print("write_arguments_to_file")
58 | with tf.Graph().as_default():
59 | tf.set_random_seed(args.seed)
60 | global_step = tf.Variable(0,trainable=False)
61 |
62 | #两个列表 image_list= 图片地址列表, label_list = 对应label列表,两个大小相同
63 | image_list, label_list = faceUtil.get_image_paths_and_labels(train_set)
64 | assert len(image_list) > 0 , 'dataset is empty'
65 | print("len(image_list) = ",len(image_list))
66 |
67 | # Create a queue that produces indices into the image_list and label_list
68 | labels = ops.convert_to_tensor(label_list,dtype=tf.int64)
69 | range_size = array_ops.shape(labels)[0]
70 | range_size = tf.Print(range_size, [tf.shape(range_size)],message='Shape of range_input_producer range_size : ',summarize=4, first_n=1)
71 |
72 | #产生一个队列,队列包含0到range_size-1的元素,打乱
73 | index_queue = tf.train.range_input_producer(range_size,num_epochs=None,shuffle=True,seed=None,capacity=32)
74 |
75 | #从index_queue中取出 args.batch_size*args.epoch_size 个元素,用来从image_list, label_list中取出一部分feed给网络
76 | index_dequeue_op = index_queue.dequeue_many(args.batch_size * args.epoch_size,'index_dequeue')
77 |
78 | #学习率
79 | learning_rate_placeholder = tf.placeholder(tf.float32,name='learning_rate')
80 | #批大小 arg.batch_size
81 | batch_size_placeholder = tf.placeholder(tf.int32,name='batch_size')
82 | #是否训练中
83 | phase_train_placeholder = tf.placeholder(tf.bool,name='phase_train')
84 | #图像路径 大小 arg.batch_size * arg.epoch_size
85 | image_paths_placeholder = tf.placeholder(tf.string,shape=[None,1],name='image_paths')
86 | #图像标签 大小:arg.batch_size * arg.epoch_size
87 | labels_placeholder = tf.placeholder(tf.int64,shape=[None,1],name='labels')
88 |
89 | #新建一个队列,数据流操作,fifo,先入先出
90 | input_queue = data_flow_ops.FIFOQueue(capacity=100000,dtypes=[tf.string,tf.int64],shapes=[(1,),(1,)],shared_name=None,name=None)
91 |
92 | # enqueue_many返回的是一个操作 ,入站的数量是 len(image_paths_placeholder) = 从index_queue中取出 args.batch_size*args.epoch_size个元素
93 | enqueue_op = input_queue.enqueue_many([image_paths_placeholder,labels_placeholder],name='enqueue_op')
94 |
95 | nrof_preprocess_threads = 4
96 | images_and_labels = []
97 |
98 | for _ in range(nrof_preprocess_threads):
99 | filenames , label = input_queue.dequeue()
100 | # label = tf.Print(label,[tf.shape(label)],message='Shape of one thread input_queue.dequeue label : ',
101 | # summarize=4,first_n=1)
102 | # filenames = tf.Print(filenames, [tf.shape(filenames)], message='Shape of one thread input_queue.dequeue filenames : ',
103 | # summarize=4, first_n=1)
104 | print("one thread input_queue.dequeue len = ",tf.shape(label))
105 | images =[]
106 | for filenames in tf.unstack(filenames):
107 | file_contents = tf.read_file(filenames)
108 | image = tf.image.decode_image(file_contents,channels=3)
109 |
110 | if args.random_rotate:
111 | image = tf.py_func(faceUtil.random_rotate_image, [image], tf.uint8)
112 |
113 | if args.random_crop:
114 | image = tf.random_crop(image,[args.image_size,args.image_size,3])
115 |
116 | else:
117 | image = tf.image.resize_image_with_crop_or_pad(image,args.image_size,args.image_size)
118 |
119 | if args.random_flip:
120 | image = tf.image.random_flip_left_right(image)
121 |
122 | image.set_shape((args.image_size,args.image_size,3))
123 | images.append(tf.image.per_image_standardization(image))
124 |
125 | #从队列中取出名字 解析为image 然后加进images_and_labels 可能长度 = 4 *
126 | images_and_labels.append([images,label])
127 |
128 | #最终一次进入网络的数据: 长应该度 = batch_size_placeholder
129 | image_batch, label_batch = tf.train.batch_join(images_and_labels,batch_size=batch_size_placeholder,
130 | shapes=[(args.image_size,args.image_size,3),()],
131 | enqueue_many = True,
132 | capacity = 4 * nrof_preprocess_threads * args.batch_size,
133 | allow_smaller_final_batch=True)
134 | print('final input net image_batch len = ',tf.shape(image_batch))
135 |
136 | image_batch = tf.Print(image_batch, [tf.shape(image_batch)], message='final input net image_batch shape = ',
137 | summarize=4, first_n=1)
138 | image_batch = tf.identity(image_batch, 'image_batch')
139 | image_batch = tf.identity(image_batch, 'input')
140 | label_batch = tf.identity(label_batch, 'label_batch')
141 |
142 | print('Total number of classes: %d' % nrof_classes)
143 | print('Total number of examples: %d' % len(image_list))
144 |
145 | print('Building training graph')
146 |
147 | # 将指数衰减应用到学习率上
148 | learning_rate = tf.train.exponential_decay(learning_rate= learning_rate_placeholder,
149 | global_step = global_step,
150 | decay_steps=args.learning_rate_decay_epochs * args.epoch_size,
151 | decay_rate=args.learning_rate_decay_factor,
152 | staircase = True)
153 | #decay_steps=args.learning_rate_decay_epochs * args.epoch_size,
154 |
155 | tf.summary.scalar('learning_rate', learning_rate)
156 |
157 | # Build the inference graph
158 | prelogits, _ = network.inference(image_batch,args.keep_probability,phase_train=phase_train_placeholder,
159 | bottleneck_layer_size=args.embedding_size,weight_decay=args.weight_decay)
160 |
161 | prelogits = tf.Print(prelogits, [tf.shape(prelogits)], message='prelogits shape = ',
162 | summarize=4, first_n=1)
163 | print("prelogits.shape = ",prelogits.get_shape().as_list())
164 |
165 | # logits =slim.fully_connected(prelogits, len(train_set), activation_fn=None,
166 | # weights_initializer=tf.contrib.layers.xavier_initializer(),
167 | # weights_regularizer=slim.l2_regularizer(args.weight_decay),
168 | # scope='Logits', reuse=False)
169 | #
170 | # # Calculate the average cross entropy loss across the batch
171 | # cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
172 | # labels=label_batch, logits=logits, name='cross_entropy_per_example')
173 | # tf.reduce_mean(cross_entropy, name='cross_entropy')
174 | _,cross_entropy_mean = soft_loss_nobias(prelogits,label_batch,len(train_set))
175 | tf.add_to_collection('losses', cross_entropy_mean)
176 |
177 | regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
178 | total_loss = tf.add_n([cross_entropy_mean] + regularization_losses,name='total_loss')
179 |
180 | train_op = faceUtil.train(total_loss, global_step, args.optimizer, learning_rate,
181 | args.moving_average_decay, tf.global_variables(), args.log_histograms)
182 | # print("global_variables len = {}".format(len(tf.global_variables())))
183 | # print("local_variables len = {}".format(len(tf.local_variables())))
184 | # print("trainable_variables len = {}".format(len(tf.trainable_variables())))
185 | # for v in tf.trainable_variables() :
186 | # print("trainable_variables :{}".format(v.name))
187 | # train_op = faceUtil.train(sphere_loss,global_step,args.optimizer,learning_rate,
188 | # args.moving_average_decay, tf.global_variables(), args.log_histograms)
189 |
190 | #创建saver
191 | variables = tf.trainable_variables()
192 | print("variables_trainable len = ", len(variables))
193 | for v in variables:
194 | print('variables_trainable : {}'.format(v.name))
195 | saver = tf.train.Saver(var_list=variables, max_to_keep=2)
196 |
197 | # variables_to_restore = [v for v in variables if v.name.split('/')[0] != 'Logits']
198 | # print("variables_trainable len = ",len(variables))
199 | # print("variables_to_restore len = ",len(variables_to_restore))
200 | # # for v in variables_to_restore :
201 | # # print("variables_to_restore : ",v.name)
202 | # saver = tf.train.Saver(var_list=variables_to_restore,max_to_keep=3)
203 |
204 |
205 | # variables_trainable = tf.trainable_variables()
206 | # print("variables_trainable len = ",len(variables_trainable))
207 | # # for v in variables_trainable :
208 | # # print('variables_trainable : {}'.format(v.name))
209 | # variables_to_restore = slim.get_variables_to_restore(include=['InceptionResnetV1'])
210 | # print("variables_to_restore len = ",len(variables_to_restore))
211 | # saver = tf.train.Saver(var_list=variables_to_restore,max_to_keep=3)
212 |
213 |
214 |
215 | # Build the summary operation based on the TF collection of Summaries.
216 | summary_op = tf.summary.merge_all()
217 |
218 | # 能够在gpu上分配的最大内存
219 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_memory_fraction)
220 | sess = tf.Session(config=tf.ConfigProto(gpu_options = gpu_options,log_device_placement = False))
221 |
222 | # Initialize variables
223 | sess.run(tf.global_variables_initializer())
224 | sess.run(tf.local_variables_initializer())
225 | summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
226 |
227 | # 获取线程坐标
228 | coord = tf.train.Coordinator()
229 |
230 | # 将队列中的所有Runner开始执行
231 | tf.train.start_queue_runners(coord=coord,sess=sess)
232 |
233 | with sess.as_default():
234 | print('Running training')
235 | if pretrained_model :
236 | print('Restoring pretrained model_checkpoint_path: %s' % pretrained_model)
237 | saver.restore(sess,pretrained_model)
238 |
239 | # Training and validation loop
240 | print('Running training really')
241 | epoch = 0
242 | # 将所有数据过一遍的次数
243 | while epoch < args.max_nrof_epochs:
244 |
245 | #这里是返回当前的global_step值吗,step可以看做是全局的批处理个数
246 | step = sess.run(global_step,feed_dict=None)
247 |
248 | #epoch_size是一个epoch中批的个数
249 | # 这个epoch是全局的批处理个数除以一个epoch中批的个数得到epoch,这个epoch将用于求学习率
250 | epoch = step // args.epoch_size
251 | # Train for one epoch
252 | train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder,
253 | learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, global_step,
254 | total_loss, train_op, summary_op, summary_writer, regularization_losses, args.learning_rate_schedule_file)
255 |
256 | # Save variables and the metagraph if it doesn't exist already
257 | save_variables_and_metagraph(sess, saver, summary_writer, model_dir, subdir, step)
258 |
259 | return model_dir
260 |
261 |
262 |
263 | def train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder,
264 | learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, global_step,
265 | loss, train_op, summary_op, summary_writer, regularization_losses, learning_rate_schedule_file):
266 |
267 | batch_number = 0
268 | if args.learning_rate>0.0:
269 |
270 | lr = args.learning_rate
271 | else:
272 | lr = faceUtil.get_learning_rate_from_file(learning_rate_schedule_file, epoch)
273 |
274 | index_epoch = sess.run(index_dequeue_op)
275 | label_epoch = np.array(label_list)[index_epoch]
276 | image_epoch = np.array(image_list)[index_epoch]
277 |
278 | # Enqueue one epoch of image paths and labels
279 | labels_array = np.expand_dims(np.array(label_epoch),1)
280 | image_paths_array = np.expand_dims(np.array(image_epoch),1)
281 | sess.run(enqueue_op,{image_paths_placeholder:image_paths_array,labels_placeholder:labels_array})
282 |
283 | # Training loop
284 | train_time = 0
285 | while batch_number < args.epoch_size:
286 | start_time = time.time()
287 | feed_dict = {learning_rate_placeholder: lr, phase_train_placeholder:True, batch_size_placeholder:args.batch_size}
288 |
289 | if (batch_number % 100 == 0) :
290 | err, _, step, reg_loss, summary_str = sess.run([loss,train_op,global_step,regularization_losses,summary_op],feed_dict=feed_dict)
291 | summary_writer.add_summary(summary_str, global_step=step)
292 | else :
293 | err, _, step, reg_loss = sess.run([loss, train_op, global_step, regularization_losses], feed_dict=feed_dict)
294 |
295 | duration = time.time() - start_time
296 | print('global_step[%d],Epoch: [%d][%d/%d]\tTime %.3f\tLoss %2.3f\tRegLoss %2.3f' %
297 | (step,epoch, batch_number+1, args.epoch_size, duration, err, np.sum(reg_loss)))
298 | batch_number += 1
299 | train_time += duration
300 |
301 | # Add validation loss and accuracy to summary
302 | summary = tf.Summary()
303 |
304 | #pylint: disable=maybe-no-member
305 | summary.value.add(tag='time/total', simple_value=train_time)
306 | summary_writer.add_summary(summary, step)
307 | return step
308 |
309 |
310 |
311 | def save_variables_and_metagraph(sess, saver, summary_writer, model_dir, model_name, step):
312 |
313 | # Save the model checkpoint
314 | print('Saving variables')
315 | start_time = time.time()
316 | checkpoint_path = os.path.join(model_dir,'model-%s.ckpt' % model_name)
317 | saver.save(sess,checkpoint_path,global_step=step,write_meta_graph=False)
318 | save_time_variables = time.time() - start_time
319 | print('Variables saved in %.2f seconds' % save_time_variables)
320 | metagraph_filename = os.path.join(model_dir,'model-%s.meta' % model_name)
321 | save_time_metagraph = 0
322 | if not os.path.exists(metagraph_filename):
323 | print('Saving metagraph')
324 | start_time = time.time()
325 | saver.export_meta_graph(metagraph_filename)
326 | save_time_metagraph = time.time() - start_time
327 | print('Metagraph saved in %.2f seconds' % save_time_metagraph)
328 |
329 | summary = tf.Summary()
330 | #pylint: disable=maybe-no-member
331 | summary.value.add(tag='time/save_variables', simple_value=save_time_variables)
332 | summary.value.add(tag='time/save_metagraph', simple_value=save_time_metagraph)
333 | summary_writer.add_summary(summary, step)
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 | def parse_arguments(argv):
356 | parser = argparse.ArgumentParser()
357 | print("parser = argparse.ArgumentParser()")
358 | parser.add_argument('--data_dir',type=str,default='align/casia_maxpy_mtcnnpy_182')
359 | parser.add_argument('--gpu_memory_fraction',type=float,default=0.8)
360 | parser.add_argument('--pretrained_model', type=str,default = '/home/huyu/models/facenet/20180209-115727/model-20180209-115727.ckpt-14400',
361 | help='Load a pretrained model before training starts.')
362 | default = '/home/huyu/models/facenet/20180209-114624/model-20180209-114624.ckpt-0',
363 | #default='/home/huyu/models/facenet/20180208-210946/model-20180208-210946.ckpt-1200',
364 | # default='modeltrained/20170512/model-20170512-110547.ckpt-250000',
365 |
366 | parser.add_argument('--max_nrof_epochs', type=int, default=200)
367 | parser.add_argument('--batch_size', type=int, default=64)
368 | parser.add_argument('--image_size',type=int,default=160)
369 | parser.add_argument('--epoch_size', type=int, default=300)
370 | parser.add_argument('--embedding_size', type=int, default=128)
371 | parser.add_argument('--random_crop',
372 | help='Performs random cropping of training images. If false, the center image_size pixels from the training images are used. ' +
373 | 'If the size of the images in the data directory is equal to image_size no cropping is performed', action='store_true')
374 | parser.add_argument('--random_flip',
375 | help='Performs random horizontal flipping of training images.', action='store_true')
376 | parser.add_argument('--random_rotate',
377 | help='Performs random rotations of training images.', action='store_true')
378 | parser.add_argument('--keep_probability', type=float,
379 | help='Keep probability of dropout for the fully connected layer(s).', default=1)
380 | parser.add_argument('--weight_decay', type=float,
381 | help='L2 weight regularization.', default=0.0)
382 | parser.add_argument('--learning_rate', type=float,
383 | help='Initial learning rate. If set to a negative value a learning rate ' +
384 | 'schedule can be specified in the file "learning_rate_schedule.txt"', default=0.005)
385 | parser.add_argument('--optimizer', type=str, choices=['ADAGRAD', 'ADADELTA', 'ADAM', 'RMSPROP', 'MOM'],
386 | help='The optimization algorithm to use', default='ADAM')
387 | parser.add_argument('--learning_rate_decay_epochs', type=int,
388 | help='Number of epochs between learning rate decay.', default=5)
389 | parser.add_argument('--learning_rate_decay_factor', type=float,
390 | help='Learning rate decay factor.', default=0.8)
391 | parser.add_argument('--moving_average_decay', type=float,
392 | help='Exponential decay for tracking of training parameters.', default=0.9999)
393 | parser.add_argument('--seed', type=int,
394 | help='Random seed.', default=666)
395 | parser.add_argument('--nrof_preprocess_threads', type=int,
396 | help='Number of preprocessing (data loading and augmentation) threads.', default=4)
397 | parser.add_argument('--log_histograms',
398 | help='Enables logging of weight/bias histograms in tensorboard.', action='store_true')
399 | parser.add_argument('--learning_rate_schedule_file', type=str,
400 | help='File containing the learning rate schedule that is used when learning_rate is set to to -1.', default='data/learning_rate_classifier_casia.txt')
401 | parser.add_argument('--filter_filename', type=str,
402 | help='File containing image data used for dataset filtering', default='')
403 | parser.add_argument('--filter_percentile', type=float,
404 | help='Keep only the percentile images closed to its class center', default=100.0)
405 | parser.add_argument('--filter_min_nrof_images_per_class', type=int,
406 | help='Keep only the classes with this number of examples or more', default=0)
407 | parser.add_argument('--logs_base_dir', type=str,
408 | help='Directory where to write event logs.', default='~/logs/facenet')
409 | parser.add_argument('--models_base_dir', type=str,
410 | help='Directory where to write trained models and checkpoints.', default='~/models/facenet')
411 |
412 |
413 | return parser.parse_args(argv)
414 |
415 |
416 | if __name__ == '__main__':
417 | main(parse_arguments(sys.argv[1:]))
418 |
--------------------------------------------------------------------------------
/src/train_softmax_no_train_inception.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | import os.path
7 | import sys
8 | import time
9 | from datetime import datetime
10 |
11 | import numpy as np
12 | from tensorflow.python.framework import ops
13 | from tensorflow.python.ops import array_ops
14 | from tensorflow.python.ops import data_flow_ops
15 | import tensorflow.contrib.slim as slim
16 |
17 | from loss.sphere import *
18 | from models import inception_resnet_no_train as network
19 | from utils import faceUtil
20 |
21 |
22 | def main(args):
23 | print("main start")
24 | np.random.seed(seed=args.seed)
25 | #train_set = ImageClass list
26 | train_set = faceUtil.get_dataset(args.data_dir)
27 |
28 | #总类别
29 | nrof_classes = len(train_set)
30 | print(nrof_classes)
31 |
32 | #subdir =20171122-112109
33 | subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
34 |
35 | #log_dir = c:\User\logs\facenet\20171122-
36 | log_dir = os.path.join(os.path.expanduser(args.logs_base_dir),subdir)
37 | if not os.path.isdir(log_dir):
38 | os.makedirs(log_dir)
39 | print("log_dir =",log_dir)
40 |
41 | # model_dir =c:\User/models/facenet/2017;;;
42 | model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
43 | if not os.path.isdir(model_dir): # Create the model directory if it doesn't exist
44 | os.makedirs(model_dir)
45 |
46 | print("model_dir =", model_dir)
47 | pretrained_model = None
48 | if args.pretrained_model:
49 | # pretrained_model = os.path.expanduser(args.pretrained_model)
50 | # pretrained_model = tf.train.get_checkpoint_state(args.pretrained_model)
51 | pretrained_model = args.pretrained_model
52 | print('Pre-trained model: %s' % pretrained_model)
53 |
54 |
55 | # Write arguments to a text file
56 | faceUtil.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt'))
57 | print("write_arguments_to_file")
58 | with tf.Graph().as_default():
59 | tf.set_random_seed(args.seed)
60 | global_step = tf.Variable(0,trainable=False)
61 |
62 | #两个列表 image_list= 图片地址列表, label_list = 对应label列表,两个大小相同
63 | image_list, label_list = faceUtil.get_image_paths_and_labels(train_set)
64 | assert len(image_list) > 0 , 'dataset is empty'
65 | print("len(image_list) = ",len(image_list))
66 |
67 | # Create a queue that produces indices into the image_list and label_list
68 | labels = ops.convert_to_tensor(label_list,dtype=tf.int64)
69 | range_size = array_ops.shape(labels)[0]
70 | range_size = tf.Print(range_size, [tf.shape(range_size)],message='Shape of range_input_producer range_size : ',summarize=4, first_n=1)
71 |
72 | #产生一个队列,队列包含0到range_size-1的元素,打乱
73 | index_queue = tf.train.range_input_producer(range_size,num_epochs=None,shuffle=True,seed=None,capacity=32)
74 |
75 | #从index_queue中取出 args.batch_size*args.epoch_size 个元素,用来从image_list, label_list中取出一部分feed给网络
76 | index_dequeue_op = index_queue.dequeue_many(args.batch_size * args.epoch_size,'index_dequeue')
77 |
78 | #学习率
79 | learning_rate_placeholder = tf.placeholder(tf.float32,name='learning_rate')
80 | #批大小 arg.batch_size
81 | batch_size_placeholder = tf.placeholder(tf.int32,name='batch_size')
82 | #是否训练中
83 | phase_train_placeholder = tf.placeholder(tf.bool,name='phase_train')
84 | #图像路径 大小 arg.batch_size * arg.epoch_size
85 | image_paths_placeholder = tf.placeholder(tf.string,shape=[None,1],name='image_paths')
86 | #图像标签 大小:arg.batch_size * arg.epoch_size
87 | labels_placeholder = tf.placeholder(tf.int64,shape=[None,1],name='labels')
88 |
89 | #新建一个队列,数据流操作,fifo,先入先出
90 | input_queue = data_flow_ops.FIFOQueue(capacity=100000,dtypes=[tf.string,tf.int64],shapes=[(1,),(1,)],shared_name=None,name=None)
91 |
92 | # enqueue_many返回的是一个操作 ,入站的数量是 len(image_paths_placeholder) = 从index_queue中取出 args.batch_size*args.epoch_size个元素
93 | enqueue_op = input_queue.enqueue_many([image_paths_placeholder,labels_placeholder],name='enqueue_op')
94 |
95 | nrof_preprocess_threads = 4
96 | images_and_labels = []
97 |
98 | for _ in range(nrof_preprocess_threads):
99 | filenames , label = input_queue.dequeue()
100 | # label = tf.Print(label,[tf.shape(label)],message='Shape of one thread input_queue.dequeue label : ',
101 | # summarize=4,first_n=1)
102 | # filenames = tf.Print(filenames, [tf.shape(filenames)], message='Shape of one thread input_queue.dequeue filenames : ',
103 | # summarize=4, first_n=1)
104 | print("one thread input_queue.dequeue len = ",tf.shape(label))
105 | images =[]
106 | for filenames in tf.unstack(filenames):
107 | file_contents = tf.read_file(filenames)
108 | image = tf.image.decode_image(file_contents,channels=3)
109 |
110 | if args.random_rotate:
111 | image = tf.py_func(faceUtil.random_rotate_image, [image], tf.uint8)
112 |
113 | if args.random_crop:
114 | image = tf.random_crop(image,[args.image_size,args.image_size,3])
115 |
116 | else:
117 | image = tf.image.resize_image_with_crop_or_pad(image,args.image_size,args.image_size)
118 |
119 | if args.random_flip:
120 | image = tf.image.random_flip_left_right(image)
121 |
122 | image.set_shape((args.image_size,args.image_size,3))
123 | images.append(tf.image.per_image_standardization(image))
124 |
125 | #从队列中取出名字 解析为image 然后加进images_and_labels 可能长度 = 4 *
126 | images_and_labels.append([images,label])
127 |
128 | #最终一次进入网络的数据: 长应该度 = batch_size_placeholder
129 | image_batch, label_batch = tf.train.batch_join(images_and_labels,batch_size=batch_size_placeholder,
130 | shapes=[(args.image_size,args.image_size,3),()],
131 | enqueue_many = True,
132 | capacity = 4 * nrof_preprocess_threads * args.batch_size,
133 | allow_smaller_final_batch=True)
134 | print('final input net image_batch len = ',tf.shape(image_batch))
135 |
136 | image_batch = tf.Print(image_batch, [tf.shape(image_batch)], message='final input net image_batch shape = ',
137 | summarize=4, first_n=1)
138 | image_batch = tf.identity(image_batch, 'image_batch')
139 | image_batch = tf.identity(image_batch, 'input')
140 | label_batch = tf.identity(label_batch, 'label_batch')
141 |
142 | print('Total number of classes: %d' % nrof_classes)
143 | print('Total number of examples: %d' % len(image_list))
144 |
145 | print('Building training graph')
146 |
147 | # 将指数衰减应用到学习率上
148 | learning_rate = tf.train.exponential_decay(learning_rate= learning_rate_placeholder,
149 | global_step = global_step,
150 | decay_steps=args.learning_rate_decay_epochs * args.epoch_size,
151 | decay_rate=args.learning_rate_decay_factor,
152 | staircase = True)
153 | #decay_steps=args.learning_rate_decay_epochs * args.epoch_size,
154 |
155 | tf.summary.scalar('learning_rate', learning_rate)
156 |
157 | # Build the inference graph
158 | prelogits, _ = network.inference(image_batch,args.keep_probability,phase_train=phase_train_placeholder,
159 | bottleneck_layer_size=args.embedding_size,weight_decay=args.weight_decay)
160 |
161 | prelogits = tf.Print(prelogits, [tf.shape(prelogits)], message='prelogits shape = ',
162 | summarize=4, first_n=1)
163 | print("prelogits.shape = ",prelogits.get_shape().as_list())
164 |
165 | # logits =slim.fully_connected(prelogits, len(train_set), activation_fn=None,
166 | # weights_initializer=tf.contrib.layers.xavier_initializer(),
167 | # weights_regularizer=slim.l2_regularizer(args.weight_decay),
168 | # scope='Logits', reuse=False)
169 | #
170 | # # Calculate the average cross entropy loss across the batch
171 | # cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
172 | # labels=label_batch, logits=logits, name='cross_entropy_per_example')
173 | # tf.reduce_mean(cross_entropy, name='cross_entropy')
174 | _,cross_entropy_mean = soft_loss(prelogits,label_batch,len(train_set))
175 | tf.add_to_collection('losses', cross_entropy_mean)
176 |
177 | regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
178 | total_loss = tf.add_n([cross_entropy_mean] + regularization_losses,name='total_loss')
179 |
180 | train_op = faceUtil.train(total_loss, global_step, args.optimizer, learning_rate,
181 | args.moving_average_decay, tf.trainable_variables(), args.log_histograms)
182 | # print("global_variables len = {}".format(len(tf.global_variables())))
183 | # print("local_variables len = {}".format(len(tf.local_variables())))
184 | # print("trainable_variables len = {}".format(len(tf.trainable_variables())))
185 | # for v in tf.trainable_variables() :
186 | # print("trainable_variables :{}".format(v.name))
187 | # train_op = faceUtil.train(sphere_loss,global_step,args.optimizer,learning_rate,
188 | # args.moving_average_decay, tf.global_variables(), args.log_histograms)
189 |
190 | #创建saver
191 | variables = tf.trainable_variables()
192 | print("variables_trainable len = ", len(variables))
193 | # for v in variables:
194 | # print('variables_trainable : {}'.format(v.name))
195 | saver = tf.train.Saver(var_list=variables, max_to_keep=2)
196 |
197 | variables_to_restore = slim.get_variables_to_restore(include=['InceptionResnetV1'])
198 |
199 | print("variables_to_restore len = ", len(variables_to_restore))
200 | saver_restore = tf.train.Saver(var_list=variables_to_restore)
201 |
202 | # variables_to_restore = [v for v in variables if v.name.split('/')[0] != 'Logits']
203 | # print("variables_trainable len = ",len(variables))
204 | # print("variables_to_restore len = ",len(variables_to_restore))
205 | # # for v in variables_to_restore :
206 | # # print("variables_to_restore : ",v.name)
207 | # saver = tf.train.Saver(var_list=variables_to_restore,max_to_keep=3)
208 |
209 |
210 | # variables_trainable = tf.trainable_variables()
211 | # print("variables_trainable len = ",len(variables_trainable))
212 | # # for v in variables_trainable :
213 | # # print('variables_trainable : {}'.format(v.name))
214 | # variables_to_restore = slim.get_variables_to_restore(include=['InceptionResnetV1'])
215 | # print("variables_to_restore len = ",len(variables_to_restore))
216 | # saver = tf.train.Saver(var_list=variables_to_restore,max_to_keep=3)
217 |
218 |
219 |
220 | # Build the summary operation based on the TF collection of Summaries.
221 | summary_op = tf.summary.merge_all()
222 |
223 | # 能够在gpu上分配的最大内存
224 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_memory_fraction)
225 | sess = tf.Session(config=tf.ConfigProto(gpu_options = gpu_options,log_device_placement = False))
226 |
227 | # Initialize variables
228 | sess.run(tf.global_variables_initializer())
229 | sess.run(tf.local_variables_initializer())
230 | summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
231 |
232 | # 获取线程坐标
233 | coord = tf.train.Coordinator()
234 |
235 | # 将队列中的所有Runner开始执行
236 | tf.train.start_queue_runners(coord=coord,sess=sess)
237 |
238 | with sess.as_default():
239 | print('Running training')
240 | if pretrained_model :
241 | print('Restoring pretrained model_checkpoint_path: %s' % pretrained_model)
242 | saver_restore.restore(sess,pretrained_model)
243 |
244 | # Training and validation loop
245 | print('Running training really')
246 | epoch = 0
247 | # 将所有数据过一遍的次数
248 | while epoch < args.max_nrof_epochs:
249 |
250 | #这里是返回当前的global_step值吗,step可以看做是全局的批处理个数
251 | step = sess.run(global_step,feed_dict=None)
252 |
253 | #epoch_size是一个epoch中批的个数
254 | # 这个epoch是全局的批处理个数除以一个epoch中批的个数得到epoch,这个epoch将用于求学习率
255 | epoch = step // args.epoch_size
256 | # Train for one epoch
257 | train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder,
258 | learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, global_step,
259 | total_loss, train_op, summary_op, summary_writer, regularization_losses, args.learning_rate_schedule_file)
260 |
261 | # Save variables and the metagraph if it doesn't exist already
262 | save_variables_and_metagraph(sess, saver, summary_writer, model_dir, subdir, step)
263 |
264 | return model_dir
265 |
266 |
267 |
268 | def train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder,
269 | learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, global_step,
270 | loss, train_op, summary_op, summary_writer, regularization_losses, learning_rate_schedule_file):
271 |
272 | batch_number = 0
273 | if args.learning_rate>0.0:
274 |
275 | lr = args.learning_rate
276 | else:
277 | lr = faceUtil.get_learning_rate_from_file(learning_rate_schedule_file, epoch)
278 |
279 | index_epoch = sess.run(index_dequeue_op)
280 | label_epoch = np.array(label_list)[index_epoch]
281 | image_epoch = np.array(image_list)[index_epoch]
282 |
283 | # Enqueue one epoch of image paths and labels
284 | labels_array = np.expand_dims(np.array(label_epoch),1)
285 | image_paths_array = np.expand_dims(np.array(image_epoch),1)
286 | sess.run(enqueue_op,{image_paths_placeholder:image_paths_array,labels_placeholder:labels_array})
287 |
288 | # Training loop
289 | train_time = 0
290 | while batch_number < args.epoch_size:
291 |
292 | start_time = time.time()
293 | feed_dict = {learning_rate_placeholder: lr, phase_train_placeholder:True, batch_size_placeholder:args.batch_size}
294 |
295 | if (batch_number % 100 == 0) :
296 | err, _, step, reg_loss, summary_str = sess.run([loss,train_op,global_step,regularization_losses,summary_op],feed_dict=feed_dict)
297 | summary_writer.add_summary(summary_str, global_step=step)
298 | else :
299 | err, _, step, reg_loss = sess.run([loss, train_op, global_step, regularization_losses], feed_dict=feed_dict)
300 |
301 | duration = time.time() - start_time
302 | print('global_step[%d],Epoch: [%d][%d/%d]\tTime %.3f\tLoss %2.3f\tRegLoss %2.3f' %
303 | (step,epoch, batch_number+1, args.epoch_size, duration, err, np.sum(reg_loss)))
304 | batch_number += 1
305 | train_time += duration
306 |
307 | # Add validation loss and accuracy to summary
308 | summary = tf.Summary()
309 |
310 | #pylint: disable=maybe-no-member
311 | summary.value.add(tag='time/total', simple_value=train_time)
312 | summary_writer.add_summary(summary, step)
313 | return step
314 |
315 |
316 |
317 | def save_variables_and_metagraph(sess, saver, summary_writer, model_dir, model_name, step):
318 |
319 | # Save the model checkpoint
320 | print('Saving variables')
321 | start_time = time.time()
322 | checkpoint_path = os.path.join(model_dir,'model-%s.ckpt' % model_name)
323 | saver.save(sess,checkpoint_path,global_step=step,write_meta_graph=False)
324 | save_time_variables = time.time() - start_time
325 | print('Variables saved in %.2f seconds' % save_time_variables)
326 | metagraph_filename = os.path.join(model_dir,'model-%s.meta' % model_name)
327 | save_time_metagraph = 0
328 | if not os.path.exists(metagraph_filename):
329 | print('Saving metagraph')
330 | start_time = time.time()
331 | saver.export_meta_graph(metagraph_filename)
332 | save_time_metagraph = time.time() - start_time
333 | print('Metagraph saved in %.2f seconds' % save_time_metagraph)
334 |
335 | summary = tf.Summary()
336 | #pylint: disable=maybe-no-member
337 | summary.value.add(tag='time/save_variables', simple_value=save_time_variables)
338 | summary.value.add(tag='time/save_metagraph', simple_value=save_time_metagraph)
339 | summary_writer.add_summary(summary, step)
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 | def parse_arguments(argv):
362 | parser = argparse.ArgumentParser()
363 | print("parser = argparse.ArgumentParser()")
364 | parser.add_argument('--data_dir',type=str,default='align/casia_maxpy_mtcnnpy_182')
365 | parser.add_argument('--gpu_memory_fraction',type=float,default=0.8)
366 | parser.add_argument('--pretrained_model', type=str,
367 | help='Load a pretrained model before training starts.')
368 | default = '/home/huyu/models/facenet/20180209-114624/model-20180209-114624.ckpt-0',
369 | #default='/home/huyu/models/facenet/20180208-210946/model-20180208-210946.ckpt-1200',
370 | # default='modeltrained/20170512/model-20170512-110547.ckpt-250000',
371 |
372 | parser.add_argument('--max_nrof_epochs', type=int, default=200)
373 | parser.add_argument('--batch_size', type=int, default=64)
374 | parser.add_argument('--image_size',type=int,default=160)
375 | parser.add_argument('--epoch_size', type=int, default=300)
376 | parser.add_argument('--embedding_size', type=int, default=128)
377 | parser.add_argument('--random_crop',
378 | help='Performs random cropping of training images. If false, the center image_size pixels from the training images are used. ' +
379 | 'If the size of the images in the data directory is equal to image_size no cropping is performed', action='store_true')
380 | parser.add_argument('--random_flip',
381 | help='Performs random horizontal flipping of training images.', action='store_true')
382 | parser.add_argument('--random_rotate',
383 | help='Performs random rotations of training images.', action='store_true')
384 | parser.add_argument('--keep_probability', type=float,
385 | help='Keep probability of dropout for the fully connected layer(s).', default=0.8)
386 | parser.add_argument('--weight_decay', type=float,
387 | help='L2 weight regularization.', default=1e-8)
388 | parser.add_argument('--learning_rate', type=float,
389 | help='Initial learning rate. If set to a negative value a learning rate ' +
390 | 'schedule can be specified in the file "learning_rate_schedule.txt"', default=0.05)
391 | parser.add_argument('--optimizer', type=str, choices=['ADAGRAD', 'ADADELTA', 'ADAM', 'RMSPROP', 'MOM'],
392 | help='The optimization algorithm to use', default='ADAM')
393 | parser.add_argument('--learning_rate_decay_epochs', type=int,
394 | help='Number of epochs between learning rate decay.', default=1)
395 | parser.add_argument('--learning_rate_decay_factor', type=float,
396 | help='Learning rate decay factor.', default=0.9)
397 | parser.add_argument('--moving_average_decay', type=float,
398 | help='Exponential decay for tracking of training parameters.', default=0.9999)
399 | parser.add_argument('--seed', type=int,
400 | help='Random seed.', default=666)
401 | parser.add_argument('--nrof_preprocess_threads', type=int,
402 | help='Number of preprocessing (data loading and augmentation) threads.', default=4)
403 | parser.add_argument('--log_histograms',
404 | help='Enables logging of weight/bias histograms in tensorboard.', action='store_true')
405 | parser.add_argument('--learning_rate_schedule_file', type=str,
406 | help='File containing the learning rate schedule that is used when learning_rate is set to to -1.', default='data/learning_rate_classifier_casia.txt')
407 | parser.add_argument('--filter_filename', type=str,
408 | help='File containing image data used for dataset filtering', default='')
409 | parser.add_argument('--filter_percentile', type=float,
410 | help='Keep only the percentile images closed to its class center', default=100.0)
411 | parser.add_argument('--filter_min_nrof_images_per_class', type=int,
412 | help='Keep only the classes with this number of examples or more', default=0)
413 | parser.add_argument('--logs_base_dir', type=str,
414 | help='Directory where to write event logs.', default='~/logs/facenet')
415 | parser.add_argument('--models_base_dir', type=str,
416 | help='Directory where to write trained models and checkpoints.', default='~/models/facenet')
417 |
418 |
419 | return parser.parse_args(argv)
420 |
421 |
422 | if __name__ == '__main__':
423 | main(parse_arguments(sys.argv[1:]))
424 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/utils/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/faceUtil.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/utils/__pycache__/faceUtil.cpython-35.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/faceUtil.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/src/utils/__pycache__/faceUtil.cpython-36.pyc
--------------------------------------------------------------------------------
/src/utils/faceUtil.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import os
7 | from subprocess import Popen, PIPE
8 | import tensorflow as tf
9 | from tensorflow.python.framework import ops
10 | import numpy as np
11 | from scipy import misc
12 | from sklearn.model_selection import KFold
13 | from scipy import interpolate
14 | from tensorflow.python.training import training
15 | import random
16 | import re
17 | from tensorflow.python.platform import gfile
18 |
19 |
20 | def center_loss_angel(features, label, alfa, nrof_classes):
21 |
22 | nrof_features = features.get_shape()[1]
23 | centers = tf.get_variable('centers', [nrof_classes, nrof_features], dtype=tf.float32,
24 | initializer=tf.constant_initializer(1), trainable=False)
25 |
26 | centers_unit = tf.nn.l2_normalize(centers,dim=1)
27 | features_unit = tf.nn.l2_normalize(centers,dim=1)
28 |
29 | label = tf.reshape(label, [-1])
30 | centers_batch = tf.gather(centers_unit, label)
31 |
32 | cos_theta = tf.reduce_sum( tf.multiply(features_unit , centers_batch),axis=1)
33 |
34 | diff = (1 - alfa) * (centers_batch - features)
35 | centers = tf.scatter_sub(centers, label, diff)
36 |
37 | loss = tf.reduce_mean(tf.square( cos_theta))
38 | return loss, centers
39 |
40 |
41 | def center_loss(features, label, alfa, nrof_classes):
42 | """Center loss based on the paper "A Discriminative Feature Learning Approach for Deep Face Recognition"
43 | (http://ydwen.github.io/papers/WenECCV16.pdf)
44 | """
45 | nrof_features = features.get_shape()[1]
46 | centers = tf.get_variable('centers', [nrof_classes, nrof_features], dtype=tf.float32,
47 | initializer=tf.constant_initializer(0), trainable=False)
48 | label = tf.reshape(label, [-1])
49 | centers_batch = tf.gather(centers, label)
50 | diff = (1 - alfa) * (centers_batch - features)
51 | centers = tf.scatter_sub(centers, label, diff)
52 | loss = tf.reduce_mean(tf.square(features - centers_batch))
53 | return loss, centers
54 |
55 |
56 | def get_image_paths_and_labels(dataset):
57 | image_paths_flat = []
58 | labels_flat = []
59 | for i in range(len(dataset)):
60 | image_paths_flat += dataset[i].image_paths
61 | labels_flat += [i] * len(dataset[i].image_paths)
62 | return image_paths_flat, labels_flat
63 |
64 |
65 | #excise
66 | def get_image_andlabes(dataset):
67 | image_path =[]
68 | label = []
69 |
70 | for i in range(len(dataset)):
71 | image_path += dataset[i].image_paths
72 | label += [i] * len(dataset[i].image_paths)
73 | return image_path,label
74 |
75 | def shuffle_examples(image_paths, labels):
76 | shuffle_list = list(zip(image_paths, labels))
77 | random.shuffle(shuffle_list)
78 | image_paths_shuff, labels_shuff = zip(*shuffle_list)
79 | return image_paths_shuff, labels_shuff
80 |
81 | def read_images_from_disk(input_queue):
82 | """Consumes a single filename and label as a ' '-delimited string.
83 | Args:
84 | filename_and_label_tensor: A scalar string tensor.
85 | Returns:
86 | Two tensors: the decoded image, and the string label.
87 | """
88 | label = input_queue[1]
89 | file_contents = tf.read_file(input_queue[0])
90 | example = tf.image.decode_image(file_contents, channels=3)
91 | return example, label
92 |
93 | def random_rotate_image(image):
94 | angle = np.random.uniform(low=-10.0, high=10.0)
95 | return misc.imrotate(image, angle, 'bicubic')
96 |
97 | def read_and_augment_data(image_list, label_list, image_size, batch_size, max_nrof_epochs,
98 | random_crop, random_flip, random_rotate, nrof_preprocess_threads, shuffle=True):
99 |
100 | images = ops.convert_to_tensor(image_list, dtype=tf.string)
101 | labels = ops.convert_to_tensor(label_list, dtype=tf.int32)
102 |
103 | # Makes an input queue
104 | input_queue = tf.train.slice_input_producer([images, labels],
105 | num_epochs=max_nrof_epochs, shuffle=shuffle)
106 |
107 | images_and_labels = []
108 | for _ in range(nrof_preprocess_threads):
109 | image, label = read_images_from_disk(input_queue)
110 | if random_rotate:
111 | image = tf.py_func(random_rotate_image, [image], tf.uint8)
112 | if random_crop:
113 | image = tf.random_crop(image, [image_size, image_size, 3])
114 | else:
115 | image = tf.image.resize_image_with_crop_or_pad(image, image_size, image_size)
116 | if random_flip:
117 | image = tf.image.random_flip_left_right(image)
118 | #pylint: disable=no-member
119 | image.set_shape((image_size, image_size, 3))
120 | image = tf.image.per_image_standardization(image)
121 | images_and_labels.append([image, label])
122 |
123 | image_batch, label_batch = tf.train.batch_join(
124 | images_and_labels, batch_size=batch_size,
125 | capacity=4 * nrof_preprocess_threads * batch_size,
126 | allow_smaller_final_batch=True)
127 |
128 | return image_batch, label_batch
129 |
130 | def _add_loss_summaries(total_loss):
131 | """Add summaries for losses.
132 |
133 | Generates moving average for all losses and associated summaries for
134 | visualizing the performance of the network.
135 |
136 | Args:
137 | total_loss: Total loss from loss().
138 | Returns:
139 | loss_averages_op: op for generating moving averages of losses.
140 | """
141 | # Compute the moving average of all individual losses and the total loss.
142 | loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
143 | losses = tf.get_collection('losses')
144 | loss_averages_op = loss_averages.apply(losses + [total_loss])
145 |
146 | # Attach a scalar summmary to all individual losses and the total loss; do the
147 | # same for the averaged version of the losses.
148 | for l in losses + [total_loss]:
149 | # Name each loss as '(raw)' and name the moving average version of the loss
150 | # as the original loss name.
151 | tf.summary.scalar(l.op.name +' (raw)', l)
152 | tf.summary.scalar(l.op.name, loss_averages.average(l))
153 |
154 | return loss_averages_op
155 |
156 | def train(total_loss, global_step, optimizer, learning_rate, moving_average_decay, update_gradient_vars, log_histograms=True):
157 | # Generate moving averages of all losses and associated summaries.
158 | loss_averages_op = _add_loss_summaries(total_loss)
159 |
160 | # Compute gradients.
161 | with tf.control_dependencies([loss_averages_op]):
162 | if optimizer=='ADAGRAD':
163 | opt = tf.train.AdagradOptimizer(learning_rate)
164 | elif optimizer=='ADADELTA':
165 | opt = tf.train.AdadeltaOptimizer(learning_rate, rho=0.9, epsilon=1e-6)
166 | elif optimizer=='ADAM':
167 | opt = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999, epsilon=0.1)
168 | elif optimizer=='RMSPROP':
169 | opt = tf.train.RMSPropOptimizer(learning_rate, decay=0.9, momentum=0.9, epsilon=1.0)
170 | elif optimizer=='MOM':
171 | opt = tf.train.MomentumOptimizer(learning_rate, 0.9, use_nesterov=True)
172 | else:
173 | raise ValueError('Invalid optimization algorithm')
174 |
175 | grads = opt.compute_gradients(total_loss, update_gradient_vars)
176 |
177 | # Apply gradients.
178 | apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
179 |
180 | # Add histograms for trainable variables.
181 | if log_histograms:
182 | for var in tf.trainable_variables():
183 | tf.summary.histogram(var.op.name, var)
184 |
185 | # Add histograms for gradients.
186 | if log_histograms:
187 | for grad, var in grads:
188 | if grad is not None:
189 | tf.summary.histogram(var.op.name + '/gradients', grad)
190 |
191 | # Track the moving averages of all trainable variables.
192 | variable_averages = tf.train.ExponentialMovingAverage(
193 | moving_average_decay, global_step)
194 | variables_averages_op = variable_averages.apply(tf.trainable_variables())
195 |
196 | with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
197 | train_op = tf.no_op(name='train')
198 |
199 | return train_op
200 |
201 | def prewhiten(x):
202 | mean = np.mean(x)
203 | std = np.std(x)
204 | std_adj = np.maximum(std, 1.0/np.sqrt(x.size))
205 | y = np.multiply(np.subtract(x, mean), 1/std_adj)
206 | return y
207 |
208 | def crop(image, random_crop, image_size):
209 | if image.shape[1]>image_size:
210 | sz1 = int(image.shape[1]//2)
211 | sz2 = int(image_size//2)
212 | if random_crop:
213 | diff = sz1-sz2
214 | (h, v) = (np.random.randint(-diff, diff+1), np.random.randint(-diff, diff+1))
215 | else:
216 | (h, v) = (0,0)
217 | image = image[(sz1-sz2+v):(sz1+sz2+v),(sz1-sz2+h):(sz1+sz2+h),:]
218 | return image
219 |
220 | def flip(image, random_flip):
221 | if random_flip and np.random.choice([True, False]):
222 | image = np.fliplr(image)
223 | return image
224 |
225 | def to_rgb(img):
226 | w, h = img.shape
227 | ret = np.empty((w, h, 3), dtype=np.uint8)
228 | ret[:, :, 0] = ret[:, :, 1] = ret[:, :, 2] = img
229 | return ret
230 |
231 | def load_data(image_paths, do_random_crop, do_random_flip, image_size, do_prewhiten=True):
232 | nrof_samples = len(image_paths)
233 | images = np.zeros((nrof_samples, image_size, image_size, 3))
234 | for i in range(nrof_samples):
235 | img = misc.imread(image_paths[i])
236 | if img.ndim == 2:
237 | img = to_rgb(img)
238 | if do_prewhiten:
239 | img = prewhiten(img)
240 | img = crop(img, do_random_crop, image_size)
241 | img = flip(img, do_random_flip)
242 | images[i,:,:,:] = img
243 | return images
244 |
245 | def get_label_batch(label_data, batch_size, batch_index):
246 | nrof_examples = np.size(label_data, 0)
247 | j = batch_index*batch_size % nrof_examples
248 | if j+batch_size<=nrof_examples:
249 | batch = label_data[j:j+batch_size]
250 | else:
251 | x1 = label_data[j:nrof_examples]
252 | x2 = label_data[0:nrof_examples-j]
253 | batch = np.vstack([x1,x2])
254 | batch_int = batch.astype(np.int64)
255 | return batch_int
256 |
257 | def get_batch(image_data, batch_size, batch_index):
258 | nrof_examples = np.size(image_data, 0)
259 | j = batch_index*batch_size % nrof_examples
260 | if j+batch_size<=nrof_examples:
261 | batch = image_data[j:j+batch_size,:,:,:]
262 | else:
263 | x1 = image_data[j:nrof_examples,:,:,:]
264 | x2 = image_data[0:nrof_examples-j,:,:,:]
265 | batch = np.vstack([x1,x2])
266 | batch_float = batch.astype(np.float32)
267 | return batch_float
268 |
269 | def get_triplet_batch(triplets, batch_index, batch_size):
270 | ax, px, nx = triplets
271 | a = get_batch(ax, int(batch_size/3), batch_index)
272 | p = get_batch(px, int(batch_size/3), batch_index)
273 | n = get_batch(nx, int(batch_size/3), batch_index)
274 | batch = np.vstack([a, p, n])
275 | return batch
276 |
277 | def get_learning_rate_from_file(filename, epoch):
278 | with open(filename, 'r') as f:
279 | for line in f.readlines():
280 | line = line.split('#', 1)[0]
281 | if line:
282 | par = line.strip().split(':')
283 | e = int(par[0])
284 | lr = float(par[1])
285 | if e <= epoch:
286 | learning_rate = lr
287 | else:
288 | return learning_rate
289 |
290 | #name -String, image_paths : list
291 | class ImageClass():
292 | "Stores the paths to images for a given class"
293 | def __init__(self, name, image_paths):
294 | self.name = name
295 | self.image_paths = image_paths
296 |
297 | def __str__(self):
298 | return self.name + ', ' + str(len(self.image_paths)) + ' images'
299 |
300 | def __len__(self):
301 | return len(self.image_paths)
302 |
303 | #return ImageClass list
304 | def get_dataset(paths, has_class_directories=True):
305 | dataset = []
306 | for path in paths.split(':'):
307 | path_exp = os.path.expanduser(path)
308 | classes = os.listdir(path_exp)
309 | classes.sort()
310 | nrof_classes = len(classes)
311 | for i in range(nrof_classes):
312 |
313 | #/casia/casia_maxpy_mtcnnalign_182_160/000001...000000089 ...
314 | class_name = classes[i]
315 | #facedir=c:User/User/datasets/casia/casia_maxpy_mtcnnalign_182_160/000001....
316 | facedir = os.path.join(path_exp, class_name)
317 | #image_paths=c:User/User/datasets/casia/casia_maxpy_mtcnnalign_182_160/0000.../....img list
318 | image_paths = get_image_paths(facedir)
319 | #class_name = 0000..., image_paths = ....img list
320 | dataset.append(ImageClass(class_name, image_paths))
321 |
322 | return dataset
323 |
324 | def get_image_paths(facedir):
325 | image_paths = []
326 | if os.path.isdir(facedir):
327 | images = os.listdir(facedir)
328 | image_paths = [os.path.join(facedir,img) for img in images]
329 | return image_paths
330 |
331 | def split_dataset(dataset, split_ratio, mode):
332 | if mode=='SPLIT_CLASSES':
333 | nrof_classes = len(dataset)
334 | class_indices = np.arange(nrof_classes)
335 | np.random.shuffle(class_indices)
336 | split = int(round(nrof_classes*split_ratio))
337 | train_set = [dataset[i] for i in class_indices[0:split]]
338 | test_set = [dataset[i] for i in class_indices[split:-1]]
339 | elif mode=='SPLIT_IMAGES':
340 | train_set = []
341 | test_set = []
342 | min_nrof_images = 2
343 | for cls in dataset:
344 | paths = cls.image_paths
345 | np.random.shuffle(paths)
346 | split = int(round(len(paths)*split_ratio))
347 | if split1:
381 | raise ValueError('There should not be more than one meta file in the model directory (%s)' % model_dir)
382 | meta_file = meta_files[0]
383 | meta_files = [s for s in files if '.ckpt' in s]
384 | max_step = -1
385 | for f in files:
386 | step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f)
387 | if step_str is not None and len(step_str.groups())>=2:
388 | step = int(step_str.groups()[1])
389 | if step > max_step:
390 | max_step = step
391 | ckpt_file = step_str.groups()[0]
392 | return meta_file, ckpt_file
393 |
394 | def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds=10):
395 | assert(embeddings1.shape[0] == embeddings2.shape[0])
396 | assert(embeddings1.shape[1] == embeddings2.shape[1])
397 | nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
398 | nrof_thresholds = len(thresholds)
399 | k_fold = KFold(n_splits=nrof_folds, shuffle=False)
400 |
401 | tprs = np.zeros((nrof_folds,nrof_thresholds))
402 | fprs = np.zeros((nrof_folds,nrof_thresholds))
403 | accuracy = np.zeros((nrof_folds))
404 |
405 | diff = np.subtract(embeddings1, embeddings2)
406 | dist = np.sum(np.square(diff),1)
407 | indices = np.arange(nrof_pairs)
408 |
409 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
410 |
411 | # Find the best threshold for the fold
412 | acc_train = np.zeros((nrof_thresholds))
413 | for threshold_idx, threshold in enumerate(thresholds):
414 | _, _, acc_train[threshold_idx] = calculate_accuracy(threshold, dist[train_set], actual_issame[train_set])
415 | best_threshold_index = np.argmax(acc_train)
416 | for threshold_idx, threshold in enumerate(thresholds):
417 | tprs[fold_idx,threshold_idx], fprs[fold_idx,threshold_idx], _ = calculate_accuracy(threshold, dist[test_set], actual_issame[test_set])
418 | _, _, accuracy[fold_idx] = calculate_accuracy(thresholds[best_threshold_index], dist[test_set], actual_issame[test_set])
419 |
420 | tpr = np.mean(tprs,0)
421 | fpr = np.mean(fprs,0)
422 | return tpr, fpr, accuracy
423 |
424 | def calculate_accuracy(threshold, dist, actual_issame):
425 | predict_issame = np.less(dist, threshold)
426 | tp = np.sum(np.logical_and(predict_issame, actual_issame))
427 | fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
428 | tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame)))
429 | fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
430 |
431 | tpr = 0 if (tp+fn==0) else float(tp) / float(tp+fn)
432 | fpr = 0 if (fp+tn==0) else float(fp) / float(fp+tn)
433 | acc = float(tp+tn)/dist.size
434 | return tpr, fpr, acc
435 |
436 |
437 |
438 | def calculate_val(thresholds, embeddings1, embeddings2, actual_issame, far_target, nrof_folds=10):
439 | assert(embeddings1.shape[0] == embeddings2.shape[0])
440 | assert(embeddings1.shape[1] == embeddings2.shape[1])
441 | nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
442 | nrof_thresholds = len(thresholds)
443 | k_fold = KFold(n_splits=nrof_folds, shuffle=False)
444 |
445 | val = np.zeros(nrof_folds)
446 | far = np.zeros(nrof_folds)
447 |
448 | diff = np.subtract(embeddings1, embeddings2)
449 | dist = np.sum(np.square(diff),1)
450 | indices = np.arange(nrof_pairs)
451 |
452 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
453 |
454 | # Find the threshold that gives FAR = far_target
455 | far_train = np.zeros(nrof_thresholds)
456 | for threshold_idx, threshold in enumerate(thresholds):
457 | _, far_train[threshold_idx] = calculate_val_far(threshold, dist[train_set], actual_issame[train_set])
458 | if np.max(far_train)>=far_target:
459 | f = interpolate.interp1d(far_train, thresholds, kind='slinear')
460 | threshold = f(far_target)
461 | else:
462 | threshold = 0.0
463 |
464 | val[fold_idx], far[fold_idx] = calculate_val_far(threshold, dist[test_set], actual_issame[test_set])
465 |
466 | val_mean = np.mean(val)
467 | far_mean = np.mean(far)
468 | val_std = np.std(val)
469 | return val_mean, val_std, far_mean
470 |
471 |
472 | def calculate_val_far(threshold, dist, actual_issame):
473 | predict_issame = np.less(dist, threshold)
474 | true_accept = np.sum(np.logical_and(predict_issame, actual_issame))
475 | false_accept = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
476 | n_same = np.sum(actual_issame)
477 | n_diff = np.sum(np.logical_not(actual_issame))
478 | val = float(true_accept) / float(n_same)
479 | far = float(false_accept) / float(n_diff)
480 | return val, far
481 |
482 | def store_revision_info(src_path, output_dir, arg_string):
483 |
484 | # Get git hash
485 | gitproc = Popen(['git', 'rev-parse', 'HEAD'], stdout = PIPE, cwd=src_path)
486 | (stdout, _) = gitproc.communicate()
487 | git_hash = stdout.strip()
488 |
489 | # Get local changes
490 | gitproc = Popen(['git', 'diff', 'HEAD'], stdout = PIPE, cwd=src_path)
491 | (stdout, _) = gitproc.communicate()
492 | git_diff = stdout.strip()
493 |
494 | # Store a text file in the log directory
495 | rev_info_filename = os.path.join(output_dir, 'revision_info.txt')
496 | with open(rev_info_filename, "w") as text_file:
497 | text_file.write('arguments: %s\n--------------------\n' % arg_string)
498 | text_file.write('git hash: %s\n--------------------\n' % git_hash)
499 | text_file.write('%s' % git_diff)
500 |
501 | def list_variables(filename):
502 | reader = training.NewCheckpointReader(filename)
503 | variable_map = reader.get_variable_to_shape_map()
504 | names = sorted(variable_map.keys())
505 | return names
506 |
507 | def put_images_on_grid(images, shape=(16,8)):
508 | nrof_images = images.shape[0]
509 | img_size = images.shape[1]
510 | bw = 3
511 | img = np.zeros((shape[1]*(img_size+bw)+bw, shape[0]*(img_size+bw)+bw, 3), np.float32)
512 | for i in range(shape[1]):
513 | x_start = i*(img_size+bw)+bw
514 | for j in range(shape[0]):
515 | img_index = i*shape[0]+j
516 | if img_index>=nrof_images:
517 | break
518 | y_start = j*(img_size+bw)+bw
519 | img[x_start:x_start+img_size, y_start:y_start+img_size, :] = images[img_index, :, :, :]
520 | if img_index>=nrof_images:
521 | break
522 | return img
523 |
524 | def write_arguments_to_file(args, filename):
525 | with open(filename, 'w') as f:
526 | for key, value in vars(args).items():
527 | f.write('%s: %s\n' % (key, str(value)))
528 |
--------------------------------------------------------------------------------
/test/.ipynb_checkpoints/sphereloss-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import tensorflow as tf\n",
12 | "import tensorflow.contrib.slim as slim\n",
13 | "import numpy as np\n",
14 | "from math import pi\n",
15 | "from builtins import range\n",
16 | "import numpy as np"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 2,
22 | "metadata": {
23 | "collapsed": true
24 | },
25 | "outputs": [],
26 | "source": [
27 | "def sphereloss(inputs,label,classes,fraction = 1, scope='Logits',reuse=None,m =4,eplion = 1e-8):\n",
28 | " \"\"\"\n",
29 | " inputs tensor shape=[batch,features_num]\n",
30 | " labels tensor shape=[batch] each unit belong num_outputs\n",
31 | " \n",
32 | " \"\"\"\n",
33 | " inputs_shape = inputs.get_shape().as_list()\n",
34 | " with tf.variable_scope(name_or_scope=scope):\n",
35 | " weight = tf.Variable(initial_value=tf.random_normal((classes,inputs_shape[1])) * tf.sqrt(2 / inputs_shape[1]),dtype=tf.float32,name='weight') # shaep =classes, features,\n",
36 | " print(\"weight shape = \",weight.get_shape().as_list())\n",
37 | " \n",
38 | " weight_unit = tf.nn.l2_normalize(weight,dim=1)\n",
39 | " print(\"weight_unit shape = \",weight_unit.get_shape().as_list())\n",
40 | " \n",
41 | " inputs_mo = tf.sqrt(tf.reduce_sum(tf.square(inputs),axis=1)+eplion) #shape=[batch\n",
42 | " print(\"inputs_mo shape = \",inputs_mo.get_shape().as_list())\n",
43 | " \n",
44 | " inputs_unit = tf.nn.l2_normalize(inputs,dim=1) #shape = [batch,features_num]\n",
45 | " print(\"inputs_unit shape = \",inputs_unit.get_shape().as_list())\n",
46 | " \n",
47 | " logits = tf.matmul(inputs,tf.transpose(weight_unit)) #shape = [batch,classes] x * w_unit\n",
48 | " print(\"logits shape = \",logits.get_shape().as_list())\n",
49 | " \n",
50 | " weight_unit_batch = tf.gather(weight_unit,label) # shaep =batch,features_num,\n",
51 | " print(\"weight_unit_batch shape = \",weight_unit_batch.get_shape().as_list())\n",
52 | " \n",
53 | " logits_inputs = tf.reduce_sum(tf.multiply(inputs,weight_unit_batch),axis=1) # shaep =batch,\n",
54 | " \n",
55 | " print(\"logits_inputs shape = \",logits_inputs.get_shape().as_list())\n",
56 | " \n",
57 | " \n",
58 | " cos_theta = tf.reduce_sum(tf.multiply(inputs_unit,weight_unit_batch),axis=1) # shaep =batch,\n",
59 | " print(\"cos_theta shape = \",cos_theta.get_shape().as_list())\n",
60 | " \n",
61 | " cos_theta_square = tf.square(cos_theta)\n",
62 | " cos_theta_biq = tf.pow(cos_theta,4)\n",
63 | " sign0 = tf.sign(cos_theta)\n",
64 | " sign2 = tf.sign(2 * cos_theta_square-1)\n",
65 | " sign3 = tf.multiply(sign2,sign0)\n",
66 | " sign4 = 2 * sign0 +sign3 -3\n",
67 | " cos_far_theta = sign3 * (8 * cos_theta_biq - 8 * cos_theta_square + 1) + sign4\n",
68 | " print(\"cos_far_theta = \",cos_far_theta.get_shape().as_list())\n",
69 | " \n",
70 | " logit_ii = tf.multiply(cos_far_theta,inputs_mo)#shape = batch \n",
71 | " print(\"logit_ii shape = \",logit_ii.get_shape().as_list())\n",
72 | " \n",
73 | " \n",
74 | " \n",
75 | " index = tf.constant(list(range(0, inputs_shape[0])), tf.int32)\n",
76 | " index_labels = tf.stack([index, label], axis = 1)\n",
77 | " index_logits = tf.scatter_nd(index_labels,tf.subtract(logit_ii,logits_inputs),logits.get_shape())\n",
78 | " print(\"index_logits shape = \",logit_ii.get_shape().as_list())\n",
79 | " \n",
80 | " logits_final = tf.add(logits,index_logits)\n",
81 | " logits_final = fraction * logits_final + (1 - fraction) * logits\n",
82 | " \n",
83 | " \n",
84 | " loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label,logits=logits_final))\n",
85 | " \n",
86 | " return logits_final,loss\n",
87 | " \n",
88 | " \n",
89 | " \n",
90 | " "
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": 3,
96 | "metadata": {},
97 | "outputs": [
98 | {
99 | "name": "stdout",
100 | "output_type": "stream",
101 | "text": [
102 | "[[ 0.54340494 0.27836939 0.42451759 ..., 0.49216697 0.40288033\n",
103 | " 0.3542983 ]\n",
104 | " [ 0.50061432 0.44517663 0.09043279 ..., 0.91975536 0.84960733\n",
105 | " 0.25446654]\n",
106 | " [ 0.87755554 0.43513019 0.72949434 ..., 0.34675413 0.1095646 0.378327 ]\n",
107 | " ..., \n",
108 | " [ 0.37716757 0.75750263 0.29912515 ..., 0.53643677 0.63122505\n",
109 | " 0.17644686]\n",
110 | " [ 0.5396841 0.55346366 0.30105263 ..., 0.97211104 0.51628182\n",
111 | " 0.44451879]\n",
112 | " [ 0.29864351 0.8312721 0.28519946 ..., 0.09984372 0.71015638\n",
113 | " 0.37341943]]\n",
114 | "[ 282 9194 3716 2977 9688 3074 487 6080 467 4974 1458 4028 9966 5131 583\n",
115 | " 2955]\n"
116 | ]
117 | }
118 | ],
119 | "source": [
120 | "np.random.seed(100)\n",
121 | "batch = 16\n",
122 | "feating = 128\n",
123 | "classess = 10000\n",
124 | "# inputsinputs = [[-1.0,1.0],[1.0,1.0],[1.0,1.0]]\n",
125 | "\n",
126 | "inputsinputs = np.random.rand(batch,feating) \n",
127 | "print(inputsinputs)\n",
128 | "inputslables = np.random.randint(0,classess,batch)\n",
129 | "print(inputslables)"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": 4,
135 | "metadata": {
136 | "collapsed": true
137 | },
138 | "outputs": [],
139 | "source": [
140 | "inputs_place = tf.placeholder(dtype=tf.float32,shape=(batch,feating),name='inputs')\n",
141 | "labels_place = tf.placeholder(dtype=tf.int32,shape=(batch),name='labels')"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": 5,
147 | "metadata": {},
148 | "outputs": [
149 | {
150 | "name": "stdout",
151 | "output_type": "stream",
152 | "text": [
153 | "weight shape = [10000, 128]\n",
154 | "weight_unit shape = [10000, 128]\n",
155 | "inputs_mo shape = [16]\n",
156 | "inputs_unit shape = [16, 128]\n",
157 | "logits shape = [16, 10000]\n",
158 | "weight_unit_batch shape = [16, 128]\n",
159 | "logits_inputs shape = [16]\n",
160 | "cos_theta shape = [16]\n",
161 | "cos_far_theta = [16]\n",
162 | "logit_ii shape = [16]\n",
163 | "index_logits shape = [16]\n"
164 | ]
165 | }
166 | ],
167 | "source": [
168 | "_,loss = sphereloss(inputs_place,labels_place,classess)\n",
169 | "optimizer = tf.train.AdamOptimizer(learning_rate=0.05)\n",
170 | "train_op = optimizer.minimize(loss)"
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": 6,
176 | "metadata": {},
177 | "outputs": [
178 | {
179 | "name": "stdout",
180 | "output_type": "stream",
181 | "text": [
182 | "[28.798058, None]\n",
183 | "[22.013681, None]\n",
184 | "[15.798929, None]\n",
185 | "[13.435317, None]\n",
186 | "[12.685493, None]\n",
187 | "[11.346285, None]\n",
188 | "[9.4595337, None]\n",
189 | "[7.3966951, None]\n",
190 | "[5.5232167, None]\n",
191 | "[4.1032286, None]\n",
192 | "[3.2184033, None]\n",
193 | "[2.7606359, None]\n",
194 | "[2.5476418, None]\n",
195 | "[2.4361296, None]\n",
196 | "[2.3488777, None]\n",
197 | "[2.2575321, None]\n",
198 | "[2.1591675, None]\n",
199 | "[2.0600286, None]\n",
200 | "[1.9670352, None]\n",
201 | "[1.884537, None]\n",
202 | "[1.8141112, None]\n",
203 | "[1.7553788, None]\n",
204 | "[1.7070088, None]\n",
205 | "[1.6673062, None]\n",
206 | "[1.6345801, None]\n",
207 | "[1.6072876, None]\n",
208 | "[1.5841055, None]\n",
209 | "[1.5639385, None]\n",
210 | "[1.5459807, None]\n",
211 | "[1.5295949, None]\n",
212 | "[1.5144105, None]\n",
213 | "[1.5001791, None]\n",
214 | "[1.4868002, None]\n",
215 | "[1.474241, None]\n",
216 | "[1.462525, None]\n",
217 | "[1.4517064, None]\n",
218 | "[1.4418349, None]\n",
219 | "[1.432952, None]\n",
220 | "[1.42506, None]\n",
221 | "[1.418117, None]\n",
222 | "[1.4120786, None]\n",
223 | "[1.4068146, None]\n",
224 | "[1.4022279, None]\n",
225 | "[1.3981946, None]\n",
226 | "[1.394609, None]\n",
227 | "[1.3913933, None]\n",
228 | "[1.3884966, None]\n",
229 | "[1.3858668, None]\n",
230 | "[1.3834978, None]\n",
231 | "[1.3813753, None]\n",
232 | "[1.3794972, None]\n",
233 | "[1.377805, None]\n",
234 | "[1.3763015, None]\n",
235 | "[1.3749441, None]\n",
236 | "[1.3737118, None]\n",
237 | "[1.3725736, None]\n",
238 | "[1.3715333, None]\n",
239 | "[1.3705664, None]\n",
240 | "[1.3696976, None]\n",
241 | "[1.3689098, None]\n",
242 | "[1.3682067, None]\n",
243 | "[1.3675909, None]\n",
244 | "[1.3670415, None]\n",
245 | "[1.3665429, None]\n",
246 | "[1.3661082, None]\n",
247 | "[1.3656967, None]\n",
248 | "[1.3653129, None]\n",
249 | "[1.3649571, None]\n",
250 | "[1.364624, None]\n",
251 | "[1.3643059, None]\n",
252 | "[1.3640261, None]\n",
253 | "[1.3637731, None]\n",
254 | "[1.3635414, None]\n",
255 | "[1.3633375, None]\n",
256 | "[1.3631665, None]\n",
257 | "[1.3629972, None]\n",
258 | "[1.3628507, None]\n",
259 | "[1.3627158, None]\n",
260 | "[1.3625877, None]\n",
261 | "[1.3624818, None]\n",
262 | "[1.3623769, None]\n",
263 | "[1.3622849, None]\n",
264 | "[1.3622081, None]\n",
265 | "[1.3621382, None]\n",
266 | "[1.3620652, None]\n",
267 | "[1.362, None]\n",
268 | "[1.3619499, None]\n",
269 | "[1.3618976, None]\n",
270 | "[1.3618547, None]\n",
271 | "[1.3618151, None]\n",
272 | "[1.3617785, None]\n",
273 | "[1.3617359, None]\n",
274 | "[1.3617026, None]\n",
275 | "[1.3616755, None]\n",
276 | "[1.3616362, None]\n",
277 | "[1.3616104, None]\n",
278 | "[1.3615811, None]\n",
279 | "[1.361555, None]\n",
280 | "[1.361526, None]\n",
281 | "[1.3615122, None]\n",
282 | "[1.3614895, None]\n",
283 | "[1.36147, None]\n",
284 | "[1.3614599, None]\n",
285 | "[1.3614435, None]\n",
286 | "[1.3614256, None]\n",
287 | "[1.3614086, None]\n",
288 | "[1.361397, None]\n",
289 | "[1.361378, None]\n",
290 | "[1.3613749, None]\n",
291 | "[1.361356, None]\n",
292 | "[1.3613384, None]\n",
293 | "[1.3613338, None]\n",
294 | "[1.3613162, None]\n",
295 | "[1.3613122, None]\n",
296 | "[1.3613, None]\n",
297 | "[1.361288, None]\n",
298 | "[1.3612826, None]\n",
299 | "[1.3612705, None]\n",
300 | "[1.3612605, None]\n",
301 | "[1.3612491, None]\n",
302 | "[1.3612472, None]\n",
303 | "[1.3612354, None]\n",
304 | "[1.3612268, None]\n",
305 | "[1.3612187, None]\n",
306 | "[1.3612064, None]\n",
307 | "[1.3612019, None]\n",
308 | "[1.3611991, None]\n",
309 | "[1.3611846, None]\n",
310 | "[1.3611834, None]\n",
311 | "[1.3611796, None]\n",
312 | "[1.3611608, None]\n",
313 | "[1.3611588, None]\n",
314 | "[1.3611513, None]\n",
315 | "[1.3611465, None]\n",
316 | "[1.3611407, None]\n",
317 | "[1.3611329, None]\n",
318 | "[1.3611214, None]\n",
319 | "[1.3611224, None]\n",
320 | "[1.3611122, None]\n",
321 | "[1.3611006, None]\n",
322 | "[1.361099, None]\n",
323 | "[1.3610919, None]\n",
324 | "[1.3610821, None]\n",
325 | "[1.3610785, None]\n",
326 | "[1.3610735, None]\n",
327 | "[1.361064, None]\n",
328 | "[1.3610643, None]\n",
329 | "[1.3610544, None]\n",
330 | "[1.3610488, None]\n",
331 | "[1.3610449, None]\n",
332 | "[1.3610425, None]\n",
333 | "[1.3610251, None]\n",
334 | "[1.361027, None]\n",
335 | "[1.3610241, None]\n",
336 | "[1.3610187, None]\n",
337 | "[1.361007, None]\n",
338 | "[1.361002, None]\n",
339 | "[1.3610014, None]\n",
340 | "[1.3609916, None]\n",
341 | "[1.3609868, None]\n",
342 | "[1.3609872, None]\n",
343 | "[1.3609767, None]\n",
344 | "[1.3609693, None]\n",
345 | "[1.360967, None]\n",
346 | "[1.3609564, None]\n",
347 | "[1.3609531, None]\n",
348 | "[1.3609515, None]\n",
349 | "[1.3609464, None]\n",
350 | "[1.3609388, None]\n",
351 | "[1.360932, None]\n",
352 | "[1.3609326, None]\n",
353 | "[1.3609302, None]\n",
354 | "[1.3609217, None]\n",
355 | "[1.3609126, None]\n",
356 | "[1.3609116, None]\n",
357 | "[1.3609014, None]\n",
358 | "[1.3608999, None]\n",
359 | "[1.3609022, None]\n",
360 | "[1.360888, None]\n",
361 | "[1.3608881, None]\n",
362 | "[1.3608785, None]\n",
363 | "[1.3608744, None]\n",
364 | "[1.3608811, None]\n",
365 | "[1.3608711, None]\n",
366 | "[1.3608688, None]\n",
367 | "[1.3608608, None]\n",
368 | "[1.3608541, None]\n",
369 | "[1.3608596, None]\n",
370 | "[1.3608459, None]\n",
371 | "[1.3608503, None]\n",
372 | "[1.3608425, None]\n",
373 | "[1.3608347, None]\n",
374 | "[1.3608333, None]\n",
375 | "[1.3608232, None]\n",
376 | "[1.3608229, None]\n",
377 | "[1.360822, None]\n",
378 | "[1.3608146, None]\n",
379 | "[1.3608112, None]\n",
380 | "[1.3608062, None]\n",
381 | "[1.3608042, None]\n",
382 | "[1.3608019, None]\n",
383 | "[1.3607998, None]\n",
384 | "[1.3607969, None]\n",
385 | "[1.360793, None]\n",
386 | "[1.3607876, None]\n",
387 | "[1.3607824, None]\n",
388 | "[1.360777, None]\n",
389 | "[1.3607728, None]\n",
390 | "[1.3607719, None]\n",
391 | "[1.3607681, None]\n",
392 | "[1.3607724, None]\n",
393 | "[1.3607639, None]\n",
394 | "[1.3607619, None]\n",
395 | "[1.3607515, None]\n",
396 | "[1.3607494, None]\n",
397 | "[1.3607481, None]\n",
398 | "[1.3607519, None]\n",
399 | "[1.3607433, None]\n",
400 | "[1.3607445, None]\n",
401 | "[1.3607424, None]\n",
402 | "[1.3607278, None]\n",
403 | "[1.3607357, None]\n",
404 | "[1.3607287, None]\n",
405 | "[1.3607242, None]\n",
406 | "[1.3607204, None]\n",
407 | "[1.3607173, None]\n",
408 | "[1.360713, None]\n",
409 | "[1.3607132, None]\n",
410 | "[1.3607059, None]\n",
411 | "[1.3607048, None]\n",
412 | "[1.3607016, None]\n",
413 | "[1.3607044, None]\n",
414 | "[1.3607032, None]\n",
415 | "[1.3606958, None]\n",
416 | "[1.3606982, None]\n",
417 | "[1.3606921, None]\n",
418 | "[1.3606881, None]\n",
419 | "[1.3606858, None]\n",
420 | "[1.3606811, None]\n",
421 | "[1.3606766, None]\n",
422 | "[1.3606761, None]\n",
423 | "[1.3606787, None]\n",
424 | "[1.3606739, None]\n",
425 | "[1.360667, None]\n",
426 | "[1.3606656, None]\n",
427 | "[1.3606657, None]\n",
428 | "[1.3606608, None]\n",
429 | "[1.360661, None]\n",
430 | "[1.3606596, None]\n",
431 | "[1.36066, None]\n",
432 | "[1.3606548, None]\n",
433 | "[1.3606544, None]\n",
434 | "[1.3606534, None]\n",
435 | "[1.3606489, None]\n",
436 | "[1.3606441, None]\n",
437 | "[1.3606439, None]\n",
438 | "[1.3606385, None]\n",
439 | "[1.3606369, None]\n",
440 | "[1.3606359, None]\n",
441 | "[1.3606341, None]\n",
442 | "[1.3606316, None]\n",
443 | "[1.3606304, None]\n",
444 | "[1.3606286, None]\n",
445 | "[1.3606267, None]\n",
446 | "[1.3606237, None]\n",
447 | "[1.3606184, None]\n",
448 | "[1.3606191, None]\n",
449 | "[1.3606161, None]\n",
450 | "[1.3606155, None]\n",
451 | "[1.3606112, None]\n",
452 | "[1.3606111, None]\n",
453 | "[1.3606093, None]\n",
454 | "[1.3606052, None]\n",
455 | "[1.3606048, None]\n",
456 | "[1.3606048, None]\n",
457 | "[1.360605, None]\n",
458 | "[1.3606012, None]\n",
459 | "[1.3605998, None]\n",
460 | "[1.3605975, None]\n",
461 | "[1.3605914, None]\n",
462 | "[1.3605908, None]\n",
463 | "[1.3605893, None]\n",
464 | "[1.36059, None]\n",
465 | "[1.3605886, None]\n",
466 | "[1.3605853, None]\n",
467 | "[1.3605857, None]\n",
468 | "[1.3605833, None]\n",
469 | "[1.3605798, None]\n",
470 | "[1.3605828, None]\n",
471 | "[1.3605783, None]\n",
472 | "[1.360579, None]\n",
473 | "[1.3605747, None]\n",
474 | "[1.3605775, None]\n",
475 | "[1.3605781, None]\n",
476 | "[1.360574, None]\n",
477 | "[1.3605719, None]\n",
478 | "[1.3605723, None]\n",
479 | "[1.3605675, None]\n",
480 | "[1.3605683, None]\n",
481 | "[1.3605654, None]\n",
482 | "[1.3605647, None]\n",
483 | "[1.3605644, None]\n",
484 | "[1.3605622, None]\n",
485 | "[1.360563, None]\n",
486 | "[1.3605621, None]\n",
487 | "[1.3605603, None]\n",
488 | "[1.360559, None]\n",
489 | "[1.360559, None]\n",
490 | "[1.3605585, None]\n",
491 | "[1.3605564, None]\n",
492 | "[1.3605515, None]\n",
493 | "[1.3605492, None]\n",
494 | "[1.360549, None]\n",
495 | "[1.3605447, None]\n",
496 | "[1.3605459, None]\n",
497 | "[1.3605437, None]\n",
498 | "[1.3605459, None]\n",
499 | "[1.3605437, None]\n",
500 | "[1.3605419, None]\n",
501 | "[1.3605409, None]\n",
502 | "[1.3605397, None]\n",
503 | "[1.3605382, None]\n",
504 | "[1.3605388, None]\n",
505 | "[1.3605382, None]\n",
506 | "[1.3605369, None]\n",
507 | "[1.3605351, None]\n",
508 | "[1.3605323, None]\n",
509 | "[1.3605349, None]\n",
510 | "[1.3605349, None]\n",
511 | "[1.3605313, None]\n",
512 | "[1.3605309, None]\n",
513 | "[1.3605285, None]\n",
514 | "[1.360528, None]\n",
515 | "[1.3605281, None]\n",
516 | "[1.3605267, None]\n",
517 | "[1.3605261, None]\n",
518 | "[1.3605261, None]\n",
519 | "[1.3605255, None]\n",
520 | "[1.3605223, None]\n",
521 | "[1.3605224, None]\n",
522 | "[1.3605227, None]\n",
523 | "[1.3605223, None]\n",
524 | "[1.3605227, None]\n",
525 | "[1.3605201, None]\n",
526 | "[1.3605181, None]\n",
527 | "[1.3605177, None]\n",
528 | "[1.3605196, None]\n",
529 | "[1.3605171, None]\n",
530 | "[1.3605158, None]\n",
531 | "[1.3605144, None]\n",
532 | "[1.3605165, None]\n",
533 | "[1.3605158, None]\n",
534 | "[1.3605144, None]\n",
535 | "[1.3605145, None]\n",
536 | "[1.3605134, None]\n",
537 | "[1.3605131, None]\n",
538 | "[1.3605111, None]\n",
539 | "[1.3605098, None]\n",
540 | "[1.3605087, None]\n",
541 | "[1.3605084, None]\n",
542 | "[1.360508, None]\n",
543 | "[1.3605077, None]\n",
544 | "[1.360508, None]\n",
545 | "[1.3605061, None]\n",
546 | "[1.3605061, None]\n",
547 | "[1.3605049, None]\n",
548 | "[1.3605019, None]\n",
549 | "[1.3605043, None]\n",
550 | "[1.360503, None]\n",
551 | "[1.3605008, None]\n",
552 | "[1.3604989, None]\n",
553 | "[1.3604987, None]\n",
554 | "[1.360498, None]\n",
555 | "[1.360498, None]\n",
556 | "[1.3604956, None]\n",
557 | "[1.3604977, None]\n",
558 | "[1.3604954, None]\n",
559 | "[1.3604962, None]\n",
560 | "[1.360498, None]\n",
561 | "[1.3604977, None]\n",
562 | "[1.3604957, None]\n",
563 | "[1.3604968, None]\n",
564 | "[1.3604938, None]\n",
565 | "[1.3604937, None]\n",
566 | "[1.3604914, None]\n",
567 | "[1.360492, None]\n",
568 | "[1.3604913, None]\n",
569 | "[1.3604912, None]\n",
570 | "[1.3604898, None]\n",
571 | "[1.36049, None]\n",
572 | "[1.3604873, None]\n",
573 | "[1.3604846, None]\n",
574 | "[1.3604844, None]\n",
575 | "[1.3604836, None]\n",
576 | "[1.3604845, None]\n",
577 | "[1.3604846, None]\n",
578 | "[1.3604861, None]\n",
579 | "[1.3604875, None]\n",
580 | "[1.3604866, None]\n",
581 | "[1.3604854, None]\n",
582 | "[1.3604875, None]\n",
583 | "[1.3604851, None]\n",
584 | "[1.3604836, None]\n",
585 | "[1.3604836, None]\n",
586 | "[1.3604813, None]\n",
587 | "[1.360482, None]\n",
588 | "[1.360482, None]\n",
589 | "[1.3604822, None]\n",
590 | "[1.3604785, None]\n",
591 | "[1.3604805, None]\n",
592 | "[1.3604803, None]\n",
593 | "[1.3604807, None]\n",
594 | "[1.3604786, None]\n",
595 | "[1.3604782, None]\n",
596 | "[1.3604753, None]\n",
597 | "[1.3604755, None]\n",
598 | "[1.3604745, None]\n",
599 | "[1.3604743, None]\n",
600 | "[1.3604729, None]\n",
601 | "[1.3604747, None]\n",
602 | "[1.3604743, None]\n",
603 | "[1.3604729, None]\n",
604 | "[1.3604736, None]\n",
605 | "[1.3604716, None]\n",
606 | "[1.3604699, None]\n",
607 | "[1.360468, None]\n",
608 | "[1.3604715, None]\n",
609 | "[1.3604712, None]\n",
610 | "[1.3604712, None]\n",
611 | "[1.3604708, None]\n",
612 | "[1.3604695, None]\n",
613 | "[1.360469, None]\n",
614 | "[1.3604705, None]\n",
615 | "[1.3604693, None]\n",
616 | "[1.3604724, None]\n",
617 | "[1.3604687, None]\n",
618 | "[1.3604704, None]\n",
619 | "[1.3604703, None]\n",
620 | "[1.3604716, None]\n",
621 | "[1.3604703, None]\n",
622 | "[1.3604674, None]\n",
623 | "[1.3604668, None]\n",
624 | "[1.3604646, None]\n",
625 | "[1.3604633, None]\n",
626 | "[1.3604618, None]\n",
627 | "[1.36046, None]\n",
628 | "[1.3604622, None]\n",
629 | "[1.3604629, None]\n",
630 | "[1.3604609, None]\n",
631 | "[1.3604593, None]\n",
632 | "[1.3604608, None]\n",
633 | "[1.3604633, None]\n",
634 | "[1.3604617, None]\n",
635 | "[1.3604594, None]\n",
636 | "[1.3604612, None]\n",
637 | "[1.3604605, None]\n",
638 | "[1.3604623, None]\n",
639 | "[1.36046, None]\n",
640 | "[1.3604598, None]\n"
641 | ]
642 | },
643 | {
644 | "name": "stdout",
645 | "output_type": "stream",
646 | "text": [
647 | "[1.3604589, None]\n",
648 | "[1.3604591, None]\n",
649 | "[1.3604579, None]\n",
650 | "[1.3604562, None]\n",
651 | "[1.3604577, None]\n",
652 | "[1.360456, None]\n",
653 | "[1.360455, None]\n",
654 | "[1.3604559, None]\n",
655 | "[1.3604553, None]\n",
656 | "[1.3604565, None]\n",
657 | "[1.3604563, None]\n",
658 | "[1.360455, None]\n",
659 | "[1.360455, None]\n",
660 | "[1.3604553, None]\n",
661 | "[1.3604541, None]\n",
662 | "[1.3604548, None]\n",
663 | "[1.3604541, None]\n",
664 | "[1.3604541, None]\n",
665 | "[1.3604527, None]\n",
666 | "[1.3604522, None]\n",
667 | "[1.3604528, None]\n",
668 | "[1.3604532, None]\n",
669 | "[1.3604512, None]\n",
670 | "[1.3604516, None]\n",
671 | "[1.3604524, None]\n",
672 | "[1.3604527, None]\n",
673 | "[1.3604497, None]\n",
674 | "[1.3604493, None]\n",
675 | "[1.3604485, None]\n",
676 | "[1.3604513, None]\n",
677 | "[1.3604528, None]\n",
678 | "[1.3604487, None]\n",
679 | "[1.3604469, None]\n",
680 | "[1.3604493, None]\n",
681 | "[1.3604496, None]\n",
682 | "[1.3604503, None]\n",
683 | "[1.3604494, None]\n",
684 | "[1.3604501, None]\n",
685 | "[1.3604507, None]\n",
686 | "[1.3604501, None]\n",
687 | "[1.3604491, None]\n",
688 | "[1.3604496, None]\n",
689 | "[1.3604499, None]\n",
690 | "[1.3604475, None]\n",
691 | "[1.3604467, None]\n",
692 | "[1.3604462, None]\n",
693 | "[1.3604478, None]\n",
694 | "[1.3604475, None]\n",
695 | "[1.360445, None]\n",
696 | "[1.3604461, None]\n",
697 | "[1.3604487, None]\n",
698 | "[1.3604486, None]\n",
699 | "[1.360446, None]\n",
700 | "[1.3604457, None]\n",
701 | "[1.3604449, None]\n",
702 | "[1.3604445, None]\n",
703 | "[1.3604444, None]\n",
704 | "[1.3604429, None]\n",
705 | "[1.360445, None]\n",
706 | "[1.3604469, None]\n",
707 | "[1.3604497, None]\n",
708 | "[1.3604454, None]\n",
709 | "[1.3604436, None]\n",
710 | "[1.3604431, None]\n",
711 | "[1.3604397, None]\n",
712 | "[1.36044, None]\n",
713 | "[1.3604395, None]\n",
714 | "[1.3604413, None]\n",
715 | "[1.3604409, None]\n",
716 | "[1.3604414, None]\n",
717 | "[1.3604426, None]\n",
718 | "[1.360442, None]\n",
719 | "[1.3604419, None]\n",
720 | "[1.360441, None]\n",
721 | "[1.3604395, None]\n",
722 | "[1.3604409, None]\n",
723 | "[1.3604407, None]\n",
724 | "[1.3604411, None]\n",
725 | "[1.3604399, None]\n",
726 | "[1.3604391, None]\n",
727 | "[1.3604399, None]\n",
728 | "[1.3604398, None]\n",
729 | "[1.3604392, None]\n",
730 | "[1.3604378, None]\n",
731 | "[1.3604375, None]\n",
732 | "[1.3604369, None]\n",
733 | "[1.3604362, None]\n",
734 | "[1.3604376, None]\n",
735 | "[1.3604379, None]\n",
736 | "[1.3604394, None]\n",
737 | "[1.3604398, None]\n",
738 | "[1.3604391, None]\n",
739 | "[1.3604403, None]\n",
740 | "[1.3604391, None]\n",
741 | "[1.3604394, None]\n",
742 | "[1.3604375, None]\n",
743 | "[1.3604392, None]\n",
744 | "[1.3604352, None]\n",
745 | "[1.3604355, None]\n",
746 | "[1.3604343, None]\n",
747 | "[1.3604351, None]\n",
748 | "[1.360438, None]\n",
749 | "[1.3604381, None]\n",
750 | "[1.3604383, None]\n",
751 | "[1.3604379, None]\n",
752 | "[1.3604364, None]\n",
753 | "[1.3604376, None]\n",
754 | "[1.3604357, None]\n",
755 | "[1.3604338, None]\n",
756 | "[1.3604321, None]\n",
757 | "[1.3604355, None]\n",
758 | "[1.3604357, None]\n",
759 | "[1.3604352, None]\n",
760 | "[1.3604356, None]\n",
761 | "[1.3604376, None]\n",
762 | "[1.3604378, None]\n",
763 | "[1.36044, None]\n",
764 | "[1.360438, None]\n",
765 | "[1.3604378, None]\n",
766 | "[1.3604367, None]\n",
767 | "[1.3604358, None]\n",
768 | "[1.3604348, None]\n",
769 | "[1.3604372, None]\n",
770 | "[1.3604383, None]\n",
771 | "[1.3604379, None]\n",
772 | "[1.3604367, None]\n",
773 | "[1.360436, None]\n",
774 | "[1.360435, None]\n",
775 | "[1.3604355, None]\n",
776 | "[1.3604378, None]\n",
777 | "[1.3604331, None]\n",
778 | "[1.3604321, None]\n",
779 | "[1.3604342, None]\n",
780 | "[1.360435, None]\n",
781 | "[1.3604349, None]\n",
782 | "[1.3604363, None]\n",
783 | "[1.360435, None]\n",
784 | "[1.3604364, None]\n",
785 | "[1.3604354, None]\n",
786 | "[1.3604329, None]\n",
787 | "[1.3604338, None]\n",
788 | "[1.3604351, None]\n",
789 | "[1.3604313, None]\n",
790 | "[1.3604326, None]\n",
791 | "[1.3604326, None]\n",
792 | "[1.360433, None]\n",
793 | "[1.3604329, None]\n",
794 | "[1.3604329, None]\n",
795 | "[1.3604331, None]\n",
796 | "[1.3604341, None]\n",
797 | "[1.3604329, None]\n",
798 | "[1.3604306, None]\n",
799 | "[1.3604308, None]\n",
800 | "[1.360431, None]\n",
801 | "[1.3604319, None]\n",
802 | "[1.3604324, None]\n",
803 | "[1.3604325, None]\n",
804 | "[1.3604331, None]\n",
805 | "[1.3604326, None]\n",
806 | "[1.3604331, None]\n",
807 | "[1.3604305, None]\n",
808 | "[1.3604348, None]\n",
809 | "[1.3604363, None]\n",
810 | "[1.3604343, None]\n",
811 | "[1.3604345, None]\n",
812 | "[1.3604338, None]\n",
813 | "[1.3604319, None]\n",
814 | "[1.3604341, None]\n",
815 | "[1.3604336, None]\n",
816 | "[1.360433, None]\n",
817 | "[1.3604335, None]\n",
818 | "[1.3604336, None]\n",
819 | "[1.3604324, None]\n",
820 | "[1.3604321, None]\n",
821 | "[1.360431, None]\n",
822 | "[1.3604298, None]\n",
823 | "[1.3604317, None]\n",
824 | "[1.3604319, None]\n",
825 | "[1.3604339, None]\n",
826 | "[1.3604307, None]\n",
827 | "[1.36043, None]\n",
828 | "[1.3604279, None]\n",
829 | "[1.3604274, None]\n",
830 | "[1.360428, None]\n",
831 | "[1.3604295, None]\n",
832 | "[1.3604305, None]\n",
833 | "[1.360433, None]\n",
834 | "[1.3604312, None]\n",
835 | "[1.3604314, None]\n",
836 | "[1.3604312, None]\n",
837 | "[1.3604314, None]\n",
838 | "[1.3604307, None]\n",
839 | "[1.3604276, None]\n",
840 | "[1.36043, None]\n",
841 | "[1.3604295, None]\n",
842 | "[1.3604283, None]\n",
843 | "[1.3604296, None]\n",
844 | "[1.3604307, None]\n",
845 | "[1.3604293, None]\n",
846 | "[1.3604269, None]\n",
847 | "[1.3604298, None]\n",
848 | "[1.3604293, None]\n",
849 | "[1.360431, None]\n",
850 | "[1.3604283, None]\n",
851 | "[1.3604295, None]\n",
852 | "[1.3604281, None]\n",
853 | "[1.360429, None]\n",
854 | "[1.3604276, None]\n",
855 | "[1.3604293, None]\n",
856 | "[1.3604281, None]\n",
857 | "[1.36043, None]\n",
858 | "[1.3604287, None]\n",
859 | "[1.3604318, None]\n",
860 | "[1.3604293, None]\n",
861 | "[1.360429, None]\n",
862 | "[1.360431, None]\n",
863 | "[1.360431, None]\n",
864 | "[1.3604298, None]\n",
865 | "[1.3604295, None]\n",
866 | "[1.3604302, None]\n",
867 | "[1.3604301, None]\n",
868 | "[1.3604285, None]\n",
869 | "[1.3604292, None]\n",
870 | "[1.3604294, None]\n",
871 | "[1.3604269, None]\n",
872 | "[1.3604283, None]\n",
873 | "[1.3604267, None]\n",
874 | "[1.3604273, None]\n",
875 | "[1.3604298, None]\n",
876 | "[1.3604314, None]\n",
877 | "[1.3604314, None]\n",
878 | "[1.3604289, None]\n",
879 | "[1.3604271, None]\n",
880 | "[1.3604271, None]\n",
881 | "[1.3604275, None]\n",
882 | "[1.3604274, None]\n",
883 | "[1.3604264, None]\n",
884 | "[1.3604262, None]\n",
885 | "[1.3604267, None]\n",
886 | "[1.3604271, None]\n",
887 | "[1.3604273, None]\n",
888 | "[1.3604271, None]\n",
889 | "[1.3604261, None]\n",
890 | "[1.3604271, None]\n",
891 | "[1.3604295, None]\n",
892 | "[1.3604293, None]\n",
893 | "[1.3604302, None]\n",
894 | "[1.3604292, None]\n",
895 | "[1.360428, None]\n",
896 | "[1.3604281, None]\n",
897 | "[1.3604275, None]\n",
898 | "[1.36043, None]\n",
899 | "[1.3604298, None]\n",
900 | "[1.3604288, None]\n",
901 | "[1.3604269, None]\n",
902 | "[1.3604301, None]\n",
903 | "[1.3604273, None]\n",
904 | "[1.3604276, None]\n",
905 | "[1.3604263, None]\n",
906 | "[1.3604269, None]\n",
907 | "[1.3604264, None]\n",
908 | "[1.3604271, None]\n",
909 | "[1.3604283, None]\n",
910 | "[1.3604276, None]\n",
911 | "[1.3604282, None]\n",
912 | "[1.3604295, None]\n",
913 | "[1.360431, None]\n",
914 | "[1.3604285, None]\n",
915 | "[1.3604279, None]\n",
916 | "[1.360428, None]\n",
917 | "[1.3604285, None]\n",
918 | "[1.3604294, None]\n",
919 | "[1.3604258, None]\n",
920 | "[1.3604246, None]\n",
921 | "[1.3604274, None]\n",
922 | "[1.3604257, None]\n",
923 | "[1.3604282, None]\n",
924 | "[1.3604288, None]\n",
925 | "[1.360427, None]\n",
926 | "[1.36043, None]\n",
927 | "[1.360429, None]\n",
928 | "[1.3604261, None]\n",
929 | "[1.3604274, None]\n",
930 | "[1.360427, None]\n",
931 | "[1.3604282, None]\n",
932 | "[1.3604257, None]\n",
933 | "[1.360425, None]\n",
934 | "[1.3604255, None]\n",
935 | "[1.3604276, None]\n",
936 | "[1.3604242, None]\n",
937 | "[1.3604239, None]\n",
938 | "[1.360423, None]\n",
939 | "[1.3604244, None]\n",
940 | "[1.3604262, None]\n",
941 | "[1.3604252, None]\n",
942 | "[1.360425, None]\n",
943 | "[1.3604248, None]\n",
944 | "[1.3604254, None]\n",
945 | "[1.3604243, None]\n",
946 | "[1.3604245, None]\n",
947 | "[1.3604279, None]\n",
948 | "[1.3604275, None]\n",
949 | "[1.3604265, None]\n",
950 | "[1.3604273, None]\n",
951 | "[1.3604257, None]\n",
952 | "[1.3604243, None]\n",
953 | "[1.3604246, None]\n",
954 | "[1.3604252, None]\n",
955 | "[1.3604249, None]\n",
956 | "[1.3604231, None]\n",
957 | "[1.3604263, None]\n",
958 | "[1.3604259, None]\n",
959 | "[1.3604276, None]\n",
960 | "[1.3604275, None]\n",
961 | "[1.3604251, None]\n",
962 | "[1.3604236, None]\n",
963 | "[1.3604252, None]\n",
964 | "[1.3604267, None]\n",
965 | "[1.3604243, None]\n",
966 | "[1.3604249, None]\n",
967 | "[1.3604248, None]\n",
968 | "[1.3604262, None]\n",
969 | "[1.3604263, None]\n",
970 | "[1.3604248, None]\n",
971 | "[1.3604274, None]\n",
972 | "[1.3604275, None]\n",
973 | "[1.360425, None]\n",
974 | "[1.3604255, None]\n",
975 | "[1.3604246, None]\n",
976 | "[1.3604271, None]\n",
977 | "[1.3604283, None]\n",
978 | "[1.3604262, None]\n",
979 | "[1.3604264, None]\n",
980 | "[1.3604268, None]\n",
981 | "[1.3604243, None]\n",
982 | "[1.3604248, None]\n",
983 | "[1.3604245, None]\n",
984 | "[1.3604263, None]\n",
985 | "[1.3604263, None]\n",
986 | "[1.3604258, None]\n",
987 | "[1.3604271, None]\n",
988 | "[1.3604263, None]\n",
989 | "[1.3604257, None]\n",
990 | "[1.3604239, None]\n",
991 | "[1.3604259, None]\n",
992 | "[1.3604271, None]\n",
993 | "[1.3604268, None]\n",
994 | "[1.3604255, None]\n",
995 | "[1.3604261, None]\n",
996 | "[1.3604271, None]\n",
997 | "[1.3604279, None]\n",
998 | "[1.3604274, None]\n",
999 | "[1.3604281, None]\n",
1000 | "[1.3604271, None]\n",
1001 | "[1.3604259, None]\n",
1002 | "[1.3604265, None]\n",
1003 | "[1.3604267, None]\n",
1004 | "[1.3604261, None]\n",
1005 | "[1.3604276, None]\n",
1006 | "[1.3604269, None]\n",
1007 | "[1.3604259, None]\n",
1008 | "[1.3604256, None]\n",
1009 | "[1.3604279, None]\n",
1010 | "[1.3604287, None]\n",
1011 | "[1.3604264, None]\n",
1012 | "[1.3604261, None]\n",
1013 | "[1.3604269, None]\n",
1014 | "[1.3604252, None]\n",
1015 | "[1.3604276, None]\n",
1016 | "[1.3604264, None]\n",
1017 | "[1.3604277, None]\n",
1018 | "[1.3604286, None]\n",
1019 | "[1.3604273, None]\n",
1020 | "[1.360427, None]\n",
1021 | "[1.3604261, None]\n",
1022 | "[1.3604271, None]\n",
1023 | "[1.3604256, None]\n",
1024 | "[1.3604269, None]\n",
1025 | "[1.3604234, None]\n",
1026 | "[1.3604236, None]\n",
1027 | "[1.3604265, None]\n",
1028 | "[1.3604293, None]\n",
1029 | "[1.3604304, None]\n",
1030 | "[1.3604267, None]\n",
1031 | "[1.3604267, None]\n",
1032 | "[1.3604252, None]\n",
1033 | "[1.3604257, None]\n",
1034 | "[1.3604243, None]\n",
1035 | "[1.3604269, None]\n",
1036 | "[1.3604283, None]\n",
1037 | "[1.3604264, None]\n",
1038 | "[1.3604258, None]\n",
1039 | "[1.3604232, None]\n",
1040 | "[1.3604242, None]\n",
1041 | "[1.3604244, None]\n",
1042 | "[1.3604279, None]\n",
1043 | "[1.3604269, None]\n",
1044 | "[1.3604261, None]\n",
1045 | "[1.360427, None]\n",
1046 | "[1.3604258, None]\n",
1047 | "[1.3604244, None]\n",
1048 | "[1.3604242, None]\n",
1049 | "[1.3604244, None]\n",
1050 | "[1.3604263, None]\n",
1051 | "[1.3604267, None]\n",
1052 | "[1.3604273, None]\n",
1053 | "[1.3604286, None]\n",
1054 | "[1.3604256, None]\n",
1055 | "[1.3604264, None]\n",
1056 | "[1.3604249, None]\n",
1057 | "[1.360425, None]\n",
1058 | "[1.3604254, None]\n",
1059 | "[1.3604261, None]\n",
1060 | "[1.3604252, None]\n",
1061 | "[1.3604255, None]\n",
1062 | "[1.3604243, None]\n",
1063 | "[1.360425, None]\n",
1064 | "[1.3604245, None]\n",
1065 | "[1.3604258, None]\n",
1066 | "[1.3604255, None]\n",
1067 | "[1.3604252, None]\n",
1068 | "[1.3604255, None]\n",
1069 | "[1.360425, None]\n",
1070 | "[1.360428, None]\n",
1071 | "[1.3604273, None]\n",
1072 | "[1.3604264, None]\n",
1073 | "[1.3604252, None]\n",
1074 | "[1.3604249, None]\n",
1075 | "[1.3604248, None]\n",
1076 | "[1.3604248, None]\n",
1077 | "[1.3604236, None]\n",
1078 | "[1.3604268, None]\n",
1079 | "[1.360425, None]\n",
1080 | "[1.3604256, None]\n",
1081 | "[1.360424, None]\n",
1082 | "[1.3604226, None]\n",
1083 | "[1.3604225, None]\n",
1084 | "[1.3604236, None]\n",
1085 | "[1.3604262, None]\n",
1086 | "[1.3604259, None]\n",
1087 | "[1.3604248, None]\n",
1088 | "[1.3604267, None]\n",
1089 | "[1.3604249, None]\n",
1090 | "[1.3604257, None]\n",
1091 | "[1.3604252, None]\n",
1092 | "[1.3604257, None]\n",
1093 | "[1.3604264, None]\n",
1094 | "[1.3604224, None]\n",
1095 | "[1.3604233, None]\n",
1096 | "[1.360424, None]\n",
1097 | "[1.3604245, None]\n",
1098 | "[1.360425, None]\n",
1099 | "[1.3604236, None]\n",
1100 | "[1.3604234, None]\n",
1101 | "[1.3604234, None]\n",
1102 | "[1.3604246, None]\n",
1103 | "[1.360424, None]\n",
1104 | "[1.3604255, None]\n",
1105 | "[1.3604273, None]\n"
1106 | ]
1107 | },
1108 | {
1109 | "name": "stdout",
1110 | "output_type": "stream",
1111 | "text": [
1112 | "[1.3604263, None]\n",
1113 | "[1.3604246, None]\n",
1114 | "[1.3604249, None]\n",
1115 | "[1.3604234, None]\n",
1116 | "[1.360425, None]\n",
1117 | "[1.3604245, None]\n",
1118 | "[1.3604246, None]\n",
1119 | "[1.3604239, None]\n",
1120 | "[1.3604271, None]\n",
1121 | "[1.3604264, None]\n",
1122 | "[1.3604269, None]\n",
1123 | "[1.3604271, None]\n",
1124 | "[1.3604252, None]\n",
1125 | "[1.3604259, None]\n",
1126 | "[1.3604225, None]\n",
1127 | "[1.3604227, None]\n",
1128 | "[1.3604248, None]\n",
1129 | "[1.3604262, None]\n",
1130 | "[1.3604263, None]\n",
1131 | "[1.3604252, None]\n",
1132 | "[1.3604264, None]\n",
1133 | "[1.3604258, None]\n",
1134 | "[1.3604263, None]\n",
1135 | "[1.3604259, None]\n",
1136 | "[1.360425, None]\n",
1137 | "[1.3604273, None]\n",
1138 | "[1.3604286, None]\n",
1139 | "[1.360429, None]\n",
1140 | "[1.3604264, None]\n",
1141 | "[1.3604259, None]\n",
1142 | "[1.3604256, None]\n",
1143 | "[1.3604267, None]\n",
1144 | "[1.3604224, None]\n",
1145 | "[1.3604224, None]\n",
1146 | "[1.3604245, None]\n",
1147 | "[1.3604246, None]\n",
1148 | "[1.360423, None]\n",
1149 | "[1.3604246, None]\n",
1150 | "[1.3604249, None]\n",
1151 | "[1.3604242, None]\n",
1152 | "[1.3604236, None]\n",
1153 | "[1.3604255, None]\n",
1154 | "[1.360424, None]\n",
1155 | "[1.3604248, None]\n",
1156 | "[1.3604264, None]\n",
1157 | "[1.3604243, None]\n",
1158 | "[1.3604252, None]\n",
1159 | "[1.3604257, None]\n",
1160 | "[1.3604244, None]\n",
1161 | "[1.3604257, None]\n",
1162 | "[1.3604285, None]\n",
1163 | "[1.3604257, None]\n",
1164 | "[1.360424, None]\n",
1165 | "[1.360424, None]\n",
1166 | "[1.3604226, None]\n",
1167 | "[1.3604243, None]\n",
1168 | "[1.3604251, None]\n",
1169 | "[1.3604252, None]\n",
1170 | "[1.3604233, None]\n",
1171 | "[1.3604245, None]\n",
1172 | "[1.3604236, None]\n",
1173 | "[1.3604261, None]\n",
1174 | "[1.3604252, None]\n",
1175 | "[1.3604292, None]\n",
1176 | "[1.3604267, None]\n",
1177 | "[1.3604248, None]\n",
1178 | "[1.360425, None]\n",
1179 | "[1.3604238, None]\n",
1180 | "[1.3604261, None]\n",
1181 | "[1.3604243, None]\n",
1182 | "[1.3604236, None]\n",
1183 | "[1.3604259, None]\n",
1184 | "[1.3604236, None]\n",
1185 | "[1.3604255, None]\n",
1186 | "[1.360425, None]\n",
1187 | "[1.3604259, None]\n",
1188 | "[1.360425, None]\n",
1189 | "[1.360425, None]\n",
1190 | "[1.3604234, None]\n",
1191 | "[1.3604248, None]\n",
1192 | "[1.3604254, None]\n",
1193 | "[1.3604251, None]\n"
1194 | ]
1195 | }
1196 | ],
1197 | "source": [
1198 | "with tf.Session() as sess:\n",
1199 | " sess.run(tf.global_variables_initializer())\n",
1200 | " \n",
1201 | " for eopch in range(1000): \n",
1202 | " print(sess.run([loss,train_op],feed_dict={inputs_place:inputsinputs,labels_place:inputslables}))"
1203 | ]
1204 | },
1205 | {
1206 | "cell_type": "code",
1207 | "execution_count": null,
1208 | "metadata": {
1209 | "collapsed": true
1210 | },
1211 | "outputs": [],
1212 | "source": [
1213 | "def sphereloss_mine(inputs,label,classes,weights_decay = 0.05,scope='Logits',reuse=None,m =4,eplion = 1e-12):\n",
1214 | " \n",
1215 | " \"\"\"\n",
1216 | " inputs tensor shape=[batch,features_num]\n",
1217 | " labels tensor shape=[batch] each unit belong num_outputs\n",
1218 | " \n",
1219 | " \"\"\"\n",
1220 | " features_num = inputs.get_shape().as_list()[1]\n",
1221 | "\n",
1222 | " with tf.variable_scope(name_or_scope=scope):\n",
1223 | " weight = tf.Variable(initial_value=tf.random_normal((classes,features_num),stddev=0.01),dtype=tf.float32,name='weight') # shaep =classes, features,\n",
1224 | " print(\"weight shape = \",weight.get_shape().as_list())\n",
1225 | " \n",
1226 | " weight_unit = tf.nn.l2_normalize(weight,dim=1)\n",
1227 | " print(\"weight_unit shape = \",weight_unit.get_shape().as_list())\n",
1228 | " \n",
1229 | " inputs_mo = tf.sqrt(tf.reduce_sum(tf.square(inputs),axis=1)+eplion) #shape=[batch\n",
1230 | " print(\"inputs_mo shape = \",inputs_mo.get_shape().as_list())\n",
1231 | " \n",
1232 | " inputs_unit = tf.nn.l2_normalize(inputs,dim=1) #shape = [batch,features_num]\n",
1233 | " print(\"inputs_unit shape = \",inputs_unit.get_shape().as_list())\n",
1234 | " \n",
1235 | " logits = tf.matmul(inputs,tf.transpose(weight_unit)) #shape = [batch,classes] x * w_unit\n",
1236 | " print(\"logits shape = \",logits.get_shape().as_list())\n",
1237 | " \n",
1238 | " weight_unit_batch = tf.gather(weight_unit,label) # shaep =batch,features_num,\n",
1239 | " print(\"weight_unit_batch shape = \",weight_unit_batch.get_shape().as_list())\n",
1240 | " \n",
1241 | " logits_inputs = tf.reduce_sum(tf.multiply(inputs,weight_unit_batch),axis=1) # shaep =batch,\n",
1242 | " \n",
1243 | " print(\"logits_inputs shape = \",logits_inputs.get_shape().as_list())\n",
1244 | " \n",
1245 | " cos_theta = tf.reduce_sum(tf.multiply(inputs_unit,weight_unit_batch),axis=1) # shaep =batch,\n",
1246 | " print(\"print shape = \",cos_theta.get_shape().as_list())\n",
1247 | " \n",
1248 | " k = tf.Variable(initial_value=tf.zeros_like(cos_theta),trainable=False)\n",
1249 | " k = tf.assign(k,tf.floor_div(tf.acos(cos_theta),pi/m)+0.0) # shaep =batch,\n",
1250 | " print(\"k shape = \",k.get_shape().as_list())\n",
1251 | " \n",
1252 | " cos_four_theta =tf.multiply(tf.pow(cos_theta,4),8) - tf.multiply(tf.pow(cos_theta,2),8)-1 #shape = batch\n",
1253 | " print(\"cos_four_theta shape = \",cos_four_theta.get_shape().as_list())\n",
1254 | " \n",
1255 | " cos_far_theta =tf.add(tf.multiply(tf.pow(-1.0,k) , cos_four_theta), tf.multiply(-2.0 , k)) #shape = batch\n",
1256 | " print(\"cos_far_theta = \",cos_far_theta.get_shape().as_list())\n",
1257 | " \n",
1258 | " \n",
1259 | " \n",
1260 | " logit_ii = tf.multiply(cos_far_theta,inputs_mo)#shape = batch \n",
1261 | " print(\"logit_ii shape = \",logit_ii.get_shape().as_list())\n",
1262 | " \n",
1263 | " logit_ii_exp = tf.exp(logit_ii)\n",
1264 | " logits_exp = tf.reduce_sum(tf.exp(logits),axis=1) \n",
1265 | " logits_inputs_exp = tf.exp(logits_inputs)\n",
1266 | " loss_exp = tf.divide (logit_ii_exp,tf.subtract(tf.add(logit_ii_exp,logits_exp ),logits_inputs_exp)) \n",
1267 | " #shape = batch\n",
1268 | " print(\"loss_exp shape = \",loss_exp.get_shape().as_list())\n",
1269 | " loss = -tf.reduce_mean(tf.log(loss_exp))\n",
1270 | " print(\"loss shape = \",loss.get_shape().as_list())\n",
1271 | " \n",
1272 | "# return weight,weight_unit,weight_unit_batch,inputs,inputs_mo,inputs_unit, logits,logits_inputs,cos_theta,k,cos_four_theta,cos_far_theta,logit_ii,loss_exp,loss\n",
1273 | "# return weight,weight_unit,weight_unit_batch,inputs_mo,inputs_unit, logits,logits_inputs,cos_theta,k,cos_four_theta,cos_far_theta,logit_ii,loss_exp, loss\n",
1274 | " return k,loss"
1275 | ]
1276 | }
1277 | ],
1278 | "metadata": {
1279 | "kernelspec": {
1280 | "display_name": "Python 3",
1281 | "language": "python",
1282 | "name": "python3"
1283 | },
1284 | "language_info": {
1285 | "codemirror_mode": {
1286 | "name": "ipython",
1287 | "version": 3
1288 | },
1289 | "file_extension": ".py",
1290 | "mimetype": "text/x-python",
1291 | "name": "python",
1292 | "nbconvert_exporter": "python",
1293 | "pygments_lexer": "ipython3",
1294 | "version": "3.6.3"
1295 | }
1296 | },
1297 | "nbformat": 4,
1298 | "nbformat_minor": 2
1299 | }
1300 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andrewhuman/sphereloss_tensorflow/a78002638573c48552adbfea6ecdb4c151a91884/test/__init__.py
--------------------------------------------------------------------------------