├── model ├── __init__.py └── embedding_model.py ├── utils ├── __init__.py └── os_utils.py ├── .gitignore ├── imgs └── contrastive_sampling.png ├── heads ├── __init__.py ├── CLS_head.py ├── direct.py ├── direct_normalize.py ├── fc1024_normalize.py └── fc1024.py ├── ranking ├── lifted_structured.py ├── __init__.py ├── angular.py ├── contrastive.py ├── npair.py ├── hard_triplet.py └── semi_hard_triplet.py ├── nets ├── __init__.py ├── README.md ├── resnet_v1_101.py ├── resnet_v1_50.py ├── mobilenet_v1_1_224.py ├── inception_utils.py ├── densenet169.py ├── resnet_utils.py ├── resnet_v1.py ├── inception_v1.py └── mobilenet_v1.py ├── aggregators.py ├── constants.py ├── tf2_test.py ├── eval.py ├── lbtoolbox.py ├── README.md ├── LICENSE ├── embed_tf2.py ├── embed.py ├── common.py └── train_tf2.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | \.idea/ 3 | 4 | __pycache__/ 5 | -------------------------------------------------------------------------------- /imgs/contrastive_sampling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmdtaha/tf_retrieval_baseline/HEAD/imgs/contrastive_sampling.png -------------------------------------------------------------------------------- /heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Used for the commandline flags. 2 | HEAD_CHOICES = ( 3 | 'direct', 4 | 'direct_normalize', 5 | 'fc1024', 6 | 'fc1024_normalize', 7 | ) 8 | -------------------------------------------------------------------------------- /ranking/lifted_structured.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def lifted_loss(labels,embeddings,margin): 4 | return tf.contrib.losses.metric_learning.lifted_struct_loss(labels,embeddings,margin=margin) 5 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # Used for the commandline flags. 5 | 6 | NET_CHOICES = ( 7 | 'mobilenet_v1_1_224', 8 | 'resnet_v1_50', 9 | 'resnet_v1_101', 10 | 'densenet169', 11 | 'inception_v1', 12 | ) 13 | -------------------------------------------------------------------------------- /heads/CLS_head.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import slim 3 | 4 | def head(intput, num_classes): 5 | 6 | output = slim.fully_connected( 7 | intput, num_classes, activation_fn=None, 8 | weights_initializer=tf.orthogonal_initializer()) 9 | 10 | return output 11 | -------------------------------------------------------------------------------- /aggregators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def mean(embs): 5 | return np.mean(embs, axis=0) 6 | 7 | 8 | def normalized_mean(embs): 9 | embs = mean(embs) 10 | return embs / np.linalg.norm(embs, axis=1, keepdims=True) 11 | 12 | 13 | AGGREGATORS = { 14 | 'mean': mean, 15 | 'normalized_mean': normalized_mean, 16 | } 17 | -------------------------------------------------------------------------------- /ranking/__init__.py: -------------------------------------------------------------------------------- 1 | # Used for the commandline flags. 2 | LOSS_CHOICES = ( 3 | 'hard_triplet', 4 | 'semi_hard_triplet', 5 | 'lifted_loss', 6 | 'npairs_loss', 7 | 'angular_loss', 8 | 'contrastive_loss', 9 | ) 10 | 11 | METRIC_CHOICES = [ 12 | 'euclidean', 13 | 'sqeuclidean', 14 | 'cityblock', 15 | 'cosine', 16 | ] 17 | -------------------------------------------------------------------------------- /nets/README.md: -------------------------------------------------------------------------------- 1 | The following files are copy-pasted from TF-slim's model repository but had to be slightly adapted to fit in here. 2 | The original code is governed by the Apache 2.0 license. 3 | Any modifications by us are minor, but marked as such in the comments. 4 | 5 | These are the files concerned: 6 | 7 | - `resnet_utils.py` 8 | - `resnet_v1.py` 9 | - `mobilenet_v1.py` 10 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | # Where is the datasets' folders are (stanford online, CUB)? 2 | dataset_dir = '/mnt/data/datasets/' 3 | 4 | # where is the pre-trained model saved (inception, resnet, densenet) ? 5 | trained_models_dir = '/mnt/data/pretrained/' 6 | 7 | # Where to save a new experiment details (log, tensorboard, trained retrieval models) 8 | experiment_root_dir = '/mnt/data/checkpoints/retrieval_models/' -------------------------------------------------------------------------------- /heads/direct.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import slim 3 | 4 | def head(endpoints, embedding_dim, is_training,weights_regularizer=None): 5 | endpoints['emb'] = endpoints['emb_raw'] = slim.fully_connected( 6 | endpoints['model_output'], embedding_dim, activation_fn=None, 7 | weights_initializer=tf.orthogonal_initializer(), scope='emb') 8 | 9 | return endpoints 10 | -------------------------------------------------------------------------------- /heads/direct_normalize.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import slim 3 | 4 | def head(endpoints, embedding_dim, is_training,weights_regularizer=None): 5 | endpoints['emb_raw'] = slim.fully_connected( 6 | endpoints['model_output'], embedding_dim, activation_fn=None, 7 | weights_initializer=tf.orthogonal_initializer(), scope='emb') 8 | endpoints['emb'] = tf.nn.l2_normalize(endpoints['emb_raw'], -1) 9 | 10 | return endpoints 11 | -------------------------------------------------------------------------------- /nets/resnet_v1_101.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from nets.resnet_v1 import resnet_v1_101, resnet_arg_scope 4 | 5 | _RGB_MEAN = [123.68, 116.78, 103.94] 6 | 7 | def endpoints(image, is_training): 8 | if image.get_shape().ndims != 4: 9 | raise ValueError('Input must be of size [batch, height, width, 3]') 10 | 11 | image = image - tf.constant(_RGB_MEAN, dtype=tf.float32, shape=(1,1,1,3)) 12 | 13 | with tf.contrib.slim.arg_scope(resnet_arg_scope(batch_norm_decay=0.9, weight_decay=0.0)): 14 | _, endpoints = resnet_v1_101(image, num_classes=None, is_training=is_training, global_pool=True) 15 | 16 | endpoints['model_output'] = endpoints['global_pool'] = tf.reduce_mean( 17 | endpoints['resnet_v1_101/block4'], [1, 2], name='pool5') 18 | 19 | return endpoints, 'resnet_v1_101' 20 | -------------------------------------------------------------------------------- /nets/resnet_v1_50.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from nets.resnet_v1 import resnet_v1_50, resnet_arg_scope 4 | 5 | _RGB_MEAN = [123.68, 116.78, 103.94] 6 | 7 | def endpoints(image, is_training,weight_decay=0.0): 8 | if image.get_shape().ndims != 4: 9 | raise ValueError('Input must be of size [batch, height, width, 3]') 10 | 11 | image = image - tf.constant(_RGB_MEAN, dtype=tf.float32, shape=(1,1,1,3)) 12 | 13 | with tf.contrib.slim.arg_scope(resnet_arg_scope(batch_norm_decay=0.9, weight_decay=weight_decay)): 14 | _, endpoints = resnet_v1_50(image, num_classes=None, is_training=is_training, global_pool=True) 15 | 16 | endpoints['model_output'] = endpoints['global_pool'] = tf.reduce_mean( 17 | endpoints['resnet_v1_50/block4'], [1, 2], name='pool5') 18 | 19 | return endpoints, 'resnet_v1_50' 20 | -------------------------------------------------------------------------------- /heads/fc1024_normalize.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import slim 3 | 4 | def head(endpoints, embedding_dim, is_training, weights_regularizer=None): 5 | predict_var = 0 6 | input = endpoints['model_output'] 7 | endpoints['head_output'] = slim.fully_connected( 8 | input, 1024, normalizer_fn=slim.batch_norm, 9 | normalizer_params={ 10 | 'decay': 0.9, 11 | 'epsilon': 1e-5, 12 | 'scale': True, 13 | 'is_training': is_training, 14 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 15 | }, 16 | weights_regularizer=weights_regularizer 17 | ) 18 | 19 | input_1 = endpoints['head_output'] 20 | 21 | endpoints['emb_raw'] = slim.fully_connected( 22 | input_1, embedding_dim + predict_var, activation_fn=None,weights_regularizer=weights_regularizer, 23 | weights_initializer=tf.orthogonal_initializer(), scope='emb') 24 | 25 | 26 | endpoints['emb'] = tf.nn.l2_normalize(endpoints['emb_raw'], -1) 27 | # endpoints['data_sigma'] = None 28 | print('Normalize batch embedding') 29 | return endpoints 30 | -------------------------------------------------------------------------------- /heads/fc1024.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import slim 3 | 4 | def head(endpoints, embedding_dim, is_training, reuse=None, weights_regularizer=None): 5 | predict_var = 0 6 | 7 | input = endpoints['model_output'] 8 | endpoints['head_output'] = slim.fully_connected( 9 | input, 1024, normalizer_fn=slim.batch_norm, 10 | normalizer_params={ 11 | 'decay': 0.9, 12 | 'epsilon': 1e-5, 13 | 'scale': True, 14 | 'is_training': is_training, 15 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 16 | },scope='emb_h1', 17 | weights_regularizer=weights_regularizer, 18 | reuse=reuse) 19 | input_1 = endpoints['head_output'] 20 | 21 | endpoints['emb_raw'] = slim.fully_connected( 22 | input_1, embedding_dim+predict_var, activation_fn=None,weights_regularizer=weights_regularizer, 23 | weights_initializer=tf.orthogonal_initializer(), scope='emb',reuse=reuse) 24 | 25 | 26 | endpoints['emb'] = endpoints['emb_raw'] 27 | # endpoints['data_sigma'] = None 28 | print('batch embedding with none data_sigma') 29 | return endpoints 30 | -------------------------------------------------------------------------------- /tf2_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow.keras import datasets, layers, models 4 | import matplotlib.pyplot as plt 5 | 6 | (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() 7 | 8 | # Normalize pixel values to be between 0 and 1 9 | train_images, test_images = train_images / 255.0, test_images / 255.0 10 | 11 | 12 | class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 13 | 'dog', 'frog', 'horse', 'ship', 'truck'] 14 | model = models.Sequential() 15 | model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3))) 16 | model.add(layers.MaxPooling2D((2, 2))) 17 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 18 | model.add(layers.MaxPooling2D((2, 2))) 19 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 20 | 21 | model.add(layers.Flatten()) 22 | model.add(layers.Dense(64, activation='relu')) 23 | model.add(layers.Dense(10, activation='softmax')) 24 | 25 | 26 | model.compile(optimizer='adam', 27 | loss='sparse_categorical_crossentropy', 28 | metrics=['accuracy']) 29 | 30 | history = model.fit(train_images, train_labels, epochs=10, 31 | validation_data=(test_images, test_labels)) 32 | -------------------------------------------------------------------------------- /ranking/angular.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def angular_loss(input_labels, anchor_features, pos_features, degree=45, batch_size=10, with_l2reg=False): 5 | ''' 6 | #NOTE: degree is degree!!! not radian value 7 | ''' 8 | if with_l2reg: 9 | reg_anchor = tf.reduce_mean(tf.reduce_sum(tf.square(anchor_features), 1)) 10 | reg_positive = tf.reduce_mean(tf.reduce_sum(tf.square(pos_features), 1)) 11 | l2loss = tf.multiply(0.25 * 0.002, reg_anchor + reg_positive, name='l2loss_angular') 12 | else: 13 | l2loss = 0.0 14 | 15 | alpha = np.deg2rad(degree) 16 | sq_tan_alpha = np.tan(alpha) ** 2 17 | 18 | # anchor_features = tf.nn.l2_normalize(anchor_features) 19 | # pos_features = tf.nn.l2_normalize(pos_features) 20 | 21 | # 2(1+(tan(alpha))^2 * xaTxp) 22 | # batch_size = 10 23 | xaTxp = tf.matmul(anchor_features, pos_features, transpose_a=False, transpose_b=True) 24 | sim_matrix_1 = tf.multiply(2.0 * (1.0 + sq_tan_alpha) * xaTxp, tf.eye(batch_size, dtype=tf.float32)) 25 | 26 | # 4((tan(alpha))^2(xa + xp)Txn 27 | xaPxpTxn = tf.matmul((anchor_features + pos_features), pos_features, transpose_a=False, transpose_b=True) 28 | sim_matrix_2 = tf.multiply(4.0 * sq_tan_alpha * xaPxpTxn, 29 | tf.ones_like(xaPxpTxn, dtype=tf.float32) - tf.eye(batch_size, dtype=tf.float32)) 30 | 31 | # similarity_matrix 32 | similarity_matrix = sim_matrix_1 + sim_matrix_2 33 | 34 | # do softmax cross-entropy 35 | lshape = tf.shape(input_labels) 36 | # assert lshape.shape == 1 37 | labels = tf.reshape(input_labels, [lshape[0], 1]) 38 | 39 | labels_remapped = tf.cast(tf.equal(labels, tf.transpose(labels)),tf.float32) 40 | labels_remapped /= tf.reduce_sum(labels_remapped, 1, keepdims=True) 41 | 42 | xent_loss = tf.nn.softmax_cross_entropy_with_logits(logits=similarity_matrix, labels=labels_remapped) 43 | xent_loss = tf.reduce_mean(xent_loss, name='xentropy_angular') 44 | 45 | 46 | return l2loss + xent_loss -------------------------------------------------------------------------------- /ranking/contrastive.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops import math_ops 3 | 4 | def contrastive_loss(labels, embeddings_anchor, embeddings_positive, 5 | margin=1.0): 6 | """Computes the contrastive loss. 7 | This loss encourages the embedding to be close to each other for 8 | the samples of the same label and the embedding to be far apart at least 9 | by the margin constant for the samples of different labels. 10 | See: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf 11 | Args: 12 | labels: 1-D tf.int32 `Tensor` with shape [batch_size] of 13 | binary labels indicating positive vs negative pair. 14 | embeddings_anchor: 2-D float `Tensor` of embedding vectors for the anchor 15 | images. Embeddings should be l2 normalized. 16 | embeddings_positive: 2-D float `Tensor` of embedding vectors for the 17 | positive images. Embeddings should be l2 normalized. 18 | margin: margin term in the loss definition. 19 | Returns: 20 | contrastive_loss: tf.float32 scalar. 21 | """ 22 | # embeddings_anchor = tf.Print(embeddings_anchor,[tf.shape(embeddings_anchor),tf.shape(embeddings_positive)],'embeddings_anchor shapes') 23 | epsilon= 10e-6 24 | distances = math_ops.sqrt( 25 | math_ops.reduce_sum( 26 | math_ops.square(embeddings_anchor - embeddings_positive), 1) + epsilon) 27 | # distances = tf.Print(distances,[tf.shape(distances),distances],'distances ',summarize=1000) 28 | # Add contrastive loss for the siamese network. 29 | # label here is {0,1} for neg, pos. 30 | 31 | pos_loss = math_ops.to_float(labels) * math_ops.square(distances) 32 | # pos_loss = tf.Print(pos_loss, [tf.shape(pos_loss),pos_loss], 'pos_loss ',summarize=1000) 33 | neg_loss = (1. - math_ops.to_float(labels)) * math_ops.square(math_ops.maximum(margin - distances, 0.)) 34 | # neg_loss = tf.Print(neg_loss, [tf.shape(neg_loss),(1. - math_ops.to_float(labels)),math_ops.square(math_ops.maximum(margin - distances, 0.)),neg_loss], 'neg_loss ',summarize=1000) 35 | 36 | contrastive_loss = math_ops.reduce_mean(pos_loss + neg_loss, name='contrastive_loss') 37 | return contrastive_loss -------------------------------------------------------------------------------- /ranking/npair.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def npairs_loss_helper(labels, embeddings_anchor, embeddings_positive, 4 | reg_lambda=0.002, print_losses=False): 5 | """Computes the npairs loss. 6 | Npairs loss expects paired data where a pair is composed of samples from the 7 | same labels and each pairs in the minibatch have different labels. The loss 8 | has two components. The first component is the L2 regularizer on the 9 | embedding vectors. The second component is the sum of cross entropy loss 10 | which takes each row of the pair-wise similarity matrix as logits and 11 | the remapped one-hot labels as labels. 12 | See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf 13 | Args: 14 | labels: 1-D tf.int32 `Tensor` of shape [batch_size/2]. 15 | embeddings_anchor: 2-D Tensor of shape [batch_size/2, embedding_dim] for the 16 | embedding vectors for the anchor images. Embeddings should not be 17 | l2 normalized. 18 | embeddings_positive: 2-D Tensor of shape [batch_size/2, embedding_dim] for the 19 | embedding vectors for the positive images. Embeddings should not be 20 | l2 normalized. 21 | reg_lambda: Float. L2 regularization term on the embedding vectors. 22 | print_losses: Boolean. Option to print the xent and l2loss. 23 | Returns: 24 | npairs_loss: tf.float32 scalar. 25 | """ 26 | # pylint: enable=line-too-long 27 | # Add the regularizer on the embedding. 28 | reg_anchor = tf.reduce_mean( 29 | tf.reduce_sum(tf.square(embeddings_anchor), 1)) 30 | reg_positive = tf.reduce_mean( 31 | tf.reduce_sum(tf.square(embeddings_positive), 1)) 32 | l2loss = tf.multiply( 33 | 0.25 * reg_lambda, reg_anchor + reg_positive, name='l2loss') 34 | 35 | # Get per pair similarities. 36 | similarity_matrix = tf.matmul( 37 | embeddings_anchor, embeddings_positive, transpose_a=False, 38 | transpose_b=True) 39 | 40 | # Reshape [batch_size] label tensor to a [batch_size, 1] label tensor. 41 | lshape = tf.shape(labels) 42 | assert lshape.shape == 1 43 | labels = tf.reshape(labels, [lshape[0], 1]) 44 | 45 | labels_remapped = tf.cast( 46 | tf.equal(labels, tf.transpose(labels)), tf.float32) 47 | labels_remapped /= tf.reduce_sum(labels_remapped, 1, keepdims=True) 48 | 49 | # Add the softmax loss. 50 | xent_loss = tf.nn.softmax_cross_entropy_with_logits( 51 | logits=similarity_matrix, labels=labels_remapped) 52 | xent_loss = tf.reduce_mean(xent_loss, name='xentropy') 53 | 54 | if print_losses: 55 | xent_loss = tf.Print( 56 | xent_loss, ['cross entropy:', xent_loss, 'l2loss:', l2loss]) 57 | 58 | return l2loss + xent_loss 59 | 60 | def npairs_loss(labels,embeddings_anchor,embeddings_positive,reg_lambda=0.002,print_losses=False): 61 | return npairs_loss_helper(labels, embeddings_anchor,embeddings_positive,reg_lambda=reg_lambda,print_losses=print_losses) 62 | 63 | # return tf.contrib.losses.metric_learning.npairs_loss(labels, embeddings_anchor,embeddings_positive,reg_lambda=reg_lambda,print_losses=print_losses) 64 | -------------------------------------------------------------------------------- /nets/mobilenet_v1_1_224.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from nets.mobilenet_v1 import mobilenet_v1 4 | from tensorflow.contrib import slim 5 | 6 | 7 | def endpoints(image, is_training,dropout_rate=0,weight_decay=None): 8 | if image.get_shape().ndims != 4: 9 | raise ValueError('Input must be of size [batch, height, width, 3]') 10 | 11 | image = tf.divide(image, 255.0) 12 | 13 | with tf.contrib.slim.arg_scope(mobilenet_v1_arg_scope(batch_norm_decay=0.9, weight_decay=0.0)): 14 | _, endpoints = mobilenet_v1(image, num_classes=1001, is_training=is_training) 15 | 16 | endpoints['model_output'] = endpoints['global_pool'] = tf.reduce_mean( 17 | endpoints['Conv2d_13_pointwise'], [1, 2], name='global_pool', keep_dims=False) 18 | 19 | return endpoints, 'MobilenetV1' 20 | 21 | 22 | # This is copied and modified from mobilenet_v1.py. 23 | def mobilenet_v1_arg_scope(is_training=True, 24 | batch_norm_decay=0.9997, 25 | batch_norm_epsilon=0.001, 26 | batch_norm_scale=True, 27 | weight_decay=0.00004, 28 | stddev=0.09, 29 | regularize_depthwise=False): 30 | 31 | """Defines the default MobilenetV1 arg scope. 32 | Args: 33 | is_training: Whether or not we're training the model. 34 | batch_norm_decay: The moving average decay when estimating layer activation 35 | statistics in batch normalization. 36 | batch_norm_epsilon: Small constant to prevent division by zero when 37 | normalizing activations by their variance in batch normalization. 38 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 39 | activations in the batch normalization layer. 40 | weight_decay: The weight decay to use for regularizing the model. 41 | stddev: The standard deviation of the trunctated normal weight initializer. 42 | regularize_depthwise: Whether or not apply regularization on depthwise. 43 | Returns: 44 | An `arg_scope` to use for the mobilenet v1 model. 45 | """ 46 | batch_norm_params = { 47 | 'is_training': is_training, 48 | 'center': True, 49 | 'scale': batch_norm_scale, 50 | 'decay': batch_norm_decay, 51 | 'epsilon': batch_norm_epsilon, 52 | } 53 | 54 | # Set weight_decay for weights in Conv and DepthSepConv layers. 55 | weights_init = tf.truncated_normal_initializer(stddev=stddev) 56 | regularizer = tf.contrib.layers.l2_regularizer(weight_decay) 57 | if regularize_depthwise: 58 | depthwise_regularizer = regularizer 59 | else: 60 | depthwise_regularizer = None 61 | with slim.arg_scope([slim.conv2d, slim.separable_conv2d], 62 | weights_initializer=weights_init, 63 | activation_fn=tf.nn.relu6, normalizer_fn=slim.batch_norm): 64 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 65 | with slim.arg_scope([slim.conv2d], weights_regularizer=regularizer): 66 | with slim.arg_scope([slim.separable_conv2d], 67 | weights_regularizer=depthwise_regularizer) as sc: 68 | return sc 69 | -------------------------------------------------------------------------------- /model/embedding_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class FC1024Head(tf.keras.Model): 4 | def __init__(self, cfg): 5 | super(FC1024Head, self).__init__() 6 | self.h_1024 = tf.keras.layers.Dense(1025, activation=None, 7 | kernel_initializer=tf.keras.initializers.Orthogonal()) 8 | self.batch_norm = tf.keras.layers.BatchNormalization( 9 | momentum = 0.9, 10 | epsilon=1e-5, 11 | scale=True, 12 | ) 13 | self.head = tf.keras.layers.Dense(cfg.embedding_dim, activation=None, 14 | kernel_initializer=tf.keras.initializers.Orthogonal()) 15 | def call(self, inputs): 16 | h1 = tf.keras.backend.relu(self.batch_norm(self.h_1024(inputs))) 17 | return self.head(h1) 18 | 19 | 20 | class DirectHead(tf.keras.Model): 21 | def __init__(self, cfg): 22 | super(DirectHead, self).__init__() 23 | self.head = tf.keras.layers.Dense(cfg.embedding_dim, activation=None, 24 | kernel_initializer=tf.keras.initializers.Orthogonal()) 25 | def call(self, inputs): 26 | return self.head(inputs) 27 | 28 | class EmbeddingModel(tf.keras.Model): 29 | 30 | def __init__(self, cfg): 31 | super(EmbeddingModel, self).__init__() 32 | self.cfg = cfg 33 | 34 | if cfg.model_name == 'inception_v1': 35 | self.base_model = tf.keras.applications.Xception(weights='imagenet', include_top=False) 36 | self.preprocess_input = tf.keras.applications.xception.preprocess_input 37 | elif cfg.model_name == 'resnet_v1_50': 38 | self.base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False) 39 | self.preprocess_input = tf.keras.applications.resnet.preprocess_input 40 | elif cfg.model_name == 'densenet169': 41 | self.base_model = tf.keras.applications.DenseNet169(weights='imagenet', include_top=False) 42 | self.preprocess_input = tf.keras.applications.densenet.preprocess_input 43 | else: 44 | raise NotImplementedError('Invalid model_name {}'.format(cfg.model_name)) 45 | 46 | 47 | 48 | self.spatial_pooling = tf.keras.layers.GlobalAvgPool2D() 49 | if 'direct' in cfg.head_name: 50 | self.embedding_head = DirectHead(cfg) 51 | elif 'fc1024' in cfg.head_name: 52 | self.embedding_head = FC1024Head(cfg) 53 | else: 54 | raise NotImplementedError('Invalid head_name {}'.format(cfg.head_name)) 55 | 56 | 57 | self.l2_embedding = 'normalize' in cfg.head_name 58 | 59 | 60 | 61 | def call(self, images): 62 | base_model_output = self.base_model(images) 63 | 64 | base_model_output_pooled = self.spatial_pooling(base_model_output) 65 | batch_embedding = self.embedding_head(base_model_output_pooled ) 66 | if self.l2_embedding: 67 | reutrn_batch_embedding = tf.nn.l2_normalize(batch_embedding, -1) 68 | else: 69 | reutrn_batch_embedding = batch_embedding 70 | return reutrn_batch_embedding 71 | 72 | -------------------------------------------------------------------------------- /nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001, 36 | activation_fn=tf.nn.relu, 37 | batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS, 38 | batch_norm_scale=False): 39 | """Defines the default arg scope for inception models. 40 | 41 | Args: 42 | weight_decay: The weight decay to use for regularizing the model. 43 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 44 | batch_norm_decay: Decay for batch norm moving average. 45 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 46 | in batch norm. 47 | activation_fn: Activation function for conv2d. 48 | batch_norm_updates_collections: Collection for the update ops for 49 | batch norm. 50 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 51 | activations in the batch normalization layer. 52 | 53 | Returns: 54 | An `arg_scope` to use for the inception models. 55 | """ 56 | batch_norm_params = { 57 | # Decay for the moving averages. 58 | 'decay': batch_norm_decay, 59 | # epsilon to prevent 0s in variance. 60 | 'epsilon': batch_norm_epsilon, 61 | # collection containing update_ops. 62 | 'updates_collections': batch_norm_updates_collections, 63 | # use fused batch norm if possible. 64 | 'fused': None, 65 | 'scale': batch_norm_scale, 66 | } 67 | if use_batch_norm: 68 | normalizer_fn = slim.batch_norm 69 | normalizer_params = batch_norm_params 70 | else: 71 | normalizer_fn = None 72 | normalizer_params = {} 73 | # Set weight_decay for weights in Conv and FC layers. 74 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 75 | weights_regularizer=slim.l2_regularizer(weight_decay)): 76 | with slim.arg_scope( 77 | [slim.conv2d], 78 | weights_initializer=slim.variance_scaling_initializer(), 79 | activation_fn=activation_fn, 80 | normalizer_fn=normalizer_fn, 81 | normalizer_params=normalizer_params) as sc: 82 | return sc 83 | -------------------------------------------------------------------------------- /utils/os_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import errno 4 | import csv 5 | from shutil import copyfile 6 | 7 | 8 | def get_last_part(path): 9 | return os.path.basename(os.path.normpath(path)) 10 | 11 | def copy_file(f,dst,rename=None): 12 | touch_dir(dst) 13 | # for f_idx,f in enumerate(src_file_lst): 14 | if os.path.exists(f): 15 | # print(f) 16 | if rename ==None: 17 | copyfile(f, os.path.join(dst,get_last_part(f))) 18 | else: 19 | _,ext = get_file_name_ext(f) 20 | copyfile(f, os.path.join(dst, rename+ext )) 21 | else: 22 | raise Exception('File not found') 23 | 24 | def copy_files(src_file_lst,dst,rename=None): 25 | touch_dir(dst) 26 | for f_idx,f in enumerate(src_file_lst): 27 | if os.path.exists(f): 28 | # print(f) 29 | if rename ==None: 30 | copyfile(f, os.path.join(dst,get_last_part(f))) 31 | else: 32 | _,ext = get_file_name_ext(f) 33 | copyfile(f, os.path.join(dst, rename[f_idx]+ext )) 34 | else: 35 | raise Exception('File not found') 36 | 37 | def dataset_tuples(dataset_path): 38 | return dataset_path + '_tuples_class' 39 | 40 | 41 | def get_dirs(base_path): 42 | return sorted([f for f in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, f))]) 43 | 44 | 45 | def get_files(base_path,extension,append_base=False): 46 | if (append_base): 47 | files =[os.path.join(base_path,f) for f in os.listdir(base_path) if (f.endswith(extension) and not f.startswith('.'))]; 48 | else: 49 | files = [f for f in os.listdir(base_path) if (f.endswith(extension) and not f.startswith('.'))]; 50 | return sorted(files); 51 | 52 | def csv_read(csv_file,has_header=False): 53 | rows = [] 54 | with open(csv_file, 'r') as csvfile: 55 | file_content = csv.reader(csvfile) 56 | if has_header: 57 | header = next(file_content, None) # skip the headers 58 | for row in file_content: 59 | rows.append(row) 60 | 61 | return rows 62 | 63 | def csv_write(csv_file,rows): 64 | with open(csv_file, mode='w') as file: 65 | rows_writer = csv.writer(file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 66 | for row in rows: 67 | rows_writer.writerow(row) 68 | 69 | 70 | 71 | def txt_read(path): 72 | with open(path) as f: 73 | content = f.readlines() 74 | lines = [x.strip() for x in content] 75 | return lines; 76 | 77 | def txt_write(path,lines): 78 | out_file = open(path, "w") 79 | for line in lines: 80 | out_file.write(line) 81 | out_file.write('\n') 82 | out_file.close() 83 | 84 | def pkl_write(path,data): 85 | pickle.dump(data, open(path, "wb")) 86 | 87 | 88 | def hot_one_vector(y, max): 89 | import numpy as np 90 | labels_hot_vector = np.zeros((y.shape[0], max),dtype=np.int32) 91 | labels_hot_vector[np.arange(y.shape[0]), y] = 1 92 | return labels_hot_vector 93 | 94 | def pkl_read(path): 95 | if(not os.path.exists(path)): 96 | return None; 97 | 98 | data = pickle.load(open(path, 'rb')) 99 | return data; 100 | 101 | def touch_dir(path): 102 | if(not os.path.exists(path)): 103 | os.makedirs(path) 104 | 105 | def touch_file_dir(file_path): 106 | if not os.path.exists(os.path.dirname(file_path)): 107 | try: 108 | os.makedirs(os.path.dirname(file_path)) 109 | except OSError as exc: # Guard against race condition 110 | if exc.errno != errno.EEXIST: 111 | raise 112 | 113 | 114 | 115 | 116 | def last_tuple_idx(path): 117 | files =[f for f in os.listdir(path) if (f.endswith('.jpg') and not f.startswith('.'))]; 118 | return len(files); 119 | 120 | def get_file_name_ext(inputFilepath): 121 | filename_w_ext = os.path.basename(inputFilepath) 122 | filename, file_extension = os.path.splitext(filename_w_ext) 123 | return filename, file_extension 124 | 125 | def get_latest_file(path,extension=''): 126 | files = get_files(path,extension=extension,append_base=True); 127 | return max(files, key=os.path.getctime) 128 | 129 | def dir_empty(path): 130 | if os.listdir(path) == []: 131 | return True; 132 | else: 133 | return False; 134 | 135 | def chkpt_exists(path): 136 | files = [f for f in os.listdir(path) if (f.find('.ckpt') > 0 and not f.startswith('.'))]; 137 | if len(files): 138 | return True; 139 | return False; -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import h5py 4 | sys.path.append('..') 5 | sys.path.append('/vulcan/scratch/ahmdtaha/libs/kmcuda/src') 6 | import common 7 | import logging.config 8 | import numpy as np 9 | import tensorflow as tf 10 | import constants as const 11 | from ranking import METRIC_CHOICES 12 | from sklearn.cluster import KMeans 13 | # from libKMCUDA import kmeans_cuda 14 | from scipy.spatial.distance import pdist 15 | from argparse import ArgumentParser, FileType 16 | from sklearn.metrics import normalized_mutual_info_score 17 | 18 | 19 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 20 | parser = ArgumentParser(description='Evaluate a ReID embedding.') 21 | 22 | parser.add_argument( 23 | '--gallery_dataset', required=True, 24 | help='Path to the gallery dataset csv file.') 25 | 26 | parser.add_argument( 27 | '--gallery_embeddings', required=True, 28 | help='Path to the h5 file containing the gallery embeddings.') 29 | 30 | parser.add_argument( 31 | '--metric', required=True, choices=METRIC_CHOICES, 32 | help='Which metric to use for the distance between embeddings.') 33 | 34 | parser.add_argument( 35 | '--filename', type=FileType('w'), 36 | help='Optional name of the json file to store the results in.') 37 | 38 | parser.add_argument( 39 | '--batch_size', default=256, type=common.positive_int, 40 | help='Batch size used during evaluation, adapt based on your memory usage.') 41 | 42 | def get_distance_matrix(x): 43 | """Get distance matrix given a matrix. Used in testing.""" 44 | square = np.sum(x ** 2.0, axis=1, keepdims=True) 45 | distance_square = square + square.transpose() - (2.0 * np.dot(x, x.transpose())) 46 | return np.sqrt(distance_square) 47 | 48 | 49 | def evaluate_emb(emb, labels): 50 | """Evaluate embeddings based on Recall@k.""" 51 | d_mat = get_distance_matrix(emb) 52 | names = [] 53 | accs = [] 54 | for k in [1, 2, 4, 8, 16]: 55 | names.append('Recall@%d' % k) 56 | correct, cnt = 0.0, 0.0 57 | for i in range(emb.shape[0]): 58 | d_mat[i, i] = 1e10 59 | nns = np.argpartition(d_mat[i], k)[:k] 60 | if any(labels[i] == labels[nn] for nn in nns): 61 | correct += 1 62 | cnt += 1 63 | accs.append(correct/cnt) 64 | return names, accs 65 | 66 | def main(argv): 67 | # Verify that parameters are set correctly. 68 | args = parser.parse_args(argv) 69 | 70 | gallery_pids, gallery_fids = common.load_dataset(args.gallery_dataset, None) 71 | 72 | log_file = os.path.join(exp_root, "recall_eval") 73 | logging.config.dictConfig(common.get_logging_dict(log_file)) 74 | log = logging.getLogger('recall_eval') 75 | 76 | with h5py.File(args.gallery_embeddings, 'r') as f_gallery: 77 | gallery_embs = np.array(f_gallery['emb']) 78 | #gallery_embs_var = np.array(f_gallery['emb_var']) 79 | #print('gallery_embs_var.shape =>',gallery_embs_var.shape) 80 | 81 | num_clusters = len(np.unique(gallery_pids)) 82 | print('Start clustering K ={}'.format(num_clusters)) 83 | 84 | log.info(exp_root) 85 | 86 | kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(gallery_embs) 87 | log.info('NMI :: {}'.format(normalized_mutual_info_score(gallery_pids, kmeans.labels_))) 88 | 89 | # centroids, assignments = kmeans_cuda(gallery_embs,num_clusters,seed=3) 90 | # log.info('NMI :: {}'.format(normalized_mutual_info_score(gallery_pids, assignments))) 91 | 92 | log.info('Clustering complete') 93 | 94 | 95 | log.info('Eval with Recall-K') 96 | names, accs = evaluate_emb(gallery_embs,gallery_pids) 97 | log.info(names) 98 | log.info(accs) 99 | 100 | if __name__ == '__main__': 101 | 102 | arg_experiment_root = const.experiment_root_dir 103 | dataset_name = 'cub' 104 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 105 | exp_dir = 'cub_densenet_direct_normalize_npairs_loss_m_0.2' 106 | foldername = 'emb' 107 | exp_root = os.path.join(arg_experiment_root+exp_dir,foldername) 108 | 109 | if dataset_name == 'cub': 110 | csv_file = 'cub' 111 | elif dataset_name == 'inshop': 112 | csv_file = 'deep_fashion' 113 | elif dataset_name == 'stanford': 114 | csv_file = 'stanford_online' 115 | 116 | else: 117 | raise NotImplementedError('dataset {} not valid'.format(dataset_name)) 118 | 119 | 120 | argv = [ 121 | 122 | '--gallery_dataset','./data/'+csv_file+'_test.csv', 123 | '--gallery_embeddings',os.path.join(exp_root ,'test_embeddings_augmented.h5'), 124 | '--metric','euclidean', 125 | '--filename',os.path.join(exp_root ,'market1501_evaluation.json'), 126 | ] 127 | main(argv) 128 | 129 | 130 | -------------------------------------------------------------------------------- /ranking/hard_triplet.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import tensorflow as tf 3 | 4 | def all_diffs(a, b): 5 | """ Returns a tensor of all combinations of a - b. 6 | 7 | Args: 8 | a (2D tensor): A batch of vectors shaped (B1, F). 9 | b (2D tensor): A batch of vectors shaped (B2, F). 10 | 11 | Returns: 12 | The matrix of all pairwise differences between all vectors in `a` and in 13 | `b`, will be of shape (B1, B2). 14 | 15 | Note: 16 | For convenience, if either `a` or `b` is a `Distribution` object, its 17 | mean is used. 18 | """ 19 | return tf.expand_dims(a, axis=1) - tf.expand_dims(b, axis=0) 20 | 21 | def cdist(a, b, metric='euclidean'): 22 | """Similar to scipy.spatial's cdist, but symbolic. 23 | 24 | The currently supported metrics can be listed as `cdist.supported_metrics` and are: 25 | - 'euclidean', although with a fudge-factor epsilon. 26 | - 'sqeuclidean', the squared euclidean. 27 | - 'cityblock', the manhattan or L1 distance. 28 | 29 | Args: 30 | a (2D tensor): The left-hand side, shaped (B1, F). 31 | b (2D tensor): The right-hand side, shaped (B2, F). 32 | metric (string): Which distance metric to use, see notes. 33 | 34 | Returns: 35 | The matrix of all pairwise distances between all vectors in `a` and in 36 | `b`, will be of shape (B1, B2). 37 | 38 | Note: 39 | When a square root is taken (such as in the Euclidean case), a small 40 | epsilon is added because the gradient of the square-root at zero is 41 | undefined. Thus, it will never return exact zero in these cases. 42 | """ 43 | with tf.name_scope("cdist"): 44 | diffs = all_diffs(a, b) 45 | if metric == 'sqeuclidean': 46 | return tf.reduce_sum(tf.square(diffs), axis=-1) 47 | elif metric == 'euclidean': 48 | return tf.sqrt(tf.reduce_sum(tf.square(diffs), axis=-1) + 1e-12) 49 | elif metric == 'cityblock': 50 | return tf.reduce_sum(tf.abs(diffs), axis=-1) 51 | elif metric == 'cosine': 52 | # https://stackoverflow.com/questions/48485373/pairwise-cosine-similarity-using-tensorflow 53 | # normalized_input = tf.nn.l2_normalize(a, dim=1) 54 | # Embedding are assumed to be normalized 55 | prod = tf.matmul(a, b,adjoint_b=True) # transpose second matrix 56 | return 1 - prod 57 | else: 58 | raise NotImplementedError( 59 | 'The following metric is not implemented by `cdist` yet: {}'.format(metric)) 60 | 61 | def batch_hard(embeddings, pids, margin,metric): 62 | """Computes the batch-hard loss from arxiv.org/abs/1703.07737. 63 | 64 | Args: 65 | dists (2D tensor): A square all-to-all distance matrix as given by cdist. 66 | pids (1D tensor): The identities of the entries in `batch`, shape (B,). 67 | This can be of any type that can be compared, thus also a string. 68 | margin: The value of the margin if a number, alternatively the string 69 | 'soft' for using the soft-margin formulation, or `None` for not 70 | using a margin at all. 71 | 72 | Returns: 73 | A 1D tensor of shape (B,) containing the loss value for each sample. 74 | """ 75 | with tf.name_scope("batch_hard"): 76 | dists = cdist(embeddings, embeddings, metric=metric) 77 | 78 | same_identity_mask = tf.equal(tf.expand_dims(pids, axis=1), 79 | tf.expand_dims(pids, axis=0)) 80 | # print(pids) 81 | # dists = tf.Print(dists, [dists], "Pair Dist", summarize=1000000) 82 | # same_identity_mask = tf.Print(same_identity_mask,[same_identity_mask, pids],"Hello World" ,summarize=1000000) 83 | negative_mask = tf.logical_not(same_identity_mask) 84 | positive_mask = tf.logical_xor(same_identity_mask, 85 | tf.eye(tf.shape(pids)[0], dtype=tf.bool)) 86 | 87 | furthest_dist = dists*tf.cast(positive_mask, tf.float32) 88 | furthest_positive = tf.reduce_max(furthest_dist, axis=1) 89 | closest_negative = tf.map_fn(lambda x: tf.reduce_min(tf.boolean_mask(x[0], x[1])), 90 | (dists, negative_mask), tf.float32) 91 | 92 | 93 | 94 | diff = (furthest_positive - closest_negative) 95 | diff = tf.squeeze(diff) 96 | #print(prefix,diff) 97 | # negative_idx = pids[negative_idx] 98 | if isinstance(margin, numbers.Real): 99 | diff_result = tf.maximum(diff + margin, 0.0) 100 | assert_op = tf.Assert(tf.equal(tf.rank(diff), 1), ['Rank of image must be equal to 1.']) 101 | with tf.control_dependencies([assert_op]): 102 | diff = diff_result 103 | elif margin == 'soft': 104 | diff_result = tf.nn.softplus(diff) 105 | assert_op = tf.Assert(tf.equal(tf.rank(diff), 1), ['Rank of image must be equal to 1.']) 106 | with tf.control_dependencies([assert_op]): 107 | diff = diff_result 108 | elif margin.lower() == 'none': 109 | pass 110 | else: 111 | raise NotImplementedError( 112 | 'The margin {} is not implemented in batch_hard'.format(margin)) 113 | return diff 114 | -------------------------------------------------------------------------------- /lbtoolbox.py: -------------------------------------------------------------------------------- 1 | # This file contains select utilities from Lucas Beyer's toolbox, the complete 2 | # toolbox can be found at https://github.com/lucasb-eyer/lbtoolbox. 3 | # 4 | # The content of this file is copyright Lucas Beyer. You may only re-use 5 | # parts of it by keeping the following comment above it: 6 | # 7 | # This is taken from Lucas Beyer's toolbox© found at 8 | # https://github.com/lucasb-eyer/lbtoolbox 9 | # and may only be redistributed and reused by keeping this notice. 10 | 11 | import json 12 | import signal 13 | 14 | import numpy as np 15 | 16 | 17 | def tuplize(what, lists=True, tuplize_none=False): 18 | """ 19 | If `what` is a tuple, return it as-is, otherwise put it into a tuple. 20 | If `lists` is true, also consider lists to be tuples (the default). 21 | If `tuplize_none` is true, a lone `None` results in an empty tuple, 22 | otherwise it will be returned as `None` (the default). 23 | """ 24 | if what is None: 25 | if tuplize_none: 26 | return tuple() 27 | else: 28 | return None 29 | 30 | if isinstance(what, tuple) or (lists and isinstance(what, list)): 31 | return tuple(what) 32 | else: 33 | return (what,) 34 | 35 | 36 | def create_dat(basename, dtype, shape, fillvalue=None, **meta): 37 | """ Creates a data file at `basename` and returns a writeable mem-map 38 | backed numpy array to it. 39 | Can also be passed any json-serializable keys and values in `meta`. 40 | """ 41 | # Sadly, we can't just add attributes (flush) to a numpy array, 42 | # so we need to dummy-subclass it. 43 | class LBArray(np.ndarray): 44 | pass 45 | 46 | Xm = np.memmap(basename, mode='w+', dtype=dtype, shape=shape) 47 | Xa = np.ndarray.__new__(np.ndarray, dtype=dtype, shape=shape, buffer=Xm) 48 | # Xa = np.ndarray.__new__(LBArray, dtype=dtype, shape=shape, buffer=Xm) 49 | # Xa.flush = Xm.flush 50 | 51 | if fillvalue is not None: 52 | Xa.fill(fillvalue) 53 | Xm.flush() 54 | # Xa.flush() 55 | 56 | meta.setdefault('dtype', np.dtype(dtype).str) 57 | meta.setdefault('shape', tuplize(shape)) 58 | json.dump(meta, open(basename + '.json', 'w+')) 59 | 60 | return Xa 61 | 62 | 63 | def load_dat(basename, mode='r'): 64 | """ Returns a read-only mem-mapped numpy array to file at `basename`. 65 | If `mode` is set to `'r+'`, the data can be written, too. 66 | """ 67 | desc = json.load(open(basename + '.json', 'r')) 68 | dtype, shape = desc['dtype'], tuplize(desc['shape']) 69 | Xm = np.memmap(basename, mode=mode, dtype=dtype, shape=shape) 70 | Xa = np.ndarray.__new__(np.ndarray, dtype=dtype, shape=shape, buffer=Xm) 71 | #Xa.flush = Xm.flush # Sadly, we can't just add attributes to a numpy array, need to subclass it. 72 | return Xa 73 | 74 | 75 | def create_or_resize_dat(basename, dtype, shape, fillvalue=None, **meta): 76 | # Not cleanly possible otherwise yet, see https://github.com/numpy/numpy/issues/4198 77 | try: 78 | old_desc = json.load(open(basename + '.json', 'r')) 79 | except: 80 | return create_dat(basename, dtype, shape, fillvalue, **meta) 81 | 82 | old_dtype, old_shape = old_desc['dtype'], tuplize(old_desc['shape']) 83 | 84 | # Standarize parameters 85 | new_shape = tuplize(shape) 86 | new_dtype_str = np.dtype(dtype).str 87 | 88 | # For memory-layout and code-simplicity reasons, we only support growing 89 | # in the first dimension, which actually covers all my use-cases so far. 90 | # https://github.com/numpy/numpy/issues/4198#issuecomment-341983443 91 | assert old_shape[1:] == new_shape[1:], "Can only grow in first dimension! Old: {}, New: {}".format(old_shape, new_shape) 92 | assert old_dtype == new_dtype_str, "Can't change the dtype! Old: {}, New: {}".format(old_dtype, new_dtype_str) 93 | 94 | # Open the mem-mapped file and reshape it to what's needed. 95 | Xm = np.memmap(basename, mode='r+', dtype=dtype, shape=old_shape) 96 | Xm._mmap.resize(Xm.dtype.itemsize * np.product(new_shape)) # BYTES HERE!! 97 | 98 | Xa = np.ndarray.__new__(np.ndarray, dtype=dtype, shape=new_shape, buffer=Xm._mmap, offset=0) 99 | # Xa.flush = Xm.flush 100 | 101 | if fillvalue is not None: 102 | Xa[old_shape[0]:] = fillvalue 103 | Xm._mmap.flush() 104 | # Xa.flush() 105 | 106 | meta.setdefault('dtype', new_dtype_str) 107 | meta.setdefault('shape', new_shape) 108 | json.dump(meta, open(basename + '.json', 'w+')) # Overwrite the old one. 109 | 110 | return Xa 111 | 112 | 113 | # Based on an original idea by https://gist.github.com/nonZero/2907502 and heavily modified. 114 | class Uninterrupt(object): 115 | """ 116 | Use as: 117 | with Uninterrupt() as u: 118 | while not u.interrupted: 119 | # train 120 | """ 121 | def __init__(self, sigs=(signal.SIGINT,), verbose=False): 122 | self.sigs = sigs 123 | self.verbose = verbose 124 | self.interrupted = False 125 | self.orig_handlers = None 126 | 127 | def __enter__(self): 128 | if self.orig_handlers is not None: 129 | raise ValueError("Can only enter `Uninterrupt` once!") 130 | 131 | self.interrupted = False 132 | self.orig_handlers = [signal.getsignal(sig) for sig in self.sigs] 133 | 134 | def handler(signum, frame): 135 | self.release() 136 | self.interrupted = True 137 | if self.verbose: 138 | print("Interruption scheduled...", flush=True) 139 | 140 | for sig in self.sigs: 141 | signal.signal(sig, handler) 142 | 143 | return self 144 | 145 | def __exit__(self, type_, value, tb): 146 | self.release() 147 | 148 | def release(self): 149 | if self.orig_handlers is not None: 150 | for sig, orig in zip(self.sigs, self.orig_handlers): 151 | signal.signal(sig, orig) 152 | self.orig_handlers = None 153 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow Retrieval Baseline 2 | This repository provides a retrieval/space embedding baseline using multiple retrieval datasets and ranking losses. This code is based on [triplet-reid](https://github.com/VisualComputingInstitute/triplet-reid) repos. 3 | 4 | ### Evaluation Metrics 5 | 1. Normalized Mutual Information (NMI) 6 | 2. Recall@K 7 | 8 | ### Deep Fashion In-shop Retrieval Evaluation 9 | All the following experiments assume a training mini-batch of size 60. The architecture employed is the one used in [In Defense of the Triplet Loss for Person Re-Identification](https://arxiv.org/abs/1703.07737) but ResNet is replaced by a DenseNet169. 10 | Optimizer: Adam, Number of iterations = 25K 11 | 12 | | Method | Normalized | Margin | NMI | R@1 | R@4 | # of classes | #samples per class | 13 | |-----------|------------|--------|-------|-------|-------|--------------|--------------------| 14 | | [Semi-Hard](https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/triplet_semihard_loss) | Yes | 0.2 | 0.902 | 87.43 | 95.42 | 10| 6| 15 | | Hard-Negative | No | 1.0 | 0.904 | 88.38 | 95.74 | 10| 6| 16 | | [Lifted Structured](https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/lifted_struct_loss) | No | 1.0 | 0.903 | 87.32 | 95.59 | 10| 6| 17 | | [N-Pair Loss](https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/npairs_loss) | No | N/A | 0.903 | 89.12 | 96.13 | 30| 2| 18 | | [Angular Loss](https://github.com/geonm/tf_angular_loss) | Yes | N/A | 0.8931 | 84.70 | 92.32 | 30| 2| 19 | | Custom [Contrastive Loss](https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/contrastive_loss) | Yes | 1.0 | 0.826 | 44.09 | 67.17 | 15| 4| 20 | 21 | ### CUB200-2011 Retrieval Evaluation 22 | Mini-batch size=120. Architecture: Inception_Net V1. 23 | Optimizer: Momentum. Number of iterations = 10K 24 | 25 | | Method | Normalized | Margin | NMI | R@1 | R@4 | # of classes | #samples per class | 26 | |-----------|------------|--------|-------|-------|-------|--------------|--------------------| 27 | | [Semi-Hard](https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/triplet_semihard_loss) | Yes | 0.2 | 0.587 | 49.03 | 73.43 | 20| 6| 28 | | Hard Negatives | No | 1.0 | 0.561 | 46.55 | 71.03 | 20| 6| 29 | | [Lifted Structured](https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/lifted_struct_loss) | No | 1.0 | 0.502 | 35.26 | 59.82 | 20| 6| 30 | | [N-Pair Loss](https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/npairs_loss) | No | N/A | 0.573 | 46.52 | 59.26 | 60| 2| 31 | | [Angular Loss](https://github.com/geonm/tf_angular_loss) | Yes | N/A | 0.546 | 45.50 | 68.43 | 60 | 2| 32 | | Custom [Contrastive Loss](https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/contrastive_loss) | Yes | 1.0 | 0.476 | 37.27 | 62.39 | 30| 4| 33 | 34 | ### Stanford Online Products Retrieval Evaluation 35 | Mini-batch size=120. Architecture: Inception_Net V1. 36 | Optimizer: Adam. Number of iterations = 30K 37 | 38 | | Method | Normalized | Margin | NMI | R@1 | R@4 | # of classes | #samples per class | 39 | |-----------|------------|--------|-------|-------|-------|--------------|--------------------| 40 | | [Semi-Hard](https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/triplet_semihard_loss) | Yes | 0.2 | 0.893 | 71.22 | 81.77 | 20| 6| 41 | | Hard Negatives | No | 1.0 | 0.895 | 72.03 | 82.55 | 20| 6| 42 | | [Lifted Structured](https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/lifted_struct_loss) | No | 1.0 | 0.889 | 68.26 | 79.72 | 20| 6| 43 | | [N-Pair Loss](https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/npairs_loss) | No | N/A | 0.893 | 72.60 | 82.59 | 60| 2| 44 | | [Angular Loss](https://github.com/geonm/tf_angular_loss) | Yes | N/A | 0.878 | 60.30 | 72.78 | 60 | 2| 45 | | Custom [Contrastive Loss](https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/contrastive_loss) | Yes | 1.0 | 0.825 | 19.05 | 32.28 | 30| 4| 46 | 47 | 48 | 49 | ### Requirements 50 | * Python 3+ [Tested on 3.4.7 / 3.7] 51 | * Tensorflow 1 and TF 2.0 [Tested on 1.8 / 1.14 / 2.0] 52 | 53 | ### Code Setup 54 | 1. Update the directories' paths in constants.py. 55 | 2. Use train.py and train_tf2.py for TF 1.X and TF 2.X, respectively. 56 | 3. Use embed.py and embed_tf2.py for TF 1.X and TF 2.X, respectively. 57 | 4. eval.py. 58 | 59 | ### Supported Ranking losses 60 | * Triplet Loss with hard mining - 'hard_triplet' 61 | * Triplet Loss with semi-hard mining - 'semi_hard_triplet' 62 | * Lifted Structure Loss - 'lifted_loss' 63 | * N-pairs loss - 'npairs_loss' 64 | * Angular loss - 'angular_loss' 65 | * Contrastive loss - 'contrastive_loss' 66 | 67 | Keep an eye on `ranking/__init__.py` for new ranking loss 68 | 69 | ### Recommeneded Setting for each loss 70 | 71 | | Method | Setting | 72 | |-----------|------------| 73 | | [Semi-Hard](https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/triplet_semihard_loss) | L2-Norm Yes, Margin =0.2 | 74 | | Hard Negatives | L2-Norm No , Margin =1.0 | 75 | | [Lifted Structured](https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/lifted_struct_loss) | L2-Norm No , Margin =1.0 | 76 | | [N-Pair Loss](https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/npairs_loss) | L2-Norm No , Margin =N/A | 77 | | [Angular Loss](https://github.com/geonm/tf_angular_loss) | L2-Norm Yes, Margin =N/A | 78 | | Custom [Contrastive Loss](https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/contrastive_loss) | L2-Norm Yes, Margin =1.0 | 79 | 80 | 81 | ### Wiki 82 | * [Done] [Explain the fast contrastive loss sampling procedure](https://github.com/ahmdtaha/tf_retrieval_baseline/wiki/Contrastive-loss-with-tf.Data) 83 | * [Done] The contrastive loss in the repos is customized to avoid nan during training. When the anchor and positive belong to the same class and the distance between their embeddings is near zero, the derivative turns into nan. [Lei Mao](https://leimao.github.io/article/Siamese-Network-MNIST/) provides a nice detailed mathematical explanation for this issue. 84 | 85 | ### TODO 86 | * [TODO] bash script for train, embed and then eval 87 | * [TODO] Evaluate space embedding during training. 88 | * [TODO] After supporting TF 2.0 (eager execution), It become easier to support more losses -- Maybe add Margin loss. 89 | 90 | 91 | ### Misc Notes 92 | * I noticed that some methods depend heavily on training parameters like the optimizer and number of iterations. For example, the semi-hard negative performance drops significantly on CUB-dataset if Adam optimizer is used instead of Momentum! The number of iterations seems also matter for this small dataset. 93 | * The Tensorflow 2.0 implementation uses more memory even when disabling the eager execution. I tested the code with a smaller batch size -- ten classes and five samples per class. After training for 10K iterations, the performance achieved is NMI=0.54, R@1=42.64, R@4=66.52. 94 | 95 | ## Release History 96 | 97 | * 0.0.1 98 | * CHANGE: Jan 8, 2020. Update code to support Tensorflow 2.0 99 | * CHANGE: Dec 31, 2019. Update code to support Tensorflow 1.14 100 | * First Commit: May 24, 2019. Code tested on Tensorflow 1.8 -------------------------------------------------------------------------------- /ranking/semi_hard_triplet.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import tensorflow as tf 3 | from tensorflow.python.ops import math_ops 4 | from tensorflow.python.ops import array_ops 5 | from tensorflow.python.framework import dtypes 6 | 7 | def masked_minimum(data, mask, dim=1): 8 | """Computes the axis wise minimum over chosen elements. 9 | 10 | Args: 11 | data: 2-D float `Tensor` of size [n, m]. 12 | mask: 2-D Boolean `Tensor` of size [n, m]. 13 | dim: The dimension over which to compute the minimum. 14 | 15 | Returns: 16 | masked_minimums: N-D `Tensor`. 17 | The minimized dimension is of size 1 after the operation. 18 | """ 19 | axis_maximums = math_ops.reduce_max(data, dim, keepdims=True) 20 | masked_minimums = math_ops.reduce_min( 21 | math_ops.multiply(data - axis_maximums, mask), dim, 22 | keepdims=True) + axis_maximums 23 | return masked_minimums 24 | 25 | 26 | 27 | def masked_maximum(data, mask, dim=1): 28 | """Computes the axis wise maximum over chosen elements. 29 | 30 | Args: 31 | data: 2-D float `Tensor` of size [n, m]. 32 | mask: 2-D Boolean `Tensor` of size [n, m]. 33 | dim: The dimension over which to compute the maximum. 34 | 35 | Returns: 36 | masked_maximums: N-D `Tensor`. 37 | The maximized dimension is of size 1 after the operation. 38 | """ 39 | axis_minimums = math_ops.reduce_min(data, dim, keepdims=True) 40 | masked_maximums = math_ops.reduce_max( 41 | math_ops.multiply(data - axis_minimums, mask), dim, 42 | keepdims=True) + axis_minimums 43 | return masked_maximums 44 | 45 | def all_diffs(a, b): 46 | """ Returns a tensor of all combinations of a - b. 47 | 48 | Args: 49 | a (2D tensor): A batch of vectors shaped (B1, F). 50 | b (2D tensor): A batch of vectors shaped (B2, F). 51 | 52 | Returns: 53 | The matrix of all pairwise differences between all vectors in `a` and in 54 | `b`, will be of shape (B1, B2). 55 | 56 | Note: 57 | For convenience, if either `a` or `b` is a `Distribution` object, its 58 | mean is used. 59 | """ 60 | return tf.expand_dims(a, axis=1) - tf.expand_dims(b, axis=0) 61 | 62 | 63 | def cdist(a, b, metric='euclidean'): 64 | """Similar to scipy.spatial's cdist, but symbolic. 65 | 66 | The currently supported metrics can be listed as `cdist.supported_metrics` and are: 67 | - 'euclidean', although with a fudge-factor epsilon. 68 | - 'sqeuclidean', the squared euclidean. 69 | - 'cityblock', the manhattan or L1 distance. 70 | 71 | Args: 72 | a (2D tensor): The left-hand side, shaped (B1, F). 73 | b (2D tensor): The right-hand side, shaped (B2, F). 74 | metric (string): Which distance metric to use, see notes. 75 | 76 | Returns: 77 | The matrix of all pairwise distances between all vectors in `a` and in 78 | `b`, will be of shape (B1, B2). 79 | 80 | Note: 81 | When a square root is taken (such as in the Euclidean case), a small 82 | epsilon is added because the gradient of the square-root at zero is 83 | undefined. Thus, it will never return exact zero in these cases. 84 | """ 85 | with tf.name_scope("cdist"): 86 | diffs = all_diffs(a, b) 87 | if metric == 'sqeuclidean': 88 | return tf.reduce_sum(tf.square(diffs), axis=-1) 89 | elif metric == 'euclidean': 90 | return tf.sqrt(tf.reduce_sum(tf.square(diffs), axis=-1) + 1e-12) 91 | elif metric == 'cityblock': 92 | return tf.reduce_sum(tf.abs(diffs), axis=-1) 93 | elif metric == 'cosine': 94 | # https://stackoverflow.com/questions/48485373/pairwise-cosine-similarity-using-tensorflow 95 | # normalized_input = tf.nn.l2_normalize(a, dim=1) 96 | # Embedding are assumed to be normalized 97 | prod = tf.matmul(a, b,adjoint_b=True) # transpose second matrix 98 | return 1 - prod 99 | else: 100 | raise NotImplementedError( 101 | 'The following metric is not implemented by `cdist` yet: {}'.format(metric)) 102 | 103 | def pairwise_distance(feature, squared=False): 104 | """Computes the pairwise distance matrix with numerical stability. 105 | output[i, j] = || feature[i, :] - feature[j, :] ||_2 106 | Args: 107 | feature: 2-D Tensor of size [number of data, feature dimension]. 108 | squared: Boolean, whether or not to square the pairwise distances. 109 | Returns: 110 | pairwise_distances: 2-D Tensor of size [number of data, number of data]. 111 | """ 112 | pairwise_distances_squared = math_ops.add( 113 | math_ops.reduce_sum(math_ops.square(feature), axis=[1], keepdims=True), 114 | math_ops.reduce_sum( 115 | math_ops.square(array_ops.transpose(feature)), 116 | axis=[0], 117 | keepdims=True)) - 2.0 * math_ops.matmul(feature, 118 | array_ops.transpose(feature)) 119 | 120 | # Deal with numerical inaccuracies. Set small negatives to zero. 121 | pairwise_distances_squared = math_ops.maximum(pairwise_distances_squared, 0.0) 122 | # Get the mask where the zero distances are at. 123 | error_mask = math_ops.less_equal(pairwise_distances_squared, 0.0) 124 | 125 | # Optionally take the sqrt. 126 | if squared: 127 | pairwise_distances = pairwise_distances_squared 128 | else: 129 | pairwise_distances = math_ops.sqrt( 130 | pairwise_distances_squared + math_ops.to_float(error_mask) * 1e-16) 131 | 132 | # Undo conditionally adding 1e-16. 133 | pairwise_distances = math_ops.multiply( 134 | pairwise_distances, math_ops.to_float(math_ops.logical_not(error_mask))) 135 | 136 | num_data = array_ops.shape(feature)[0] 137 | # Explicitly set diagonals to zero. 138 | mask_offdiagonals = array_ops.ones_like(pairwise_distances) - array_ops.diag( 139 | array_ops.ones([num_data])) 140 | pairwise_distances = math_ops.multiply(pairwise_distances, mask_offdiagonals) 141 | return pairwise_distances 142 | 143 | def triplet_semihard_loss(embeddings,labels, margin=1.0): 144 | """Computes the triplet loss with semi-hard negative mining. 145 | 146 | The loss encourages the positive distances (between a pair of embeddings with 147 | the same labels) to be smaller than the minimum negative distance among 148 | which are at least greater than the positive distance plus the margin constant 149 | (called semi-hard negative) in the mini-batch. If no such negative exists, 150 | uses the largest negative distance instead. 151 | See: https://arxiv.org/abs/1503.03832. 152 | 153 | Args: 154 | labels: 1-D tf.int32 `Tensor` with shape [batch_size] of 155 | multiclass integer labels. 156 | embeddings: 2-D float `Tensor` of embedding vectors. Embeddings should 157 | be l2 normalized. 158 | margin: Float, margin term in the loss definition. 159 | 160 | Returns: 161 | triplet_loss: tf.float32 scalar. 162 | """ 163 | # Reshape [batch_size] label tensor to a [batch_size, 1] label tensor. 164 | #pdist_matrix = cdist(embeddings, embeddings, metric=metric) 165 | 166 | lshape = array_ops.shape(labels) 167 | assert lshape.shape == 1 168 | labels = array_ops.reshape(labels, [lshape[0], 1]) 169 | 170 | # Build pairwise squared distance matrix. 171 | pdist_matrix = pairwise_distance(embeddings, squared=True) 172 | # Build pairwise binary adjacency matrix. 173 | adjacency = math_ops.equal(labels, array_ops.transpose(labels)) 174 | # Invert so we can select negatives only. 175 | adjacency_not = math_ops.logical_not(adjacency) 176 | 177 | batch_size = array_ops.size(labels) 178 | 179 | # Compute the mask. 180 | ## Is there any element with different label and is farther than me? If Yes, then there exists a semi-hard negative 181 | pdist_matrix_tile = array_ops.tile(pdist_matrix, [batch_size, 1]) 182 | mask = math_ops.logical_and( 183 | array_ops.tile(adjacency_not, [batch_size, 1]), 184 | math_ops.greater( 185 | pdist_matrix_tile, array_ops.reshape( 186 | array_ops.transpose(pdist_matrix), [-1, 1]))) 187 | 188 | mask_final = array_ops.reshape( 189 | math_ops.greater( 190 | math_ops.reduce_sum( 191 | tf.cast(mask, dtype=dtypes.float32), 1, keepdims=True), 192 | 0.0), [batch_size, batch_size]) 193 | mask_final = array_ops.transpose(mask_final) 194 | 195 | adjacency_not = tf.cast(adjacency_not, dtype=dtypes.float32) 196 | 197 | mask = tf.cast(mask, dtype=dtypes.float32) 198 | 199 | # negatives_outside: smallest D_an where D_an > D_ap. 200 | negatives_outside = array_ops.reshape( 201 | masked_minimum(pdist_matrix_tile, mask), [batch_size, batch_size]) 202 | negatives_outside = array_ops.transpose(negatives_outside) 203 | 204 | # negatives_inside: largest D_an. 205 | negatives_inside = array_ops.tile( 206 | masked_maximum(pdist_matrix, adjacency_not), [1, batch_size]) 207 | 208 | 209 | semi_hard_negatives = array_ops.where( 210 | mask_final, negatives_outside, negatives_inside) 211 | 212 | 213 | if isinstance(margin, numbers.Real): 214 | # diff = tf.maximum(diff + margin, 0.0) 215 | loss_mat = pdist_matrix - semi_hard_negatives + margin 216 | elif margin == 'soft': 217 | # diff = tf.nn.softplus(diff) 218 | loss_mat = pdist_matrix - semi_hard_negatives 219 | elif margin.lower() == 'none': 220 | pass 221 | else: 222 | raise NotImplementedError( 223 | 'The margin {} is not implemented in batch_hard'.format(margin)) 224 | 225 | 226 | mask_positives = tf.cast( 227 | adjacency, dtype=dtypes.float32) - array_ops.diag( 228 | array_ops.ones([batch_size])) 229 | 230 | 231 | if isinstance(margin, numbers.Real): 232 | print('Margin is real') 233 | triplet_loss_result = math_ops.maximum(tf.boolean_mask(loss_mat, tf.cast(mask_positives, tf.bool)), 234 | 0.0) 235 | assert_op = tf.Assert(tf.equal(tf.rank(triplet_loss_result), 1), ['Rank of image must be equal to 1.']) 236 | with tf.control_dependencies([assert_op]): 237 | triplet_loss = triplet_loss_result 238 | elif margin == 'soft': 239 | triplet_loss_result = tf.nn.softplus(tf.boolean_mask(loss_mat, tf.cast(mask_positives, tf.bool))) 240 | assert_op = tf.Assert(tf.equal(tf.rank(triplet_loss_result), 1), ['Rank of image must be equal to 1.']) 241 | with tf.control_dependencies([assert_op]): 242 | triplet_loss = triplet_loss_result 243 | elif margin.lower() == 'none': 244 | pass 245 | else: 246 | raise NotImplementedError( 247 | 'The margin {} is not implemented in batch_hard'.format(margin)) 248 | 249 | return triplet_loss 250 | -------------------------------------------------------------------------------- /nets/densenet169.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 pudae. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the definition of the DenseNet architecture. 16 | 17 | As described in https://arxiv.org/abs/1608.06993. 18 | 19 | Densely Connected Convolutional Networks 20 | Gao Huang, Zhuang Liu, Kilian Q. Weinberger, Laurens van der Maaten 21 | """ 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import tensorflow as tf 27 | 28 | slim = tf.contrib.slim 29 | #import constants as const 30 | #import configuration as config 31 | #import nets.nn_utils as nn_utils 32 | #import utils.os_utils as os_utils 33 | import os 34 | 35 | 36 | @slim.add_arg_scope 37 | def _global_avg_pool2d(inputs, data_format='NHWC', scope=None, outputs_collections=None): 38 | with tf.variable_scope(scope, 'xx', [inputs]) as sc: 39 | axis = [1, 2] if data_format == 'NHWC' else [2, 3] 40 | net = tf.reduce_mean(inputs, axis=axis, keep_dims=True) 41 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 42 | return net 43 | 44 | 45 | @slim.add_arg_scope 46 | def _conv(inputs, num_filters, kernel_size, stride=1, dropout_rate=None, 47 | scope=None, outputs_collections=None): 48 | with tf.variable_scope(scope, 'xx', [inputs]) as sc: 49 | net = slim.batch_norm(inputs) 50 | net = tf.nn.relu(net) 51 | net = slim.conv2d(net, num_filters, kernel_size) 52 | 53 | if dropout_rate: 54 | net = tf.nn.dropout(net,dropout_rate) 55 | 56 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 57 | 58 | return net 59 | 60 | 61 | @slim.add_arg_scope 62 | def _conv_block(inputs, num_filters, data_format='NHWC', scope=None, outputs_collections=None): 63 | with tf.variable_scope(scope, 'conv_blockx', [inputs]) as sc: 64 | net = inputs 65 | net = _conv(net, num_filters*4, 1, scope='x1') 66 | net = _conv(net, num_filters, 3, scope='x2') 67 | if data_format == 'NHWC': 68 | net = tf.concat([inputs, net], axis=3) 69 | else: # "NCHW" 70 | net = tf.concat([inputs, net], axis=1) 71 | 72 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 73 | 74 | return net 75 | 76 | 77 | @slim.add_arg_scope 78 | def _dense_block(inputs, num_layers, num_filters, growth_rate, 79 | grow_num_filters=True, scope=None, outputs_collections=None): 80 | 81 | with tf.variable_scope(scope, 'dense_blockx', [inputs]) as sc: 82 | net = inputs 83 | for i in range(num_layers): 84 | branch = i + 1 85 | net = _conv_block(net, growth_rate, scope='conv_block'+str(branch)) 86 | 87 | if grow_num_filters: 88 | num_filters += growth_rate 89 | 90 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 91 | 92 | return net, num_filters 93 | 94 | 95 | @slim.add_arg_scope 96 | def _transition_block(inputs, num_filters, compression=1.0, 97 | scope=None, outputs_collections=None): 98 | 99 | num_filters = int(num_filters * compression) 100 | with tf.variable_scope(scope, 'transition_blockx', [inputs]) as sc: 101 | net = inputs 102 | net = _conv(net, num_filters, 1, scope='blk') 103 | 104 | net = slim.avg_pool2d(net, 2) 105 | 106 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 107 | 108 | return net, num_filters 109 | 110 | 111 | def densenet(inputs, 112 | num_classes=1000, 113 | reduction=None, 114 | growth_rate=None, 115 | num_filters=None, 116 | num_layers=None, 117 | dropout_rate=None, 118 | data_format='NHWC', 119 | is_training=True, 120 | reuse=None, 121 | scope=None): 122 | assert reduction is not None 123 | assert growth_rate is not None 124 | assert num_filters is not None 125 | assert num_layers is not None 126 | 127 | compression = 1.0 - reduction 128 | num_dense_blocks = len(num_layers) 129 | 130 | if data_format == 'NCHW': 131 | inputs = tf.transpose(inputs, [0, 3, 1, 2]) 132 | 133 | with tf.variable_scope(scope, 'densenetxxx', [inputs, num_classes], 134 | reuse=reuse) as sc: 135 | end_points_collection = sc.name + '_end_points' 136 | with slim.arg_scope([slim.batch_norm, slim.dropout], 137 | is_training=is_training), \ 138 | slim.arg_scope([slim.conv2d, _conv, _conv_block, 139 | _dense_block, _transition_block], 140 | outputs_collections=end_points_collection), \ 141 | slim.arg_scope([_conv], dropout_rate=dropout_rate): 142 | net = inputs 143 | 144 | # initial convolution 145 | net = slim.conv2d(net, num_filters, 7, stride=2, scope='conv1') 146 | 147 | net = slim.batch_norm(net) 148 | net = tf.nn.relu(net) 149 | net = slim.max_pool2d(net, 3, stride=2, padding='SAME') 150 | 151 | # blocks 152 | for i in range(num_dense_blocks - 1): 153 | # dense blocks 154 | net, num_filters = _dense_block(net, num_layers[i], num_filters, 155 | growth_rate, 156 | scope='dense_block' + str(i+1)) 157 | 158 | # Add transition_block 159 | net, num_filters = _transition_block(net, num_filters, 160 | compression=compression, 161 | scope='transition_block' + str(i+1)) 162 | 163 | net, num_filters = _dense_block( 164 | net, num_layers[-1], num_filters, 165 | growth_rate, 166 | scope='dense_block' + str(num_dense_blocks)) 167 | 168 | # final blocks 169 | with tf.variable_scope('final_block', [inputs]): 170 | net = slim.batch_norm(net) 171 | net = tf.nn.relu(net) 172 | net = _global_avg_pool2d(net, scope='global_avg_pool') 173 | 174 | net = slim.conv2d(net, num_classes, 1, 175 | biases_initializer=tf.zeros_initializer(), 176 | scope='logits') 177 | 178 | 179 | end_points = slim.utils.convert_collection_to_dict( 180 | end_points_collection) 181 | 182 | if num_classes is not None: 183 | end_points['predictions'] = slim.softmax(net, scope='predictions') 184 | 185 | return net, end_points 186 | 187 | 188 | def densenet121(inputs, num_classes=1000, data_format='NHWC', is_training=True, reuse=None): 189 | return densenet(inputs, 190 | num_classes=num_classes, 191 | reduction=0.5, 192 | growth_rate=32, 193 | num_filters=64, 194 | num_layers=[6,12,24,16], 195 | data_format=data_format, 196 | is_training=is_training, 197 | reuse=reuse, 198 | scope='densenet121') 199 | densenet121.default_image_size = 224 200 | 201 | 202 | def densenet161(inputs, num_classes=1000, data_format='NHWC', is_training=True, reuse=None): 203 | return densenet(inputs, 204 | num_classes=num_classes, 205 | reduction=0.5, 206 | growth_rate=48, 207 | num_filters=96, 208 | num_layers=[6,12,36,24], 209 | data_format=data_format, 210 | is_training=is_training, 211 | reuse=reuse, 212 | scope='densenet161') 213 | densenet161.default_image_size = 224 214 | 215 | 216 | def densenet169(inputs, num_classes=1000, data_format='NHWC', is_training=True, reuse=None): 217 | return densenet(inputs, 218 | num_classes=num_classes, 219 | reduction=0.5, 220 | growth_rate=32, 221 | num_filters=64, 222 | num_layers=[6,12,32,32], 223 | data_format=data_format, 224 | is_training=is_training, 225 | reuse=reuse, 226 | scope='densenet169') 227 | densenet169.default_image_size = 224 228 | 229 | 230 | def densenet_arg_scope(weight_decay=1e-4, 231 | batch_norm_decay=0.999, 232 | batch_norm_epsilon=1e-5, 233 | data_format='NHWC'): 234 | with slim.arg_scope([slim.conv2d, slim.batch_norm, slim.avg_pool2d, slim.max_pool2d, 235 | _conv_block, _global_avg_pool2d], 236 | data_format=data_format): 237 | with slim.arg_scope([slim.conv2d], 238 | weights_regularizer=slim.l2_regularizer(weight_decay), 239 | activation_fn=None, 240 | biases_initializer=None): 241 | with slim.arg_scope([slim.batch_norm], 242 | scale=True, 243 | decay=batch_norm_decay, 244 | epsilon=batch_norm_epsilon) as scope: 245 | return scope 246 | 247 | 248 | def endpoints(image, is_training): 249 | if image.get_shape().ndims != 4: 250 | raise ValueError('Input must be of size [batch, height, width, 3]') 251 | 252 | image = tf.divide(image, 255.0) 253 | weight_decay = 0.0001 254 | data_format = 'NHWC' 255 | 256 | with tf.contrib.slim.arg_scope(densenet_arg_scope(weight_decay=weight_decay, data_format=data_format)): 257 | _, endpoints = densenet(image, 258 | num_classes=1000, 259 | reduction=0.5, 260 | growth_rate=32, 261 | num_filters=64, 262 | dropout_rate=None, 263 | num_layers=[6,12,32,32], 264 | data_format=data_format, 265 | is_training=is_training, 266 | reuse=None, 267 | scope='densenet169') 268 | 269 | endpoints['model_output'] = endpoints['global_pool'] = tf.reduce_mean( 270 | endpoints['densenet169/dense_block4'], [1, 2], name='global_pool', keep_dims=False) 271 | 272 | return endpoints, 'densenet169' 273 | 274 | -------------------------------------------------------------------------------- /nets/resnet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains building blocks for various versions of Residual Networks. 16 | 17 | Residual networks (ResNets) were proposed in: 18 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015 20 | 21 | More variants were introduced in: 22 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 23 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016 24 | 25 | We can obtain different ResNet variants by changing the network depth, width, 26 | and form of residual unit. This module implements the infrastructure for 27 | building them. Concrete ResNet units and full ResNet networks are implemented in 28 | the accompanying resnet_v1.py and resnet_v2.py modules. 29 | 30 | Compared to https://github.com/KaimingHe/deep-residual-networks, in the current 31 | implementation we subsample the output activations in the last residual unit of 32 | each block, instead of subsampling the input activations in the first residual 33 | unit of each block. The two implementations give identical results but our 34 | implementation is more memory efficient. 35 | """ 36 | from __future__ import absolute_import 37 | from __future__ import division 38 | from __future__ import print_function 39 | 40 | import collections 41 | import tensorflow as tf 42 | 43 | slim = tf.contrib.slim 44 | 45 | 46 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): 47 | """A named tuple describing a ResNet block. 48 | 49 | Its parts are: 50 | scope: The scope of the `Block`. 51 | unit_fn: The ResNet unit function which takes as input a `Tensor` and 52 | returns another `Tensor` with the output of the ResNet unit. 53 | args: A list of length equal to the number of units in the `Block`. The list 54 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the 55 | block to serve as argument to unit_fn. 56 | """ 57 | 58 | 59 | def subsample(inputs, factor, scope=None): 60 | """Subsamples the input along the spatial dimensions. 61 | 62 | Args: 63 | inputs: A `Tensor` of size [batch, height_in, width_in, channels]. 64 | factor: The subsampling factor. 65 | scope: Optional variable_scope. 66 | 67 | Returns: 68 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the 69 | input, either intact (if factor == 1) or subsampled (if factor > 1). 70 | """ 71 | if factor == 1: 72 | return inputs 73 | else: 74 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) 75 | 76 | 77 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 78 | """Strided 2-D convolution with 'SAME' padding. 79 | 80 | When stride > 1, then we do explicit zero-padding, followed by conv2d with 81 | 'VALID' padding. 82 | 83 | Note that 84 | 85 | net = conv2d_same(inputs, num_outputs, 3, stride=stride) 86 | 87 | is equivalent to 88 | 89 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') 90 | net = subsample(net, factor=stride) 91 | 92 | whereas 93 | 94 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') 95 | 96 | is different when the input's height or width is even, which is why we add the 97 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). 98 | 99 | Args: 100 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. 101 | num_outputs: An integer, the number of output filters. 102 | kernel_size: An int with the kernel_size of the filters. 103 | stride: An integer, the output stride. 104 | rate: An integer, rate for atrous convolution. 105 | scope: Scope. 106 | 107 | Returns: 108 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with 109 | the convolution output. 110 | """ 111 | if stride == 1: 112 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate, 113 | padding='SAME', scope=scope) 114 | else: 115 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 116 | pad_total = kernel_size_effective - 1 117 | pad_beg = pad_total // 2 118 | pad_end = pad_total - pad_beg 119 | inputs = tf.pad(inputs, 120 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) 121 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, 122 | rate=rate, padding='VALID', scope=scope) 123 | 124 | 125 | @slim.add_arg_scope 126 | def stack_blocks_dense(net, blocks, output_stride=None, 127 | outputs_collections=None): 128 | """Stacks ResNet `Blocks` and controls output feature density. 129 | 130 | First, this function creates scopes for the ResNet in the form of 131 | 'block_name/unit_1', 'block_name/unit_2', etc. 132 | 133 | Second, this function allows the user to explicitly control the ResNet 134 | output_stride, which is the ratio of the input to output spatial resolution. 135 | This is useful for dense prediction tasks such as semantic segmentation or 136 | object detection. 137 | 138 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a 139 | factor of 2 when transitioning between consecutive ResNet blocks. This results 140 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to 141 | half the nominal network stride (e.g., output_stride=4), then we compute 142 | responses twice. 143 | 144 | Control of the output feature density is implemented by atrous convolution. 145 | 146 | Args: 147 | net: A `Tensor` of size [batch, height, width, channels]. 148 | blocks: A list of length equal to the number of ResNet `Blocks`. Each 149 | element is a ResNet `Block` object describing the units in the `Block`. 150 | output_stride: If `None`, then the output will be computed at the nominal 151 | network stride. If output_stride is not `None`, it specifies the requested 152 | ratio of input to output spatial resolution, which needs to be equal to 153 | the product of unit strides from the start up to some level of the ResNet. 154 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, 155 | then valid values for the output_stride are 1, 2, 6, 24 or None (which 156 | is equivalent to output_stride=24). 157 | outputs_collections: Collection to add the ResNet block outputs. 158 | 159 | Returns: 160 | net: Output tensor with stride equal to the specified output_stride. 161 | 162 | Raises: 163 | ValueError: If the target output_stride is not valid. 164 | """ 165 | # The current_stride variable keeps track of the effective stride of the 166 | # activations. This allows us to invoke atrous convolution whenever applying 167 | # the next residual unit would result in the activations having stride larger 168 | # than the target output_stride. 169 | current_stride = 1 170 | 171 | # The atrous convolution rate parameter. 172 | rate = 1 173 | 174 | for block in blocks: 175 | with tf.variable_scope(block.scope, 'block', [net]) as sc: 176 | for i, unit in enumerate(block.args): 177 | if output_stride is not None and current_stride > output_stride: 178 | raise ValueError('The target output_stride cannot be reached.') 179 | 180 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]): 181 | # If we have reached the target output_stride, then we need to employ 182 | # atrous convolution with stride=1 and multiply the atrous rate by the 183 | # current unit's stride for use in subsequent layers. 184 | if output_stride is not None and current_stride == output_stride: 185 | net = block.unit_fn(net, rate=rate, **dict(unit, stride=1)) 186 | rate *= unit.get('stride', 1) 187 | 188 | else: 189 | net = block.unit_fn(net, rate=1, **unit) 190 | current_stride *= unit.get('stride', 1) 191 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 192 | 193 | if output_stride is not None and current_stride != output_stride: 194 | raise ValueError('The target output_stride cannot be reached.') 195 | 196 | return net 197 | 198 | 199 | def resnet_arg_scope(weight_decay=0.0001, 200 | batch_norm_decay=0.997, 201 | batch_norm_epsilon=1e-5, 202 | batch_norm_scale=True, 203 | activation_fn=tf.nn.relu, 204 | use_batch_norm=True): 205 | """Defines the default ResNet arg scope. 206 | 207 | TODO(gpapan): The batch-normalization related default values above are 208 | appropriate for use in conjunction with the reference ResNet models 209 | released at https://github.com/KaimingHe/deep-residual-networks. When 210 | training ResNets from scratch, they might need to be tuned. 211 | 212 | Args: 213 | weight_decay: The weight decay to use for regularizing the model. 214 | batch_norm_decay: The moving average decay when estimating layer activation 215 | statistics in batch normalization. 216 | batch_norm_epsilon: Small constant to prevent division by zero when 217 | normalizing activations by their variance in batch normalization. 218 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 219 | activations in the batch normalization layer. 220 | activation_fn: The activation function which is used in ResNet. 221 | use_batch_norm: Whether or not to use batch normalization. 222 | 223 | Returns: 224 | An `arg_scope` to use for the resnet models. 225 | """ 226 | batch_norm_params = { 227 | 'decay': batch_norm_decay, 228 | 'epsilon': batch_norm_epsilon, 229 | 'scale': batch_norm_scale, 230 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 231 | } 232 | 233 | with slim.arg_scope( 234 | [slim.conv2d], 235 | weights_regularizer=slim.l2_regularizer(weight_decay), 236 | weights_initializer=slim.variance_scaling_initializer(), 237 | activation_fn=activation_fn, 238 | normalizer_fn=slim.batch_norm if use_batch_norm else None, 239 | normalizer_params=batch_norm_params): 240 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 241 | # The following implies padding='SAME' for pool1, which makes feature 242 | # alignment easier for dense prediction tasks. This is also used in 243 | # https://github.com/facebook/fb.resnet.torch. However the accompanying 244 | # code of 'Deep Residual Learning for Image Recognition' uses 245 | # padding='VALID' for pool1. You can switch to that choice by setting 246 | # slim.arg_scope([slim.max_pool2d], padding='VALID'). 247 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: 248 | return arg_sc 249 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /embed_tf2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import h5py 4 | import json 5 | import psutil 6 | import common 7 | import numpy as np 8 | import logging.config 9 | import os.path as osp 10 | import tensorflow as tf 11 | import constants as const 12 | from itertools import count 13 | from utils import os_utils 14 | from aggregators import AGGREGATORS 15 | from argparse import ArgumentParser 16 | from model.embedding_model import EmbeddingModel 17 | 18 | 19 | 20 | 21 | parser = ArgumentParser(description='Embed a dataset using a trained network.') 22 | 23 | # Required 24 | 25 | parser.add_argument( 26 | '--experiment_root', required=True, 27 | help='Location used to store checkpoints and dumped data.') 28 | 29 | parser.add_argument( 30 | '--dataset', required=True, 31 | help='Path to the dataset csv file to be embedded.') 32 | 33 | # Optional 34 | 35 | parser.add_argument( 36 | '--image_root', type=common.readable_directory, 37 | help='Path that will be pre-pended to the filenames in the train_set csv.') 38 | 39 | parser.add_argument( 40 | '--checkpoint', default=None, 41 | help='Name of checkpoint file of the trained network within the experiment ' 42 | 'root. Uses the last checkpoint if not provided.') 43 | 44 | parser.add_argument( 45 | '--loading_threads', default=8, type=common.positive_int, 46 | help='Number of threads used for parallel data loading.') 47 | 48 | 49 | 50 | parser.add_argument( 51 | '--batch_size', default=64, type=common.positive_int, 52 | help='Batch size used during evaluation, adapt based on available memory.') 53 | 54 | parser.add_argument( 55 | '--filename', default=None, 56 | help='Name of the HDF5 file in which to store the embeddings, relative to' 57 | ' the `experiment_root` location. If omitted, appends `_embeddings.h5`' 58 | ' to the dataset name.') 59 | 60 | parser.add_argument( 61 | '--foldername', default=None, 62 | help='Name of dir to save embeds') 63 | 64 | parser.add_argument( 65 | '--flip_augment', action='store_true', default=False, 66 | help='When this flag is provided, flip augmentation is performed.') 67 | 68 | parser.add_argument( 69 | '--crop_augment', choices=['center', 'avgpool', 'five'], default=None, 70 | help='When this flag is provided, crop augmentation is performed.' 71 | '`avgpool` means the full image at the precrop size is used and ' 72 | 'the augmentation is performed by the average pooling. `center` means' 73 | 'only the center crop is used and `five` means the four corner and ' 74 | 'center crops are used. When not provided, by default the image is ' 75 | 'resized to network input size.') 76 | 77 | parser.add_argument( 78 | '--aggregator', choices=AGGREGATORS.keys(), default=None, 79 | help='The type of aggregation used to combine the different embeddings ' 80 | 'after augmentation.') 81 | 82 | parser.add_argument( 83 | '--quiet', action='store_true', default=False, 84 | help='Don\'t be so verbose.') 85 | 86 | 87 | def flip_augment(image, fid, pid): 88 | """ Returns both the original and the horizontal flip of an image. """ 89 | images = tf.stack([image, tf.reverse(image, [1])]) 90 | return images, tf.stack([fid]*2), tf.stack([pid]*2) 91 | 92 | def five_crops(image, crop_size): 93 | """ Returns the central and four corner crops of `crop_size` from `image`. """ 94 | image_size = tf.shape(image)[:2] 95 | crop_margin = tf.subtract(image_size, crop_size) 96 | assert_size = tf.debugging.assert_non_negative( 97 | crop_margin, message='Crop size must be smaller or equal to the image size.') 98 | with tf.control_dependencies([assert_size]): 99 | top_left = tf.compat.v1.floor_div(crop_margin, 2) 100 | bottom_right = tf.math.add(top_left, crop_size) 101 | 102 | center = image[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] 103 | top_left = image[:-crop_margin[0], :-crop_margin[1]] 104 | top_right = image[:-crop_margin[0], crop_margin[1]:] 105 | bottom_left = image[crop_margin[0]:, :-crop_margin[1]] 106 | bottom_right = image[crop_margin[0]:, crop_margin[1]:] 107 | return center, top_left, top_right, bottom_left, bottom_right 108 | 109 | 110 | def main(argv): 111 | # Verify that parameters are set correctly. 112 | args = parser.parse_args(argv) 113 | 114 | if not os.path.exists(args.dataset): 115 | return 116 | 117 | # Possibly auto-generate the output filename. 118 | if args.filename is None: 119 | basename = os.path.basename(args.dataset) 120 | args.filename = os.path.splitext(basename)[0] + '_embeddings.h5' 121 | 122 | os_utils.touch_dir(os.path.join(args.experiment_root,args.foldername)) 123 | 124 | log_file = os.path.join(args.experiment_root,args.foldername, "embed") 125 | logging.config.dictConfig(common.get_logging_dict(log_file)) 126 | log = logging.getLogger('embed') 127 | 128 | args.filename = os.path.join(args.experiment_root,args.foldername, args.filename) 129 | var_filepath = os.path.join(args.experiment_root, args.foldername, args.filename[:-3] + '_var.txt') 130 | # Load the args from the original experiment. 131 | args_file = os.path.join(args.experiment_root, 'args.json') 132 | 133 | if os.path.isfile(args_file): 134 | if not args.quiet: 135 | print('Loading args from {}.'.format(args_file)) 136 | with open(args_file, 'r') as f: 137 | args_resumed = json.load(f) 138 | 139 | # Add arguments from training. 140 | for key, value in args_resumed.items(): 141 | args.__dict__.setdefault(key, value) 142 | 143 | # A couple special-cases and sanity checks 144 | if (args_resumed['crop_augment']) == (args.crop_augment is None): 145 | print('WARNING: crop augmentation differs between training and ' 146 | 'evaluation.') 147 | args.image_root = args.image_root or args_resumed['image_root'] 148 | else: 149 | raise IOError('`args.json` could not be found in: {}'.format(args_file)) 150 | 151 | # Check a proper aggregator is provided if augmentation is used. 152 | if args.flip_augment or args.crop_augment == 'five': 153 | if args.aggregator is None: 154 | print('ERROR: Test time augmentation is performed but no aggregator' 155 | 'was specified.') 156 | exit(1) 157 | else: 158 | if args.aggregator is not None: 159 | print('ERROR: No test time augmentation that needs aggregating is ' 160 | 'performed but an aggregator was specified.') 161 | exit(1) 162 | 163 | if not args.quiet: 164 | print('Evaluating using the following parameters:') 165 | for key, value in sorted(vars(args).items()): 166 | print('{}: {}'.format(key, value)) 167 | 168 | # Load the data from the CSV file. 169 | _, data_fids = common.load_dataset(args.dataset, args.image_root) 170 | 171 | net_input_size = (args.net_input_height, args.net_input_width) 172 | pre_crop_size = (args.pre_crop_height, args.pre_crop_width) 173 | 174 | # Setup a tf Dataset containing all images. 175 | dataset = tf.data.Dataset.from_tensor_slices(data_fids) 176 | 177 | # Convert filenames to actual image tensors. 178 | dataset = dataset.map( 179 | lambda fid: common.fid_to_image( 180 | fid, tf.constant('dummy'), image_root=args.image_root, 181 | image_size=pre_crop_size if args.crop_augment else net_input_size), 182 | num_parallel_calls=args.loading_threads) 183 | 184 | # Augment the data if specified by the arguments. 185 | # `modifiers` is a list of strings that keeps track of which augmentations 186 | # have been applied, so that a human can understand it later on. 187 | modifiers = ['original'] 188 | if args.flip_augment: 189 | dataset = dataset.map(flip_augment) 190 | dataset = dataset.apply(tf.contrib.data.unbatch()) 191 | modifiers = [o + m for m in ['', '_flip'] for o in modifiers] 192 | 193 | if args.crop_augment == 'center': 194 | dataset = dataset.map(lambda im, fid, pid: 195 | (five_crops(im, net_input_size)[0], fid, pid)) 196 | modifiers = [o + '_center' for o in modifiers] 197 | elif args.crop_augment == 'five': 198 | dataset = dataset.map(lambda im, fid, pid: ( 199 | tf.stack(five_crops(im, net_input_size)), 200 | tf.stack([fid]*5), 201 | tf.stack([pid]*5))) 202 | dataset = dataset.apply(tf.contrib.data.unbatch()) 203 | modifiers = [o + m for o in modifiers for m in [ 204 | '_center', '_top_left', '_top_right', '_bottom_left', '_bottom_right']] 205 | elif args.crop_augment == 'avgpool': 206 | modifiers = [o + '_avgpool' for o in modifiers] 207 | else: 208 | modifiers = [o + '_resize' for o in modifiers] 209 | 210 | emb_model = EmbeddingModel(args) 211 | 212 | # Group it back into PK batches. 213 | dataset = dataset.batch(args.batch_size) 214 | dataset = dataset.map(lambda im, fid, pid: (emb_model.preprocess_input(im), fid, pid)) 215 | # Overlap producing and consuming. 216 | dataset = dataset.prefetch(1) 217 | tf.keras.backend.set_learning_phase(0) 218 | 219 | 220 | with h5py.File(args.filename, 'w') as f_out: 221 | 222 | ckpt = tf.train.Checkpoint(step=tf.Variable(1), net=emb_model) 223 | manager = tf.train.CheckpointManager(ckpt, osp.join(args.experiment_root, 'tf_ckpts'),max_to_keep=1) 224 | ckpt.restore(manager.latest_checkpoint) 225 | if manager.latest_checkpoint: 226 | print("Restored from {}".format(manager.latest_checkpoint)) 227 | else: 228 | print("Initializing from scratch.") 229 | 230 | emb_storage = np.zeros( 231 | (len(data_fids) * len(modifiers), args.embedding_dim), np.float32) 232 | 233 | # for batch_idx,batch in enumerate(dataset): 234 | dataset_iter = iter(dataset) 235 | for start_idx in count(step=args.batch_size): 236 | 237 | try: 238 | images, _, _ = next(dataset_iter) 239 | emb = emb_model(images) 240 | emb_storage[start_idx:start_idx + len(emb)] += emb 241 | print('\rEmbedded batch {}-{}/{}'.format( 242 | start_idx, start_idx + len(emb), len(emb_storage)), 243 | flush=True, end='') 244 | except StopIteration: 245 | break # This just indicates the end of the dataset. 246 | 247 | 248 | if not args.quiet: 249 | print("Done with embedding, aggregating augmentations...", flush=True) 250 | 251 | if len(modifiers) > 1: 252 | # Pull out the augmentations into a separate first dimension. 253 | emb_storage = emb_storage.reshape(len(data_fids), len(modifiers), -1) 254 | emb_storage = emb_storage.transpose((1,0,2)) # (Aug,FID,128D) 255 | 256 | # Store the embedding of all individual variants too. 257 | emb_dataset = f_out.create_dataset('emb_aug', data=emb_storage) 258 | 259 | # Aggregate according to the specified parameter. 260 | emb_storage = AGGREGATORS[args.aggregator](emb_storage) 261 | 262 | # Store the final embeddings. 263 | emb_dataset = f_out.create_dataset('emb', data=emb_storage) 264 | 265 | 266 | # Store information about the produced augmentation and in case no crop 267 | # augmentation was used, if the images are resized or avg pooled. 268 | f_out.create_dataset('augmentation_types', data=np.asarray(modifiers, dtype='|S')) 269 | 270 | if __name__ == '__main__': 271 | 272 | 273 | 274 | arg_experiment_root = const.experiment_root_dir 275 | 276 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 277 | for subset in ['test']: 278 | exp_dir = 'cub_densenet_direct_normalize_npairs_loss_m_0.2' 279 | folder_name = 'emb' 280 | dataset_name = 'cub' 281 | if dataset_name == 'cub': 282 | csv_file = 'cub' 283 | elif dataset_name == 'inshop': 284 | csv_file = 'deep_fashion' 285 | elif dataset_name == 'stanford': 286 | csv_file = 'stanford_online' 287 | else: 288 | raise NotImplementedError('dataset {} not valid'.format(dataset_name)) 289 | 290 | args = [ 291 | '--experiment_root', arg_experiment_root + exp_dir, 292 | '--dataset', './data/'+csv_file+'_'+subset+'.csv', 293 | '--filename', subset+'_embeddings_augmented.h5', 294 | '--foldername',folder_name, 295 | '--crop_augment','center', ## Make sure it follows the training resolution 296 | # '--batch_size','40', 297 | ] 298 | main(args) 299 | 300 | -------------------------------------------------------------------------------- /embed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import h5py 4 | import json 5 | import psutil 6 | import common 7 | import numpy as np 8 | import tensorflow as tf 9 | import constants as const 10 | from itertools import count 11 | from aggregators import AGGREGATORS 12 | from argparse import ArgumentParser 13 | from importlib import import_module 14 | from utils import os_utils 15 | from tensorflow.contrib import slim 16 | import logging.config 17 | 18 | 19 | parser = ArgumentParser(description='Embed a dataset using a trained network.') 20 | 21 | # Required 22 | 23 | parser.add_argument( 24 | '--experiment_root', required=True, 25 | help='Location used to store checkpoints and dumped data.') 26 | 27 | parser.add_argument( 28 | '--dataset', required=True, 29 | help='Path to the dataset csv file to be embedded.') 30 | 31 | # Optional 32 | 33 | parser.add_argument( 34 | '--image_root', type=common.readable_directory, 35 | help='Path that will be pre-pended to the filenames in the train_set csv.') 36 | 37 | parser.add_argument( 38 | '--checkpoint', default=None, 39 | help='Name of checkpoint file of the trained network within the experiment ' 40 | 'root. Uses the last checkpoint if not provided.') 41 | 42 | parser.add_argument( 43 | '--loading_threads', default=8, type=common.positive_int, 44 | help='Number of threads used for parallel data loading.') 45 | 46 | 47 | 48 | parser.add_argument( 49 | '--batch_size', default=256, type=common.positive_int, 50 | help='Batch size used during evaluation, adapt based on available memory.') 51 | 52 | parser.add_argument( 53 | '--filename', default=None, 54 | help='Name of the HDF5 file in which to store the embeddings, relative to' 55 | ' the `experiment_root` location. If omitted, appends `_embeddings.h5`' 56 | ' to the dataset name.') 57 | 58 | parser.add_argument( 59 | '--foldername', default=None, 60 | help='Name of dir to save embeds') 61 | 62 | parser.add_argument( 63 | '--flip_augment', action='store_true', default=False, 64 | help='When this flag is provided, flip augmentation is performed.') 65 | 66 | parser.add_argument( 67 | '--crop_augment', choices=['center', 'avgpool', 'five'], default=None, 68 | help='When this flag is provided, crop augmentation is performed.' 69 | '`avgpool` means the full image at the precrop size is used and ' 70 | 'the augmentation is performed by the average pooling. `center` means' 71 | 'only the center crop is used and `five` means the four corner and ' 72 | 'center crops are used. When not provided, by default the image is ' 73 | 'resized to network input size.') 74 | 75 | parser.add_argument( 76 | '--aggregator', choices=AGGREGATORS.keys(), default=None, 77 | help='The type of aggregation used to combine the different embeddings ' 78 | 'after augmentation.') 79 | 80 | parser.add_argument( 81 | '--quiet', action='store_true', default=False, 82 | help='Don\'t be so verbose.') 83 | 84 | 85 | def flip_augment(image, fid, pid): 86 | """Returns both the original and the horizontal flip of an image. 87 | 88 | Parameters 89 | --------- 90 | img: array of images 91 | fid: array of image filepaths 92 | pid: array of images ids 93 | 94 | """ 95 | 96 | images = tf.stack([image, tf.reverse(image, [1])]) 97 | return images, tf.stack([fid]*2), tf.stack([pid]*2) 98 | 99 | 100 | def five_crops(image, crop_size): 101 | """ Returns the central and four corner crops of `crop_size` from `image`. """ 102 | image_size = tf.shape(image)[:2] 103 | crop_margin = tf.subtract(image_size, crop_size) 104 | assert_size = tf.assert_non_negative( 105 | crop_margin, message='Crop size must be smaller or equal to the image size.') 106 | with tf.control_dependencies([assert_size]): 107 | top_left = tf.floor_div(crop_margin, 2) 108 | bottom_right = tf.add(top_left, crop_size) 109 | center = image[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] 110 | top_left = image[:-crop_margin[0], :-crop_margin[1]] 111 | top_right = image[:-crop_margin[0], crop_margin[1]:] 112 | bottom_left = image[crop_margin[0]:, :-crop_margin[1]] 113 | bottom_right = image[crop_margin[0]:, crop_margin[1]:] 114 | return center, top_left, top_right, bottom_left, bottom_right 115 | 116 | 117 | def main(argv): 118 | # Verify that parameters are set correctly. 119 | args = parser.parse_args(argv) 120 | 121 | if not os.path.exists(args.dataset): 122 | return 123 | 124 | # Possibly auto-generate the output filename. 125 | if args.filename is None: 126 | basename = os.path.basename(args.dataset) 127 | args.filename = os.path.splitext(basename)[0] + '_embeddings.h5' 128 | 129 | os_utils.touch_dir(os.path.join(args.experiment_root,args.foldername)) 130 | 131 | log_file = os.path.join(args.experiment_root,args.foldername, "embed") 132 | logging.config.dictConfig(common.get_logging_dict(log_file)) 133 | log = logging.getLogger('embed') 134 | 135 | args.filename = os.path.join(args.experiment_root,args.foldername, args.filename) 136 | var_filepath = os.path.join(args.experiment_root, args.foldername, args.filename[:-3] + '_var.txt') 137 | # Load the args from the original experiment. 138 | args_file = os.path.join(args.experiment_root, 'args.json') 139 | 140 | if os.path.isfile(args_file): 141 | if not args.quiet: 142 | print('Loading args from {}.'.format(args_file)) 143 | with open(args_file, 'r') as f: 144 | args_resumed = json.load(f) 145 | 146 | # Add arguments from training. 147 | for key, value in args_resumed.items(): 148 | args.__dict__.setdefault(key, value) 149 | 150 | # A couple special-cases and sanity checks 151 | if (args_resumed['crop_augment']) == (args.crop_augment is None): 152 | print('WARNING: crop augmentation differs between training and ' 153 | 'evaluation.') 154 | args.image_root = args.image_root or args_resumed['image_root'] 155 | else: 156 | raise IOError('`args.json` could not be found in: {}'.format(args_file)) 157 | 158 | # Check a proper aggregator is provided if augmentation is used. 159 | if args.flip_augment or args.crop_augment == 'five': 160 | if args.aggregator is None: 161 | print('ERROR: Test time augmentation is performed but no aggregator' 162 | 'was specified.') 163 | exit(1) 164 | else: 165 | if args.aggregator is not None: 166 | print('ERROR: No test time augmentation that needs aggregating is ' 167 | 'performed but an aggregator was specified.') 168 | exit(1) 169 | 170 | if not args.quiet: 171 | print('Evaluating using the following parameters:') 172 | for key, value in sorted(vars(args).items()): 173 | print('{}: {}'.format(key, value)) 174 | 175 | # Load the data from the CSV file. 176 | _, data_fids = common.load_dataset(args.dataset, args.image_root) 177 | 178 | net_input_size = (args.net_input_height, args.net_input_width) 179 | pre_crop_size = (args.pre_crop_height, args.pre_crop_width) 180 | 181 | # Setup a tf Dataset containing all images. 182 | dataset = tf.data.Dataset.from_tensor_slices(data_fids) 183 | 184 | # Convert filenames to actual image tensors. 185 | dataset = dataset.map( 186 | lambda fid: common.fid_to_image( 187 | fid, tf.constant('dummy'), image_root=args.image_root, 188 | image_size=pre_crop_size if args.crop_augment else net_input_size), 189 | num_parallel_calls=args.loading_threads) 190 | 191 | # Augment the data if specified by the arguments. 192 | # `modifiers` is a list of strings that keeps track of which augmentations 193 | # have been applied, so that a human can understand it later on. 194 | modifiers = ['original'] 195 | if args.flip_augment: 196 | dataset = dataset.map(flip_augment) 197 | dataset = dataset.apply(tf.contrib.data.unbatch()) 198 | modifiers = [o + m for m in ['', '_flip'] for o in modifiers] 199 | 200 | if args.crop_augment == 'center': 201 | dataset = dataset.map(lambda im, fid, pid: 202 | (five_crops(im, net_input_size)[0], fid, pid)) 203 | modifiers = [o + '_center' for o in modifiers] 204 | elif args.crop_augment == 'five': 205 | dataset = dataset.map(lambda im, fid, pid: ( 206 | tf.stack(five_crops(im, net_input_size)), 207 | tf.stack([fid]*5), 208 | tf.stack([pid]*5))) 209 | dataset = dataset.apply(tf.contrib.data.unbatch()) 210 | modifiers = [o + m for o in modifiers for m in [ 211 | '_center', '_top_left', '_top_right', '_bottom_left', '_bottom_right']] 212 | elif args.crop_augment == 'avgpool': 213 | modifiers = [o + '_avgpool' for o in modifiers] 214 | else: 215 | modifiers = [o + '_resize' for o in modifiers] 216 | 217 | # Group it back into PK batches. 218 | dataset = dataset.batch(args.batch_size) 219 | 220 | # Overlap producing and consuming. 221 | dataset = dataset.prefetch(1) 222 | 223 | #images, _, _ = dataset.make_one_shot_iterator().get_next() 224 | #init_iter = dataset.make_initializable_iterator() 225 | init_iter = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) 226 | images, _, _ = init_iter.get_next() 227 | iter_init_op = init_iter.make_initializer(dataset) 228 | # Create the model and an embedding head. 229 | model = import_module('nets.' + args.model_name) 230 | head = import_module('heads.' + args.head_name) 231 | 232 | 233 | 234 | images_ph = tf.placeholder(dataset.output_types[0], dataset.output_shapes[0]) 235 | endpoints, body_prefix = model.endpoints(images_ph, is_training=False) 236 | 237 | 238 | 239 | with tf.name_scope('head'): 240 | endpoints = head.head(endpoints, args.embedding_dim, is_training=False) 241 | 242 | 243 | 244 | gpu_options = tf.GPUOptions(allow_growth=True) 245 | gpu_config = tf.ConfigProto(gpu_options=gpu_options) 246 | with h5py.File(args.filename, 'w') as f_out, tf.Session(config=gpu_config) as sess: 247 | # Initialize the network/load the checkpoint. 248 | if args.checkpoint is None: 249 | checkpoint = tf.train.latest_checkpoint(args.experiment_root) 250 | else: 251 | checkpoint = os.path.join(args.experiment_root, args.checkpoint) 252 | if not args.quiet: 253 | print('Restoring from checkpoint: {}'.format(checkpoint)) 254 | tf.train.Saver().restore(sess, checkpoint) 255 | 256 | # Go ahead and embed the whole dataset, with all augmented versions too. 257 | emb_storage = np.zeros( 258 | (len(data_fids) * len(modifiers), args.embedding_dim), np.float32) 259 | 260 | 261 | ##sess.run(init_iter.initializer) 262 | sess.run(iter_init_op) 263 | 264 | for start_idx in count(step=args.batch_size): 265 | try: 266 | current_imgs = sess.run(images) 267 | batch_embedding = endpoints['emb'] 268 | emb = sess.run(batch_embedding,feed_dict={images_ph:current_imgs}) 269 | emb_storage[start_idx:start_idx + len(emb)] += emb 270 | print('\rEmbedded batch {}-{}/{}'.format( 271 | start_idx, start_idx + len(emb), len(emb_storage)), 272 | flush=True, end='') 273 | except tf.errors.OutOfRangeError: 274 | break # This just indicates the end of the dataset. 275 | 276 | 277 | 278 | 279 | if not args.quiet: 280 | print("Done with embedding, aggregating augmentations...", flush=True) 281 | 282 | if len(modifiers) > 1: 283 | # Pull out the augmentations into a separate first dimension. 284 | emb_storage = emb_storage.reshape(len(data_fids), len(modifiers), -1) 285 | emb_storage = emb_storage.transpose((1,0,2)) # (Aug,FID,128D) 286 | 287 | # Store the embedding of all individual variants too. 288 | emb_dataset = f_out.create_dataset('emb_aug', data=emb_storage) 289 | 290 | # Aggregate according to the specified parameter. 291 | emb_storage = AGGREGATORS[args.aggregator](emb_storage) 292 | 293 | # Store the final embeddings. 294 | emb_dataset = f_out.create_dataset('emb', data=emb_storage) 295 | 296 | 297 | # Store information about the produced augmentation and in case no crop 298 | # augmentation was used, if the images are resized or avg pooled. 299 | f_out.create_dataset('augmentation_types', data=np.asarray(modifiers, dtype='|S')) 300 | 301 | 302 | if __name__ == '__main__': 303 | 304 | 305 | 306 | arg_experiment_root = const.experiment_root_dir 307 | 308 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 309 | for subset in ['test']: 310 | exp_dir = 'stanford_inc_v1_direct_npairs_loss_m_1.0' 311 | folder_name = 'emb' 312 | dataset_name = 'stanford' 313 | if dataset_name == 'cub': 314 | csv_file = 'cub' 315 | elif dataset_name == 'inshop': 316 | csv_file = 'deep_fashion' 317 | elif dataset_name == 'stanford': 318 | csv_file = 'stanford_online' 319 | else: 320 | raise NotImplementedError('dataset {} not valid'.format(dataset_name)) 321 | 322 | args = [ 323 | '--experiment_root', arg_experiment_root + exp_dir, 324 | '--dataset', './data/'+csv_file+'_'+subset+'.csv', 325 | '--filename', subset+'_embeddings_augmented.h5', 326 | '--foldername',folder_name, 327 | '--crop_augment', 'center', ## Make sure it follows the training resolution 328 | 329 | ] 330 | main(args) 331 | 332 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | """ A bunch of general utilities shared by train/embed/eval """ 2 | 3 | from argparse import ArgumentTypeError 4 | import logging 5 | import os 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | # Commandline argument parsing 11 | ### 12 | 13 | def check_directory(arg, access=os.W_OK, access_str="writeable"): 14 | """ Check for directory-type argument validity. 15 | 16 | Checks whether the given `arg` commandline argument is either a readable 17 | existing directory, or a createable/writeable directory. 18 | 19 | Args: 20 | arg (string): The commandline argument to check. 21 | access (constant): What access rights to the directory are requested. 22 | access_str (string): Used for the error message. 23 | 24 | Returns: 25 | The string passed din `arg` if the checks succeed. 26 | 27 | Raises: 28 | ArgumentTypeError if the checks fail. 29 | """ 30 | path_head = arg 31 | while path_head: 32 | if os.path.exists(path_head): 33 | if os.access(path_head, access): 34 | # Seems legit, but it still doesn't guarantee a valid path. 35 | # We'll just go with it for now though. 36 | return arg 37 | else: 38 | raise ArgumentTypeError( 39 | 'The provided string `{0}` is not a valid {1} path ' 40 | 'since {2} is an existing folder without {1} access.' 41 | ''.format(arg, access_str, path_head)) 42 | path_head, _ = os.path.split(path_head) 43 | 44 | # No part of the provided string exists and can be written on. 45 | raise ArgumentTypeError('The provided string `{}` is not a valid {}' 46 | ' path.'.format(arg, access_str)) 47 | 48 | 49 | def writeable_directory(arg): 50 | """ To be used as a type for `ArgumentParser.add_argument`. """ 51 | return check_directory(arg, os.W_OK, "writeable") 52 | 53 | 54 | def readable_directory(arg): 55 | """ To be used as a type for `ArgumentParser.add_argument`. """ 56 | return check_directory(arg, os.R_OK, "readable") 57 | 58 | 59 | def number_greater_x(arg, type_, x): 60 | try: 61 | value = type_(arg) 62 | except ValueError: 63 | raise ArgumentTypeError('The argument "{}" is not an {}.'.format( 64 | arg, type_.__name__)) 65 | 66 | if value > x: 67 | return value 68 | else: 69 | raise ArgumentTypeError('Found {} where an {} greater than {} was ' 70 | 'required'.format(arg, type_.__name__, x)) 71 | 72 | 73 | def positive_int(arg): 74 | return number_greater_x(arg, int, 0) 75 | 76 | 77 | def nonnegative_int(arg): 78 | return number_greater_x(arg, int, -1) 79 | 80 | 81 | def positive_float(arg): 82 | return number_greater_x(arg, float, 0) 83 | 84 | 85 | def float_or_string(arg): 86 | """Tries to convert the string to float, otherwise returns the string.""" 87 | try: 88 | return float(arg) 89 | except (ValueError, TypeError): 90 | return arg 91 | 92 | 93 | # Dataset handling 94 | ### 95 | 96 | 97 | def load_dataset(csv_file, image_root, fail_on_missing=True): 98 | """ Loads a dataset .csv file, returning PIDs and FIDs. 99 | 100 | PIDs are the "person IDs", i.e. class names/labels. 101 | FIDs are the "file IDs", which are individual relative filenames. 102 | 103 | Args: 104 | csv_file (string, file-like object): The csv data file to load. 105 | image_root (string): The path to which the image files as stored in the 106 | csv file are relative to. Used for verification purposes. 107 | If this is `None`, no verification at all is made. 108 | fail_on_missing (bool or None): If one or more files from the dataset 109 | are not present in the `image_root`, either raise an IOError (if 110 | True) or remove it from the returned dataset (if False). 111 | 112 | Returns: 113 | (pids, fids) a tuple of numpy string arrays corresponding to the PIDs, 114 | i.e. the identities/classes/labels and the FIDs, i.e. the filenames. 115 | 116 | Raises: 117 | IOError if any one file is missing and `fail_on_missing` is True. 118 | """ 119 | #np.fromregex(csv_file, r'(\d+),"(.+)"', np.object) 120 | 121 | if 'clothing1m' in csv_file: 122 | dataset = np.fromregex(csv_file, r'(\d+),"(.+)"', np.object) 123 | else: 124 | dataset = np.genfromtxt(csv_file, delimiter=',', dtype='|U') 125 | 126 | 127 | 128 | 129 | pids, fids = dataset.T 130 | 131 | # Possibly check if all files exist 132 | if image_root is not None: 133 | missing = np.full(len(fids), False, dtype=bool) 134 | for i, fid in enumerate(fids): 135 | missing[i] = not os.path.isfile(os.path.join(image_root, fid)) 136 | 137 | missing_count = np.sum(missing) 138 | if missing_count > 0: 139 | if fail_on_missing: 140 | raise IOError('Using the `{}` file and `{}` as an image root {}/' 141 | '{} images are missing'.format( 142 | csv_file, image_root, missing_count, len(fids))) 143 | else: 144 | print('[Warning] removing {} missing file(s) from the' 145 | ' dataset.'.format(missing_count)) 146 | # We simply remove the missing files. 147 | fids = fids[np.logical_not(missing)] 148 | pids = pids[np.logical_not(missing)] 149 | 150 | return pids, fids 151 | 152 | 153 | def fid_to_image(fid, pid, image_root, image_size): 154 | # fid = tf.Print(fid,[fid, pid],'fid ::') 155 | """ Loads and resizes an image given by FID. Pass-through the PID. """ 156 | # Since there is no symbolic path.join, we just add a '/' to be sure. 157 | image_encoded = tf.io.read_file(tf.strings.reduce_join([image_root, '/', fid])) 158 | 159 | # tf.image.decode_image doesn't set the shape, not even the dimensionality, 160 | # because it potentially loads animated .gif files. Instead, we use either 161 | # decode_jpeg or decode_png, each of which can decode both. 162 | # Sounds ridiculous, but is true: 163 | # https://github.com/tensorflow/tensorflow/issues/9356#issuecomment-309144064 164 | image_decoded = tf.io.decode_jpeg(image_encoded, channels=3) 165 | image_resized = tf.image.resize(image_decoded, image_size) 166 | 167 | return image_resized, fid, pid 168 | 169 | 170 | def get_logging_dict(name): 171 | return { 172 | 'version': 1, 173 | 'disable_existing_loggers': False, 174 | 'formatters': { 175 | 'standard': { 176 | 'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s' 177 | }, 178 | }, 179 | 'handlers': { 180 | 'stderr': { 181 | 'level': 'INFO', 182 | 'formatter': 'standard', 183 | 'class': 'common.ColorStreamHandler', 184 | 'stream': 'ext://sys.stderr', 185 | }, 186 | 'logfile': { 187 | 'level': 'DEBUG', 188 | 'formatter': 'standard', 189 | 'class': 'logging.FileHandler', 190 | 'filename': name + '.log', 191 | 'mode': 'a', 192 | } 193 | }, 194 | 'loggers': { 195 | '': { 196 | 'handlers': ['stderr', 'logfile'], 197 | 'level': 'DEBUG', 198 | 'propagate': True 199 | }, 200 | 201 | # extra ones to shut up. 202 | 'tensorflow': { 203 | 'handlers': ['stderr', 'logfile'], 204 | 'level': 'INFO', 205 | }, 206 | } 207 | } 208 | 209 | 210 | # Source for the remainder: https://gist.github.com/mooware/a1ed40987b6cc9ab9c65 211 | # Fixed some things mentioned in the comments there. 212 | 213 | # colored stream handler for python logging framework (use the ColorStreamHandler class). 214 | # 215 | # based on: 216 | # http://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output/1336640#1336640 217 | 218 | # how to use: 219 | # i used a dict-based logging configuration, not sure what else would work. 220 | # 221 | # import logging, logging.config, colorstreamhandler 222 | # 223 | # _LOGCONFIG = { 224 | # "version": 1, 225 | # "disable_existing_loggers": False, 226 | # 227 | # "handlers": { 228 | # "console": { 229 | # "class": "colorstreamhandler.ColorStreamHandler", 230 | # "stream": "ext://sys.stderr", 231 | # "level": "INFO" 232 | # } 233 | # }, 234 | # 235 | # "root": { 236 | # "level": "INFO", 237 | # "handlers": ["console"] 238 | # } 239 | # } 240 | # 241 | # logging.config.dictConfig(_LOGCONFIG) 242 | # mylogger = logging.getLogger("mylogger") 243 | # mylogger.warning("foobar") 244 | 245 | # Copyright (c) 2014 Markus Pointner 246 | # 247 | # Permission is hereby granted, free of charge, to any person obtaining a copy 248 | # of this software and associated documentation files (the "Software"), to deal 249 | # in the Software without restriction, including without limitation the rights 250 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 251 | # copies of the Software, and to permit persons to whom the Software is 252 | # furnished to do so, subject to the following conditions: 253 | # 254 | # The above copyright notice and this permission notice shall be included in 255 | # all copies or substantial portions of the Software. 256 | # 257 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 258 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 259 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 260 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 261 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 262 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 263 | # THE SOFTWARE. 264 | 265 | class _AnsiColorStreamHandler(logging.StreamHandler): 266 | DEFAULT = '\x1b[0m' 267 | RED = '\x1b[31m' 268 | GREEN = '\x1b[32m' 269 | YELLOW = '\x1b[33m' 270 | CYAN = '\x1b[36m' 271 | 272 | CRITICAL = RED 273 | ERROR = RED 274 | WARNING = YELLOW 275 | INFO = DEFAULT # GREEN 276 | DEBUG = CYAN 277 | 278 | @classmethod 279 | def _get_color(cls, level): 280 | if level >= logging.CRITICAL: return cls.CRITICAL 281 | elif level >= logging.ERROR: return cls.ERROR 282 | elif level >= logging.WARNING: return cls.WARNING 283 | elif level >= logging.INFO: return cls.INFO 284 | elif level >= logging.DEBUG: return cls.DEBUG 285 | else: return cls.DEFAULT 286 | 287 | def __init__(self, stream=None): 288 | logging.StreamHandler.__init__(self, stream) 289 | 290 | def format(self, record): 291 | text = logging.StreamHandler.format(self, record) 292 | color = self._get_color(record.levelno) 293 | return (color + text + self.DEFAULT) if self.is_tty() else text 294 | 295 | def is_tty(self): 296 | isatty = getattr(self.stream, 'isatty', None) 297 | return isatty and isatty() 298 | 299 | 300 | class _WinColorStreamHandler(logging.StreamHandler): 301 | # wincon.h 302 | FOREGROUND_BLACK = 0x0000 303 | FOREGROUND_BLUE = 0x0001 304 | FOREGROUND_GREEN = 0x0002 305 | FOREGROUND_CYAN = 0x0003 306 | FOREGROUND_RED = 0x0004 307 | FOREGROUND_MAGENTA = 0x0005 308 | FOREGROUND_YELLOW = 0x0006 309 | FOREGROUND_GREY = 0x0007 310 | FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified. 311 | FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED 312 | 313 | BACKGROUND_BLACK = 0x0000 314 | BACKGROUND_BLUE = 0x0010 315 | BACKGROUND_GREEN = 0x0020 316 | BACKGROUND_CYAN = 0x0030 317 | BACKGROUND_RED = 0x0040 318 | BACKGROUND_MAGENTA = 0x0050 319 | BACKGROUND_YELLOW = 0x0060 320 | BACKGROUND_GREY = 0x0070 321 | BACKGROUND_INTENSITY = 0x0080 # background color is intensified. 322 | 323 | DEFAULT = FOREGROUND_WHITE 324 | CRITICAL = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY 325 | ERROR = FOREGROUND_RED | FOREGROUND_INTENSITY 326 | WARNING = FOREGROUND_YELLOW | FOREGROUND_INTENSITY 327 | INFO = FOREGROUND_GREEN 328 | DEBUG = FOREGROUND_CYAN 329 | 330 | @classmethod 331 | def _get_color(cls, level): 332 | if level >= logging.CRITICAL: return cls.CRITICAL 333 | elif level >= logging.ERROR: return cls.ERROR 334 | elif level >= logging.WARNING: return cls.WARNING 335 | elif level >= logging.INFO: return cls.INFO 336 | elif level >= logging.DEBUG: return cls.DEBUG 337 | else: return cls.DEFAULT 338 | 339 | def _set_color(self, code): 340 | import ctypes 341 | ctypes.windll.kernel32.SetConsoleTextAttribute(self._outhdl, code) 342 | 343 | def __init__(self, stream=None): 344 | logging.StreamHandler.__init__(self, stream) 345 | # get file handle for the stream 346 | import ctypes, ctypes.util 347 | # for some reason find_msvcrt() sometimes doesn't find msvcrt.dll on my system? 348 | crtname = ctypes.util.find_msvcrt() 349 | if not crtname: 350 | crtname = ctypes.util.find_library("msvcrt") 351 | crtlib = ctypes.cdll.LoadLibrary(crtname) 352 | self._outhdl = crtlib._get_osfhandle(self.stream.fileno()) 353 | 354 | def emit(self, record): 355 | color = self._get_color(record.levelno) 356 | self._set_color(color) 357 | logging.StreamHandler.emit(self, record) 358 | self._set_color(self.FOREGROUND_WHITE) 359 | 360 | # select ColorStreamHandler based on platform 361 | import platform 362 | if platform.system() == 'Windows': 363 | ColorStreamHandler = _WinColorStreamHandler 364 | else: 365 | ColorStreamHandler = _AnsiColorStreamHandler 366 | -------------------------------------------------------------------------------- /nets/resnet_v1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for the original form of Residual Networks. 16 | 17 | The 'v1' residual networks (ResNets) implemented in this module were proposed 18 | by: 19 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 20 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 21 | 22 | Other variants were introduced in: 23 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 24 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 25 | 26 | The networks defined in this module utilize the bottleneck building block of 27 | [1] with projection shortcuts only for increasing depths. They employ batch 28 | normalization *after* every weight layer. This is the architecture used by 29 | MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and 30 | ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1' 31 | architecture and the alternative 'v2' architecture of [2] which uses batch 32 | normalization *before* every weight layer in the so-called full pre-activation 33 | units. 34 | 35 | Typical use: 36 | 37 | from tensorflow.contrib.slim.nets import resnet_v1 38 | 39 | ResNet-101 for image classification into 1000 classes: 40 | 41 | # inputs has shape [batch, 224, 224, 3] 42 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 43 | net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False) 44 | 45 | ResNet-101 for semantic segmentation into 21 classes: 46 | 47 | # inputs has shape [batch, 513, 513, 3] 48 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 49 | net, end_points = resnet_v1.resnet_v1_101(inputs, 50 | 21, 51 | is_training=False, 52 | global_pool=False, 53 | output_stride=16) 54 | """ 55 | from __future__ import absolute_import 56 | from __future__ import division 57 | from __future__ import print_function 58 | 59 | import tensorflow as tf 60 | 61 | from nets import resnet_utils 62 | 63 | 64 | resnet_arg_scope = resnet_utils.resnet_arg_scope 65 | slim = tf.contrib.slim 66 | 67 | 68 | @slim.add_arg_scope 69 | def bottleneck(inputs, 70 | depth, 71 | depth_bottleneck, 72 | stride, 73 | rate=1, 74 | outputs_collections=None, 75 | scope=None, 76 | use_bounded_activations=False): 77 | """Bottleneck residual unit variant with BN after convolutions. 78 | 79 | This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for 80 | its definition. Note that we use here the bottleneck variant which has an 81 | extra bottleneck layer. 82 | 83 | When putting together two consecutive ResNet blocks that use this unit, one 84 | should use stride = 2 in the last unit of the first block. 85 | 86 | Args: 87 | inputs: A tensor of size [batch, height, width, channels]. 88 | depth: The depth of the ResNet unit output. 89 | depth_bottleneck: The depth of the bottleneck layers. 90 | stride: The ResNet unit's stride. Determines the amount of downsampling of 91 | the units output compared to its input. 92 | rate: An integer, rate for atrous convolution. 93 | outputs_collections: Collection to add the ResNet unit output. 94 | scope: Optional variable_scope. 95 | use_bounded_activations: Whether or not to use bounded activations. Bounded 96 | activations better lend themselves to quantized inference. 97 | 98 | Returns: 99 | The ResNet unit's output. 100 | """ 101 | with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: 102 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 103 | if depth == depth_in: 104 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') 105 | else: 106 | shortcut = slim.conv2d( 107 | inputs, 108 | depth, [1, 1], 109 | stride=stride, 110 | activation_fn=tf.nn.relu6 if use_bounded_activations else None, 111 | scope='shortcut') 112 | 113 | residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1, 114 | scope='conv1') 115 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride, 116 | rate=rate, scope='conv2') 117 | residual = slim.conv2d(residual, depth, [1, 1], stride=1, 118 | activation_fn=None, scope='conv3') 119 | 120 | if use_bounded_activations: 121 | # Use clip_by_value to simulate bandpass activation. 122 | residual = tf.clip_by_value(residual, -6.0, 6.0) 123 | output = tf.nn.relu6(shortcut + residual) 124 | else: 125 | output = tf.nn.relu(shortcut + residual) 126 | 127 | return slim.utils.collect_named_outputs(outputs_collections, 128 | sc.original_name_scope, 129 | output) 130 | 131 | 132 | def resnet_v1(inputs, 133 | blocks, 134 | num_classes=None, 135 | is_training=True, 136 | global_pool=True, 137 | output_stride=None, 138 | include_root_block=True, 139 | spatial_squeeze=True, 140 | reuse=None, 141 | scope=None): 142 | """Generator for v1 ResNet models. 143 | 144 | This function generates a family of ResNet v1 models. See the resnet_v1_*() 145 | methods for specific model instantiations, obtained by selecting different 146 | block instantiations that produce ResNets of various depths. 147 | 148 | Training for image classification on Imagenet is usually done with [224, 224] 149 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet 150 | block for the ResNets defined in [1] that have nominal stride equal to 32. 151 | However, for dense prediction tasks we advise that one uses inputs with 152 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In 153 | this case the feature maps at the ResNet output will have spatial shape 154 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] 155 | and corners exactly aligned with the input image corners, which greatly 156 | facilitates alignment of the features to the image. Using as input [225, 225] 157 | images results in [8, 8] feature maps at the output of the last ResNet block. 158 | 159 | For dense prediction tasks, the ResNet needs to run in fully-convolutional 160 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all 161 | have nominal stride equal to 32 and a good choice in FCN mode is to use 162 | output_stride=16 in order to increase the density of the computed features at 163 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. 164 | 165 | Args: 166 | inputs: A tensor of size [batch, height_in, width_in, channels]. 167 | blocks: A list of length equal to the number of ResNet blocks. Each element 168 | is a resnet_utils.Block object describing the units in the block. 169 | num_classes: Number of predicted classes for classification tasks. If None 170 | we return the features before the logit layer. 171 | is_training: whether is training or not. 172 | global_pool: If True, we perform global average pooling before computing the 173 | logits. Set to True for image classification, False for dense prediction. 174 | output_stride: If None, then the output will be computed at the nominal 175 | network stride. If output_stride is not None, it specifies the requested 176 | ratio of input to output spatial resolution. 177 | include_root_block: If True, include the initial convolution followed by 178 | max-pooling, if False excludes it. 179 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is 180 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 181 | To use this parameter, the input images must be smaller than 300x300 182 | pixels, in which case the output logit layer does not contain spatial 183 | information and can be removed. 184 | reuse: whether or not the network and its variables should be reused. To be 185 | able to reuse 'scope' must be given. 186 | scope: Optional variable_scope. 187 | 188 | Returns: 189 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. 190 | If global_pool is False, then height_out and width_out are reduced by a 191 | factor of output_stride compared to the respective height_in and width_in, 192 | else both height_out and width_out equal one. If num_classes is None, then 193 | net is the output of the last ResNet block, potentially after global 194 | average pooling. If num_classes is not None, net contains the pre-softmax 195 | activations. 196 | end_points: A dictionary from components of the network to the corresponding 197 | activation. 198 | 199 | Raises: 200 | ValueError: If the target output_stride is not valid. 201 | """ 202 | with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc: 203 | end_points_collection = sc.name + '_end_points' 204 | with slim.arg_scope([slim.conv2d, bottleneck, 205 | resnet_utils.stack_blocks_dense], 206 | outputs_collections=end_points_collection): 207 | with slim.arg_scope([slim.batch_norm], is_training=is_training): 208 | net = inputs 209 | if include_root_block: 210 | if output_stride is not None: 211 | if output_stride % 4 != 0: 212 | raise ValueError('The output_stride needs to be a multiple of 4.') 213 | output_stride /= 4 214 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') 215 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') 216 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) 217 | if global_pool: 218 | # Global average pooling. 219 | net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) 220 | if num_classes is not None: 221 | net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 222 | normalizer_fn=None, scope='logits') 223 | if spatial_squeeze: 224 | net = tf.squeeze(net, [1, 2], name='SpatialSqueeze') 225 | # Convert end_points_collection into a dictionary of end_points. 226 | end_points = slim.utils.convert_collection_to_dict( 227 | end_points_collection) 228 | if num_classes is not None: 229 | end_points['predictions'] = slim.softmax(net, scope='predictions') 230 | return net, end_points 231 | resnet_v1.default_image_size = 224 232 | 233 | 234 | def resnet_v1_block(scope, base_depth, num_units, stride): 235 | """Helper function for creating a resnet_v1 bottleneck block. 236 | 237 | Args: 238 | scope: The scope of the block. 239 | base_depth: The depth of the bottleneck layer for each unit. 240 | num_units: The number of units in the block. 241 | stride: The stride of the block, implemented as a stride in the last unit. 242 | All other units have stride=1. 243 | 244 | Returns: 245 | A resnet_v1 bottleneck block. 246 | """ 247 | return resnet_utils.Block(scope, bottleneck, [{ 248 | 'depth': base_depth * 4, 249 | 'depth_bottleneck': base_depth, 250 | 'stride': 1 251 | }] * (num_units - 1) + [{ 252 | 'depth': base_depth * 4, 253 | 'depth_bottleneck': base_depth, 254 | 'stride': stride 255 | }]) 256 | 257 | 258 | def resnet_v1_50(inputs, 259 | num_classes=None, 260 | is_training=True, 261 | global_pool=True, 262 | output_stride=None, 263 | spatial_squeeze=True, 264 | reuse=None, 265 | scope='resnet_v1_50'): 266 | """ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" 267 | blocks = [ 268 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 269 | resnet_v1_block('block2', base_depth=128, num_units=4, stride=2), 270 | resnet_v1_block('block3', base_depth=256, num_units=6, stride=2), 271 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 272 | ] 273 | return resnet_v1(inputs, blocks, num_classes, is_training, 274 | global_pool=global_pool, output_stride=output_stride, 275 | include_root_block=True, spatial_squeeze=spatial_squeeze, 276 | reuse=reuse, scope=scope) 277 | resnet_v1_50.default_image_size = resnet_v1.default_image_size 278 | 279 | 280 | def resnet_v1_101(inputs, 281 | num_classes=None, 282 | is_training=True, 283 | global_pool=True, 284 | output_stride=None, 285 | spatial_squeeze=True, 286 | reuse=None, 287 | scope='resnet_v1_101'): 288 | """ResNet-101 model of [1]. See resnet_v1() for arg and return description.""" 289 | blocks = [ 290 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 291 | resnet_v1_block('block2', base_depth=128, num_units=4, stride=2), 292 | resnet_v1_block('block3', base_depth=256, num_units=23, stride=2), 293 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 294 | ] 295 | return resnet_v1(inputs, blocks, num_classes, is_training, 296 | global_pool=global_pool, output_stride=output_stride, 297 | include_root_block=True, spatial_squeeze=spatial_squeeze, 298 | reuse=reuse, scope=scope) 299 | resnet_v1_101.default_image_size = resnet_v1.default_image_size 300 | 301 | 302 | def resnet_v1_152(inputs, 303 | num_classes=None, 304 | is_training=True, 305 | global_pool=True, 306 | output_stride=None, 307 | spatial_squeeze=True, 308 | reuse=None, 309 | scope='resnet_v1_152'): 310 | """ResNet-152 model of [1]. See resnet_v1() for arg and return description.""" 311 | blocks = [ 312 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 313 | resnet_v1_block('block2', base_depth=128, num_units=8, stride=2), 314 | resnet_v1_block('block3', base_depth=256, num_units=36, stride=2), 315 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 316 | ] 317 | return resnet_v1(inputs, blocks, num_classes, is_training, 318 | global_pool=global_pool, output_stride=output_stride, 319 | include_root_block=True, spatial_squeeze=spatial_squeeze, 320 | reuse=reuse, scope=scope) 321 | resnet_v1_152.default_image_size = resnet_v1.default_image_size 322 | 323 | 324 | def resnet_v1_200(inputs, 325 | num_classes=None, 326 | is_training=True, 327 | global_pool=True, 328 | output_stride=None, 329 | spatial_squeeze=True, 330 | reuse=None, 331 | scope='resnet_v1_200'): 332 | """ResNet-200 model of [2]. See resnet_v1() for arg and return description.""" 333 | blocks = [ 334 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 335 | resnet_v1_block('block2', base_depth=128, num_units=24, stride=2), 336 | resnet_v1_block('block3', base_depth=256, num_units=36, stride=2), 337 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 338 | ] 339 | return resnet_v1(inputs, blocks, num_classes, is_training, 340 | global_pool=global_pool, output_stride=output_stride, 341 | include_root_block=True, spatial_squeeze=spatial_squeeze, 342 | reuse=reuse, scope=scope) 343 | resnet_v1_200.default_image_size = resnet_v1.default_image_size 344 | -------------------------------------------------------------------------------- /nets/inception_v1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the definition for inception v1 classification network.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from nets import inception_utils 23 | # import nets.nn_utils as nn_utils 24 | import utils.os_utils as os_utils 25 | import os 26 | 27 | slim = tf.contrib.slim 28 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 29 | 30 | 31 | def inception_v1_base(inputs, 32 | final_endpoint='Mixed_5c', 33 | include_root_block=True, 34 | scope='InceptionV1'): 35 | """Defines the Inception V1 base architecture. 36 | 37 | This architecture is defined in: 38 | Going deeper with convolutions 39 | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, 40 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. 41 | http://arxiv.org/pdf/1409.4842v1.pdf. 42 | 43 | Args: 44 | inputs: a tensor of size [batch_size, height, width, channels]. 45 | final_endpoint: specifies the endpoint to construct the network up to. It 46 | can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 47 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 48 | 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 49 | 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 'Mixed_5c']. If 50 | include_root_block is False, ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 51 | 'Conv2d_2b_1x1', 'Conv2d_2c_3x3', 'MaxPool_3a_3x3'] will not be available. 52 | include_root_block: If True, include the convolution and max-pooling layers 53 | before the inception modules. If False, excludes those layers. 54 | scope: Optional variable_scope. 55 | 56 | Returns: 57 | A dictionary from components of the network to the corresponding activation. 58 | 59 | Raises: 60 | ValueError: if final_endpoint is not set to one of the predefined values. 61 | """ 62 | end_points = {} 63 | with tf.variable_scope(scope, 'InceptionV1', [inputs]): 64 | with slim.arg_scope( 65 | [slim.conv2d, slim.fully_connected], 66 | weights_initializer=trunc_normal(0.01)): 67 | with slim.arg_scope([slim.conv2d, slim.max_pool2d], 68 | stride=1, padding='SAME'): 69 | net = inputs 70 | if include_root_block: 71 | end_point = 'Conv2d_1a_7x7' 72 | net = slim.conv2d(inputs, 64, [7, 7], stride=2, scope=end_point) 73 | end_points[end_point] = net 74 | if final_endpoint == end_point: 75 | return net, end_points 76 | end_point = 'MaxPool_2a_3x3' 77 | net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point) 78 | end_points[end_point] = net 79 | if final_endpoint == end_point: 80 | return net, end_points 81 | end_point = 'Conv2d_2b_1x1' 82 | net = slim.conv2d(net, 64, [1, 1], scope=end_point) 83 | end_points[end_point] = net 84 | if final_endpoint == end_point: 85 | return net, end_points 86 | end_point = 'Conv2d_2c_3x3' 87 | net = slim.conv2d(net, 192, [3, 3], scope=end_point) 88 | end_points[end_point] = net 89 | if final_endpoint == end_point: 90 | return net, end_points 91 | end_point = 'MaxPool_3a_3x3' 92 | net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point) 93 | end_points[end_point] = net 94 | if final_endpoint == end_point: 95 | return net, end_points 96 | 97 | end_point = 'Mixed_3b' 98 | with tf.variable_scope(end_point): 99 | with tf.variable_scope('Branch_0'): 100 | branch_0 = slim.conv2d(net, 64, [1, 1], scope='Conv2d_0a_1x1') 101 | with tf.variable_scope('Branch_1'): 102 | branch_1 = slim.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1') 103 | branch_1 = slim.conv2d(branch_1, 128, [3, 3], scope='Conv2d_0b_3x3') 104 | with tf.variable_scope('Branch_2'): 105 | branch_2 = slim.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1') 106 | branch_2 = slim.conv2d(branch_2, 32, [3, 3], scope='Conv2d_0b_3x3') 107 | with tf.variable_scope('Branch_3'): 108 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 109 | branch_3 = slim.conv2d(branch_3, 32, [1, 1], scope='Conv2d_0b_1x1') 110 | net = tf.concat( 111 | axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 112 | end_points[end_point] = net 113 | if final_endpoint == end_point: return net, end_points 114 | 115 | end_point = 'Mixed_3c' 116 | with tf.variable_scope(end_point): 117 | with tf.variable_scope('Branch_0'): 118 | branch_0 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1') 119 | with tf.variable_scope('Branch_1'): 120 | branch_1 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1') 121 | branch_1 = slim.conv2d(branch_1, 192, [3, 3], scope='Conv2d_0b_3x3') 122 | with tf.variable_scope('Branch_2'): 123 | branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1') 124 | branch_2 = slim.conv2d(branch_2, 96, [3, 3], scope='Conv2d_0b_3x3') 125 | with tf.variable_scope('Branch_3'): 126 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 127 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1') 128 | net = tf.concat( 129 | axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 130 | 131 | 132 | end_points[end_point] = net 133 | if final_endpoint == end_point: return net, end_points 134 | 135 | end_point = 'MaxPool_4a_3x3' 136 | net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point) 137 | end_points[end_point] = net 138 | if final_endpoint == end_point: return net, end_points 139 | 140 | end_point = 'Mixed_4b' 141 | with tf.variable_scope(end_point): 142 | with tf.variable_scope('Branch_0'): 143 | branch_0 = slim.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1') 144 | with tf.variable_scope('Branch_1'): 145 | branch_1 = slim.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1') 146 | branch_1 = slim.conv2d(branch_1, 208, [3, 3], scope='Conv2d_0b_3x3') 147 | with tf.variable_scope('Branch_2'): 148 | branch_2 = slim.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1') 149 | branch_2 = slim.conv2d(branch_2, 48, [3, 3], scope='Conv2d_0b_3x3') 150 | with tf.variable_scope('Branch_3'): 151 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 152 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1') 153 | net = tf.concat( 154 | axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 155 | 156 | 157 | end_points[end_point] = net 158 | if final_endpoint == end_point: return net, end_points 159 | 160 | end_point = 'Mixed_4c' 161 | with tf.variable_scope(end_point): 162 | with tf.variable_scope('Branch_0'): 163 | branch_0 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1') 164 | with tf.variable_scope('Branch_1'): 165 | branch_1 = slim.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1') 166 | branch_1 = slim.conv2d(branch_1, 224, [3, 3], scope='Conv2d_0b_3x3') 167 | with tf.variable_scope('Branch_2'): 168 | branch_2 = slim.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1') 169 | branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3') 170 | with tf.variable_scope('Branch_3'): 171 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 172 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1') 173 | net = tf.concat( 174 | axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 175 | 176 | 177 | 178 | end_points[end_point] = net 179 | if final_endpoint == end_point: return net, end_points 180 | 181 | end_point = 'Mixed_4d' 182 | with tf.variable_scope(end_point): 183 | with tf.variable_scope('Branch_0'): 184 | branch_0 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1') 185 | with tf.variable_scope('Branch_1'): 186 | branch_1 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1') 187 | branch_1 = slim.conv2d(branch_1, 256, [3, 3], scope='Conv2d_0b_3x3') 188 | with tf.variable_scope('Branch_2'): 189 | branch_2 = slim.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1') 190 | branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3') 191 | with tf.variable_scope('Branch_3'): 192 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 193 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1') 194 | net = tf.concat( 195 | axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 196 | 197 | 198 | end_points[end_point] = net 199 | if final_endpoint == end_point: return net, end_points 200 | 201 | end_point = 'Mixed_4e' 202 | with tf.variable_scope(end_point): 203 | with tf.variable_scope('Branch_0'): 204 | branch_0 = slim.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1') 205 | with tf.variable_scope('Branch_1'): 206 | branch_1 = slim.conv2d(net, 144, [1, 1], scope='Conv2d_0a_1x1') 207 | branch_1 = slim.conv2d(branch_1, 288, [3, 3], scope='Conv2d_0b_3x3') 208 | with tf.variable_scope('Branch_2'): 209 | branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1') 210 | branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3') 211 | with tf.variable_scope('Branch_3'): 212 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 213 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1') 214 | net = tf.concat( 215 | axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 216 | 217 | 218 | end_points[end_point] = net 219 | if final_endpoint == end_point: return net, end_points 220 | 221 | end_point = 'Mixed_4f' 222 | with tf.variable_scope(end_point): 223 | with tf.variable_scope('Branch_0'): 224 | branch_0 = slim.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1') 225 | with tf.variable_scope('Branch_1'): 226 | branch_1 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1') 227 | branch_1 = slim.conv2d(branch_1, 320, [3, 3], scope='Conv2d_0b_3x3') 228 | with tf.variable_scope('Branch_2'): 229 | branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1') 230 | branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0b_3x3') 231 | with tf.variable_scope('Branch_3'): 232 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 233 | branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1') 234 | net = tf.concat( 235 | axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 236 | 237 | 238 | end_points[end_point] = net 239 | if final_endpoint == end_point: return net, end_points 240 | 241 | end_point = 'MaxPool_5a_2x2' 242 | net = slim.max_pool2d(net, [2, 2], stride=2, scope=end_point) 243 | end_points[end_point] = net 244 | if final_endpoint == end_point: return net, end_points 245 | 246 | end_point = 'Mixed_5b' 247 | with tf.variable_scope(end_point): 248 | with tf.variable_scope('Branch_0'): 249 | branch_0 = slim.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1') 250 | with tf.variable_scope('Branch_1'): 251 | branch_1 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1') 252 | branch_1 = slim.conv2d(branch_1, 320, [3, 3], scope='Conv2d_0b_3x3') 253 | with tf.variable_scope('Branch_2'): 254 | branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1') 255 | branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0a_3x3') 256 | with tf.variable_scope('Branch_3'): 257 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 258 | branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1') 259 | net = tf.concat( 260 | axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 261 | 262 | 263 | end_points[end_point] = net 264 | if final_endpoint == end_point: return net, end_points 265 | 266 | end_point = 'Mixed_5c' 267 | with tf.variable_scope(end_point): 268 | with tf.variable_scope('Branch_0'): 269 | branch_0 = slim.conv2d(net, 384, [1, 1], scope='Conv2d_0a_1x1') 270 | with tf.variable_scope('Branch_1'): 271 | branch_1 = slim.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1') 272 | branch_1 = slim.conv2d(branch_1, 384, [3, 3], scope='Conv2d_0b_3x3') 273 | with tf.variable_scope('Branch_2'): 274 | branch_2 = slim.conv2d(net, 48, [1, 1], scope='Conv2d_0a_1x1') 275 | branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0b_3x3') 276 | with tf.variable_scope('Branch_3'): 277 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 278 | branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1') 279 | net = tf.concat( 280 | axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 281 | 282 | 283 | end_points[end_point] = net 284 | if final_endpoint == end_point: return net, end_points 285 | raise ValueError('Unknown final endpoint %s' % final_endpoint) 286 | 287 | 288 | def inception_v1(inputs, 289 | num_classes=1000, 290 | is_training=True, 291 | dropout_keep_prob=0.8, 292 | prediction_fn=slim.softmax, 293 | spatial_squeeze=True, 294 | reuse=None, 295 | scope='InceptionV1', 296 | global_pool=False): 297 | """Defines the Inception V1 architecture. 298 | 299 | This architecture is defined in: 300 | 301 | Going deeper with convolutions 302 | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, 303 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. 304 | http://arxiv.org/pdf/1409.4842v1.pdf. 305 | 306 | The default image size used to train this network is 224x224. 307 | 308 | Args: 309 | inputs: a tensor of size [batch_size, height, width, channels]. 310 | num_classes: number of predicted classes. If 0 or None, the logits layer 311 | is omitted and the input features to the logits layer (before dropout) 312 | are returned instead. 313 | is_training: whether is training or not. 314 | dropout_keep_prob: the percentage of activation values that are retained. 315 | prediction_fn: a function to get predictions out of logits. 316 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is of 317 | shape [B, 1, 1, C], where B is batch_size and C is number of classes. 318 | reuse: whether or not the network and its variables should be reused. To be 319 | able to reuse 'scope' must be given. 320 | scope: Optional variable_scope. 321 | global_pool: Optional boolean flag to control the avgpooling before the 322 | logits layer. If false or unset, pooling is done with a fixed window 323 | that reduces default-sized inputs to 1x1, while larger inputs lead to 324 | larger outputs. If true, any input size is pooled down to 1x1. 325 | 326 | Returns: 327 | net: a Tensor with the logits (pre-softmax activations) if num_classes 328 | is a non-zero integer, or the non-dropped-out input to the logits layer 329 | if num_classes is 0 or None. 330 | end_points: a dictionary from components of the network to the corresponding 331 | activation. 332 | """ 333 | # Final pooling and prediction 334 | with tf.variable_scope(scope, 'InceptionV1', [inputs], reuse=reuse) as scope: 335 | with slim.arg_scope([slim.batch_norm, slim.dropout], 336 | is_training=is_training): 337 | net, end_points = inception_v1_base(inputs, scope=scope) 338 | 339 | 340 | 341 | with tf.variable_scope('Logits'): 342 | if global_pool: 343 | # Global average pooling. 344 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 345 | end_points['global_pool'] = net 346 | else: 347 | # Pooling with a fixed kernel size. 348 | net = slim.avg_pool2d(net, [7, 7], stride=1, scope='AvgPool_0a_7x7') 349 | end_points['AvgPool_0a_7x7'] = net 350 | if not num_classes: 351 | return net, end_points 352 | net = slim.dropout(net, dropout_keep_prob, scope='Dropout_0b') 353 | logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 354 | normalizer_fn=None, scope='Conv2d_0c_1x1') 355 | if spatial_squeeze: 356 | logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze') 357 | 358 | end_points['Logits'] = logits 359 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 360 | return logits, end_points 361 | inception_v1.default_image_size = 224 362 | 363 | inception_v1_arg_scope = inception_utils.inception_arg_scope 364 | 365 | def endpoints(image, is_training,weight_decay=0.0): 366 | image = tf.divide(image, 255.0) 367 | image = tf.subtract(image, 0.5) 368 | image = tf.multiply(image, 2.0) 369 | with slim.arg_scope(inception_v1_arg_scope()): 370 | _, endpoints = inception_v1(image, num_classes=None,is_training=is_training, global_pool=True) 371 | 372 | endpoints['model_output'] = endpoints['global_pool'] = tf.reduce_mean(endpoints['global_pool'], [1, 2], name='pool5') 373 | 374 | return endpoints, 'InceptionV1' 375 | -------------------------------------------------------------------------------- /nets/mobilenet_v1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """MobileNet v1. 16 | 17 | MobileNet is a general architecture and can be used for multiple use cases. 18 | Depending on the use case, it can use different input layer size and different 19 | head (for example: embeddings, localization and classification). 20 | 21 | As described in https://arxiv.org/abs/1704.04861. 22 | 23 | MobileNets: Efficient Convolutional Neural Networks for 24 | Mobile Vision Applications 25 | Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, 26 | Tobias Weyand, Marco Andreetto, Hartwig Adam 27 | 28 | 100% Mobilenet V1 (base) with input size 224x224: 29 | 30 | See mobilenet_v1() 31 | 32 | Layer params macs 33 | -------------------------------------------------------------------------------- 34 | MobilenetV1/Conv2d_0/Conv2D: 864 10,838,016 35 | MobilenetV1/Conv2d_1_depthwise/depthwise: 288 3,612,672 36 | MobilenetV1/Conv2d_1_pointwise/Conv2D: 2,048 25,690,112 37 | MobilenetV1/Conv2d_2_depthwise/depthwise: 576 1,806,336 38 | MobilenetV1/Conv2d_2_pointwise/Conv2D: 8,192 25,690,112 39 | MobilenetV1/Conv2d_3_depthwise/depthwise: 1,152 3,612,672 40 | MobilenetV1/Conv2d_3_pointwise/Conv2D: 16,384 51,380,224 41 | MobilenetV1/Conv2d_4_depthwise/depthwise: 1,152 903,168 42 | MobilenetV1/Conv2d_4_pointwise/Conv2D: 32,768 25,690,112 43 | MobilenetV1/Conv2d_5_depthwise/depthwise: 2,304 1,806,336 44 | MobilenetV1/Conv2d_5_pointwise/Conv2D: 65,536 51,380,224 45 | MobilenetV1/Conv2d_6_depthwise/depthwise: 2,304 451,584 46 | MobilenetV1/Conv2d_6_pointwise/Conv2D: 131,072 25,690,112 47 | MobilenetV1/Conv2d_7_depthwise/depthwise: 4,608 903,168 48 | MobilenetV1/Conv2d_7_pointwise/Conv2D: 262,144 51,380,224 49 | MobilenetV1/Conv2d_8_depthwise/depthwise: 4,608 903,168 50 | MobilenetV1/Conv2d_8_pointwise/Conv2D: 262,144 51,380,224 51 | MobilenetV1/Conv2d_9_depthwise/depthwise: 4,608 903,168 52 | MobilenetV1/Conv2d_9_pointwise/Conv2D: 262,144 51,380,224 53 | MobilenetV1/Conv2d_10_depthwise/depthwise: 4,608 903,168 54 | MobilenetV1/Conv2d_10_pointwise/Conv2D: 262,144 51,380,224 55 | MobilenetV1/Conv2d_11_depthwise/depthwise: 4,608 903,168 56 | MobilenetV1/Conv2d_11_pointwise/Conv2D: 262,144 51,380,224 57 | MobilenetV1/Conv2d_12_depthwise/depthwise: 4,608 225,792 58 | MobilenetV1/Conv2d_12_pointwise/Conv2D: 524,288 25,690,112 59 | MobilenetV1/Conv2d_13_depthwise/depthwise: 9,216 451,584 60 | MobilenetV1/Conv2d_13_pointwise/Conv2D: 1,048,576 51,380,224 61 | -------------------------------------------------------------------------------- 62 | Total: 3,185,088 567,716,352 63 | 64 | 65 | 75% Mobilenet V1 (base) with input size 128x128: 66 | 67 | See mobilenet_v1_075() 68 | 69 | Layer params macs 70 | -------------------------------------------------------------------------------- 71 | MobilenetV1/Conv2d_0/Conv2D: 648 2,654,208 72 | MobilenetV1/Conv2d_1_depthwise/depthwise: 216 884,736 73 | MobilenetV1/Conv2d_1_pointwise/Conv2D: 1,152 4,718,592 74 | MobilenetV1/Conv2d_2_depthwise/depthwise: 432 442,368 75 | MobilenetV1/Conv2d_2_pointwise/Conv2D: 4,608 4,718,592 76 | MobilenetV1/Conv2d_3_depthwise/depthwise: 864 884,736 77 | MobilenetV1/Conv2d_3_pointwise/Conv2D: 9,216 9,437,184 78 | MobilenetV1/Conv2d_4_depthwise/depthwise: 864 221,184 79 | MobilenetV1/Conv2d_4_pointwise/Conv2D: 18,432 4,718,592 80 | MobilenetV1/Conv2d_5_depthwise/depthwise: 1,728 442,368 81 | MobilenetV1/Conv2d_5_pointwise/Conv2D: 36,864 9,437,184 82 | MobilenetV1/Conv2d_6_depthwise/depthwise: 1,728 110,592 83 | MobilenetV1/Conv2d_6_pointwise/Conv2D: 73,728 4,718,592 84 | MobilenetV1/Conv2d_7_depthwise/depthwise: 3,456 221,184 85 | MobilenetV1/Conv2d_7_pointwise/Conv2D: 147,456 9,437,184 86 | MobilenetV1/Conv2d_8_depthwise/depthwise: 3,456 221,184 87 | MobilenetV1/Conv2d_8_pointwise/Conv2D: 147,456 9,437,184 88 | MobilenetV1/Conv2d_9_depthwise/depthwise: 3,456 221,184 89 | MobilenetV1/Conv2d_9_pointwise/Conv2D: 147,456 9,437,184 90 | MobilenetV1/Conv2d_10_depthwise/depthwise: 3,456 221,184 91 | MobilenetV1/Conv2d_10_pointwise/Conv2D: 147,456 9,437,184 92 | MobilenetV1/Conv2d_11_depthwise/depthwise: 3,456 221,184 93 | MobilenetV1/Conv2d_11_pointwise/Conv2D: 147,456 9,437,184 94 | MobilenetV1/Conv2d_12_depthwise/depthwise: 3,456 55,296 95 | MobilenetV1/Conv2d_12_pointwise/Conv2D: 294,912 4,718,592 96 | MobilenetV1/Conv2d_13_depthwise/depthwise: 6,912 110,592 97 | MobilenetV1/Conv2d_13_pointwise/Conv2D: 589,824 9,437,184 98 | -------------------------------------------------------------------------------- 99 | Total: 1,800,144 106,002,432 100 | 101 | """ 102 | 103 | # Tensorflow mandates these. 104 | from __future__ import absolute_import 105 | from __future__ import division 106 | from __future__ import print_function 107 | 108 | from collections import namedtuple 109 | import functools 110 | 111 | import tensorflow as tf 112 | 113 | slim = tf.contrib.slim 114 | 115 | # Conv and DepthSepConv namedtuple define layers of the MobileNet architecture 116 | # Conv defines 3x3 convolution layers 117 | # DepthSepConv defines 3x3 depthwise convolution followed by 1x1 convolution. 118 | # stride is the stride of the convolution 119 | # depth is the number of channels or filters in a layer 120 | Conv = namedtuple('Conv', ['kernel', 'stride', 'depth']) 121 | DepthSepConv = namedtuple('DepthSepConv', ['kernel', 'stride', 'depth']) 122 | 123 | # _CONV_DEFS specifies the MobileNet body 124 | _CONV_DEFS = [ 125 | Conv(kernel=[3, 3], stride=2, depth=32), 126 | DepthSepConv(kernel=[3, 3], stride=1, depth=64), 127 | DepthSepConv(kernel=[3, 3], stride=2, depth=128), 128 | DepthSepConv(kernel=[3, 3], stride=1, depth=128), 129 | DepthSepConv(kernel=[3, 3], stride=2, depth=256), 130 | DepthSepConv(kernel=[3, 3], stride=1, depth=256), 131 | DepthSepConv(kernel=[3, 3], stride=2, depth=512), 132 | DepthSepConv(kernel=[3, 3], stride=1, depth=512), 133 | DepthSepConv(kernel=[3, 3], stride=1, depth=512), 134 | DepthSepConv(kernel=[3, 3], stride=1, depth=512), 135 | DepthSepConv(kernel=[3, 3], stride=1, depth=512), 136 | DepthSepConv(kernel=[3, 3], stride=1, depth=512), 137 | DepthSepConv(kernel=[3, 3], stride=2, depth=1024), 138 | DepthSepConv(kernel=[3, 3], stride=1, depth=1024) 139 | ] 140 | 141 | 142 | def mobilenet_v1_base(inputs, 143 | final_endpoint='Conv2d_13_pointwise', 144 | min_depth=8, 145 | depth_multiplier=1.0, 146 | conv_defs=None, 147 | output_stride=None, 148 | scope=None): 149 | """Mobilenet v1. 150 | 151 | Constructs a Mobilenet v1 network from inputs to the given final endpoint. 152 | 153 | Args: 154 | inputs: a tensor of shape [batch_size, height, width, channels]. 155 | final_endpoint: specifies the endpoint to construct the network up to. It 156 | can be one of ['Conv2d_0', 'Conv2d_1_pointwise', 'Conv2d_2_pointwise', 157 | 'Conv2d_3_pointwise', 'Conv2d_4_pointwise', 'Conv2d_5_pointwise', 158 | 'Conv2d_6_pointwise', 'Conv2d_7_pointwise', 'Conv2d_8_pointwise', 159 | 'Conv2d_9_pointwise', 'Conv2d_10_pointwise', 'Conv2d_11_pointwise', 160 | 'Conv2d_12_pointwise', 'Conv2d_13_pointwise']. 161 | min_depth: Minimum depth value (number of channels) for all convolution ops. 162 | Enforced when depth_multiplier < 1, and not an active constraint when 163 | depth_multiplier >= 1. 164 | depth_multiplier: Float multiplier for the depth (number of channels) 165 | for all convolution ops. The value must be greater than zero. Typical 166 | usage will be to set this value in (0, 1) to reduce the number of 167 | parameters or computation cost of the model. 168 | conv_defs: A list of ConvDef namedtuples specifying the net architecture. 169 | output_stride: An integer that specifies the requested ratio of input to 170 | output spatial resolution. If not None, then we invoke atrous convolution 171 | if necessary to prevent the network from reducing the spatial resolution 172 | of the activation maps. Allowed values are 8 (accurate fully convolutional 173 | mode), 16 (fast fully convolutional mode), 32 (classification mode). 174 | scope: Optional variable_scope. 175 | 176 | Returns: 177 | tensor_out: output tensor corresponding to the final_endpoint. 178 | end_points: a set of activations for external use, for example summaries or 179 | losses. 180 | 181 | Raises: 182 | ValueError: if final_endpoint is not set to one of the predefined values, 183 | or depth_multiplier <= 0, or the target output_stride is not 184 | allowed. 185 | """ 186 | depth = lambda d: max(int(d * depth_multiplier), min_depth) 187 | end_points = {} 188 | 189 | # Used to find thinned depths for each layer. 190 | if depth_multiplier <= 0: 191 | raise ValueError('depth_multiplier is not greater than zero.') 192 | 193 | if conv_defs is None: 194 | conv_defs = _CONV_DEFS 195 | 196 | if output_stride is not None and output_stride not in [8, 16, 32]: 197 | raise ValueError('Only allowed output_stride values are 8, 16, 32.') 198 | 199 | with tf.variable_scope(scope, 'MobilenetV1', [inputs]): 200 | with slim.arg_scope([slim.conv2d, slim.separable_conv2d], padding='SAME'): 201 | # The current_stride variable keeps track of the output stride of the 202 | # activations, i.e., the running product of convolution strides up to the 203 | # current network layer. This allows us to invoke atrous convolution 204 | # whenever applying the next convolution would result in the activations 205 | # having output stride larger than the target output_stride. 206 | current_stride = 1 207 | 208 | # The atrous convolution rate parameter. 209 | rate = 1 210 | 211 | net = inputs 212 | for i, conv_def in enumerate(conv_defs): 213 | end_point_base = 'Conv2d_%d' % i 214 | 215 | if output_stride is not None and current_stride == output_stride: 216 | # If we have reached the target output_stride, then we need to employ 217 | # atrous convolution with stride=1 and multiply the atrous rate by the 218 | # current unit's stride for use in subsequent layers. 219 | layer_stride = 1 220 | layer_rate = rate 221 | rate *= conv_def.stride 222 | else: 223 | layer_stride = conv_def.stride 224 | layer_rate = 1 225 | current_stride *= conv_def.stride 226 | 227 | if isinstance(conv_def, Conv): 228 | end_point = end_point_base 229 | net = slim.conv2d(net, depth(conv_def.depth), conv_def.kernel, 230 | stride=conv_def.stride, 231 | normalizer_fn=slim.batch_norm, 232 | scope=end_point) 233 | end_points[end_point] = net 234 | if end_point == final_endpoint: 235 | return net, end_points 236 | 237 | elif isinstance(conv_def, DepthSepConv): 238 | end_point = end_point_base + '_depthwise' 239 | 240 | # By passing filters=None 241 | # separable_conv2d produces only a depthwise convolution layer 242 | net = slim.separable_conv2d(net, None, conv_def.kernel, 243 | depth_multiplier=1, 244 | stride=layer_stride, 245 | rate=layer_rate, 246 | normalizer_fn=slim.batch_norm, 247 | scope=end_point) 248 | 249 | end_points[end_point] = net 250 | if end_point == final_endpoint: 251 | return net, end_points 252 | 253 | end_point = end_point_base + '_pointwise' 254 | 255 | net = slim.conv2d(net, depth(conv_def.depth), [1, 1], 256 | stride=1, 257 | normalizer_fn=slim.batch_norm, 258 | scope=end_point) 259 | 260 | end_points[end_point] = net 261 | if end_point == final_endpoint: 262 | return net, end_points 263 | else: 264 | raise ValueError('Unknown convolution type %s for layer %d' 265 | % (conv_def.ltype, i)) 266 | raise ValueError('Unknown final endpoint %s' % final_endpoint) 267 | 268 | 269 | def mobilenet_v1(inputs, 270 | num_classes=1000, 271 | dropout_keep_prob=0.999, 272 | is_training=True, 273 | min_depth=8, 274 | depth_multiplier=1.0, 275 | conv_defs=None, 276 | prediction_fn=tf.contrib.layers.softmax, 277 | spatial_squeeze=True, 278 | reuse=None, 279 | scope='MobilenetV1'): 280 | """Mobilenet v1 model for classification. 281 | 282 | Args: 283 | inputs: a tensor of shape [batch_size, height, width, channels]. 284 | num_classes: number of predicted classes. 285 | dropout_keep_prob: the percentage of activation values that are retained. 286 | is_training: whether is training or not. 287 | min_depth: Minimum depth value (number of channels) for all convolution ops. 288 | Enforced when depth_multiplier < 1, and not an active constraint when 289 | depth_multiplier >= 1. 290 | depth_multiplier: Float multiplier for the depth (number of channels) 291 | for all convolution ops. The value must be greater than zero. Typical 292 | usage will be to set this value in (0, 1) to reduce the number of 293 | parameters or computation cost of the model. 294 | conv_defs: A list of ConvDef namedtuples specifying the net architecture. 295 | prediction_fn: a function to get predictions out of logits. 296 | spatial_squeeze: if True, logits is of shape is [B, C], if false logits is 297 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 298 | reuse: whether or not the network and its variables should be reused. To be 299 | able to reuse 'scope' must be given. 300 | scope: Optional variable_scope. 301 | 302 | Returns: 303 | logits: the pre-softmax activations, a tensor of size 304 | [batch_size, num_classes] 305 | end_points: a dictionary from components of the network to the corresponding 306 | activation. 307 | 308 | Raises: 309 | ValueError: Input rank is invalid. 310 | """ 311 | input_shape = inputs.get_shape().as_list() 312 | if len(input_shape) != 4: 313 | raise ValueError('Invalid input tensor rank, expected 4, was: %d' % 314 | len(input_shape)) 315 | 316 | with tf.variable_scope(scope, 'MobilenetV1', [inputs, num_classes], 317 | reuse=reuse) as scope: 318 | with slim.arg_scope([slim.batch_norm, slim.dropout], 319 | is_training=is_training): 320 | net, end_points = mobilenet_v1_base(inputs, scope=scope, 321 | min_depth=min_depth, 322 | depth_multiplier=depth_multiplier, 323 | conv_defs=conv_defs) 324 | with tf.variable_scope('Logits'): 325 | kernel_size = _reduced_kernel_size_for_small_input(net, [7, 7]) 326 | net = slim.avg_pool2d(net, kernel_size, padding='VALID', 327 | scope='AvgPool_1a') 328 | end_points['AvgPool_1a'] = net 329 | # 1 x 1 x 1024 330 | net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b') 331 | logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 332 | normalizer_fn=None, scope='Conv2d_1c_1x1') 333 | if spatial_squeeze: 334 | logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze') 335 | end_points['Logits'] = logits 336 | if prediction_fn: 337 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 338 | return logits, end_points 339 | 340 | mobilenet_v1.default_image_size = 224 341 | 342 | 343 | def wrapped_partial(func, *args, **kwargs): 344 | partial_func = functools.partial(func, *args, **kwargs) 345 | functools.update_wrapper(partial_func, func) 346 | return partial_func 347 | 348 | 349 | mobilenet_v1_075 = wrapped_partial(mobilenet_v1, depth_multiplier=0.75) 350 | mobilenet_v1_050 = wrapped_partial(mobilenet_v1, depth_multiplier=0.50) 351 | mobilenet_v1_025 = wrapped_partial(mobilenet_v1, depth_multiplier=0.25) 352 | 353 | 354 | def _reduced_kernel_size_for_small_input(input_tensor, kernel_size): 355 | """Define kernel size which is automatically reduced for small input. 356 | 357 | If the shape of the input images is unknown at graph construction time this 358 | function assumes that the input images are large enough. 359 | 360 | Args: 361 | input_tensor: input tensor of size [batch_size, height, width, channels]. 362 | kernel_size: desired kernel size of length 2: [kernel_height, kernel_width] 363 | 364 | Returns: 365 | a tensor with the kernel size. 366 | """ 367 | shape = input_tensor.get_shape().as_list() 368 | if shape[1] is None or shape[2] is None: 369 | kernel_size_out = kernel_size 370 | else: 371 | kernel_size_out = [min(shape[1], kernel_size[0]), 372 | min(shape[2], kernel_size[1])] 373 | return kernel_size_out 374 | 375 | 376 | def mobilenet_v1_arg_scope(is_training=True, 377 | weight_decay=0.00004, 378 | stddev=0.09, 379 | regularize_depthwise=False): 380 | """Defines the default MobilenetV1 arg scope. 381 | 382 | Args: 383 | is_training: Whether or not we're training the model. 384 | weight_decay: The weight decay to use for regularizing the model. 385 | stddev: The standard deviation of the trunctated normal weight initializer. 386 | regularize_depthwise: Whether or not apply regularization on depthwise. 387 | 388 | Returns: 389 | An `arg_scope` to use for the mobilenet v1 model. 390 | """ 391 | batch_norm_params = { 392 | 'is_training': is_training, 393 | 'center': True, 394 | 'scale': True, 395 | 'decay': 0.9997, 396 | 'epsilon': 0.001, 397 | } 398 | 399 | # Set weight_decay for weights in Conv and DepthSepConv layers. 400 | weights_init = tf.truncated_normal_initializer(stddev=stddev) 401 | regularizer = tf.contrib.layers.l2_regularizer(weight_decay) 402 | if regularize_depthwise: 403 | depthwise_regularizer = regularizer 404 | else: 405 | depthwise_regularizer = None 406 | with slim.arg_scope([slim.conv2d, slim.separable_conv2d], 407 | weights_initializer=weights_init, 408 | activation_fn=tf.nn.relu6, normalizer_fn=slim.batch_norm): 409 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 410 | with slim.arg_scope([slim.conv2d], weights_regularizer=regularizer): 411 | with slim.arg_scope([slim.separable_conv2d], 412 | weights_regularizer=depthwise_regularizer) as sc: 413 | return sc 414 | -------------------------------------------------------------------------------- /train_tf2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import os 5 | import sys 6 | import time 7 | import json 8 | import numpy as np 9 | import os.path as osp 10 | import logging.config 11 | import tensorflow as tf 12 | from datetime import timedelta 13 | from signal import SIGINT, SIGTERM 14 | from argparse import ArgumentParser 15 | from importlib import import_module 16 | # from tensorflow.contrib import slim 17 | 18 | import matplotlib 19 | import constants as const 20 | matplotlib.use('Agg') 21 | 22 | import common 23 | import lbtoolbox as lb 24 | from nets import NET_CHOICES 25 | from heads import HEAD_CHOICES 26 | from ranking.npair import npairs_loss 27 | from ranking.angular import angular_loss 28 | from ranking.hard_triplet import batch_hard 29 | from ranking import LOSS_CHOICES,METRIC_CHOICES 30 | from model.embedding_model import EmbeddingModel 31 | from ranking.contrastive import contrastive_loss 32 | from ranking.lifted_structured import lifted_loss 33 | from ranking.semi_hard_triplet import triplet_semihard_loss 34 | 35 | 36 | 37 | OPTIMIZER_CHOICES = ( 38 | 'adam', 39 | 'momentum', 40 | ) 41 | 42 | 43 | parser = ArgumentParser(description='Train a ReID network.') 44 | 45 | # Required. 46 | 47 | parser.add_argument( 48 | '--experiment_root', required=True, type=common.writeable_directory, 49 | help='Location used to store checkpoints and dumped data.') 50 | 51 | parser.add_argument( 52 | '--train_set', 53 | help='Path to the train_set csv file.') 54 | 55 | parser.add_argument( 56 | '--image_root', type=common.readable_directory, 57 | help='Path that will be pre-pended to the filenames in the train_set csv.') 58 | 59 | # Optional with sane defaults. 60 | 61 | parser.add_argument( 62 | '--resume', action='store_true', default=False, 63 | help='When this flag is provided, all other arguments apart from the ' 64 | 'experiment_root are ignored and a previously saved set of arguments ' 65 | 'is loaded.') 66 | 67 | parser.add_argument( 68 | '--model_name', default='resnet_v1_50', choices=NET_CHOICES, 69 | help='Name of the model to use.') 70 | 71 | parser.add_argument( 72 | '--head_name', default='fc1024', choices=HEAD_CHOICES, 73 | help='Name of the head to use.') 74 | 75 | parser.add_argument( 76 | '--optimizer', default='adam', choices=OPTIMIZER_CHOICES, 77 | help='Name of the head to use.') 78 | 79 | parser.add_argument( 80 | '--embedding_dim', default=128, type=common.positive_int, 81 | help='Dimensionality of the embedding space.') 82 | 83 | parser.add_argument( 84 | '--initial_checkpoint', default=None, 85 | help='Path to the checkpoint file of the pretrained network.') 86 | 87 | # TODO move these defaults to the .sh script? 88 | parser.add_argument( 89 | '--batch_p', default=32, type=common.positive_int, 90 | help='The number P used in the PK-batches') 91 | 92 | parser.add_argument( 93 | '--batch_k', default=4, type=common.positive_int, 94 | help='The numberK used in the PK-batches') 95 | 96 | parser.add_argument( 97 | '--net_input_height', default=256, type=common.positive_int, 98 | help='Height of the input directly fed into the network.') 99 | 100 | parser.add_argument( 101 | '--net_input_width', default=128, type=common.positive_int, 102 | help='Width of the input directly fed into the network.') 103 | 104 | parser.add_argument( 105 | '--pre_crop_height', default=288, type=common.positive_int, 106 | help='Height used to resize a loaded image. This is ignored when no crop ' 107 | 'augmentation is applied.') 108 | 109 | parser.add_argument( 110 | '--pre_crop_width', default=144, type=common.positive_int, 111 | help='Width used to resize a loaded image. This is ignored when no crop ' 112 | 'augmentation is applied.') 113 | # TODO end 114 | 115 | parser.add_argument( 116 | '--loading_threads', default=8, type=common.positive_int, 117 | help='Number of threads used for parallel loading.') 118 | 119 | parser.add_argument( 120 | '--margin', default='soft', type=common.float_or_string, 121 | help='What margin to use: a float value for hard-margin, "soft" for ' 122 | 'soft-margin, or no margin if "none".') 123 | 124 | parser.add_argument( 125 | '--metric', default='euclidean', choices=METRIC_CHOICES, 126 | help='Which metric to use for the distance between embeddings.') 127 | 128 | parser.add_argument( 129 | '--loss', default='batch_hard', choices=LOSS_CHOICES, 130 | help='Enable the super-mega-advanced top-secret sampling stabilizer.') 131 | 132 | parser.add_argument( 133 | '--learning_rate', default=3e-4, type=common.positive_float, 134 | help='The initial value of the learning-rate, before it kicks in.') 135 | 136 | parser.add_argument( 137 | '--train_iterations', default=25000, type=common.positive_int, 138 | help='Number of training iterations.') 139 | 140 | parser.add_argument( 141 | '--decay_start_iteration', default=15000, type=int, 142 | help='At which iteration the learning-rate decay should kick-in.' 143 | 'Set to -1 to disable decay completely.') 144 | 145 | parser.add_argument( 146 | '--gpu', default='0', type=str, 147 | help='Which GPU to use') 148 | 149 | parser.add_argument( 150 | '--checkpoint_frequency', default=1000, type=common.nonnegative_int, 151 | help='After how many iterations a checkpoint is stored. Set this to 0 to ' 152 | 'disable intermediate storing. This will result in only one final ' 153 | 'checkpoint.') 154 | 155 | parser.add_argument( 156 | '--flip_augment', action='store_true', default=False, 157 | help='When this flag is provided, flip augmentation is performed.') 158 | 159 | parser.add_argument( 160 | '--crop_augment', action='store_true', default=False, 161 | help='When this flag is provided, crop augmentation is performed. Based on' 162 | 'The `crop_height` and `crop_width` parameters. Changing this flag ' 163 | 'thus likely changes the network input size!') 164 | 165 | parser.add_argument( 166 | '--detailed_logs', action='store_true', default=False, 167 | help='Store very detailed logs of the training in addition to TensorBoard' 168 | ' summaries. These are mem-mapped numpy files containing the' 169 | ' embeddings, losses and FIDs seen in each batch during training.' 170 | ' Everything can be re-constructed and analyzed that way.') 171 | 172 | parser.add_argument( 173 | '--augment', action='store_true', default=False, help='Data augmentation with imgaug') 174 | 175 | 176 | def sample_k_fids_for_pid(pid, all_fids, all_pids, batch_k): 177 | """ Given a PID, select K FIDs of that specific PID. """ 178 | possible_fids = tf.boolean_mask(all_fids, tf.math.equal(all_pids, pid)) 179 | 180 | # The following simply uses a subset of K of the possible FIDs 181 | # if more than, or exactly K are available. Otherwise, we first 182 | # create a padded list of indices which contain a multiple of the 183 | # original FID count such that all of them will be sampled equally likely. 184 | count = tf.shape(possible_fids)[0] 185 | padded_count = tf.cast(tf.math.ceil(batch_k / tf.dtypes.cast(count, tf.dtypes.float32)), tf.dtypes.int32) * count 186 | full_range = tf.math.mod(tf.range(padded_count), count) 187 | 188 | # Sampling is always performed by shuffling and taking the first k. 189 | shuffled = tf.random.shuffle(full_range) 190 | selected_fids = tf.gather(possible_fids, shuffled[:batch_k]) 191 | 192 | return selected_fids, tf.fill([batch_k], pid) 193 | 194 | 195 | 196 | 197 | def main(argv): 198 | 199 | 200 | args = parser.parse_args(argv) 201 | 202 | if args.gpu: 203 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 204 | 205 | # tf.compat.v1.disable_eager_execution() 206 | 207 | # physical_devices = tf.config.experimental.list_physical_devices('GPU') 208 | # tf.config.experimental.set_memory_growth(physical_devices[0], True) 209 | 210 | 211 | # We store all arguments in a json file. This has two advantages: 212 | # 1. We can always get back and see what exactly that experiment was 213 | # 2. We can resume an experiment as-is without needing to remember all flags. 214 | args_file = os.path.join(args.experiment_root, 'args.json') 215 | if args.resume: 216 | if not os.path.isfile(args_file): 217 | raise IOError('`args.json` not found in {}'.format(args_file)) 218 | 219 | print('Loading args from {}.'.format(args_file)) 220 | with open(args_file, 'r') as f: 221 | args_resumed = json.load(f) 222 | args_resumed['resume'] = True # This would be overwritten. 223 | 224 | # When resuming, we not only want to populate the args object with the 225 | # values from the file, but we also want to check for some possible 226 | # conflicts between loaded and given arguments. 227 | for key, value in args.__dict__.items(): 228 | if key in args_resumed: 229 | resumed_value = args_resumed[key] 230 | if resumed_value != value: 231 | print('Warning: For the argument `{}` we are using the' 232 | ' loaded value `{}`. The provided value was `{}`' 233 | '.'.format(key, resumed_value, value)) 234 | args.__dict__[key] = resumed_value 235 | else: 236 | print('Warning: A new argument was added since the last run:' 237 | ' `{}`. Using the new value: `{}`.'.format(key, value)) 238 | 239 | else: 240 | # If the experiment directory exists already, we bail in fear. 241 | if os.path.exists(args.experiment_root): 242 | if os.listdir(args.experiment_root): 243 | print('The directory {} already exists and is not empty.' 244 | ' If you want to resume training, append --resume to' 245 | ' your call.'.format(args.experiment_root)) 246 | exit(1) 247 | else: 248 | os.makedirs(args.experiment_root) 249 | 250 | # Store the passed arguments for later resuming and grepping in a nice 251 | # and readable format. 252 | with open(args_file, 'w') as f: 253 | json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True) 254 | 255 | log_file = os.path.join(args.experiment_root, "train") 256 | logging.config.dictConfig(common.get_logging_dict(log_file)) 257 | log = logging.getLogger('train') 258 | 259 | # Also show all parameter values at the start, for ease of reading logs. 260 | log.info('Training using the following parameters:') 261 | for key, value in sorted(vars(args).items()): 262 | log.info('{}: {}'.format(key, value)) 263 | 264 | # Check them here, so they are not required when --resume-ing. 265 | if not args.train_set: 266 | parser.print_help() 267 | log.error("You did not specify the `train_set` argument!") 268 | sys.exit(1) 269 | if not args.image_root: 270 | parser.print_help() 271 | log.error("You did not specify the required `image_root` argument!") 272 | sys.exit(1) 273 | 274 | # Load the data from the CSV file. 275 | pids, fids = common.load_dataset(args.train_set, args.image_root) 276 | max_fid_len = max(map(len, fids)) # We'll need this later for logfiles. 277 | 278 | # Setup a tf.Dataset where one "epoch" loops over all PIDS. 279 | # PIDS are shuffled after every epoch and continue indefinitely. 280 | unique_pids = np.unique(pids) 281 | if len(unique_pids) < args.batch_p: 282 | unique_pids = np.tile(unique_pids, int(np.ceil(args.batch_p / len(unique_pids)))) 283 | dataset = tf.data.Dataset.from_tensor_slices(unique_pids) 284 | dataset = dataset.shuffle(len(unique_pids)) 285 | 286 | # Constrain the dataset size to a multiple of the batch-size, so that 287 | # we don't get overlap at the end of each epoch. 288 | dataset = dataset.take((len(unique_pids) // args.batch_p) * args.batch_p) 289 | dataset = dataset.repeat(None) # Repeat forever. Funny way of stating it. 290 | 291 | # For every PID, get K images. 292 | dataset = dataset.map(lambda pid: sample_k_fids_for_pid( 293 | pid, all_fids=fids, all_pids=pids, batch_k=args.batch_k)) 294 | 295 | # Ungroup/flatten the batches for easy loading of the files. 296 | dataset = dataset.unbatch() 297 | 298 | # Convert filenames to actual image tensors. 299 | net_input_size = (args.net_input_height, args.net_input_width) 300 | pre_crop_size = (args.pre_crop_height, args.pre_crop_width) 301 | 302 | dataset = dataset.map( 303 | lambda fid, pid: common.fid_to_image( 304 | fid, pid, image_root=args.image_root, 305 | image_size=pre_crop_size if args.crop_augment else net_input_size), 306 | num_parallel_calls=args.loading_threads) 307 | 308 | 309 | # Augment the data if specified by the arguments. 310 | 311 | dataset = dataset.map( 312 | lambda im, fid, pid: common.fid_to_image( 313 | fid, pid, image_root=args.image_root, 314 | image_size=pre_crop_size if args.crop_augment else net_input_size), # Ergys 315 | num_parallel_calls=args.loading_threads) 316 | 317 | if args.flip_augment: 318 | dataset = dataset.map( 319 | lambda im, fid, pid: (tf.image.random_flip_left_right(im), fid, pid)) 320 | if args.crop_augment: 321 | dataset = dataset.map( 322 | lambda im, fid, pid: (tf.image.random_crop(im, net_input_size + (3,)), fid, pid)) 323 | 324 | # Create the model and an embedding head. 325 | tf.keras.backend.set_learning_phase(1) 326 | emb_model = EmbeddingModel(args) 327 | 328 | # Group it back into PK batches. 329 | batch_size = args.batch_p * args.batch_k 330 | dataset = dataset.map(lambda im, fid, pid: (emb_model.preprocess_input(im), fid, pid)) 331 | dataset = dataset.batch(batch_size) 332 | 333 | # Overlap producing and consuming for parallelism. 334 | dataset = dataset.prefetch(1) 335 | 336 | # Since we repeat the data infinitely, we only need a one-shot iterator. 337 | 338 | 339 | 340 | # Feed the image through the model. The returned `body_prefix` will be used 341 | # further down to load the pre-trained weights for all variables with this 342 | # prefix. 343 | 344 | 345 | 346 | 347 | # all_trainable_variables = embedding_head.trainable_variables+base_model.trainable_variables 348 | 349 | # Define the optimizer and the learning-rate schedule. 350 | # Unfortunately, we get NaNs if we don't handle no-decay separately. 351 | if 0 <= args.decay_start_iteration < args.train_iterations: 352 | learning_rate = tf.optimizers.schedules.PolynomialDecay(args.learning_rate, args.train_iterations, 353 | end_learning_rate=1e-7) 354 | else: 355 | learning_rate = args.learning_rate 356 | 357 | if args.optimizer == 'adam': 358 | optimizer = tf.keras.optimizers.Adam(learning_rate) 359 | elif args.optimizer == 'momentum': 360 | optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=0.9) 361 | else: 362 | raise NotImplementedError('Invalid optimizer {}'.format(args.optimizer)) 363 | 364 | @tf.function 365 | def train_step(images, pids ): 366 | 367 | with tf.GradientTape() as tape: 368 | batch_embedding = emb_model(images) 369 | if args.loss == 'semi_hard_triplet': 370 | embedding_loss = triplet_semihard_loss(batch_embedding, pids, args.margin) 371 | elif args.loss == 'hard_triplet': 372 | embedding_loss = batch_hard(batch_embedding, pids, args.margin, args.metric) 373 | elif args.loss == 'lifted_loss': 374 | embedding_loss = lifted_loss(pids, batch_embedding, margin=args.margin) 375 | elif args.loss == 'contrastive_loss': 376 | assert batch_size % 2 == 0 377 | assert args.batch_k == 4 ## Can work with other number but will need tuning 378 | 379 | contrastive_idx = np.tile([0, 1, 4, 3, 2, 5, 6, 7], args.batch_p // 2) 380 | for i in range(args.batch_p // 2): 381 | contrastive_idx[i * 8:i * 8 + 8] += i * 8 382 | 383 | contrastive_idx = np.expand_dims(contrastive_idx, 1) 384 | batch_embedding_ordered = tf.gather_nd(batch_embedding, contrastive_idx) 385 | pids_ordered = tf.gather_nd(pids, contrastive_idx) 386 | # batch_embedding_ordered = tf.Print(batch_embedding_ordered,[pids_ordered],'pids_ordered :: ',summarize=1000) 387 | embeddings_anchor, embeddings_positive = tf.unstack( 388 | tf.reshape(batch_embedding_ordered, [-1, 2, args.embedding_dim]), 2, 389 | 1) 390 | # embeddings_anchor = tf.Print(embeddings_anchor,[pids_ordered,embeddings_anchor,embeddings_positive,batch_embedding,batch_embedding_ordered],"Tensors ", summarize=1000) 391 | 392 | fixed_labels = np.tile([1, 0, 0, 1], args.batch_p // 2) 393 | # fixed_labels = np.reshape(fixed_labels,(len(fixed_labels),1)) 394 | # print(fixed_labels) 395 | labels = tf.constant(fixed_labels) 396 | # labels = tf.Print(labels,[labels],'labels ',summarize=1000) 397 | embedding_loss = contrastive_loss(labels, embeddings_anchor, embeddings_positive, 398 | margin=args.margin) 399 | elif args.loss == 'angular_loss': 400 | embeddings_anchor, embeddings_positive = tf.unstack( 401 | tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]), 2, 402 | 1) 403 | # pids = tf.Print(pids, [pids], 'pids:: ', summarize=100) 404 | pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1) 405 | # pids = tf.Print(pids,[pids],'pids:: ',summarize=100) 406 | embedding_loss = angular_loss(pids, embeddings_anchor, embeddings_positive, 407 | batch_size=args.batch_p, with_l2reg=True) 408 | 409 | elif args.loss == 'npairs_loss': 410 | assert args.batch_k == 2 ## Single positive pair per class 411 | embeddings_anchor, embeddings_positive = tf.unstack( 412 | tf.reshape(batch_embedding, [-1, 2, args.embedding_dim]), 2, 1) 413 | pids, _ = tf.unstack(tf.reshape(pids, [-1, 2, 1]), 2, 1) 414 | pids = tf.reshape(pids, [-1]) 415 | embedding_loss = npairs_loss(pids, embeddings_anchor, embeddings_positive) 416 | 417 | else: 418 | raise NotImplementedError('Invalid Loss {}'.format(args.loss)) 419 | loss_mean = tf.reduce_mean(embedding_loss) 420 | 421 | 422 | gradients = tape.gradient(loss_mean, emb_model.trainable_variables) 423 | optimizer.apply_gradients(zip(gradients, emb_model.trainable_variables)) 424 | 425 | return embedding_loss 426 | 427 | # sess = tf.compat.v1.Session() 428 | # start_step = sess.run(global_step) 429 | # checkpoint_saver = tf.train.Saver(max_to_keep=2) 430 | start_step = 0 431 | log.info('Starting training from iteration {}.'.format(start_step)) 432 | dataset_iter = iter(dataset) 433 | 434 | ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=emb_model) 435 | manager = tf.train.CheckpointManager(ckpt, osp.join(args.experiment_root,'tf_ckpts'), max_to_keep=3) 436 | 437 | ckpt.restore(manager.latest_checkpoint) 438 | if manager.latest_checkpoint: 439 | print("Restored from {}".format(manager.latest_checkpoint)) 440 | else: 441 | print("Initializing from scratch.") 442 | 443 | with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u: 444 | for i in range(ckpt.step.numpy(), args.train_iterations): 445 | # for batch_idx, batch in enumerate(): 446 | start_time = time.time() 447 | images, fids, pids = next(dataset_iter) 448 | batch_loss = train_step(images, pids) 449 | elapsed_time = time.time() - start_time 450 | seconds_todo = (args.train_iterations - i) * elapsed_time 451 | # print(tf.reduce_min(batch_loss).numpy(),tf.reduce_mean(batch_loss).numpy(),tf.reduce_max(batch_loss).numpy()) 452 | log.info('iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ETA: {} ({:.2f}s/it)'.format( 453 | i, 454 | tf.reduce_min(batch_loss).numpy(),tf.reduce_mean(batch_loss).numpy(),tf.reduce_max(batch_loss).numpy(), 455 | # args.batch_k - 1, float(b_prec_at_k), 456 | timedelta(seconds=int(seconds_todo)), 457 | elapsed_time)) 458 | 459 | ckpt.step.assign_add(1) 460 | if (args.checkpoint_frequency > 0 and i % args.checkpoint_frequency == 0): 461 | 462 | # uncomment if you want to save the model weight separately 463 | # emb_model.save_weights(os.path.join(args.experiment_root, 'model_weights_{0:04d}.w'.format(i))) 464 | 465 | manager.save() 466 | 467 | # Stop the main-loop at the end of the step, if requested. 468 | if u.interrupted: 469 | log.info("Interrupted on request!") 470 | break 471 | 472 | # print(fids) 473 | 474 | 475 | if __name__ == '__main__': 476 | 477 | dataset_dir = const.dataset_dir 478 | trained_models_dir = const.trained_models_dir 479 | experiment_root_dir = const.experiment_root_dir 480 | 481 | dataset_name = 'cub' 482 | 483 | if dataset_name == 'cub': 484 | db_dir = 'CUB_200_2011/images' 485 | train_file = 'cub_train.csv' 486 | extra_args = [ 487 | '--batch_p', '20', 488 | '--batch_k', '2', 489 | '--train_iterations','10000', 490 | '--optimizer', 'momentum', 491 | ] 492 | elif dataset_name == 'inshop': 493 | db_dir = 'In_shop_Clothes_Retrieval_Benchmark' 494 | train_file = 'deep_fashion_train.csv' 495 | extra_args = [ 496 | # p_10,k_6 497 | '--batch_p', '10', 498 | '--batch_k', '6', 499 | '--optimizer', 'adam', 500 | ] 501 | elif dataset_name == 'stanford': 502 | db_dir = 'Stanford_Online_Products' 503 | train_file = 'stanford_online_train.csv' 504 | extra_args = [ 505 | # p_10,k_6 506 | '--batch_p', '20', 507 | '--batch_k', '2', 508 | '--train_iterations', '30000', 509 | '--optimizer', 'adam', 510 | ] 511 | else: 512 | raise NotImplementedError('invalid dataset {}'.format(dataset_name)) 513 | 514 | arg_loss = 'npairs_loss' 515 | arg_head = 'direct_normalize' 516 | arg_margin = '0.2' 517 | arg_arch = 'densenet' 518 | 519 | 520 | exp_name = [dataset_name, arg_arch, arg_head, arg_loss, 'm_{}'.format(arg_margin)] 521 | exp_name = '_'.join(exp_name) 522 | 523 | 524 | args = [ 525 | '--image_root', dataset_dir + db_dir, 526 | '--experiment_root', experiment_root_dir + exp_name, 527 | 528 | 529 | '--train_set', './data/' + train_file, 530 | 531 | '--net_input_height', '224', 532 | '--net_input_width', '224', 533 | '--pre_crop_height', '256', 534 | '--pre_crop_width', '256', 535 | 536 | '--flip_augment', 537 | '--crop_augment', 538 | 539 | # '--resume', 540 | '--head_name', arg_head, 541 | '--margin', arg_margin, 542 | '--loss', arg_loss, 543 | '--gpu', '0', 544 | ] 545 | args.extend([ 546 | 547 | ]) 548 | if arg_arch == 'resnet': 549 | args.extend( 550 | [ 551 | '--initial_checkpoint', trained_models_dir + 'resnet_v1_50/resnet_v1_50.ckpt', 552 | '--model_name', 'resnet_v1_50', 553 | ] 554 | ) 555 | if arg_arch == 'inc_v1': 556 | args.extend( 557 | [ 558 | '--initial_checkpoint', trained_models_dir + 'inception_v1/inception_v1.ckpt', 559 | '--model_name', 'inception_v1', 560 | ] 561 | ) 562 | elif arg_arch == 'densenet': 563 | args.extend( 564 | [ 565 | '--initial_checkpoint', trained_models_dir + 'tf-densenet169/tf-densenet169.ckpt', 566 | '--model_name', 'densenet169', 567 | ] 568 | ) 569 | 570 | 571 | args.extend(extra_args) 572 | 573 | main(args) 574 | 575 | --------------------------------------------------------------------------------