├── README.md ├── activitynet_category_name.txt ├── aher_adv.py ├── aher_anet.py ├── aher_common.py ├── aher_k600_test.py ├── aher_k600_train.py ├── custom_layers.py ├── data_loader.py ├── evaluation ├── eval_detection.py ├── eval_proposal.py ├── get_detection_performance.py ├── get_proposal_performance.py └── utils.py ├── k600_category.txt ├── k600_val_annotation.txt ├── pic └── eccv_framework.JPG ├── tf_extended ├── __init__.py ├── bboxes.py ├── math.py ├── metrics.py └── tensors.py └── tf_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # AherNet: Learning to Localize Actions from Moments 2 | 3 | This repository includes the codes and configuration files of the transfer setting of "From ActivityNet v1.3 to Kinetics-600, i.e., ANet -> K600" in the paper. 4 | 5 | All the training/validation features and related meta files of the dataset are not uploaded in the repository since they are too large, which will be released on google drive or Baidu Yun in the future. 6 | 7 | The file `k600_val_annotation.txt` contains the mannul temporal annotations of 6,459 videos in the validation set of Kinetics-600. 8 | 9 | # Update 10 | * 2020.8.25: Repository for AherNet and annotations of sampled validation video in Kinetics-600 11 | 12 | # Contents: 13 | 14 | * [Paper Introduction](#paper-introduction) 15 | * [Environment](#environment) 16 | * [Training of AherNet](#training-of-ahernet) 17 | * [Testing of AherNet on Kinetics-600](#testing-of-ahernet-on-kinetics-600) 18 | * [Citation](#citation) 19 | 20 | # Paper Introduction 21 | 22 |
image
23 | 24 | With the knowledge of action moments (i.e., trimmed video clips that each contains an action instance), humans could routinely localize an action temporally in an untrimmed video. Nevertheless, most practical methods still require all training videos to be labeled with temporal annotations (action category and temporal boundary) and develop the models in a fully-supervised manner, despite expensive labeling efforts and inapplicable to new categories. In this paper, we introduce a new design of transfer learning type to learn action localization for a large set of action categories, but only on action moments from the categories of interest and temporal annotations of untrimmed videos from a small set of action classes. Specifically, we present Action Herald Networks (AherNet) that integrate such design into an one-stage action localization framework. Technically, a weight transfer function is uniquely devised to build the transformation between classification of action moments or foreground video segments and action localization in synthetic contextual moments or untrimmed videos. The context of each moment is learnt through the adversarial mechanism to differentiate the generated features from those of background in untrimmed videos. Extensive experiments are conducted on the learning both across the splits of ActivityNet v1.3 and from THUMOS14 to ActivityNet v1.3. Our AherNet demonstrates the superiority even comparing to most fully-supervised action localization methods. More remarkably, we train AherNet to localize actions from 600 categories on the leverage of action moments in Kinetics-600 and temporal annotations from 200 classes in ActivityNet v1.3. 25 | 26 | # Environment 27 | 28 | TensorFlow version: 1.12.3 of GPU 29 | 30 | Operation system in docker version: Ubuntu16.04 31 | 32 | Python version: 3.6.7 33 | 34 | GPU version: NVIDIA Tesla P40 (23GB) 35 | 36 | Cuda and Cudnn version: CUDA 9.0 and cudnn 7.3.1 37 | 38 | 39 | # Training of AherNet 40 | 41 | The training script is `aher_k600_train.py`. 42 | Once you have all the training features and information files, you could run 43 | 44 | ``` 45 | CUDA_VISIBLE_DEVICES=0 python3 aher_k600_train.py 46 | ``` 47 | 48 | # Testing of AherNet on Kinetics-600 49 | 50 | The testing script is `aher_k600_test.py`. 51 | Once you have all the testing features and information files, you could run 52 | 53 | ``` 54 | CUDA_VISIBLE_DEVICES=0 python3 aher_k600_test.py 55 | ``` 56 | You can evaluate several snapshots (models) during testing stage to find which model is the best one. 57 | 58 | After the evaluation, you can get the results of ".json" file which will be evaluated by `./evaluation/get_proposal_performance.py` or `./evaluation/get_detection_performance.py` for temporal action proposal and temporal action localization evaluation. 59 | 60 | 61 | # Citation 62 | 63 | If you use these models in your research, please cite: 64 | 65 | @inproceedings{Long:ECCV20, 66 | title={Learning To Localize Actions from Moments}, 67 | author={Fuchen Long, Ting Yao, Zhaofan Qiu, Xinmei Tian, Jiebo Luo and Tao Mei}, 68 | booktitle={ECCV}, 69 | year={2020} 70 | } 71 | -------------------------------------------------------------------------------- /activitynet_category_name.txt: -------------------------------------------------------------------------------- 1 | Drinking coffee 2 | Zumba 3 | Doing kickboxing 4 | Tango 5 | Playing polo 6 | Putting on makeup 7 | Snatch 8 | Long jump 9 | Cricket 10 | Ironing clothes 11 | Clean and jerk 12 | Using parallel bars 13 | Bathing dog 14 | Discus throw 15 | Playing field hockey 16 | Grooming horse 17 | Preparing salad 18 | Doing karate 19 | Playing harmonica 20 | Pole vault 21 | Playing saxophone 22 | Chopping wood 23 | Washing face 24 | Using the pommel horse 25 | Javelin throw 26 | Spinning 27 | Ping-pong 28 | Volleyball 29 | Brushing hair 30 | Making a sandwich 31 | Playing water polo 32 | Cleaning shoes 33 | Drinking beer 34 | Playing bagpipes 35 | Cheerleading 36 | Paintball 37 | Cleaning windows 38 | Brushing teeth 39 | Playing flauta 40 | Tennis serve with ball bouncing 41 | Starting a campfire 42 | Bungee jumping 43 | Triple jump 44 | Polishing forniture 45 | Layup drill in basketball 46 | Vacuuming floor 47 | Dodgeball 48 | Doing nails 49 | Shot put 50 | Fixing bicycle 51 | Washing hands 52 | Horseback riding 53 | Skateboarding 54 | Wrapping presents 55 | Using the balance beam 56 | Shoveling snow 57 | Preparing pasta 58 | Getting a tattoo 59 | Rock climbing 60 | Smoking hookah 61 | Shaving 62 | Using uneven bars 63 | Springboard diving 64 | Playing squash 65 | High jump 66 | Playing piano 67 | Shaving legs 68 | Smoking a cigarette 69 | Getting a haircut 70 | Playing lacrosse 71 | Cumbia 72 | Washing dishes 73 | Getting a piercing 74 | Painting 75 | Mowing the lawn 76 | Walking the dog 77 | Hammer throw 78 | Polishing shoes 79 | Ballet 80 | Doing step aerobics 81 | Hand washing clothes 82 | Plataform diving 83 | Playing violin 84 | Breakdancing 85 | Windsurfing 86 | Sailing 87 | Hopscotch 88 | Doing motocross 89 | Mixing drinks 90 | Belly dance 91 | Removing curlers 92 | Archery 93 | Playing guitarra 94 | Playing racquetball 95 | Kayaking 96 | Playing kickball 97 | Tumbling 98 | Tai chi 99 | Playing accordion 100 | Playing badminton 101 | Arm wrestling 102 | Assembling bicycle 103 | BMX 104 | Baking cookies 105 | Baton twirling 106 | Beach soccer 107 | Beer pong 108 | Blow-drying hair 109 | Blowing leaves 110 | Playing ten pins 111 | Braiding hair 112 | Building sandcastles 113 | Bullfighting 114 | Calf roping 115 | Camel ride 116 | Canoeing 117 | Capoeira 118 | Carving jack-o-lanterns 119 | Changing car wheel 120 | Cleaning sink 121 | Clipping cat claws 122 | Croquet 123 | Curling 124 | Cutting the grass 125 | Decorating the Christmas tree 126 | Disc dog 127 | Doing a powerbomb 128 | Doing crunches 129 | Drum corps 130 | Elliptical trainer 131 | Doing fencing 132 | Fixing the roof 133 | Fun sliding down 134 | Futsal 135 | Gargling mouthwash 136 | Grooming dog 137 | Hand car wash 138 | Hanging wallpaper 139 | Having an ice cream 140 | Hitting a pinata 141 | Hula hoop 142 | Hurling 143 | Ice fishing 144 | Installing carpet 145 | Kite flying 146 | Kneeling 147 | Knitting 148 | Laying tile 149 | Longboarding 150 | Making a cake 151 | Making a lemonade 152 | Making an omelette 153 | Mooping floor 154 | Painting fence 155 | Painting furniture 156 | Peeling potatoes 157 | Plastering 158 | Playing beach volleyball 159 | Playing blackjack 160 | Playing congas 161 | Playing drums 162 | Playing ice hockey 163 | Playing pool 164 | Playing rubik cube 165 | Powerbocking 166 | Putting in contact lenses 167 | Putting on shoes 168 | Rafting 169 | Raking leaves 170 | Removing ice from car 171 | Riding bumper cars 172 | River tubing 173 | Rock-paper-scissors 174 | Rollerblading 175 | Roof shingle removal 176 | Rope skipping 177 | Running a marathon 178 | Scuba diving 179 | Sharpening knives 180 | Shuffleboard 181 | Skiing 182 | Slacklining 183 | Snow tubing 184 | Snowboarding 185 | Spread mulch 186 | Sumo 187 | Surfing 188 | Swimming 189 | Swinging at the playground 190 | Table soccer 191 | Throwing darts 192 | Trimming branches or hedges 193 | Tug of war 194 | Using the monkey bar 195 | Using the rowing machine 196 | Wakeboarding 197 | Waterskiing 198 | Waxing skis 199 | Welding 200 | Applying sunscreen 201 | -------------------------------------------------------------------------------- /aher_adv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | import numpy as np 4 | import aher_anet 5 | from aher_anet import aher_multibox_adv_layer 6 | import tf_utils 7 | from datetime import datetime 8 | import time 9 | import data_loader 10 | import tf_extended as tfe 11 | import os 12 | import sys 13 | import pandas as pd 14 | from multiprocessing import Process,Queue,JoinableQueue 15 | import multiprocessing 16 | import math 17 | import random 18 | from tf_utils import * 19 | 20 | FLAGS = tf.app.flags.FLAGS 21 | tf.app.flags.DEFINE_float('dis_weights', 0.1, 'The weight for the discriminator.') 22 | tf.app.flags.DEFINE_float('gen_weights', 0.1, 'The weight for the generator.') 23 | 24 | class Config(object): 25 | def __init__(self): 26 | self.learning_rates=[0.001]*100+[0.0001]*100 27 | #self.training_epochs = len(self.learning_rates) 28 | self.training_epochs = 1 29 | self.total_batch_num = 15000 30 | self.n_inputs = 2048 31 | self.batch_size = 16 32 | self.input_steps=512 33 | self.input_moment_steps=256 34 | self.gt_hold_num = 25 35 | self.gt_hold_num_th = 25 36 | self.batch_size_val=1 37 | 38 | # generate context feature 39 | def Context_Train(m_feature,ratio_id,pos_id): 40 | """ Model and loss function of context information network 41 | input: m_feature: batch_size x 256 x 2048 42 | input: position: batch_size 43 | output: concate_feature: batch_size x 512 x 2048 44 | """ 45 | config = Config() 46 | 47 | # The start context generator 48 | net1_i=tf.contrib.layers.conv1d(inputs=m_feature[:,:,],num_outputs=1024,kernel_size=3, \ 49 | stride=1,padding='same',scope='g_conv_s1') 50 | 51 | net1=tf.contrib.layers.conv1d(inputs=net1_i,num_outputs=2048,kernel_size=3, \ 52 | stride=2,padding='same',scope='g_conv_s2') 53 | 54 | # The end context generator 55 | net2_i=tf.contrib.layers.conv1d(inputs=m_feature[:,:,],num_outputs=1024,kernel_size=3, 56 | stride=1,padding='same',scope='g_conv_e1') 57 | 58 | net2=tf.contrib.layers.conv1d(inputs=net2_i,num_outputs=2048,kernel_size=3, \ 59 | stride=2,padding='same',scope='g_conv_e2') 60 | 61 | 62 | # random crop and select temporal gt 63 | net_res = [] 64 | temporal_gt = [] 65 | for i in range(config.batch_size): 66 | ratio = tf.cast(ratio_id[i],tf.float32) * tf.constant(0.05) 67 | posi = tf.cast(pos_id[i],tf.float32) * tf.constant(0.05) 68 | 69 | resize_fea_len = tf.cast(tf.constant(512.0)*ratio, tf.int32) 70 | temp_feature = tf.expand_dims(m_feature,2) 71 | resize_fea = tf.image.resize_images(temp_feature,[resize_fea_len,1]) 72 | reduce_fea = tf.squeeze(resize_fea,2) 73 | 74 | net1_len = tf.cast((tf.constant(512)-resize_fea_len)/tf.constant(2),tf.int32) 75 | net2_len = tf.constant(512)-resize_fea_len-net1_len 76 | 77 | temp_net1 = tf.expand_dims(net1,2) 78 | resize_net1 = tf.image.resize_images(temp_net1,[net1_len,1]) 79 | inc_net1 = tf.squeeze(resize_net1,2) 80 | 81 | temp_net2 = tf.expand_dims(net2,2) 82 | resize_net2 = tf.image.resize_images(temp_net2,[net2_len,1]) 83 | inc_net2 = tf.squeeze(resize_net2,2) 84 | 85 | if i % 2 == 0: 86 | start = tf.cast((tf.constant(1.0)-ratio) * tf.cast(net1_len,tf.float32) * posi, tf.int32) 87 | net_a = inc_net1[i,:start,] 88 | net_b = inc_net1[i,start:,] 89 | net_3 = tf.keras.layers.concatenate(inputs=[net_a,reduce_fea[i,:,],net_b,inc_net2[i,:,]],axis=0) 90 | net_res.append(tf.reshape(net_3,[1,config.input_steps,config.n_inputs])) 91 | temporal_gt.append(tf.reshape(tf.cast(start,tf.float32),[1])) 92 | temporal_gt.append(tf.reshape(tf.cast(start + resize_fea_len,tf.float32),[1])) 93 | else: 94 | start = tf.cast((tf.constant(1.0)-ratio) * tf.cast(net2_len,tf.float32) * posi, tf.int32) 95 | net_a = inc_net2[i,:start,] 96 | net_b = inc_net2[i,start:,] 97 | net_3 = tf.keras.layers.concatenate(inputs=[inc_net1[i,:,],net_a,reduce_fea[i,:,],net_b],axis=0) 98 | net_res.append(tf.reshape(net_3,[1,config.input_steps,config.n_inputs])) 99 | temporal_gt.append(tf.reshape(tf.cast(start + net1_len,tf.float32),[1])) 100 | temporal_gt.append(tf.reshape(tf.cast(start + resize_fea_len,tf.float32),[1])) 101 | 102 | net_c = tf.concat(net_res,axis=0) 103 | temp_gt = tf.concat(temporal_gt,axis=0) 104 | temp_gt = tf.reshape(temp_gt,[config.batch_size,1,2]) 105 | 106 | return net_c,temp_gt 107 | 108 | # Discriminator in each anchor layer 109 | def Context_Back_Discriminator(input_points, 110 | feat_layers=aher_anet.AHERNet.default_params.feat_layers, 111 | anchor_sizes=aher_anet.AHERNet.default_params.anchor_sizes, 112 | anchor_ratios=aher_anet.AHERNet.default_params.anchor_ratios, 113 | normalizations=aher_anet.AHERNet.default_params.normalizations, 114 | reuse = None): 115 | 116 | num_classes = 2 117 | D_logits = [] 118 | D = [] 119 | for i, layer in enumerate(feat_layers): 120 | with tf.variable_scope(layer + '_adv',reuse=reuse): 121 | adv_logits = aher_multibox_adv_layer(input_points[layer], 122 | num_classes, 123 | anchor_sizes[i], 124 | anchor_ratios[i], 125 | normalizations[i]) 126 | D_logits.append(adv_logits) 127 | D.append(tf.math.sigmoid(adv_logits)) 128 | return D,D_logits 129 | 130 | # Background Discriminator 131 | def Adversary_Back_Train(D, D_logits, D_, D_logits_, 132 | gscore_untrim, gscore_gene,scope=None): 133 | 134 | with tf.name_scope(scope,'aher_adv_losses'): 135 | lshape = tfe.get_shape(D_logits[0], 8) 136 | batch_size = lshape[0] 137 | fgscore_untrim = [] 138 | fgscore_gene = [] 139 | f_D_logits = [] 140 | f_D_logits_ = [] 141 | f_D = [] 142 | f_D_ = [] 143 | for i in range(len(D_logits)): 144 | fgscore_untrim.append(tf.reshape(gscore_untrim[i],[-1])) 145 | fgscore_gene.append(tf.reshape(gscore_gene[i],[-1])) 146 | f_D_logits.append(tf.reshape(D_logits[i],[-1])) 147 | f_D_logits_.append(tf.reshape(D_logits_[i],[-1])) 148 | f_D.append(tf.reshape(D[i],[-1])) 149 | f_D_.append(tf.reshape(D_[i],[-1])) 150 | gscore_untrim = tf.concat(fgscore_untrim, axis=0) 151 | gscore_gene = tf.concat(fgscore_gene, axis=0) 152 | D_logits = tf.concat(f_D_logits, axis=0) 153 | D_logits_ = tf.concat(f_D_logits_, axis=0) 154 | D = tf.concat(f_D, axis=0) 155 | D_ = tf.concat(f_D_, axis=0) 156 | dtype = D_logits.dtype 157 | 158 | # select the background position and logits 159 | pos_mask_untrim = gscore_untrim > 0.70 160 | nmask_untrim = tf.logical_and(tf.logical_not(pos_mask_untrim),gscore_untrim < 0.3) 161 | 162 | pos_mask_gene = gscore_gene > 0.70 163 | nmask_gene = tf.logical_and(tf.logical_not(pos_mask_gene),gscore_gene < 0.3) 164 | 165 | nmask = tf.logical_and(nmask_untrim,nmask_gene) 166 | fnmask = tf.cast(nmask, dtype) 167 | fnmask_num = tf.reduce_sum(fnmask) 168 | 169 | # compute the sigmoid cross entropy loss 170 | d_loss_real=sigmoid_cross_entropy_with_logits(D_logits, tf.ones_like(D)) 171 | d_loss_real=tf.div(tf.reduce_sum(d_loss_real*fnmask), fnmask_num/FLAGS.dis_weights, name='d_loss_real') 172 | 173 | d_loss_fake=sigmoid_cross_entropy_with_logits(D_logits_, tf.zeros_like(D_)) 174 | d_loss_fake=tf.div(tf.reduce_sum(d_loss_fake*fnmask), fnmask_num/FLAGS.dis_weights, name='d_loss_fake') 175 | 176 | g_loss=sigmoid_cross_entropy_with_logits(D_logits_, tf.ones_like(D_)) 177 | g_loss=tf.div(tf.reduce_sum(g_loss*fnmask), fnmask_num/FLAGS.gen_weights, name='g_loss') 178 | 179 | return d_loss_real,d_loss_fake,g_loss 180 | -------------------------------------------------------------------------------- /aher_common.py: -------------------------------------------------------------------------------- 1 | """Shared function between different AherNet implementations. 2 | """ 3 | import numpy as np 4 | import tensorflow as tf 5 | import tf_extended as tfe 6 | 7 | batch_size_fix=16 8 | 9 | # =========================================================================== # 10 | # TensorFlow implementation of boxes Aher encoding / decoding. 11 | # =========================================================================== # 12 | def tf_aher_bboxes_encode_layer(labels, 13 | bboxes, 14 | anchors_layer, 15 | num_classes, 16 | temporal_shape, 17 | matching_threshold=0.70, 18 | prior_scaling=[1.0,1.0], 19 | dtype=tf.float32): 20 | """Encode groundtruth labels and bounding boxes using 1D anchors from 21 | one layer. 22 | 23 | Arguments: 24 | labels: 1D Tensor(int32) containing groundtruth labels; 25 | bboxes: batch_size x N x 2 Tensor(float) with bboxes relative coordinates; 26 | anchors_layer: Numpy array with layer anchors; 27 | matching_threshold: Threshold for positive match with groundtruth bboxes; 28 | prior_scaling: Scaling of encoded coordinates. 29 | 30 | Return: 31 | (target_labels, target_localizations, target_scores): Target Tensors. 32 | """ 33 | # Anchors coordinates and volume. 34 | yref, tref= anchors_layer 35 | ymin = yref - tref / 2. 36 | ymax = yref + tref / 2. 37 | 38 | ymin = tf.maximum(0.0, ymin) 39 | ymax = tf.minimum(ymax, temporal_shape - 1.0) 40 | 41 | vol_anchors = ymax - ymin 42 | 43 | if batch_size_fix == 1: 44 | bboxes = tf.reshape(bboxes,[1,bboxes.shape[0],bboxes.shape[1]]) 45 | labels = tf.reshape(labels,[1]) 46 | 47 | # Initialize tensors... 48 | shape = (bboxes.shape[0], yref.shape[0], tref.size) 49 | s_shape = (yref.shape[0], tref.size) 50 | feat_scores = tf.zeros(shape, dtype=dtype) 51 | feat_max_iou = tf.zeros(shape, dtype=dtype) 52 | 53 | feat_ymin = tf.zeros(shape, dtype=dtype) 54 | feat_ymax = tf.ones(shape, dtype=dtype) 55 | mask_minus = feat_ymin - feat_ymax 56 | s_max_one = tf.ones(s_shape, dtype=tf.int32) 57 | s_max_one_f = tf.ones(s_shape, dtype=dtype) 58 | 59 | label_shape = (bboxes.shape[0], yref.shape[0], 1) 60 | s_label_shape = (yref.shape[0], 1) 61 | s_label_max_one = tf.ones(s_label_shape, dtype=tf.int32) 62 | feat_labels = tf.zeros(label_shape, dtype=tf.int32) 63 | 64 | 65 | def jaccard_with_anchors(bbox): 66 | """Compute jaccard score between a box and the anchors. 67 | """ 68 | int_ymin = tf.maximum(ymin, bbox[0]) 69 | int_ymax = tf.minimum(ymax, bbox[1]) 70 | t = tf.maximum(int_ymax - int_ymin, 0.) 71 | 72 | # Volumes. 73 | inter_vol = t 74 | union_vol = vol_anchors - inter_vol \ 75 | + (bbox[1] - bbox[0]) 76 | jaccard = tf.div(inter_vol, union_vol) 77 | return jaccard 78 | 79 | def intersection_with_anchors(bbox): 80 | """Compute intersection between score a box and the anchors. 81 | """ 82 | int_ymin = tf.maximum(ymin, bbox[0]) 83 | int_ymax = tf.minimum(ymax, bbox[1]) 84 | t = tf.maximum(int_ymax - int_ymin, 0.) 85 | inter_vol = t 86 | scores = tf.div(inter_vol, vol_anchors) 87 | return scores 88 | 89 | def condition(i,ik, feat_labels, feat_scores, 90 | feat_ymin, feat_ymax): 91 | """Condition: check label index. 92 | """ 93 | # remove unusable gt 94 | bbox = bboxes[i] 95 | bbox = tf.reshape(tf.boolean_mask(bbox,tf.not_equal([-1.0,-1.0],bbox)),[-1,2]) 96 | r = tf.less(ik, tf.shape(bbox)[0]) 97 | return r 98 | 99 | def body(batch_id, i, feat_labels, feat_scores, 100 | feat_ymin, feat_ymax): 101 | """Body: update feature labels, scores and bboxes. 102 | - assign values when jaccard > 0.5; 103 | - only update if beat the score of other bboxes. 104 | """ 105 | # Jaccard score. 106 | label = labels[batch_id] 107 | bbox = bboxes[batch_id] 108 | 109 | bbox = tf.reshape(tf.boolean_mask(bbox,tf.not_equal([-1.0,-1.0],bbox)),[-1,2])[i] 110 | 111 | jaccard = jaccard_with_anchors(bbox) 112 | # Mask: check threshold + scores + no annotations + num_classes. 113 | mask = tf.greater(jaccard, feat_scores) 114 | imask = tf.cast(mask, tf.int32) 115 | fmask = tf.cast(mask, dtype) 116 | # Update values using mask. 117 | 118 | #feat_labels = imask * label + (1 - imask) * feat_labels 119 | feat_labels = s_label_max_one * label 120 | 121 | feat_scores = tf.where(mask, jaccard, feat_scores) 122 | 123 | feat_ymin = fmask * bbox[0] + (1 - fmask) * feat_ymin 124 | feat_ymax = fmask * bbox[1] + (1 - fmask) * feat_ymax 125 | 126 | # Check no annotation label: ignore these anchors... 127 | # interscts = intersection_with_anchors(bbox) 128 | # mask = tf.logical_and(interscts > ignore_threshold, 129 | # label == no_annotation_label) 130 | # # Replace scores by -1. 131 | # feat_scores = tf.where(mask, -tf.cast(mask, dtype), feat_scores) 132 | 133 | return [batch_id, i+1, feat_labels, feat_scores,feat_ymin, feat_ymax] 134 | 135 | def batch_condition(i, feat_labels, feat_scores, 136 | feat_ymin, feat_ymax): 137 | r = tf.less(i, tf.shape(bboxes)[0]) 138 | return r 139 | 140 | def batch_body(i, feat_labels, feat_scores, 141 | feat_ymin,feat_ymax): 142 | ik = 0 143 | s_feat_labels = tf.zeros(s_label_shape, dtype=tf.int32) 144 | s_feat_scores = tf.zeros(s_shape, dtype=dtype) 145 | s_feat_ymin = tf.zeros(s_shape, dtype=dtype) 146 | s_feat_ymax = tf.ones(s_shape, dtype=dtype) 147 | [i, ik, s_feat_labels, s_feat_scores, s_feat_ymin,s_feat_ymax] = tf.while_loop(condition, body, 148 | [i, ik, s_feat_labels, s_feat_scores, s_feat_ymin, s_feat_ymax]) 149 | 150 | s_feat_labels = tf.reshape(s_feat_labels,[1,s_label_shape[0],s_label_shape[1]]) 151 | s_feat_scores = tf.reshape(s_feat_scores,[1,s_shape[0],s_shape[1]]) 152 | s_feat_ymin = tf.reshape(s_feat_ymin,[1,s_shape[0],s_shape[1]]) 153 | 154 | 155 | 156 | # labels 157 | if batch_size_fix != 1: 158 | s_feat_labels_1 = tf.zeros([i,s_label_shape[0],s_label_shape[1]], dtype=tf.int32) 159 | s_feat_labels_2 = tf.zeros([batch_size_fix-i-1,s_label_shape[0],s_label_shape[1]], dtype=tf.int32) 160 | s_feat_labels = tf.concat([s_feat_labels_1,s_feat_labels,s_feat_labels_2],0) 161 | feat_labels = feat_labels + s_feat_labels 162 | else: 163 | feat_labels = s_feat_labels 164 | 165 | # scores 166 | if batch_size_fix != 1: 167 | s_feat_scores_1 = tf.zeros([i,s_shape[0],s_shape[1]], dtype=dtype) 168 | s_feat_scores_2 = tf.zeros([batch_size_fix-i-1,s_shape[0],s_shape[1]], dtype=dtype) 169 | s_feat_scores = tf.concat([s_feat_scores_1,s_feat_scores,s_feat_scores_2],0) 170 | feat_scores = feat_scores + s_feat_scores 171 | else: 172 | feat_scores = s_feat_scores 173 | 174 | # ymin 175 | if batch_size_fix != 1: 176 | s_feat_ymin_1 = tf.zeros([i,s_shape[0],s_shape[1]], dtype=dtype) 177 | s_feat_ymin_2 = tf.zeros([batch_size_fix-i-1,s_shape[0],s_shape[1]], dtype=dtype) 178 | s_feat_ymin = tf.concat([s_feat_ymin_1,s_feat_ymin,s_feat_ymin_2],0) 179 | feat_ymin = feat_ymin + s_feat_ymin 180 | else: 181 | feat_ymin = s_feat_ymin 182 | 183 | # ymax 184 | if batch_size_fix != 1: 185 | s_feat_ymax = s_feat_ymax - s_max_one_f 186 | s_feat_ymax = tf.reshape(s_feat_ymax,[1,s_shape[0],s_shape[1]]) 187 | s_feat_ymax_1 = tf.zeros([i,s_shape[0],s_shape[1]], dtype=dtype) 188 | s_feat_ymax_2 = tf.zeros([batch_size_fix-i-1,s_shape[0],s_shape[1]], dtype=dtype) 189 | s_feat_ymax = tf.concat([s_feat_ymax_1,s_feat_ymax,s_feat_ymax_2],0) 190 | feat_ymax = feat_ymax + s_feat_ymax 191 | else: 192 | s_feat_ymax = tf.reshape(s_feat_ymax,[1,s_shape[0],s_shape[1]]) 193 | feat_ymax = s_feat_ymax 194 | 195 | #feat_labels = tf.concat([feat_labels,s_feat_labels],0) 196 | #feat_scores = tf.concat([feat_scores,s_feat_scores],0) 197 | #feat_ymin = tf.concat([feat_ymin,s_feat_ymin],0) 198 | #feat_ymax = tf.concat([feat_ymax,s_feat_ymax],0) 199 | #feat_labels[i] = s_feat_labels 200 | #feat_scores[i] = s_feat_scores 201 | #feat_ymin[i] = s_feat_ymin 202 | #feat_ymax[i] = s_feat_ymax 203 | return [i+1,feat_labels,feat_scores,feat_ymin,feat_ymax] 204 | 205 | # Main loop definition. 206 | i = 0 207 | #[i, feat_labels, feat_scores,feat_ymin, feat_ymax] = tf.while_loop(condition, body, 208 | # [i, feat_labels, feat_scores, feat_ymin,feat_ymax]) 209 | [i, feat_labels, feat_scores,feat_ymin, feat_ymax] = tf.while_loop(batch_condition, batch_body, 210 | [i, feat_labels, feat_scores, feat_ymin,feat_ymax]) 211 | 212 | #feat_labels = labels 213 | mask_zero = tf.equal(feat_scores,0.0) 214 | fmask_zero = tf.cast(mask_zero,dtype) 215 | feat_max_iou = fmask_zero * mask_minus + (1-fmask_zero)*feat_scores 216 | 217 | # Transform to center / size. 218 | feat_cy = (feat_ymax + feat_ymin) / 2. 219 | feat_t = feat_ymax - feat_ymin 220 | # Encode features. 221 | feat_cy = (feat_cy - yref) / tref / prior_scaling[0] 222 | feat_t = tf.log(feat_t / tref) / prior_scaling[1] 223 | # Use ordering: x / y / w / h instead of ours. 224 | feat_localizations = tf.stack([feat_cy, feat_t], axis=-1) 225 | 226 | return feat_labels, feat_localizations, feat_scores, feat_max_iou 227 | 228 | def tf_aher_bboxes_encode(labels, 229 | bboxes, 230 | anchors, 231 | num_classes, 232 | temporal_shape, 233 | matching_threshold=0.70, 234 | prior_scaling=[1.0,1.0], 235 | dtype=tf.float32, 236 | scope='aher_bboxes_encode'): 237 | """Encode groundtruth labels and bounding boxes using 1D net anchors. 238 | Encoding boxes for all feature layers. 239 | 240 | Arguments: 241 | labels: 1D Tensor(int32) containing groundtruth labels; 242 | bboxes: Nx2 Tensor(float) with bboxes relative coordinates; 243 | anchors: List of Numpy array with layer anchors; 244 | matching_threshold: Threshold for positive match with groundtruth bboxes; 245 | prior_scaling: Scaling of encoded coordinates. 246 | 247 | Return: 248 | (target_labels, target_localizations, target_scores): 249 | Each element is a list of target Tensors. 250 | """ 251 | with tf.name_scope(scope): 252 | target_labels = [] 253 | target_localizations = [] 254 | target_scores = [] 255 | target_iou = [] 256 | for i, anchors_layer in enumerate(anchors): 257 | with tf.name_scope('bboxes_encode_block_%i' % i): 258 | t_labels, t_loc, t_scores, t_iou = \ 259 | tf_aher_bboxes_encode_layer(labels, bboxes, anchors_layer, 260 | num_classes, temporal_shape, 261 | matching_threshold, 262 | prior_scaling, dtype) 263 | target_labels.append(t_labels) 264 | target_localizations.append(t_loc) 265 | target_scores.append(t_scores) 266 | target_iou.append(t_iou) 267 | return target_labels, target_localizations, target_scores, target_iou 268 | 269 | def tf_aher_bboxes_decode_layer(feat_localizations, 270 | duration, 271 | anchors_layer, 272 | prior_scaling=[1.0, 1.0]): 273 | """Compute the relative bounding boxes from the layer features and 274 | reference anchor bounding boxes. 275 | 276 | Arguments: 277 | feat_localizations: Tensor containing localization features. 278 | anchors: List of numpy array containing anchor boxes. 279 | 280 | Return: 281 | Tensor Nx2: ymin, ymax 282 | """ 283 | yref, tref = anchors_layer 284 | 285 | # Compute center, height and width 286 | cy = feat_localizations[:, :, :, 0] * tref * prior_scaling[0] + yref # tf.expand_dims(yref,-1) 287 | t = tref * tf.exp(feat_localizations[:, :, :, 1] * prior_scaling[1]) 288 | # Boxes coordinates. 289 | ymin = cy - t / 2. 290 | ymax = cy + t / 2. 291 | ymin = tf.maximum(ymin, 0.0) / 512.0 * duration 292 | ymax = tf.minimum(ymax, 512.0) / 512.0 * duration 293 | bboxes = tf.stack([ymin, ymax], axis=-1) 294 | return bboxes 295 | 296 | def tf_aher_bboxes_decode_logits_layer(feat_localizations, 297 | duration, 298 | anchors_layer, 299 | logitsprediction, 300 | prior_scaling=[1.0, 1.0]): 301 | """Compute the relative bounding boxes from the layer features and 302 | reference anchor bounding boxes. 303 | 304 | Arguments: 305 | feat_localizations: Tensor containing localization features. 306 | anchors: List of numpy array containing anchor boxes. 307 | 308 | Return: 309 | Tensor Nx2: ymin, ymax 310 | """ 311 | yref, tref = anchors_layer 312 | 313 | # Compute center, height and width 314 | cy = feat_localizations[:, :, :, 0] * tref * prior_scaling[0] + yref # tf.expand_dims(yref,-1) 315 | t = tref * tf.exp(feat_localizations[:, :, :, 1] * prior_scaling[1]) 316 | # Boxes coordinates. 317 | ymin = cy - t / 2. 318 | ymax = cy + t / 2. 319 | ymin = tf.maximum(ymin, 0.0) / 512.0 * duration 320 | ymax = tf.minimum(ymax, 512.0) / 512.0 * duration 321 | 322 | maxid = tf.argmax(logitsprediction,axis=-1) 323 | maxidtile = tf.tile(maxid,(1,1,3)) 324 | maxidtilef = tf.cast(maxidtile,tf.float32) 325 | 326 | maxscore = tf.reduce_max(logitsprediction,axis=-1) 327 | maxscoretile = tf.tile(maxscore,(1,1,3)) 328 | 329 | bboxes = tf.stack([ymin, ymax, maxidtilef, maxscoretile], axis=-1) 330 | 331 | return bboxes 332 | 333 | def tf_aher_bboxes_decode_detect_layer(feat_localizations, 334 | duration, 335 | anchors_layer, 336 | propredictions, 337 | prior_scaling=[1.0, 1.0]): 338 | """Compute the relative bounding boxes from the layer features and 339 | reference anchor bounding boxes. 340 | 341 | Arguments: 342 | feat_localizations: Tensor containing localization features. 343 | anchors: List of numpy array containing anchor boxes. 344 | 345 | 346 | Return: 347 | Tensor Nx2: ymin, ymax 348 | """ 349 | yref, tref = anchors_layer 350 | 351 | # Compute center, height and width 352 | cy = feat_localizations[:, :, :, 0] * tref * prior_scaling[0] + yref # tf.expand_dims(yref,-1) 353 | t = tref * tf.exp(feat_localizations[:, :, :, 1] * prior_scaling[1]) 354 | # Boxes coordinates. 355 | ymin = cy - t / 2. 356 | ymax = cy + t / 2. 357 | ymin = tf.maximum(ymin, 0.0) / 512.0 * duration 358 | ymax = tf.minimum(ymax, 512.0) / 512.0 * duration 359 | 360 | #maxid = tf.argmax(propredictions,axis=-1) 361 | #maxidtile = tf.tile(maxid,(1,1,3)) 362 | #maxidf = tf.cast(maxid,tf.float32) 363 | 364 | #maxscore = tf.reduce_max(propredictions,axis=-1) 365 | #maxscoretile = tf.tile(maxscore,(1,1,3)) 366 | 367 | bboxes = tf.stack([ymin, ymax], axis=-1) 368 | bboxes = tf.concat([bboxes, propredictions],axis=-1) 369 | 370 | return bboxes 371 | 372 | def tf_aher_bboxes_decode(feat_localizations, 373 | duration, 374 | anchors, 375 | prior_scaling=[1.0,1.0], 376 | scope='aher_bboxes_decode'): 377 | """Compute the relative bounding boxes from the net features and 378 | reference anchors bounding boxes. 379 | 380 | Arguments: 381 | feat_localizations: List of Tensors containing localization features. 382 | anchors: List of numpy array containing anchor boxes. 383 | 384 | Return: 385 | List of Tensors Nx2: ymin, ymax 386 | """ 387 | with tf.name_scope(scope): 388 | bboxes = [] 389 | for i, anchors_layer in enumerate(anchors): 390 | bboxes.append( 391 | tf_aher_bboxes_decode_layer(feat_localizations[i], 392 | duration, 393 | anchors_layer, 394 | prior_scaling)) 395 | return bboxes 396 | 397 | def tf_aher_bboxes_decode_logits(feat_localizations, 398 | duration, 399 | anchors, 400 | logitsprediction, 401 | prior_scaling=[1.0,1.0], 402 | scope='aher_bboxes_decode'): 403 | """Compute the relative bounding boxes from the net features and 404 | reference anchors bounding boxes. 405 | 406 | Arguments: 407 | feat_localizations: List of Tensors containing localization features. 408 | anchors: List of numpy array containing anchor boxes. 409 | 410 | Return: 411 | List of Tensors Nx2: ymin, ymax 412 | """ 413 | with tf.name_scope(scope): 414 | bboxes = [] 415 | for i, anchors_layer in enumerate(anchors): 416 | bboxes.append( 417 | tf_aher_bboxes_decode_logits_layer(feat_localizations[i], 418 | duration, 419 | anchors_layer, 420 | logitsprediction[i], 421 | prior_scaling)) 422 | return bboxes 423 | 424 | def tf_aher_bboxes_decode_detect(feat_localizations, 425 | duration, 426 | anchors, 427 | propprediction, 428 | prior_scaling=[1.0,1.0], 429 | scope='aher_bboxes_decode'): 430 | """Compute the relative bounding boxes from the net features and 431 | reference anchors bounding boxes. 432 | 433 | Arguments: 434 | feat_localizations: List of Tensors containing localization features. 435 | anchors: List of numpy array containing anchor boxes. 436 | 437 | Return: 438 | List of Tensors Nx2: ymin, ymax 439 | """ 440 | with tf.name_scope(scope): 441 | bboxes = [] 442 | for i, anchors_layer in enumerate(anchors): 443 | bboxes.append( 444 | tf_aher_bboxes_decode_detect_layer(feat_localizations[i], 445 | duration, 446 | anchors_layer, 447 | propprediction[i], 448 | prior_scaling)) 449 | return bboxes 450 | 451 | # =========================================================================== # 452 | # temporal boxes selection. 453 | # =========================================================================== # 454 | def tf_aher_bboxes_select_layer(predictions_layer, localizations_layer, 455 | select_threshold=None, 456 | num_classes=21, 457 | ignore_class=0, 458 | scope=None, 459 | IoU_flag=False): 460 | """Extract classes, scores and bounding boxes from features in one layer. 461 | Batch-compatible: inputs are supposed to have batch-type shapes. 462 | 463 | Args: 464 | predictions_layer: A prediction layer; 465 | localizations_layer: A localization layer; 466 | select_threshold: Classification threshold for selecting a box. All boxes 467 | under the threshold are set to 'zero'. If None, no threshold applied. 468 | Return: 469 | d_scores, d_bboxes: Dictionary of scores and bboxes Tensors of 470 | size Batches X N x 1 | 2. Each key corresponding to a class. 471 | """ 472 | select_threshold = 0.0 if select_threshold is None else select_threshold 473 | with tf.name_scope(scope, 'aher_bboxes_select_layer', 474 | [predictions_layer, localizations_layer]): 475 | # Reshape features: Batches x N x N_labels | 4 476 | p_shape = tfe.get_shape(predictions_layer) 477 | predictions_layer = tf.reshape(predictions_layer, 478 | tf.stack([p_shape[0], -1, p_shape[-1]])) 479 | if IoU_flag: 480 | zeros_m = tf.zeros([predictions_layer.shape[1],1]) 481 | predictions_layer = tf.reshape(tf.stack([zeros_m, 482 | tf.reshape(predictions_layer,[predictions_layer.shape[1],1])],axis=1), 483 | [predictions_layer.shape[0],predictions_layer.shape[1],2]) 484 | l_shape = tfe.get_shape(localizations_layer) 485 | localizations_layer = tf.reshape(localizations_layer, 486 | tf.stack([l_shape[0], -1, l_shape[-1]])) 487 | 488 | d_scores = {} 489 | d_bboxes = {} 490 | for c in range(0, num_classes): 491 | if c != ignore_class: 492 | # Remove boxes under the threshold. 493 | scores = predictions_layer[:, :, c] 494 | fmask = tf.cast(tf.greater_equal(scores, select_threshold), scores.dtype) 495 | scores = scores * fmask 496 | bboxes = localizations_layer * tf.expand_dims(fmask, axis=-1) 497 | # Append to dictionary. 498 | d_scores[c] = scores 499 | d_bboxes[c] = bboxes 500 | 501 | return d_scores, d_bboxes 502 | 503 | def tf_aher_bboxes_select(predictions_net, localizations_net, 504 | select_threshold=None, 505 | num_classes=21, 506 | ignore_class=0, 507 | scope=None, 508 | IoU_flag = False): 509 | """Extract classes, scores and bounding boxes from network output layers. 510 | Batch-compatible: inputs are supposed to have batch-type shapes. 511 | 512 | Args: 513 | predictions_net: List of prediction layers; 514 | localizations_net: List of localization layers; 515 | select_threshold: Classification threshold for selecting a box. All boxes 516 | under the threshold are set to 'zero'. If None, no threshold applied. 517 | Return: 518 | d_scores, d_bboxes: Dictionary of scores and bboxes Tensors of 519 | size Batches X N x 1 | 4. Each key corresponding to a class. 520 | """ 521 | with tf.name_scope(scope, 'aher_bboxes_select', 522 | [predictions_net, localizations_net]): 523 | l_scores = [] 524 | l_bboxes = [] 525 | for i in range(len(predictions_net)): 526 | scores, bboxes = tf_aher_bboxes_select_layer(predictions_net[i], 527 | localizations_net[i], 528 | select_threshold, 529 | num_classes, 530 | ignore_class, 531 | IoU_flag = IoU_flag) 532 | l_scores.append(scores) 533 | l_bboxes.append(bboxes) 534 | # Concat results. 535 | d_scores = {} 536 | d_bboxes = {} 537 | for c in l_scores[0].keys(): 538 | ls = [s[c] for s in l_scores] 539 | lb = [b[c] for b in l_bboxes] 540 | d_scores[c] = tf.concat(ls, axis=1) 541 | d_bboxes[c] = tf.concat(lb, axis=1) 542 | return d_scores, d_bboxes 543 | 544 | def tf_aher_bboxes_select_layer_all_classes(predictions_layer, localizations_layer, 545 | select_threshold=None): 546 | """Extract classes, scores and bounding boxes from features in one layer. 547 | Batch-compatible: inputs are supposed to have batch-type shapes. 548 | 549 | Args: 550 | predictions_layer: A prediction layer; 551 | localizations_layer: A localization layer; 552 | select_threshold: Classification threshold for selecting a box. If None, 553 | select boxes whose classification score is higher than 'no class'. 554 | Return: 555 | classes, scores, bboxes: Input Tensors. 556 | """ 557 | # Reshape features: Batches x N x N_labels | 4 558 | p_shape = tfe.get_shape(predictions_layer) 559 | predictions_layer = tf.reshape(predictions_layer, 560 | tf.stack([p_shape[0], -1, p_shape[-1]])) 561 | l_shape = tfe.get_shape(localizations_layer) 562 | localizations_layer = tf.reshape(localizations_layer, 563 | tf.stack([l_shape[0], -1, l_shape[-1]])) 564 | # Boxes selection: use threshold or score > no-label criteria. 565 | if select_threshold is None or select_threshold == 0: 566 | # Class prediction and scores: assign 0. to 0-class 567 | classes = tf.argmax(predictions_layer, axis=2) 568 | scores = tf.reduce_max(predictions_layer, axis=2) 569 | scores = scores * tf.cast(classes > 0, scores.dtype) 570 | else: 571 | sub_predictions = predictions_layer[:, :, 1:] 572 | classes = tf.argmax(sub_predictions, axis=2) + 1 573 | scores = tf.reduce_max(sub_predictions, axis=2) 574 | # Only keep predictions higher than threshold. 575 | mask = tf.greater(scores, select_threshold) 576 | classes = classes * tf.cast(mask, classes.dtype) 577 | scores = scores * tf.cast(mask, scores.dtype) 578 | # Assume localization layer already decoded. 579 | bboxes = localizations_layer 580 | return classes, scores, bboxes 581 | 582 | def tf_aher_bboxes_select_all_classes(predictions_net, localizations_net, 583 | select_threshold=None, 584 | scope=None): 585 | """Extract classes, scores and bounding boxes from network output layers. 586 | Batch-compatible: inputs are supposed to have batch-type shapes. 587 | 588 | Args: 589 | predictions_net: List of prediction layers; 590 | localizations_net: List of localization layers; 591 | select_threshold: Classification threshold for selecting a box. If None, 592 | select boxes whose classification score is higher than 'no class'. 593 | Return: 594 | classes, scores, bboxes: Tensors. 595 | """ 596 | with tf.name_scope(scope, 'aher_bboxes_select', 597 | [predictions_net, localizations_net]): 598 | l_classes = [] 599 | l_scores = [] 600 | l_bboxes = [] 601 | for i in range(len(predictions_net)): 602 | classes, scores, bboxes = \ 603 | tf_aher_bboxes_select_layer_all_classes(predictions_net[i], 604 | localizations_net[i], 605 | select_threshold) 606 | l_classes.append(classes) 607 | l_scores.append(scores) 608 | l_bboxes.append(bboxes) 609 | 610 | classes = tf.concat(l_classes, axis=1) 611 | scores = tf.concat(l_scores, axis=1) 612 | bboxes = tf.concat(l_bboxes, axis=1) 613 | return classes, scores, bboxes 614 | 615 | -------------------------------------------------------------------------------- /aher_k600_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | import numpy as np 4 | import aher_anet 5 | import tf_utils 6 | from datetime import datetime 7 | import time 8 | from data_loader import * 9 | import tf_extended as tfe 10 | import os 11 | import sys 12 | import pandas as pd 13 | from multiprocessing import Process,Queue 14 | import multiprocessing 15 | import math 16 | import random 17 | from evaluation import get_proposal_performance 18 | 19 | FLAGS = tf.app.flags.FLAGS 20 | os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' 21 | 22 | tf.app.flags.DEFINE_float('loss_alpha', 1., 'Alpha parameter in the loss function.') 23 | tf.app.flags.DEFINE_float('match_threshold', 0.65, 'Matching threshold in the loss function.') # 0.65 24 | tf.app.flags.DEFINE_float('neg_match_threshold', 0.35, 'Negative threshold in the loss function.') # 0.3 25 | tf.app.flags.DEFINE_float('negative_ratio', 1., 'Negative ratio in the loss function.') 26 | tf.app.flags.DEFINE_float('label_smoothing', 0.0, 'The amount of label smoothing.') 27 | tf.app.flags.DEFINE_integer('output_idx', 1, 'The output index.') 28 | tf.app.flags.DEFINE_integer('training_epochs',7,'The training epochs number') 29 | tf.app.flags.DEFINE_float('cls_weights', 2.0, 'The weight for the classification.') # 0.1 30 | tf.app.flags.DEFINE_float('iou_weights', 25.0, 'The weight for the iou prediction.') # 5.0 31 | tf.app.flags.DEFINE_float('initial_learning_rate', 0.0001,'Initial learning rate.') 32 | # Flags for box generation 33 | tf.app.flags.DEFINE_float('select_threshold', 0.0, 'Selection threshold.') 34 | tf.app.flags.DEFINE_float('nms_threshold', 0.90, 'Non-Maximum Selection threshold.') 35 | tf.app.flags.DEFINE_integer('select_top_k', 765, 'Select top-k detected bounding boxes.') 36 | tf.app.flags.DEFINE_integer('keep_top_k', 100, 'Keep top-k detected objects.') 37 | tf.app.flags.DEFINE_bool('cls_flag', False, 'Utilize classification score.') 38 | 39 | def AHER_init(): 40 | # AHER parameter and Model 41 | aher_anet_param = aher_anet.AHERNet.default_params 42 | aher_anet_model = aher_anet.AHERNet(aher_anet_param) 43 | aher_anet_temporal_shape = aher_anet_param.temporal_shape 44 | aher_anet_anchor = aher_anet_model.anchors(aher_anet_temporal_shape) 45 | return aher_anet_model,aher_anet_anchor 46 | 47 | def AHER_Predictor_Cls(aher_anet_model,aher_anet_anchor, 48 | feature,temporal_gt,vname,label,duration,reuse,class_num=600,cls_suffix='_anchor'): 49 | """ Model and loss function of sigle shot action localization 50 | feature: batch_size x 512 x 2048 51 | temporal_gt: batch_size x 25 x 2 52 | vname: batch_size x 1 53 | label: batch_size x 1 54 | duration: batch_size x 1 55 | """ 56 | 57 | # Encode groundtruth labels and bboxes. 58 | gclasses, glocalisations, gscores, giou = \ 59 | aher_anet_model.bboxes_encode(label, temporal_gt, aher_anet_anchor) 60 | 61 | # predict location and iou 62 | predictions, localisation, logits, proplogits, proppredictions, iouprediction, clsweights, clsbias, end_points = \ 63 | aher_anet_model.net_pool_cls(feature, num_classes= class_num, untrim_num=class_num, is_training=True,reuse=reuse,cls_suffix=cls_suffix) 64 | 65 | return predictions, localisation, logits, proplogits, iouprediction, \ 66 | clsweights, clsbias, gscores, giou, gclasses, glocalisations, end_points 67 | 68 | def AHER_Detection_Inference(aher_anet_model,aher_anet_anchor, 69 | feature,vname,label,duration, clsweights, clsbias, reuse, n_class,cls_suffix='_anet'): 70 | """ Inference bbox of sigle shot action localization 71 | feature: batch_size x 512 x 4069 72 | vname: batch_size x 1 73 | label: batch_size x 1 74 | duration: batch_size x 1 75 | """ 76 | 77 | predictions, localisation, logits, proplogits, proppredictions, iouprediction, end_points \ 78 | = aher_anet_model.net_prop_iou(feature,clsweights,clsbias,is_training=True,reuse=reuse,num_classes=n_class,cls_suffix=cls_suffix) 79 | 80 | # decode bounding box and get scores 81 | localisation = aher_anet_model.bboxes_decode_logits(localisation, duration ,aher_anet_anchor, predictions) 82 | if FLAGS.cls_flag: 83 | rscores, rbboxes = aher_anet_model.detected_bboxes_classwise( 84 | proppredictions, localisation, 85 | select_threshold=FLAGS.select_threshold, 86 | nms_threshold=FLAGS.nms_threshold, 87 | clipping_bbox=None, 88 | top_k=FLAGS.select_top_k, 89 | keep_top_k=FLAGS.keep_top_k, 90 | iou_flag=False) 91 | else: 92 | rscores, rbboxes = aher_anet_model.detected_bboxes_classwise( 93 | iouprediction, localisation, 94 | select_threshold=FLAGS.select_threshold, 95 | nms_threshold=FLAGS.nms_threshold, 96 | clipping_bbox=None, 97 | top_k=FLAGS.select_top_k, 98 | keep_top_k=FLAGS.keep_top_k, 99 | iou_flag=True) 100 | # compute pooling score 101 | lshape = tfe.get_shape(predictions[0], 8) 102 | num_classes = lshape[-1] 103 | batch_size = lshape[0] 104 | fprediction = [] 105 | for i in range(len(predictions)): 106 | fprediction.append(tf.reshape(predictions[i], [-1, num_classes])) 107 | predictions = tf.concat(fprediction, axis=0) 108 | avergeprediction = tf.reduce_mean(predictions,axis=0) 109 | labelid = tf.argmax(avergeprediction, 0) 110 | argmaxid = tf.argmax(predictions, 1) 111 | 112 | prebbox={"rscores":rscores,"rbboxes":rbboxes,"label":labelid,"avescore":avergeprediction, \ 113 | "rawscore":predictions,"argmaxid":argmaxid} 114 | return prebbox 115 | 116 | class Config(object): 117 | def __init__(self): 118 | self.learning_rates=[0.001]*100+[0.0001]*100 119 | #self.training_epochs = 15 120 | self.training_epochs = 12 121 | self.n_inputs = 2048 122 | self.batch_size = 8 123 | self.input_steps=512 124 | self.input_resize_steps=512 125 | self.gt_hold_num = 25 126 | self.batch_size_val=1 127 | 128 | if __name__ == "__main__": 129 | """ define the input and the network""" 130 | config = Config() 131 | LR= tf.placeholder(tf.float32) 132 | 133 | #--------------------------------------------Restore Folder and Feature Folder----------------------------------------------# 134 | eva_step_id = [8000,9000,10000] # The model id array to be evaluated 135 | output_folder_dir = 'p3d_k600_models/AHER_k600_gen_adv' # The restore folder to output localization results 136 | csv_dir = '/data/Kinetics-600/csv_p3d_clip_val_untrim_512' # validation clip feature for Kinetics-600 137 | csv_oris_dir = '' 138 | #------------------------------------------------------------------------------------------------------------------------------# 139 | 140 | 141 | 142 | 143 | #-------------------------------------------------------Test Placeholder-----------------------------------------------------# 144 | # train placehold 145 | feature_seg = tf.placeholder(tf.float32, shape=(config.batch_size,config.input_steps,config.n_inputs)) 146 | temporal_gt_seg = tf.placeholder(tf.float32, shape=(config.batch_size, config.gt_hold_num, 2)) 147 | vname_seg = tf.placeholder(tf.string, shape=(config.batch_size)) 148 | label_seg = tf.placeholder(tf.int32, shape=(config.batch_size)) 149 | duration_seg = tf.placeholder(tf.float32,shape=(config.batch_size)) 150 | 151 | # val placehold 152 | feature_val = tf.placeholder(tf.float32, shape=(config.batch_size_val,config.input_steps,config.n_inputs)) 153 | temporal_gt_val = tf.placeholder(tf.float32, shape=(config.batch_size_val, config.gt_hold_num, 2)) 154 | vname_val = tf.placeholder(tf.string, shape=(config.batch_size_val)) 155 | label_val = tf.placeholder(tf.int32, shape=(config.batch_size_val)) 156 | duration_val = tf.placeholder(tf.float32,shape=(config.batch_size_val)) 157 | #------------------------------------------------------------------------------------------------------------------------------# 158 | 159 | 160 | 161 | 162 | #----------------------------------------------- AherNet Structure ------------------------------------------------------------# 163 | aher_anet_model,aher_anet_anchor = AHER_init() 164 | # Initialize the backbone 165 | predictions_gene_seg, localisation_gene_seg, logits_gene_seg, proplogits_gene_seg, iouprediction_gene_seg, \ 166 | clsweights_ki, clsbias_ki, gscore_gene_seg, \ 167 | giou_gene_seg, gclasses_gene_seg, glocalisations_gene_seg, \ 168 | end_points_gene_seg \ 169 | = AHER_Predictor_Cls(aher_anet_model,aher_anet_anchor,feature_seg, \ 170 | temporal_gt_seg,vname_seg,label_seg,duration_seg,reuse=False,class_num=600,cls_suffix='_anchor') 171 | 172 | # localization inference 173 | bbox = AHER_Detection_Inference(aher_anet_model,aher_anet_anchor,feature_val, \ 174 | vname_val,label_val,duration_val, clsweights_ki, clsbias_ki, reuse=True, n_class=600,cls_suffix='_anchor') 175 | #------------------------------------------------------------------------------------------------------------------------------# 176 | 177 | 178 | 179 | 180 | AHER_trainable_variables=tf.trainable_variables() 181 | 182 | """ Init tf""" 183 | model_saver=tf.train.Saver(var_list=AHER_trainable_variables,max_to_keep=80) 184 | full_vars = [] 185 | tf_config = tf.ConfigProto() 186 | tf_config.gpu_options.allow_growth = True 187 | tf_config.log_device_placement =True 188 | sess=tf.InteractiveSession(config=tf_config) 189 | 190 | 191 | 192 | #---------------------------------------------------- Test Data Input List -------------------------------------------------------# 193 | data_generate_val = KineticsDatasetLoadTrimList('val_moment',csv_dir,csv_oris_dir,25,512,df_file='./cvs_records_k600/k600_train_val_info.csv') 194 | print('Load train feature, wait...') 195 | while True: 196 | if len(data_generate_val.train_feature_list) == 6459: 197 | data_generate_val.stop_process() 198 | time.sleep(30) 199 | break 200 | print('Wait done.') 201 | print('The val feature len is %d'%(len(data_generate_val.train_feature_list))) 202 | dataset_val = tf.data.Dataset.from_generator(data_generate_val.gen, 203 | (tf.float32, tf.float32, tf.string, tf.int32, tf.float32), 204 | (tf.TensorShape([512, 2048]),tf.TensorShape([25, 2]),tf.TensorShape([]),tf.TensorShape([]),tf.TensorShape([]))) 205 | dataset_val = dataset_val.batch(config.batch_size_val) 206 | batch_num_val = int(len(data_generate_val.train_feature_list) / config.batch_size_val) 207 | iterator_val = dataset_val.make_one_shot_iterator() 208 | feature_g_val, \ 209 | video_gt_g_val, \ 210 | video_name_g_val, \ 211 | video_label_g_val, \ 212 | video_duration_g_val = iterator_val.get_next() 213 | #---------------------------------------------------------------------------------------------------------------------------------# 214 | 215 | 216 | 217 | #---------------------------------------------------- Category Information --------------------------------------------------------# 218 | # load category information 219 | cateidx_name = {} 220 | category_info = pd.read_csv('k600_category.txt',sep='\n',header=None) 221 | for i in range(len(category_info)): 222 | name = category_info.loc[i][0] 223 | cateidx_name[i] = name 224 | #---------------------------------------------------------------------------------------------------------------------------------------# 225 | 226 | 227 | 228 | 229 | #---------------------------------------------------- Detection Result Extraction --------------------------------------------------------# 230 | fw_log = open('%s/log_res_k600_gene_dec.txt'%(output_folder_dir),'w',1) 231 | with tf.Session() as sess: 232 | tf.global_variables_initializer().run() 233 | tf.local_variables_initializer().run() 234 | for epoch in eva_step_id: 235 | print('Restore model %s/aher_adv_model_checkpoint-step_%d'%(output_folder_dir,epoch)) 236 | model_saver.restore(sess,"%s/aher_adv_model_checkpoint-step_%d"%(output_folder_dir,epoch)) 237 | """ Validation""" 238 | hit_video_num = 0 239 | total_video_num = 0 240 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + \ 241 | "----Epoch-%d Val set validata start." %(epoch)) 242 | fw_proposal = open('%s/results_aher_k600_gendec_step-%d.json'%(output_folder_dir,epoch),'w') 243 | fw_proposal.write('{\"version\": \"VERSION 1.3\", \"results\": {') 244 | for idx in range(batch_num_val): 245 | feature_batch_val, \ 246 | video_gt_batch_val, \ 247 | video_name_batch_val, \ 248 | video_label_batch_val, \ 249 | video_duration_batch_val = sess.run([feature_g_val,video_gt_g_val,video_name_g_val,video_label_g_val,video_duration_g_val]) 250 | out_bbox=sess.run(bbox,feed_dict={feature_val:feature_batch_val, 251 | vname_val:video_name_batch_val, 252 | label_val:video_label_batch_val, 253 | duration_val:video_duration_batch_val}) 254 | predict_score = out_bbox['rscores'] 255 | predict_bbox = out_bbox['rbboxes'] 256 | average_cls_score = out_bbox['avescore'] 257 | cls_id = out_bbox['label'] 258 | raw_score = out_bbox['rawscore'] 259 | argmaxid = out_bbox['argmaxid'] 260 | real_label = video_label_batch_val[0] 261 | if real_label == cls_id: hit_video_num += 1 262 | total_video_num += 1 263 | cateids = 1 264 | write_first = 0 265 | # write the json file 266 | if idx!=0: fw_proposal.write(', ') 267 | fw_proposal.write('\"%s\":['%(video_name_batch_val[0].decode("utf-8"))) 268 | for kk in range(len(predict_score[1][0])): 269 | # compute the ranking score for each localization result 270 | score = predict_score[cateids][0][kk] * average_cls_score[cls_id] 271 | start_time = predict_bbox[cateids][0][kk][0] 272 | end_time = predict_bbox[cateids][0][kk][1] 273 | if score > 0 and end_time - start_time > 0: 274 | if write_first == 0: 275 | fw_proposal.write('{\"score\": %f, \"segment\": [%f, %f], \"label\": \"%s\"}'% \ 276 | (score,start_time,end_time,cateidx_name[cls_id])) # 277 | else: 278 | fw_proposal.write(', {\"score\": %f, \"segment\": [%f, %f], \"label\": \"%s\"}'% \ 279 | (score,start_time,end_time,cateidx_name[cls_id])) 280 | write_first += 1 281 | fw_proposal.write(']') 282 | fw_proposal.write('}, \"external_data\": {}}') 283 | fw_proposal.close() 284 | # evaluation for temporal action proposal 285 | an_ar,recall = get_proposal_performance.evaluate_return_area('data_split/k600_val_action.json', \ 286 | '%s/results_aher_k600_gendec_step-%d.json'%(output_folder_dir,epoch)) 287 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + \ 288 | "----Epoch-%d Val set validata finished." %(epoch)) 289 | accuracy_video = hit_video_num / total_video_num 290 | 291 | print('*********************************************************') 292 | print("Epoch-%d Val AUC: %.04f AverageRecall: %04f Accuracy: %04f" %(epoch,an_ar,recall,accuracy_video)) 293 | print('*********************************************************') 294 | fw_log.write('*********************************************************\n') 295 | fw_log.write("Epoch-%d Val AUC: %.04f AverageRecall: %04f Accuracy: %04f\n" %(epoch,an_ar,recall,accuracy_video)) 296 | fw_log.write('*********************************************************\n') 297 | 298 | fw_log.close() 299 | -------------------------------------------------------------------------------- /custom_layers.py: -------------------------------------------------------------------------------- 1 | """Implement some custom layers, not provided by TensorFlow. 2 | 3 | Trying to follow as much as possible the style/standards used in 4 | tf.contrib.layers 5 | """ 6 | import tensorflow as tf 7 | 8 | from tensorflow.contrib.framework.python.ops import add_arg_scope 9 | from tensorflow.contrib.layers.python.layers import initializers 10 | from tensorflow.contrib.framework.python.ops import variables 11 | from tensorflow.contrib.layers.python.layers import utils 12 | from tensorflow.python.ops import nn 13 | from tensorflow.python.ops import init_ops 14 | from tensorflow.python.ops import variable_scope 15 | 16 | 17 | def abs_smooth(x): 18 | """Smoothed absolute function. Useful to compute an L1 smooth error. 19 | 20 | Define as: 21 | x^2 / 2 if abs(x) < 1 22 | abs(x) - 0.5 if abs(x) > 1 23 | We use here a differentiable definition using min(x) and abs(x). Clearly 24 | not optimal, but good enough for our purpose! 25 | """ 26 | absx = tf.abs(x) 27 | minx = tf.minimum(absx, 1) 28 | r = 0.5 * ((absx - 1) * minx + absx) 29 | return r 30 | 31 | def single_hinge_loss(x,label,hinge=1.0): 32 | hinge_value = tf.constant(hinge) 33 | mini_hinge = tf.constant(0.0) 34 | r = tf.maximum(hinge_value - label * x, mini_hinge) 35 | return r 36 | 37 | @add_arg_scope 38 | def l2_normalization( 39 | inputs, 40 | scaling=False, 41 | scale_initializer=init_ops.ones_initializer(), 42 | reuse=None, 43 | variables_collections=None, 44 | outputs_collections=None, 45 | data_format='NHWC', 46 | trainable=True, 47 | scope=None): 48 | """Implement L2 normalization on every feature (i.e. spatial normalization). 49 | 50 | Should be extended in some near future to other dimensions, providing a more 51 | flexible normalization framework. 52 | 53 | Args: 54 | inputs: a 4-D tensor with dimensions [batch_size, height, width, channels]. 55 | scaling: whether or not to add a post scaling operation along the dimensions 56 | which have been normalized. 57 | scale_initializer: An initializer for the weights. 58 | reuse: whether or not the layer and its variables should be reused. To be 59 | able to reuse the layer scope must be given. 60 | variables_collections: optional list of collections for all the variables or 61 | a dictionary containing a different list of collection per variable. 62 | outputs_collections: collection to add the outputs. 63 | data_format: NHWC or NCHW data format. 64 | trainable: If `True` also add variables to the graph collection 65 | `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). 66 | scope: Optional scope for `variable_scope`. 67 | Returns: 68 | A `Tensor` representing the output of the operation. 69 | """ 70 | 71 | with variable_scope.variable_scope( 72 | scope, 'L2Normalization', [inputs], reuse=reuse) as sc: 73 | inputs_shape = inputs.get_shape() 74 | inputs_rank = inputs_shape.ndims 75 | dtype = inputs.dtype.base_dtype 76 | if data_format == 'NHWC': 77 | # norm_dim = tf.range(1, inputs_rank-1) 78 | norm_dim = tf.range(inputs_rank-1, inputs_rank) 79 | params_shape = inputs_shape[-1:] 80 | elif data_format == 'NCHW': 81 | # norm_dim = tf.range(2, inputs_rank) 82 | norm_dim = tf.range(1, 2) 83 | params_shape = (inputs_shape[1]) 84 | 85 | # Normalize along spatial dimensions. 86 | outputs = nn.l2_normalize(inputs, norm_dim, epsilon=1e-12) 87 | # Additional scaling. 88 | if scaling: 89 | scale_collections = utils.get_variable_collections( 90 | variables_collections, 'scale') 91 | scale = variables.model_variable('gamma', 92 | shape=params_shape, 93 | dtype=dtype, 94 | initializer=scale_initializer, 95 | collections=scale_collections, 96 | trainable=trainable) 97 | if data_format == 'NHWC': 98 | outputs = tf.multiply(outputs, scale) 99 | elif data_format == 'NCHW': 100 | scale = tf.expand_dims(scale, axis=-1) 101 | scale = tf.expand_dims(scale, axis=-1) 102 | outputs = tf.multiply(outputs, scale) 103 | # outputs = tf.transpose(outputs, perm=(0, 2, 3, 1)) 104 | 105 | return utils.collect_named_outputs(outputs_collections, 106 | sc.original_name_scope, outputs) 107 | 108 | @add_arg_scope 109 | def pad2d(inputs, 110 | pad=(0, 0), 111 | mode='CONSTANT', 112 | data_format='NHWC', 113 | trainable=True, 114 | scope=None): 115 | """2D Padding layer, adding a symmetric padding to H and W dimensions. 116 | 117 | Aims to mimic padding in Caffe and MXNet, helping the port of models to 118 | TensorFlow. Tries to follow the naming convention of `tf.contrib.layers`. 119 | 120 | Args: 121 | inputs: 4D input Tensor; 122 | pad: 2-Tuple with padding values for H and W dimensions; 123 | mode: Padding mode. C.f. `tf.pad` 124 | data_format: NHWC or NCHW data format. 125 | """ 126 | with tf.name_scope(scope, 'pad2d', [inputs]): 127 | # Padding shape. 128 | if data_format == 'NHWC': 129 | paddings = [[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]] 130 | elif data_format == 'NCHW': 131 | paddings = [[0, 0], [0, 0], [pad[0], pad[0]], [pad[1], pad[1]]] 132 | net = tf.pad(inputs, paddings, mode=mode) 133 | return net 134 | 135 | @add_arg_scope 136 | def channel_to_last(inputs, 137 | data_format='NHWC', 138 | scope=None): 139 | """Move the channel axis to the last dimension. Allows to 140 | provide a single output format whatever the input data format. 141 | 142 | Args: 143 | inputs: Input Tensor; 144 | data_format: NHWC or NCHW. 145 | Return: 146 | Input in NHWC format. 147 | """ 148 | with tf.name_scope(scope, 'channel_to_last', [inputs]): 149 | if data_format == 'NHWC': 150 | net = inputs 151 | elif data_format == 'NCHW': 152 | net = tf.transpose(inputs, perm=(0, 2, 3, 1)) 153 | return net 154 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import pandas as pd 4 | import json 5 | import scipy 6 | import time 7 | import tensorflow as tf 8 | from multiprocessing import Process,Queue,JoinableQueue 9 | import multiprocessing 10 | 11 | x_tdim = 512 12 | fea_dim_th = 2048 13 | 14 | def pool_fea_interpolate_th(fea_temp,feature_count,duration,fnumber,resize_temporal_dim=100,pool_type="mean"): 15 | 16 | num_bin = 1 17 | num_sample_bin=3 18 | 19 | num_prop = resize_temporal_dim 20 | video_frame = fnumber 21 | video_second = duration 22 | 23 | data = fea_temp[0:feature_count] 24 | 25 | 26 | feature_frame=feature_count*8 27 | corrected_second=float(feature_frame)/video_frame*video_second 28 | fps=float(video_frame)/video_second 29 | st=8/fps 30 | 31 | if feature_count==1: 32 | video_feature=np.stack([data]*num_prop) 33 | video_feature=np.reshape(video_feature,[num_prop,fea_dim_th]) 34 | return video_feature 35 | 36 | x=[st/2+ii*st for ii in range(feature_count)] 37 | f=scipy.interpolate.interp1d(x,data,axis=0) 38 | 39 | video_feature=[] 40 | zero_sample=np.zeros(num_bin*fea_dim_th) 41 | tmp_anchor_xmin=[1.0/num_prop*i for i in range(num_prop)] 42 | tmp_anchor_xmax=[1.0/num_prop*i for i in range(1,num_prop+1)] 43 | 44 | num_sample=num_bin*num_sample_bin 45 | for idx in range(num_prop): 46 | xmin=max(x[0]+0.0001,tmp_anchor_xmin[idx]*corrected_second) 47 | xmax=min(x[-1]-0.0001,tmp_anchor_xmax[idx]*corrected_second) 48 | if xmaxx[-1]: 52 | video_feature.append(zero_sample) 53 | continue 54 | 55 | plen=(xmax-xmin)/(num_sample-1) 56 | x_new=[xmin+plen*ii for ii in range(num_sample)] 57 | y_new=f(x_new) 58 | y_new_pool=[] 59 | for b in range(num_bin): 60 | tmp_y_new=y_new[num_sample_bin*b:num_sample_bin*(b+1)] 61 | if pool_type=="mean": 62 | tmp_y_new=np.mean(y_new,axis=0) 63 | elif pool_type=="max": 64 | tmp_y_new=np.max(y_new,axis=0) 65 | y_new_pool.append(tmp_y_new) 66 | y_new_pool=np.stack(y_new_pool) 67 | y_new_pool=np.reshape(y_new_pool,[-1]) 68 | video_feature.append(y_new_pool) 69 | video_feature=np.stack(video_feature) 70 | return video_feature 71 | 72 | def load_json(file): 73 | with open(file) as json_file: 74 | data = json.load(json_file) 75 | return data 76 | 77 | def getDatasetDict(df_file='./cvs_records/anet_train_val_info.csv'): 78 | """Load dataset file 79 | """ 80 | df=pd.read_csv(df_file) 81 | json_data= load_json("./cvs_records_anet/anet_annotation_json.json") 82 | database=json_data 83 | train_dict={} 84 | val_dict={} 85 | test_dict={} 86 | miss_count = 0 87 | for i in range(len(df)): 88 | video_name=df.video.values[i] 89 | if video_name not in database.keys(): 90 | miss_count += 1 91 | continue 92 | video_info=database[video_name] 93 | video_new_info={} 94 | video_new_info['duration_frame']=video_info['duration_frame'] 95 | video_new_info['duration_second']=video_info['duration_second'] 96 | video_new_info["feature_frame"]=video_info['feature_frame'] 97 | video_subset=df.subset.values[i] 98 | video_new_info['annotations']=video_info['annotations'] 99 | if video_subset=="training": 100 | train_dict[video_name]=video_new_info 101 | elif video_subset=="validation": 102 | val_dict[video_name]=video_new_info 103 | elif video_subset=="testing": 104 | test_dict[video_name]=video_new_info 105 | print('Miss video annotation: %d'%(miss_count)) 106 | return train_dict,val_dict,test_dict 107 | 108 | def getDatasetDictK600(df_file='./cvs_records_k600/k600_train_val_info.csv'): 109 | """Load dataset file 110 | """ 111 | df=pd.read_csv(df_file) 112 | json_data= load_json("./cvs_records_k600/k600_annotation_json.json") 113 | database=json_data 114 | train_dict={} 115 | val_dict={} 116 | test_dict={} 117 | miss_count = 0 118 | for i in range(len(df)): 119 | video_name=df.video.values[i] 120 | if video_name not in database.keys(): 121 | miss_count += 1 122 | continue 123 | video_info=database[video_name] 124 | video_new_info={} 125 | video_new_info['duration_frame']=video_info['duration_frame'] 126 | video_new_info['duration_second']=video_info['duration_second'] 127 | video_new_info["feature_frame"]=video_info['feature_frame'] 128 | video_subset=df.subset.values[i] 129 | video_new_info['annotations']=video_info['annotations'] 130 | if video_subset=="training": 131 | train_dict[video_name]=video_new_info 132 | elif video_subset=="validation": 133 | val_dict[video_name]=video_new_info 134 | elif video_subset=="testing": 135 | test_dict[video_name]=video_new_info 136 | print('Miss video annotation: %d'%(miss_count)) 137 | return train_dict,val_dict,test_dict 138 | 139 | def getSplitsetDict(df_file='./cvs_records/anet_train_val_info.csv'): 140 | """Load dataset file 141 | ActivityNet: 142 | untrim set class number: 87 143 | moment set class number: 113 144 | """ 145 | cate_info = pd.read_csv('activitynet_train_val_gt.txt',sep = ' ',header = None) 146 | cate_dict = {} 147 | for v_id in range(len(cate_info)): 148 | v = cate_info.loc[v_id][0] 149 | cate = cate_info.loc[v_id][2] 150 | cate_dict[v] = cate 151 | 152 | moment_info = pd.read_csv('data_split/label_map_moment.txt',sep = '\n',header = None) 153 | moment_cate = [] 154 | for idm in range(len(moment_info)): 155 | moment_cate.append(moment_info.loc[idm][0]) 156 | 157 | untrim_info = pd.read_csv('data_split/label_map_untrim.txt',sep = '\n',header = None) 158 | untrim_cate = [] 159 | for idm in range(len(untrim_info)): 160 | untrim_cate.append(untrim_info.loc[idm][0]) 161 | 162 | df=pd.read_csv("./cvs_records/anet_train_val_info.csv") 163 | json_data= load_json("./cvs_records/anet_annotation_json.json") 164 | database=json_data 165 | train_untrim_dict={} 166 | train_moment_dict={} 167 | val_untrim_dict={} 168 | val_moment_dict={} 169 | miss_count = 0 170 | for i in range(len(df)): 171 | video_name=df.video.values[i] 172 | if video_name not in database.keys(): 173 | miss_count += 1 174 | continue 175 | video_info=database[video_name] 176 | video_cate=cate_dict[video_name] 177 | video_new_info={} 178 | video_new_info['duration_frame']=video_info['duration_frame'] 179 | video_new_info['duration_second']=video_info['duration_second'] 180 | video_new_info["feature_frame"]=video_info['feature_frame'] 181 | video_subset=df.subset.values[i] 182 | video_new_info['annotations']=video_info['annotations'] 183 | 184 | if video_subset=="training" and video_cate in untrim_cate: 185 | train_untrim_dict[video_name]=video_new_info 186 | if video_subset=="training" and video_cate in moment_cate: 187 | train_moment_dict[video_name]=video_new_info 188 | if video_subset=="validation" and video_cate in untrim_cate: 189 | val_untrim_dict[video_name]=video_new_info 190 | if video_subset=="validation" and video_cate in moment_cate: 191 | val_moment_dict[video_name]=video_new_info 192 | 193 | print('Miss video annotation: %d'%(miss_count)) 194 | return train_untrim_dict,train_moment_dict,val_untrim_dict,val_moment_dict 195 | 196 | class AnetDatasetLoadForeQueue(): 197 | def __init__(self,dataSet,csv_dir,csv_oris_dir,gt_hold_num,x_tdim,df_file='./cvs_records_anet/anet_train_val_reduce_info.csv',untrim_start=87,resize_dim=512): 198 | self.train_dict,self.val_dict,self.test_dict=getDatasetDict(df_file=df_file) 199 | if dataSet == 'train': 200 | self.train_list = list(self.train_dict.keys()) 201 | else: 202 | self.train_list = list(self.val_dict.keys()) 203 | self.train_list_copy = self.train_list 204 | cate_info = pd.read_csv('activitynet_train_val_gt.txt',sep = ' ',header = None) 205 | self.cate_dict = {} 206 | for v_id in range(len(cate_info)): 207 | v = cate_info.loc[v_id][0] 208 | cate = cate_info.loc[v_id][2] 209 | self.cate_dict[v] = cate 210 | self.shuffle() 211 | self.resize_dim = resize_dim 212 | self.dataSet = dataSet 213 | self.csv_dir = csv_dir 214 | self.csv_oris_dir = csv_oris_dir 215 | self.gt_hold_num = gt_hold_num 216 | self.x_tdim = x_tdim 217 | self.num_samples = len(self.train_list) 218 | self.batch_size = 1 219 | self.queue = Queue(maxsize=1536) #Joinable 220 | self.train_feature_list = multiprocessing.Manager().list() 221 | self.process_num = 32 222 | self.process_array = [] 223 | self.start_train_idx = 0 224 | for pid in range(self.process_num): 225 | pid_num = int(self.num_samples / self.process_num) 226 | start_idx = pid*pid_num 227 | end_idx = (pid+1)*pid_num 228 | if pid == (self.process_num - 1): 229 | if end_idx < self.num_samples: end_idx = self.num_samples 230 | t = Process(target=self.load_queue,args=(start_idx,end_idx)) 231 | self.process_array.append(t) 232 | t.start() 233 | def shuffle(self): 234 | randperm = np.random.permutation(len(self.train_list)) 235 | self.train_list = [self.train_list_copy[int(randperm[idx])] for idx in range(len(randperm))] 236 | def load_single_feature(self,start): 237 | result = [] 238 | name = self.train_list[start] 239 | s_bbox = [] 240 | r_bbox = [] 241 | video_info=self.train_dict[name] 242 | video_frame=video_info['duration_frame'] 243 | video_second=video_info['duration_second'] 244 | feature_frame=video_info['feature_frame'] 245 | feature_num = int(feature_frame / 8) 246 | corrected_second=float(feature_frame)/video_frame*video_second 247 | video_labels=video_info['annotations'] 248 | vlabels = self.cate_dict[name] 249 | for j in range(len(video_labels)): 250 | tmp_info=video_labels[j] 251 | tmp_start=tmp_info['segment'][0] 252 | tmp_end=tmp_info['segment'][1] 253 | tmp_start=max(min(1,tmp_start/corrected_second),0) 254 | tmp_end=max(min(1,tmp_end/corrected_second),0) 255 | tmp_fea_start = tmp_start*self.x_tdim 256 | tmp_fea_end = tmp_end*self.x_tdim 257 | s_bbox.append([tmp_fea_start,tmp_fea_end]) 258 | tmp_len = tmp_end - tmp_start 259 | if tmp_len*feature_num >= 32: r_bbox.append([tmp_start,tmp_end]) 260 | 261 | tmp_df=pd.read_csv(self.csv_oris_dir+"/"+name+".csv",skiprows=1,header=None,sep=',') 262 | tmp_values = tmp_df.values[:,:] 263 | for m in range(len(r_bbox)): 264 | tmp_start_fea,tmp_end_fea = int(r_bbox[m][0]*feature_num),int(r_bbox[m][1]*feature_num) 265 | proposal_fea = tmp_values[tmp_start_fea:tmp_end_fea,:] 266 | proposal_duration,proposal_frame = (r_bbox[m][1]-r_bbox[m][0])*video_second,(r_bbox[m][1]-r_bbox[m][0])*video_frame 267 | proposal_inter_fea = pool_fea_interpolate_th(proposal_fea,len(proposal_fea),proposal_duration,proposal_frame,self.resize_dim) 268 | result.append([proposal_inter_fea,[[128.0,384.0]],name,vlabels,proposal_duration*2]) 269 | return result 270 | def load_queue(self,start,end): 271 | bindex = 0 272 | lendata = end - start 273 | while True: 274 | t = self.load_single_feature(bindex+start) 275 | if self.dataSet == 'train_untrim' or \ 276 | self.dataSet == 'train_momwhole' or \ 277 | self.dataSet == 'val_untrim' or \ 278 | self.dataSet == 'val_moment': 279 | self.queue.put(t) 280 | else: 281 | for z in range(len(t)): 282 | self.queue.put(t[z]) 283 | bindex += 1 284 | if bindex == lendata: 285 | bindex = 0 286 | def gen(self): 287 | while True: 288 | t = self.queue.get() 289 | yield t[0],t[1],t[2],t[3],t[4] 290 | def stop_process(self): 291 | for t in self.process_array: 292 | t.terminate() 293 | t.join() 294 | print('Stop the process sucess.') 295 | 296 | class KineticsDatasetLoadTrimMultiProcess(): 297 | def __init__(self,dataSet,csv_dir,csv_oris_dir,gt_hold_num,x_tdim,df_file='./cvs_records_k600/k600_train_val_info.csv',resize_dim=256,ratio_sid=1,ratio_eid=9): 298 | train_video_dict, val_video_dict, test_video_dict=getDatasetDictK600(df_file=df_file) 299 | if "train" in dataSet: 300 | self.train_dict=train_video_dict 301 | else: 302 | self.train_dict=val_video_dict 303 | self.dataSet = dataSet 304 | self.resize_dim = resize_dim 305 | self.ratio_sid = ratio_sid 306 | self.ratio_eid = ratio_eid 307 | self.train_list = list(self.train_dict.keys()) 308 | self.train_list_copy = self.train_list 309 | print('The %s video num is: %d'%(self.dataSet,len(self.train_list))) 310 | print('Load k600 train val gt info...') 311 | cate_info = pd.read_csv('k600_train_val_gt.txt',sep = ' ',header = None) 312 | self.cate_dict = {} 313 | for v_id in range(len(cate_info)): 314 | v = cate_info.loc[v_id][0] 315 | cate = cate_info.loc[v_id][2] 316 | self.cate_dict[v] = cate 317 | print('Load Done.') 318 | self.shuffle() 319 | self.csv_dir = csv_dir 320 | self.csv_oris_dir = csv_oris_dir 321 | self.gt_hold_num = gt_hold_num 322 | self.x_tdim = x_tdim 323 | self.num_samples = len(self.train_list) 324 | self.batch_size = 1 325 | #self.train_feature_list = multiprocessing.Manager().list() 326 | self.process_num = 32 327 | self.process_array = [] 328 | self.queue = Queue(maxsize=1536) #Joinable 329 | self.start_train_idx = 0 330 | for pid in range(self.process_num): 331 | pid_num = int(self.num_samples / self.process_num) 332 | start_idx = pid*pid_num 333 | end_idx = (pid+1)*pid_num 334 | if pid == (self.process_num - 1): 335 | if end_idx < self.num_samples: end_idx = self.num_samples 336 | t = Process(target=self.load_queue,args=(start_idx,end_idx)) 337 | self.process_array.append(t) 338 | t.start() 339 | def shuffle(self): 340 | randperm = np.random.permutation(len(self.train_list)) 341 | self.train_list = [self.train_list_copy[int(randperm[idx])] for idx in range(len(randperm))] 342 | def load_single_feature(self,start): 343 | result = [] 344 | name = self.train_list[start] 345 | s_bbox = [] 346 | r_bbox = [] 347 | label_bbox=[] 348 | slabel_box = [] 349 | rlabel_bbox = [] 350 | video_info=self.train_dict[name] 351 | video_frame=video_info['duration_frame'] 352 | video_second=video_info['duration_second'] 353 | feature_frame=video_info['feature_frame'] 354 | feature_num = int(feature_frame / 8) 355 | corrected_second=float(feature_frame)/video_frame*video_second 356 | video_labels=video_info['annotations'] 357 | vlabels = self.cate_dict[name] 358 | for j in range(len(video_labels)): 359 | tmp_info=video_labels[j] 360 | tmp_start=tmp_info['segment'][0] 361 | tmp_end=tmp_info['segment'][1] 362 | tmp_label=tmp_info['label'] 363 | tmp_start=max(min(1,tmp_start/corrected_second),0) 364 | tmp_end=max(min(1,tmp_end/corrected_second),0) 365 | tmp_fea_start = tmp_start*self.x_tdim 366 | tmp_fea_end = tmp_end*self.x_tdim 367 | s_bbox.append([tmp_fea_start,tmp_fea_end]) 368 | tmp_len = tmp_end - tmp_start 369 | if tmp_len*feature_num >= 8: 370 | r_bbox.append([tmp_start,tmp_end]) 371 | if self.dataSet == 'val_moment': 372 | tmp_df=pd.read_csv(self.csv_dir+"/"+name+".csv",skiprows=1,header=None,sep=',') 373 | anchor_feature = tmp_df.values[:,:] 374 | for kk in range(len(s_bbox),self.gt_hold_num): 375 | s_bbox.append([-1.0,-1.0]) 376 | result.append(anchor_feature) 377 | result.append(s_bbox) 378 | result.append(name) 379 | result.append(vlabels) 380 | result.append(video_second) 381 | return result 382 | else: 383 | tmp_df=pd.read_csv(self.csv_oris_dir+"/"+name+".csv",skiprows=1,header=None,sep=',') 384 | tmp_values = tmp_df.values[:,:] 385 | for m in range(len(r_bbox)): 386 | tmp_start_fea,tmp_end_fea = int(r_bbox[m][0]*feature_num),int(r_bbox[m][1]*feature_num) 387 | proposal_fea = tmp_values[tmp_start_fea:tmp_end_fea,:] 388 | proposal_duration,proposal_frame = (r_bbox[m][1]-r_bbox[m][0])*video_second,(r_bbox[m][1]-r_bbox[m][0])*video_frame 389 | proposal_inter_fea = pool_fea_interpolate_th(proposal_fea,len(proposal_fea),proposal_duration,proposal_frame,self.resize_dim) 390 | result.append([proposal_inter_fea,[[128.0,384.0]],name,vlabels,proposal_duration*2]) 391 | return result 392 | def load_queue(self,start,end): 393 | bindex = 0 394 | lendata = end - start 395 | while True: 396 | t = self.load_single_feature(bindex+start) 397 | if self.dataSet == 'val_moment': 398 | self.queue.put(t) 399 | else: 400 | for z in range(len(t)): 401 | self.queue.put(t[z]) 402 | bindex += 1 403 | if bindex == lendata: bindex = 0 404 | def gen(self): 405 | while True: 406 | if self.queue.empty(): 407 | print('The Kinetics %s Queue is empty, loading feature...'%(self.dataSet)) 408 | time.sleep(10) 409 | else: 410 | t = self.queue.get_nowait() 411 | yield t[0],t[1],t[2],t[3],t[4] 412 | def stop_process(self): 413 | for t in self.process_array: 414 | t.terminate() 415 | t.join() 416 | print('Stop the process sucess.') 417 | def gen_pos_random(self): 418 | while True: 419 | t = self.queue.get() 420 | yield t[0],t[1],t[2],t[3],t[4],random.randint(self.ratio_sid,self.ratio_eid),random.randint(1,19) 421 | 422 | class AnetDatasetLoadMultiProcessQueue(): 423 | def __init__(self,dataSet,csv_dir,gt_hold_num,x_tdim,df_file='./cvs_records_anet/anet_train_val_reduce_info.csv'): 424 | self.train_dict,self.val_dict,self.test_dict=getDatasetDict(df_file=df_file) 425 | if dataSet == 'train': 426 | self.train_list = list(self.train_dict.keys()) 427 | else: 428 | self.train_list = list(self.val_dict.keys()) 429 | self.train_list_copy = self.train_list 430 | cate_info = pd.read_csv('activitynet_train_val_gt.txt',sep = ' ',header = None) 431 | self.cate_dict = {} 432 | for v_id in range(len(cate_info)): 433 | v = cate_info.loc[v_id][0] 434 | cate = cate_info.loc[v_id][2] 435 | self.cate_dict[v] = cate 436 | self.shuffle() 437 | self.dataSet = dataSet 438 | self.csv_dir = csv_dir 439 | self.gt_hold_num = gt_hold_num 440 | self.x_tdim = x_tdim 441 | self.num_samples = len(self.train_list) 442 | self.batch_size = 1 443 | self.queue = Queue(maxsize=1536) #Joinable 444 | self.train_feature_list = multiprocessing.Manager().list() 445 | self.process_num = 32 446 | self.process_array = [] 447 | self.start_train_idx = 0 448 | for pid in range(self.process_num): 449 | pid_num = int(self.num_samples / self.process_num) 450 | start_idx = pid*pid_num 451 | end_idx = (pid+1)*pid_num 452 | if pid == (self.process_num - 1): 453 | if end_idx < self.num_samples: end_idx = self.num_samples 454 | t = Process(target=self.load_queue,args=(start_idx,end_idx)) 455 | self.process_array.append(t) 456 | t.start() 457 | def shuffle(self): 458 | randperm = np.random.permutation(len(self.train_list)) 459 | self.train_list = [self.train_list_copy[int(randperm[idx])] for idx in range(len(randperm))] 460 | def load_single_feature(self,start): 461 | result = [] 462 | name = self.train_list[start] 463 | s_bbox = [] 464 | if self.dataSet == 'train': 465 | video_info=self.train_dict[name] 466 | else: 467 | video_info=self.val_dict[name] 468 | video_frame=video_info['duration_frame'] 469 | video_second=video_info['duration_second'] 470 | feature_frame=video_info['feature_frame'] 471 | corrected_second=float(feature_frame)/video_frame*video_second 472 | video_labels=video_info['annotations'] 473 | vlabels = self.cate_dict[name] 474 | for j in range(len(video_labels)): 475 | tmp_info=video_labels[j] 476 | tmp_start=tmp_info['segment'][0] 477 | tmp_end=tmp_info['segment'][1] 478 | tmp_start=max(min(1,tmp_start/corrected_second),0) 479 | tmp_end=max(min(1,tmp_end/corrected_second),0) 480 | tmp_fea_start = tmp_start*self.x_tdim 481 | tmp_fea_end = tmp_end*self.x_tdim 482 | s_bbox.append([tmp_fea_start,tmp_fea_end]) 483 | for kk in range(len(s_bbox),self.gt_hold_num): 484 | s_bbox.append([-1.0,-1.0]) 485 | tmp_df=pd.read_csv(self.csv_dir+"/"+name+".csv",skiprows=1,header=None,sep=',') 486 | anchor_feature = tmp_df.values[:,:] 487 | result.append(anchor_feature) 488 | result.append(s_bbox) 489 | result.append(name) 490 | result.append(vlabels) 491 | result.append(video_second) 492 | return result 493 | def load_queue(self,start,end): 494 | bindex = 0 495 | lendata = end - start 496 | while True: 497 | t = self.load_single_feature(bindex+start) 498 | self.queue.put(t) 499 | bindex += 1 500 | if bindex == lendata: 501 | bindex = 0 502 | def gen(self): 503 | while True: 504 | t = self.queue.get() 505 | yield t[0],t[1],t[2],t[3],t[4] 506 | def stop_process(self): 507 | for t in self.process_array: 508 | t.terminate() 509 | t.join() 510 | print('Stop the process sucess.') 511 | 512 | #-------------------------- Test Data Loader -----------------------------------# 513 | class KineticsDatasetLoadTrimList(): 514 | def __init__(self,dataSet,csv_dir,csv_oris_dir,gt_hold_num,x_tdim,df_file='./cvs_records_k600/k600_train_val_info.csv',resize_dim=256): 515 | train_video_dict, val_video_dict, test_video_dict=getDatasetDictK600(df_file=df_file) 516 | if "train" in dataSet: 517 | self.train_dict=train_video_dict 518 | else: 519 | self.train_dict=val_video_dict 520 | self.dataSet = dataSet 521 | self.resize_dim = resize_dim 522 | self.train_list = list(self.train_dict.keys()) 523 | self.train_list_copy = self.train_list 524 | print('The %s video num is: %d'%(self.dataSet,len(self.train_list))) 525 | print('Load k600 train val gt info...') 526 | cate_info = pd.read_csv('k600_train_val_gt.txt',sep = ' ',header = None) 527 | self.cate_dict = {} 528 | for v_id in range(len(cate_info)): 529 | v = cate_info.loc[v_id][0] 530 | cate = cate_info.loc[v_id][2] 531 | self.cate_dict[v] = cate 532 | print('Load Done.') 533 | self.shuffle() 534 | self.csv_dir = csv_dir 535 | self.csv_oris_dir = csv_oris_dir 536 | self.gt_hold_num = gt_hold_num 537 | self.x_tdim = x_tdim 538 | self.num_samples = len(self.train_list) 539 | self.batch_size = 1 540 | self.train_feature_list = multiprocessing.Manager().list() 541 | self.process_num = 32 542 | self.process_array = [] 543 | self.start_train_idx = 0 544 | for pid in range(self.process_num): 545 | pid_num = int(self.num_samples / self.process_num) 546 | start_idx = pid*pid_num 547 | end_idx = (pid+1)*pid_num 548 | if pid == (self.process_num - 1): 549 | if end_idx < self.num_samples: end_idx = self.num_samples 550 | t = Process(target=self.load_queue,args=(start_idx,end_idx)) 551 | self.process_array.append(t) 552 | t.start() 553 | def shuffle(self): 554 | randperm = np.random.permutation(len(self.train_list)) 555 | self.train_list = [self.train_list_copy[int(randperm[idx])] for idx in range(len(randperm))] 556 | def load_single_feature(self,start): 557 | result = [] 558 | name = self.train_list[start] 559 | s_bbox = [] 560 | r_bbox = [] 561 | label_bbox=[] 562 | slabel_box = [] 563 | rlabel_bbox = [] 564 | video_info=self.train_dict[name] 565 | video_frame=video_info['duration_frame'] 566 | video_second=video_info['duration_second'] 567 | feature_frame=video_info['feature_frame'] 568 | feature_num = int(feature_frame / 8) 569 | corrected_second=float(feature_frame)/video_frame*video_second 570 | video_labels=video_info['annotations'] 571 | vlabels = self.cate_dict[name] 572 | for j in range(len(video_labels)): 573 | tmp_info=video_labels[j] 574 | tmp_start=tmp_info['segment'][0] 575 | tmp_end=tmp_info['segment'][1] 576 | tmp_label=tmp_info['label'] 577 | tmp_start=max(min(1,tmp_start/corrected_second),0) 578 | tmp_end=max(min(1,tmp_end/corrected_second),0) 579 | tmp_fea_start = tmp_start*self.x_tdim 580 | tmp_fea_end = tmp_end*self.x_tdim 581 | s_bbox.append([tmp_fea_start,tmp_fea_end]) 582 | tmp_len = tmp_end - tmp_start 583 | if tmp_len*feature_num >= 8: 584 | r_bbox.append([tmp_start,tmp_end]) 585 | if self.dataSet == 'val_moment': 586 | tmp_df=pd.read_csv(self.csv_dir+"/"+name+".csv",skiprows=1,header=None,sep=',') 587 | anchor_feature = tmp_df.values[:,:] 588 | for kk in range(len(s_bbox),self.gt_hold_num): 589 | s_bbox.append([-1.0,-1.0]) 590 | result.append(anchor_feature) 591 | result.append(s_bbox) 592 | result.append(name) 593 | result.append(vlabels) 594 | result.append(video_second) 595 | return result 596 | else: 597 | tmp_df=pd.read_csv(self.csv_oris_dir+"/"+name+".csv",skiprows=1,header=None,sep=',') 598 | tmp_values = tmp_df.values[:,:] 599 | for m in range(len(r_bbox)): 600 | tmp_start_fea,tmp_end_fea = int(r_bbox[m][0]*feature_num),int(r_bbox[m][1]*feature_num) 601 | proposal_fea = tmp_values[tmp_start_fea:tmp_end_fea,:] 602 | proposal_duration,proposal_frame = (r_bbox[m][1]-r_bbox[m][0])*video_second,(r_bbox[m][1]-r_bbox[m][0])*video_frame 603 | proposal_inter_fea = pool_fea_interpolate_th(proposal_fea,len(proposal_fea),proposal_duration,proposal_frame,self.resize_dim) 604 | result.append([proposal_inter_fea,[[128.0,384.0]],name,vlabels,proposal_duration*2]) 605 | return result 606 | def load_queue(self,start,end): 607 | bindex = 0 608 | lendata = end - start 609 | for i in range(lendata): 610 | t = self.load_single_feature(bindex+start) 611 | if self.dataSet == 'val_moment': 612 | self.train_feature_list.append(t) 613 | else: 614 | for z in range(len(t)): 615 | self.train_feature_list.append(t[z]) 616 | bindex += 1 617 | if bindex == lendata: bindex = 0 618 | def gen(self): 619 | self.num_samples = len(self.train_feature_list) 620 | print('The sample number of generator %s is %d'%(self.dataSet,self.num_samples)) 621 | while True: 622 | t = self.train_feature_list[self.start_train_idx] 623 | self.start_train_idx += 1 624 | if self.start_train_idx == self.num_samples: 625 | self.start_train_idx = 0 626 | yield t[0],t[1],t[2],t[3],t[4] 627 | def stop_process(self): 628 | for t in self.process_array: 629 | t.terminate() 630 | t.join() 631 | print('Stop the process sucess.') 632 | -------------------------------------------------------------------------------- /evaluation/eval_detection.py: -------------------------------------------------------------------------------- 1 | import json 2 | #import urllib2 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from utils import get_blocked_videos 8 | from utils import interpolated_prec_rec 9 | from utils import segment_iou 10 | 11 | class ANETdetection(object): 12 | 13 | GROUND_TRUTH_FIELDS = ['database', 'version'] #'taxonomy', 14 | PREDICTION_FIELDS = ['results', 'version', 'external_data'] 15 | 16 | def __init__(self, ground_truth_filename=None, prediction_filename=None, 17 | ground_truth_fields=GROUND_TRUTH_FIELDS, 18 | prediction_fields=PREDICTION_FIELDS, 19 | tiou_thresholds=np.linspace(0.5, 0.95, 10), 20 | subset='validation', verbose=False, 21 | check_status=True): 22 | if not ground_truth_filename: 23 | raise IOError('Please input a valid ground truth file.') 24 | if not prediction_filename: 25 | raise IOError('Please input a valid prediction file.') 26 | self.subset = subset 27 | self.tiou_thresholds = tiou_thresholds 28 | self.verbose = verbose 29 | self.gt_fields = ground_truth_fields 30 | self.pred_fields = prediction_fields 31 | self.ap = None 32 | self.check_status = check_status 33 | 34 | self.blocked_videos = {} 35 | # Retrieve blocked videos from server. 36 | #if self.check_status: 37 | # self.blocked_videos = get_blocked_videos() 38 | #else: 39 | # self.blocked_videos = list() 40 | # Import ground truth and predictions. 41 | self.ground_truth, self.activity_index = self._import_ground_truth( 42 | ground_truth_filename) 43 | self.prediction = self._import_prediction(prediction_filename) 44 | 45 | if self.verbose: 46 | print('[INIT] Loaded annotations from {} subset.'.format(subset)) 47 | nr_gt = len(self.ground_truth) 48 | print('\tNumber of ground truth instances: {}'.format(nr_gt)) 49 | nr_pred = len(self.prediction) 50 | print('\tNumber of predictions: {}'.format(nr_pred)) 51 | print('\tFixed threshold for tiou score: {}'.format(self.tiou_thresholds)) 52 | 53 | def _import_ground_truth(self, ground_truth_filename): 54 | """Reads ground truth file, checks if it is well formatted, and returns 55 | the ground truth instances and the activity classes. 56 | 57 | Parameters 58 | ---------- 59 | ground_truth_filename : str 60 | Full path to the ground truth json file. 61 | 62 | Outputs 63 | ------- 64 | ground_truth : df 65 | Data frame containing the ground truth instances. 66 | activity_index : dict 67 | Dictionary containing class index. 68 | """ 69 | with open(ground_truth_filename, 'r') as fobj: 70 | data = json.load(fobj) 71 | # Checking format 72 | if not all([field in data.keys() for field in self.gt_fields]): 73 | raise IOError('Please input a valid ground truth file.') 74 | 75 | # Read ground truth data. 76 | activity_index, cidx = {}, 0 77 | video_lst, t_start_lst, t_end_lst, label_lst = [], [], [], [] 78 | for videoid, v in data['database'].items(): 79 | if self.subset != v['subset']: 80 | continue 81 | if videoid in self.blocked_videos: 82 | continue 83 | for ann in v['annotations']: 84 | if ann['label'] not in activity_index: 85 | activity_index[ann['label']] = cidx 86 | cidx += 1 87 | video_lst.append(videoid) 88 | t_start_lst.append(ann['segment'][0]) 89 | t_end_lst.append(ann['segment'][1]) 90 | label_lst.append(activity_index[ann['label']]) 91 | 92 | ground_truth = pd.DataFrame({'video-id': video_lst, 93 | 't-start': t_start_lst, 94 | 't-end': t_end_lst, 95 | 'label': label_lst}) 96 | return ground_truth, activity_index 97 | 98 | def _import_prediction(self, prediction_filename): 99 | """Reads prediction file, checks if it is well formatted, and returns 100 | the prediction instances. 101 | 102 | Parameters 103 | ---------- 104 | prediction_filename : str 105 | Full path to the prediction json file. 106 | 107 | Outputs 108 | ------- 109 | prediction : df 110 | Data frame containing the prediction instances. 111 | """ 112 | with open(prediction_filename, 'r') as fobj: 113 | data = json.load(fobj) 114 | # Checking format... 115 | if not all([field in data.keys() for field in self.pred_fields]): 116 | raise IOError('Please input a valid prediction file.') 117 | 118 | # Read predicitons. 119 | video_lst, t_start_lst, t_end_lst = [], [], [] 120 | label_lst, score_lst = [], [] 121 | for videoid, v in data['results'].items(): 122 | if videoid not in self.ground_truth['video-id'].values: continue 123 | if videoid in self.blocked_videos: 124 | continue 125 | for result in v: 126 | if result['label'] not in self.activity_index.keys(): continue 127 | label = self.activity_index[result['label']] 128 | video_lst.append(videoid) 129 | t_start_lst.append(result['segment'][0]) 130 | t_end_lst.append(result['segment'][1]) 131 | label_lst.append(label) 132 | score_lst.append(result['score']) 133 | prediction = pd.DataFrame({'video-id': video_lst, 134 | 't-start': t_start_lst, 135 | 't-end': t_end_lst, 136 | 'label': label_lst, 137 | 'score': score_lst}) 138 | return prediction 139 | 140 | def wrapper_compute_average_precision(self): 141 | """Computes average precision for each class in the subset. 142 | """ 143 | ap = np.zeros((len(self.tiou_thresholds), len(self.activity_index.items()))) 144 | for activity, cidx in self.activity_index.items(): 145 | gt_idx = self.ground_truth['label'] == cidx 146 | pred_idx = self.prediction['label'] == cidx 147 | ap[:,cidx] = compute_average_precision_detection( 148 | self.ground_truth.loc[gt_idx].reset_index(drop=True), 149 | self.prediction.loc[pred_idx].reset_index(drop=True), 150 | tiou_thresholds=self.tiou_thresholds) 151 | return ap 152 | 153 | def evaluate(self,output_log): 154 | """Evaluates a prediction file. For the detection task we measure the 155 | interpolated mean average precision to measure the performance of a 156 | method. 157 | """ 158 | self.ap = self.wrapper_compute_average_precision() 159 | self.mAP = self.ap.mean(axis=1) 160 | if self.verbose: 161 | print('[RESULTS] Performance on ActivityNet detection task.') 162 | print('AP in each IOU: {}'.format(self.mAP)) 163 | print('\tAverage-mAP: {}'.format(self.mAP.mean())) 164 | if len(output_log) != 0: 165 | with open(output_log,'w') as fw: 166 | mAPc = self.ap.mean(axis=0) 167 | for activity, cidx in self.activity_index.items(): 168 | #cname = self.activity_index[i] 169 | apc = mAPc[cidx] 170 | fw.write('%s %0.4f\n'%(activity,apc)) 171 | 172 | def compute_average_precision_detection(ground_truth, prediction, tiou_thresholds=np.linspace(0.5, 0.95, 10)): 173 | """Compute average precision (detection task) between ground truth and 174 | predictions data frames. If multiple predictions occurs for the same 175 | predicted segment, only the one with highest score is matches as 176 | true positive. This code is greatly inspired by Pascal VOC devkit. 177 | 178 | Parameters 179 | ---------- 180 | ground_truth : df 181 | Data frame containing the ground truth instances. 182 | Required fields: ['video-id', 't-start', 't-end'] 183 | prediction : df 184 | Data frame containing the prediction instances. 185 | Required fields: ['video-id, 't-start', 't-end', 'score'] 186 | tiou_thresholds : 1darray, optional 187 | Temporal intersection over union threshold. 188 | 189 | Outputs 190 | ------- 191 | ap : float 192 | Average precision score. 193 | """ 194 | npos = float(len(ground_truth)) 195 | lock_gt = np.ones((len(tiou_thresholds),len(ground_truth))) * -1 196 | # Sort predictions by decreasing score order. 197 | sort_idx = prediction['score'].values.argsort()[::-1] 198 | prediction = prediction.loc[sort_idx].reset_index(drop=True) 199 | 200 | # Initialize true positive and false positive vectors. 201 | tp = np.zeros((len(tiou_thresholds), len(prediction))) 202 | fp = np.zeros((len(tiou_thresholds), len(prediction))) 203 | 204 | # Adaptation to query faster 205 | ground_truth_gbvn = ground_truth.groupby('video-id') 206 | 207 | # Assigning true positive to truly grount truth instances. 208 | for idx, this_pred in prediction.iterrows(): 209 | 210 | try: 211 | # Check if there is at least one ground truth in the video associated. 212 | ground_truth_videoid = ground_truth_gbvn.get_group(this_pred['video-id']) 213 | except Exception as e: 214 | fp[:, idx] = 1 215 | continue 216 | 217 | this_gt = ground_truth_videoid.reset_index() 218 | tiou_arr = segment_iou(this_pred[['t-start', 't-end']].values, 219 | this_gt[['t-start', 't-end']].values) 220 | # We would like to retrieve the predictions with highest tiou score. 221 | tiou_sorted_idx = tiou_arr.argsort()[::-1] 222 | for tidx, tiou_thr in enumerate(tiou_thresholds): 223 | for jdx in tiou_sorted_idx: 224 | if tiou_arr[jdx] < tiou_thr: 225 | fp[tidx, idx] = 1 226 | break 227 | if lock_gt[tidx, this_gt.loc[jdx]['index']] >= 0: 228 | continue 229 | # Assign as true positive after the filters above. 230 | tp[tidx, idx] = 1 231 | lock_gt[tidx, this_gt.loc[jdx]['index']] = idx 232 | break 233 | 234 | if fp[tidx, idx] == 0 and tp[tidx, idx] == 0: 235 | fp[tidx, idx] = 1 236 | 237 | ap = np.zeros(len(tiou_thresholds)) 238 | 239 | for tidx in range(len(tiou_thresholds)): 240 | # Computing prec-rec 241 | this_tp = np.cumsum(tp[tidx,:]).astype(np.float) 242 | this_fp = np.cumsum(fp[tidx,:]).astype(np.float) 243 | rec = this_tp / npos 244 | prec = this_tp / (this_tp + this_fp) 245 | ap[tidx] = interpolated_prec_rec(prec, rec) 246 | 247 | return ap 248 | -------------------------------------------------------------------------------- /evaluation/eval_proposal.py: -------------------------------------------------------------------------------- 1 | import json 2 | #import urllib2 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from utils import get_blocked_videos 8 | from utils import interpolated_prec_rec 9 | from utils import segment_iou 10 | from utils import wrapper_segment_iou 11 | 12 | class ANETproposal(object): 13 | 14 | GROUND_TRUTH_FIELDS = ['database', 'version'] 15 | PROPOSAL_FIELDS = ['results', 'version', 'external_data'] 16 | 17 | def __init__(self, ground_truth_filename=None, proposal_filename=None, 18 | ground_truth_fields=GROUND_TRUTH_FIELDS, 19 | proposal_fields=PROPOSAL_FIELDS, 20 | tiou_thresholds=np.linspace(0.5, 0.95, 10), 21 | max_avg_nr_proposals=None, 22 | subset='validation', verbose=False, 23 | check_status=True): 24 | if not ground_truth_filename: 25 | raise IOError('Please input a valid ground truth file.') 26 | if not proposal_filename: 27 | raise IOError('Please input a valid proposal file.') 28 | self.subset = subset 29 | self.tiou_thresholds = tiou_thresholds 30 | self.max_avg_nr_proposals = max_avg_nr_proposals 31 | self.verbose = verbose 32 | self.gt_fields = ground_truth_fields 33 | self.pred_fields = proposal_fields 34 | self.recall = None 35 | self.avg_recall = None 36 | self.proposals_per_video = None 37 | self.check_status = check_status 38 | 39 | # Include the Block video in server (modified in 6/3/2017) 40 | # Retrieve blocked videos from server. 41 | #if self.check_status: 42 | # self.blocked_videos = get_blocked_videos() 43 | #else: 44 | # self.blocked_videos = list() 45 | self.blocked_videos = list() 46 | #print('Block Video Num: %d'%(len(self.blocked_videos))) 47 | 48 | #with open('block_video_list.txt','w') as fw: 49 | # for i in range(len(self.blocked_videos)): 50 | # fw.write('%s\n'%self.blocked_videos[i]) 51 | 52 | self.lesstime_videos = list() 53 | print('The less video number: %d'%(len(self.lesstime_videos))) 54 | # Import ground truth and proposals. 55 | self.ground_truth, self.activity_index = self._import_ground_truth( 56 | ground_truth_filename) 57 | self.proposal = self._import_proposal(proposal_filename) 58 | 59 | #if self.verbose: 60 | # print '[INIT] Loaded annotations from {} subset.'.format(subset) 61 | # nr_gt = len(self.ground_truth) 62 | # print '\tNumber of ground truth instances: {}'.format(nr_gt) 63 | # nr_pred = len(self.proposal) 64 | # print '\tNumber of proposals: {}'.format(nr_pred) 65 | # print '\tFixed threshold for tiou score: {}'.format(self.tiou_thresholds) 66 | 67 | def _import_ground_truth(self, ground_truth_filename): 68 | """Reads ground truth file, checks if it is well formatted, and returns 69 | the ground truth instances and the activity classes. 70 | 71 | Parameters 72 | ---------- 73 | ground_truth_filename : str 74 | Full path to the ground truth json file. 75 | 76 | Outputs 77 | ------- 78 | ground_truth : df 79 | Data frame containing the ground truth instances. 80 | activity_index : dict 81 | Dictionary containing class index. 82 | """ 83 | with open(ground_truth_filename, 'r') as fobj: 84 | data = json.load(fobj) 85 | # Checking format 86 | if not all([field in data.keys() for field in self.gt_fields]): 87 | raise IOError('Please input a valid ground truth file.') 88 | 89 | # Read ground truth data. 90 | activity_index, cidx, block_video_n = {}, 0, 0 91 | video_lst, t_start_lst, t_end_lst, label_lst = [], [], [], [] 92 | for videoid, v in data['database'].items(): 93 | if self.subset != v['subset']: 94 | continue 95 | if videoid in self.blocked_videos: 96 | block_video_n += 1 97 | continue 98 | if videoid in self.lesstime_videos: 99 | continue 100 | for ann in v['annotations']: 101 | if ann['label'] not in activity_index: 102 | activity_index[ann['label']] = cidx 103 | cidx += 1 104 | video_lst.append(videoid) 105 | t_start_lst.append(ann['segment'][0]) 106 | t_end_lst.append(ann['segment'][1]) 107 | label_lst.append(activity_index[ann['label']]) 108 | print('time less video: %d'%(len(video_lst))) 109 | print('block video number: %d'%(block_video_n)) 110 | ground_truth = pd.DataFrame({'video-id': video_lst, 111 | 't-start': t_start_lst, 112 | 't-end': t_end_lst, 113 | 'label': label_lst}) 114 | return ground_truth, activity_index 115 | 116 | def _import_proposal(self, proposal_filename): 117 | """Reads proposal file, checks if it is well formatted, and returns 118 | the proposal instances. 119 | 120 | Parameters 121 | ---------- 122 | proposal_filename : str 123 | Full path to the proposal json file. 124 | 125 | Outputs 126 | ------- 127 | proposal : df 128 | Data frame containing the proposal instances. 129 | """ 130 | with open(proposal_filename, 'r') as fobj: 131 | data = json.load(fobj) 132 | # Checking format... 133 | if not all([field in data.keys() for field in self.pred_fields]): 134 | raise IOError('Please input a valid proposal file.') 135 | 136 | # Read predictions. 137 | video_lst, t_start_lst, t_end_lst = [], [], [] 138 | score_lst = [] 139 | for videoid, v in data['results'].items(): 140 | if videoid in self.blocked_videos: 141 | continue 142 | for result in v: 143 | video_lst.append(videoid) 144 | t_start_lst.append(result['segment'][0]) 145 | t_end_lst.append(result['segment'][1]) 146 | score_lst.append(result['score']) 147 | proposal = pd.DataFrame({'video-id': video_lst, 148 | 't-start': t_start_lst, 149 | 't-end': t_end_lst, 150 | 'score': score_lst}) 151 | return proposal 152 | 153 | def evaluate(self): 154 | """Evaluates a proposal file. To measure the performance of a 155 | method for the proposal task, we computes the area under the 156 | average recall vs average number of proposals per video curve. 157 | """ 158 | recall, avg_recall, proposals_per_video = average_recall_vs_avg_nr_proposals( 159 | self.ground_truth, self.proposal, 160 | max_avg_nr_proposals=self.max_avg_nr_proposals, 161 | tiou_thresholds=self.tiou_thresholds) 162 | 163 | area_under_curve = np.trapz(avg_recall, proposals_per_video) 164 | 165 | if self.verbose: 166 | print('[RESULTS] Performance on ActivityNet proposal task.') 167 | print('\tArea Under the AR vs AN curve: {}%'.format(100.*float(area_under_curve)/proposals_per_video[-1])) 168 | print('\tProposal 100, IOU from 0.5:0.05:0.95, Average Recall: {}%'.format(100.*float(avg_recall[-1]))) 169 | 170 | self.recall = recall 171 | self.avg_recall = avg_recall 172 | self.proposals_per_video = proposals_per_video 173 | 174 | def average_recall_vs_avg_nr_proposals(ground_truth, proposals, 175 | max_avg_nr_proposals=None, 176 | tiou_thresholds=np.linspace(0.5, 0.95, 10)): 177 | """ Computes the average recall given an average number 178 | of proposals per video. 179 | 180 | Parameters 181 | ---------- 182 | ground_truth : df 183 | Data frame containing the ground truth instances. 184 | Required fields: ['video-id', 't-start', 't-end'] 185 | proposal : df 186 | Data frame containing the proposal instances. 187 | Required fields: ['video-id, 't-start', 't-end', 'score'] 188 | tiou_thresholds : 1darray, optional 189 | array with tiou thresholds. 190 | 191 | Outputs 192 | ------- 193 | recall : 2darray 194 | recall[i,j] is recall at ith tiou threshold at the jth average number of average number of proposals per video. 195 | average_recall : 1darray 196 | recall averaged over a list of tiou threshold. This is equivalent to recall.mean(axis=0). 197 | proposals_per_video : 1darray 198 | average number of proposals per video. 199 | """ 200 | 201 | # Get list of videos. 202 | video_lst = ground_truth['video-id'].unique() 203 | 204 | if not max_avg_nr_proposals: 205 | max_avg_nr_proposals = float(proposals.shape[0])/video_lst.shape[0] 206 | 207 | ratio = max_avg_nr_proposals*float(video_lst.shape[0])/proposals.shape[0] 208 | print('max avg p: %d video_lst shape0: %d proposals_shape0: %d'%(max_avg_nr_proposals,video_lst.shape[0],proposals.shape[0])) 209 | print('\ttotal proposal: %d ratio: %0.4f'%(proposals.shape[0],ratio)) 210 | 211 | # Adaptation to query faster 212 | ground_truth_gbvn = ground_truth.groupby('video-id') 213 | proposals_gbvn = proposals.groupby('video-id') 214 | 215 | # For each video, computes tiou scores among the retrieved proposals. 216 | score_lst = [] 217 | total_nr_proposals = 0 218 | for videoid in video_lst: 219 | 220 | # Get ground-truth instances associated to this video. 221 | ground_truth_videoid = ground_truth_gbvn.get_group(videoid) 222 | this_video_ground_truth = ground_truth_videoid.loc[:,['t-start', 't-end']].values 223 | 224 | # Get proposals for this video. 225 | try: 226 | proposals_videoid = proposals_gbvn.get_group(videoid) 227 | this_video_proposals = proposals_videoid.loc[:, ['t-start', 't-end']].values 228 | except: 229 | n = this_video_ground_truth.shape[0] 230 | score_lst.append(np.zeros((n, 1))) 231 | continue 232 | 233 | # Sort proposals by score. 234 | sort_idx = proposals_videoid['score'].argsort()[::-1] 235 | this_video_proposals = this_video_proposals[sort_idx, :] 236 | 237 | if this_video_proposals.shape[0] == 0: 238 | n = this_video_ground_truth.shape[0] 239 | score_lst.append(np.zeros((n, 1))) 240 | continue 241 | 242 | if this_video_proposals.ndim != 2: 243 | this_video_proposals = np.expand_dims(this_video_proposals, axis=0) 244 | if this_video_ground_truth.ndim != 2: 245 | this_video_ground_truth = np.expand_dims(this_video_ground_truth, axis=0) 246 | 247 | nr_proposals = np.minimum(int(this_video_proposals.shape[0] * ratio), this_video_proposals.shape[0]) 248 | total_nr_proposals += nr_proposals 249 | this_video_proposals = this_video_proposals[:nr_proposals, :] 250 | 251 | # Compute tiou scores. 252 | tiou = wrapper_segment_iou(this_video_proposals, this_video_ground_truth) 253 | score_lst.append(tiou) 254 | 255 | # Given that the length of the videos is really varied, we 256 | # compute the number of proposals in terms of a ratio of the total 257 | # proposals retrieved, i.e. average recall at a percentage of proposals 258 | # retrieved per video. 259 | 260 | print('\ttotal proposal: %d'%total_nr_proposals) 261 | print('\tvideo list number: %d'%video_lst.shape[0]) 262 | # Computes average recall. 263 | pcn_lst = np.arange(1, 101) / 100.0 *(max_avg_nr_proposals*float(video_lst.shape[0])/total_nr_proposals) 264 | matches = np.empty((video_lst.shape[0], pcn_lst.shape[0])) 265 | positives = np.empty(video_lst.shape[0]) 266 | recall = np.empty((tiou_thresholds.shape[0], pcn_lst.shape[0])) 267 | # Iterates over each tiou threshold. 268 | for ridx, tiou in enumerate(tiou_thresholds): 269 | 270 | # Inspect positives retrieved per video at different 271 | # number of proposals (percentage of the total retrieved). 272 | for i, score in enumerate(score_lst): 273 | # Total positives per video. 274 | positives[i] = score.shape[0] 275 | # Find proposals that satisfies minimum tiou threshold. 276 | true_positives_tiou = score >= tiou 277 | # Get number of proposals as a percentage of total retrieved. 278 | pcn_proposals = np.minimum((score.shape[1] * pcn_lst).astype(np.int), score.shape[1]) 279 | 280 | for j, nr_proposals in enumerate(pcn_proposals): 281 | # Compute the number of matches for each percentage of the proposals 282 | matches[i, j] = np.count_nonzero((true_positives_tiou[:, :nr_proposals]).sum(axis=1)) 283 | # Computes recall given the set of matches per video. 284 | recall[ridx, :] = matches.sum(axis=0) / positives.sum() 285 | 286 | # Recall is averaged. 287 | avg_recall = recall.mean(axis=0) 288 | 289 | # Get the average number of proposals per video. 290 | proposals_per_video = pcn_lst * (float(total_nr_proposals) / video_lst.shape[0]) 291 | 292 | return recall, avg_recall, proposals_per_video 293 | 294 | -------------------------------------------------------------------------------- /evaluation/get_detection_performance.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | from eval_detection import ANETdetection 5 | 6 | def main(ground_truth_filename, prediction_filename, 7 | subset='validation', tiou_thresholds=np.linspace(0.5, 0.95, 10), 8 | verbose=True, check_status=True): 9 | 10 | anet_detection = ANETdetection(ground_truth_filename, prediction_filename, 11 | subset=subset, tiou_thresholds=tiou_thresholds, 12 | verbose=verbose, check_status=True) 13 | anet_detection.evaluate() 14 | 15 | def detect(prediction_filename, ground_truth_filename = 'activity_net.v1-3.min.json',output_log=''): 16 | anet_detection = ANETdetection(ground_truth_filename, prediction_filename, 17 | subset='validation', tiou_thresholds=np.linspace(0.5, 0.95, 10), 18 | verbose=True, check_status=True) 19 | anet_detection.evaluate(output_log) 20 | return anet_detection.mAP 21 | 22 | 23 | 24 | def parse_input(): 25 | description = ('This script allows you to evaluate the ActivityNet ' 26 | 'detection task which is intended to evaluate the ability ' 27 | 'of algorithms to temporally localize activities in ' 28 | 'untrimmed video sequences.') 29 | p = argparse.ArgumentParser(description=description) 30 | p.add_argument('ground_truth_filename', 31 | help='Full path to json file containing the ground truth.') 32 | p.add_argument('prediction_filename', 33 | help='Full path to json file containing the predictions.') 34 | p.add_argument('--subset', default='validation', 35 | help=('String indicating subset to evaluate: ' 36 | '(training, validation)')) 37 | p.add_argument('--tiou_thresholds', type=float, default=np.linspace(0.5, 0.95, 10), 38 | help='Temporal intersection over union threshold.') 39 | p.add_argument('--verbose', type=bool, default=True) 40 | p.add_argument('--check_status', type=bool, default=True) 41 | return p.parse_args() 42 | 43 | if __name__ == '__main__': 44 | args = parse_input() 45 | main(**vars(args)) 46 | -------------------------------------------------------------------------------- /evaluation/get_proposal_performance.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import sys 4 | sys.path.append('./evaluation') 5 | import matplotlib.pyplot as plt 6 | import json 7 | 8 | from eval_proposal import ANETproposal 9 | 10 | def main(ground_truth_filename, proposal_filename, max_avg_nr_proposals=100, 11 | tiou_thresholds=np.linspace(0.5, 0.95, 10), 12 | subset='validation', verbose=True, check_status=True): 13 | 14 | anet_proposal = ANETproposal(ground_truth_filename, proposal_filename, 15 | tiou_thresholds=tiou_thresholds, 16 | max_avg_nr_proposals=max_avg_nr_proposals, 17 | subset=subset, verbose=True, check_status=True) 18 | anet_proposal.evaluate() 19 | 20 | def run_evaluation(ground_truth_filename, proposal_filename, 21 | max_avg_nr_proposals=100, 22 | tiou_thresholds=np.linspace(0.5, 0.95, 10), 23 | subset='validation'): 24 | 25 | anet_proposal = ANETproposal(ground_truth_filename, proposal_filename, 26 | tiou_thresholds=tiou_thresholds, 27 | max_avg_nr_proposals=max_avg_nr_proposals, 28 | subset=subset, verbose=True, check_status=True) 29 | anet_proposal.evaluate() 30 | 31 | recall = anet_proposal.recall 32 | average_recall = anet_proposal.avg_recall 33 | average_nr_proposals = anet_proposal.proposals_per_video 34 | 35 | return (average_nr_proposals, average_recall, recall) 36 | 37 | def plot_metric(average_nr_proposals, average_recall, recall, 38 | tiou_thresholds=np.linspace(0.5, 0.95, 10)): 39 | fn_size = 14 40 | plt.figure(num=None, figsize=(6, 5)) 41 | ax = plt.subplot(1,1,1) 42 | 43 | colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9'] 44 | 45 | area_under_curve = np.zeros_like(tiou_thresholds) 46 | for i in range(recall.shape[0]): 47 | area_under_curve[i] = np.trapz(recall[i], average_nr_proposals) 48 | 49 | for idx, tiou in enumerate(tiou_thresholds[::2]): 50 | ax.plot(average_nr_proposals, recall[2*idx,:], color=colors[idx+1], 51 | label="tiou=[" + str(tiou) + "], area=" + str(int(area_under_curve[2*idx]*100)/100.), 52 | linewidth=4, linestyle='--', marker=None) 53 | 54 | # Plots Average Recall vs Average number of proposals. 55 | ax.plot(average_nr_proposals, average_recall, color=colors[0], 56 | label="tiou = 0.5:0.05:0.95," + " area=" + str(int(np.trapz(average_recall, average_nr_proposals)*100)/100.), 57 | linewidth=4, linestyle='-', marker=None)# 58 | 59 | handles, labels = ax.get_legend_handles_labels() 60 | ax.legend([handles[-1]] + handles[:-1], [labels[-1]] + labels[:-1], loc='best') 61 | 62 | plt.ylabel('Average Recall', fontsize=fn_size) 63 | plt.xlabel('Average Number of Proposals per Video', fontsize=fn_size) 64 | plt.grid(b=True, which="both") 65 | plt.ylim([0, 1.0]) 66 | plt.setp(plt.axes().get_xticklabels(), fontsize=fn_size) 67 | plt.setp(plt.axes().get_yticklabels(), fontsize=fn_size) 68 | 69 | plt.show() 70 | 71 | def plot_figure(ground_truth_filename, proposal_filename, max_avg_nr_proposals=100, 72 | tiou_thresholds=np.linspace(0.5, 0.95, 10), 73 | subset='validation', verbose=True, check_status=True): 74 | uniform_average_nr_proposals_valid, uniform_average_recall_valid, uniform_recall_valid = run_evaluation( 75 | ground_truth_filename, 76 | proposal_filename, 77 | max_avg_nr_proposals=100, 78 | tiou_thresholds=np.linspace(0.5, 0.95, 10), 79 | subset='validation') 80 | plot_metric(uniform_average_nr_proposals_valid, uniform_average_recall_valid, uniform_recall_valid) 81 | 82 | def evaluate_return_area(ground_truth_filename,proposal_filename,max_avg_nr_proposals=100, 83 | tiou_thresholds=np.linspace(0.5,0.95,10), 84 | subset='validation',verbose = True, check_status = True): 85 | anet_proposal = ANETproposal(ground_truth_filename, proposal_filename, 86 | tiou_thresholds=tiou_thresholds, 87 | max_avg_nr_proposals=max_avg_nr_proposals, 88 | subset=subset, verbose=True, check_status=True) 89 | anet_proposal.evaluate() 90 | 91 | recall = anet_proposal.recall 92 | average_recall = anet_proposal.avg_recall 93 | average_nr_proposals = anet_proposal.proposals_per_video 94 | 95 | area_under_curve = np.trapz(average_recall, average_nr_proposals) 96 | AR_AN = 100.*float(area_under_curve)/average_nr_proposals[-1] 97 | Recall_all = 100.*float(average_recall[-1]) 98 | 99 | return (AR_AN,Recall_all) 100 | 101 | def parse_input(): 102 | description = ('This script allows you to evaluate the ActivityNet ' 103 | 'proposal task which is intended to evaluate the ability ' 104 | 'of algorithms to generate activity proposals that temporally ' 105 | 'localize activities in untrimmed video sequences.') 106 | p = argparse.ArgumentParser(description=description) 107 | p.add_argument('ground_truth_filename', 108 | help='Full path to json file containing the ground truth.') 109 | p.add_argument('proposal_filename', 110 | help='Full path to json file containing the proposals.') 111 | p.add_argument('--subset', default='validation', 112 | help=('String indicating subset to evaluate: ' 113 | '(training, validation)')) 114 | p.add_argument('--verbose', type=bool, default=True) 115 | p.add_argument('--check_status', type=bool, default=True) 116 | return p.parse_args() 117 | 118 | def write_ar_an(gt, json_path, out_file): 119 | uniform_average_nr_proposals_valid, uniform_average_recall_valid, uniform_recall_valid = run_evaluation( 120 | gt,json_path, 121 | max_avg_nr_proposals=100, 122 | tiou_thresholds=np.linspace(0.5, 0.95, 10), 123 | subset='validation') 124 | with open(out_file,'w') as fw: 125 | for k in range(len(uniform_average_nr_proposals_valid)): 126 | fw.write('%f %f\n'%(uniform_average_nr_proposals_valid[k],uniform_average_recall_valid[k])) 127 | -------------------------------------------------------------------------------- /evaluation/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import urllib 3 | 4 | import numpy as np 5 | 6 | API = 'http://ec2-52-11-11-89.us-west-2.compute.amazonaws.com/challenge16/api.py' 7 | 8 | def get_blocked_videos(api=API): 9 | api_url = '{}?action=get_blocked'.format(api) 10 | req = urllib.Request(api_url) 11 | response = urllib.urlopen(req) 12 | return json.loads(response.read()) 13 | 14 | def interpolated_prec_rec(prec, rec): 15 | """Interpolated AP - VOCdevkit from VOC 2011. 16 | """ 17 | mprec = np.hstack([[0], prec, [0]]) 18 | mrec = np.hstack([[0], rec, [1]]) 19 | for i in range(len(mprec) - 1)[::-1]: 20 | mprec[i] = max(mprec[i], mprec[i + 1]) 21 | idx = np.where(mrec[1::] != mrec[0:-1])[0] + 1 22 | ap = np.sum((mrec[idx] - mrec[idx - 1]) * mprec[idx]) 23 | return ap 24 | 25 | def segment_iou(target_segment, candidate_segments): 26 | """Compute the temporal intersection over union between a 27 | target segment and all the test segments. 28 | 29 | Parameters 30 | ---------- 31 | target_segment : 1d array 32 | Temporal target segment containing [starting, ending] times. 33 | candidate_segments : 2d array 34 | Temporal candidate segments containing N x [starting, ending] times. 35 | 36 | Outputs 37 | ------- 38 | tiou : 1d array 39 | Temporal intersection over union score of the N's candidate segments. 40 | """ 41 | tt1 = np.maximum(target_segment[0], candidate_segments[:, 0]) 42 | tt2 = np.minimum(target_segment[1], candidate_segments[:, 1]) 43 | # Intersection including Non-negative overlap score. 44 | segments_intersection = (tt2 - tt1).clip(0) 45 | # Segment union. 46 | segments_union = (candidate_segments[:, 1] - candidate_segments[:, 0]) \ 47 | + (target_segment[1] - target_segment[0]) - segments_intersection 48 | # Compute overlap as the ratio of the intersection 49 | # over union of two segments. 50 | tIoU = segments_intersection.astype(float) / segments_union 51 | return tIoU 52 | 53 | def wrapper_segment_iou(target_segments, candidate_segments): 54 | """Compute intersection over union btw segments 55 | Parameters 56 | ---------- 57 | target_segments : ndarray 58 | 2-dim array in format [m x 2:=[init, end]] 59 | candidate_segments : ndarray 60 | 2-dim array in format [n x 2:=[init, end]] 61 | Outputs 62 | ------- 63 | tiou : ndarray 64 | 2-dim array [n x m] with IOU ratio. 65 | Note: It assumes that candidate-segments are more scarce that target-segments 66 | """ 67 | if candidate_segments.ndim != 2 or target_segments.ndim != 2: 68 | raise ValueError('Dimension of arguments is incorrect') 69 | 70 | n, m = candidate_segments.shape[0], target_segments.shape[0] 71 | tiou = np.empty((n, m)) 72 | for i in range(m): 73 | tiou[:, i] = segment_iou(target_segments[i,:], candidate_segments) 74 | 75 | return tiou 76 | -------------------------------------------------------------------------------- /k600_category.txt: -------------------------------------------------------------------------------- 1 | abseiling 2 | acting_in_play 3 | adjusting_glasses 4 | air_drumming 5 | alligator_wrestling 6 | answering_questions 7 | applauding 8 | applying_cream 9 | archaeological_excavation 10 | archery 11 | arguing 12 | arm_wrestling 13 | arranging_flowers 14 | assembling_bicycle 15 | assembling_computer 16 | attending_conference 17 | auctioning 18 | backflip_(human) 19 | baking_cookies 20 | bandaging 21 | barbequing 22 | bartending 23 | base_jumping 24 | bathing_dog 25 | battle_rope_training 26 | beatboxing 27 | bee_keeping 28 | belly_dancing 29 | bench_pressing 30 | bending_back 31 | bending_metal 32 | biking_through_snow 33 | blasting_sand 34 | blowdrying_hair 35 | blowing_bubble_gum 36 | blowing_glass 37 | blowing_leaves 38 | blowing_nose 39 | blowing_out_candles 40 | bobsledding 41 | bodysurfing 42 | bookbinding 43 | bottling 44 | bouncing_on_bouncy_castle 45 | bouncing_on_trampoline 46 | bowling 47 | braiding_hair 48 | breading_or_breadcrumbing 49 | breakdancing 50 | breaking_boards 51 | breathing_fire 52 | brush_painting 53 | brushing_hair 54 | brushing_teeth 55 | building_cabinet 56 | building_lego 57 | building_sandcastle 58 | building_shed 59 | bull_fighting 60 | bulldozing 61 | bungee_jumping 62 | burping 63 | busking 64 | calculating 65 | calligraphy 66 | canoeing_or_kayaking 67 | capoeira 68 | capsizing 69 | card_stacking 70 | card_throwing 71 | carrying_baby 72 | cartwheeling 73 | carving_ice 74 | carving_pumpkin 75 | casting_fishing_line 76 | catching_fish 77 | catching_or_throwing_baseball 78 | catching_or_throwing_frisbee 79 | catching_or_throwing_softball 80 | celebrating 81 | changing_gear_in_car 82 | changing_oil 83 | changing_wheel_(not_on_bike) 84 | checking_tires 85 | cheerleading 86 | chewing_gum 87 | chiseling_stone 88 | chiseling_wood 89 | chopping_meat 90 | chopping_vegetables 91 | chopping_wood 92 | clam_digging 93 | clapping 94 | clay_pottery_making 95 | clean_and_jerk 96 | cleaning_gutters 97 | cleaning_pool 98 | cleaning_shoes 99 | cleaning_toilet 100 | cleaning_windows 101 | climbing_a_rope 102 | climbing_ladder 103 | climbing_tree 104 | coloring_in 105 | combing_hair 106 | contact_juggling 107 | contorting 108 | cooking_egg 109 | cooking_on_campfire 110 | cooking_sausages_(not_on_barbeque) 111 | cooking_scallops 112 | cosplaying 113 | counting_money 114 | country_line_dancing 115 | cracking_back 116 | cracking_knuckles 117 | cracking_neck 118 | crawling_baby 119 | crossing_eyes 120 | crossing_river 121 | crying 122 | cumbia 123 | curling_(sport) 124 | curling_hair 125 | cutting_apple 126 | cutting_nails 127 | cutting_orange 128 | cutting_pineapple 129 | cutting_watermelon 130 | dancing_ballet 131 | dancing_charleston 132 | dancing_gangnam_style 133 | dancing_macarena 134 | deadlifting 135 | decorating_the_christmas_tree 136 | delivering_mail 137 | dining 138 | directing_traffic 139 | disc_golfing 140 | diving_cliff 141 | docking_boat 142 | dodgeball 143 | doing_aerobics 144 | doing_jigsaw_puzzle 145 | doing_laundry 146 | doing_nails 147 | drawing 148 | dribbling_basketball 149 | drinking_shots 150 | driving_car 151 | driving_tractor 152 | drooling 153 | drop_kicking 154 | drumming_fingers 155 | dumpster_diving 156 | dunking_basketball 157 | dyeing_eyebrows 158 | dyeing_hair 159 | eating_burger 160 | eating_cake 161 | eating_carrots 162 | eating_chips 163 | eating_doughnuts 164 | eating_hotdog 165 | eating_ice_cream 166 | eating_spaghetti 167 | eating_watermelon 168 | egg_hunting 169 | embroidering 170 | exercising_with_an_exercise_ball 171 | extinguishing_fire 172 | faceplanting 173 | falling_off_bike 174 | falling_off_chair 175 | feeding_birds 176 | feeding_fish 177 | feeding_goats 178 | fencing_(sport) 179 | fidgeting 180 | finger_snapping 181 | fixing_bicycle 182 | fixing_hair 183 | flint_knapping 184 | flipping_pancake 185 | fly_tying 186 | flying_kite 187 | folding_clothes 188 | folding_napkins 189 | folding_paper 190 | front_raises 191 | frying_vegetables 192 | geocaching 193 | getting_a_haircut 194 | getting_a_piercing 195 | getting_a_tattoo 196 | giving_or_receiving_award 197 | gold_panning 198 | golf_chipping 199 | golf_driving 200 | golf_putting 201 | gospel_singing_in_church 202 | grinding_meat 203 | grooming_dog 204 | grooming_horse 205 | gymnastics_tumbling 206 | hammer_throw 207 | hand_washing_clothes 208 | head_stand 209 | headbanging 210 | headbutting 211 | high_jump 212 | high_kick 213 | historical_reenactment 214 | hitting_baseball 215 | hockey_stop 216 | holding_snake 217 | home_roasting_coffee 218 | hopscotch 219 | hoverboarding 220 | huddling 221 | hugging_(not_baby) 222 | hugging_baby 223 | hula_hooping 224 | hurdling 225 | hurling_(sport) 226 | ice_climbing 227 | ice_fishing 228 | ice_skating 229 | ice_swimming 230 | inflating_balloons 231 | installing_carpet 232 | ironing_hair 233 | ironing 234 | javelin_throw 235 | jaywalking 236 | jetskiing 237 | jogging 238 | juggling_balls 239 | juggling_fire 240 | juggling_soccer_ball 241 | jumping_bicycle 242 | jumping_into_pool 243 | jumping_jacks 244 | jumpstyle_dancing 245 | karaoke 246 | kicking_field_goal 247 | kicking_soccer_ball 248 | kissing 249 | kitesurfing 250 | knitting 251 | krumping 252 | land_sailing 253 | laughing 254 | lawn_mower_racing 255 | laying_bricks 256 | laying_concrete 257 | laying_stone 258 | laying_tiles 259 | leatherworking 260 | licking 261 | lifting_hat 262 | lighting_fire 263 | lock_picking 264 | long_jump 265 | longboarding 266 | looking_at_phone 267 | luge 268 | lunge 269 | making_a_cake 270 | making_a_sandwich 271 | making_balloon_shapes 272 | making_bubbles 273 | making_cheese 274 | making_horseshoes 275 | making_jewelry 276 | making_paper_aeroplanes 277 | making_pizza 278 | making_snowman 279 | making_sushi 280 | making_tea 281 | making_the_bed 282 | marching 283 | marriage_proposal 284 | massaging_back 285 | massaging_feet 286 | massaging_legs 287 | massaging_neck 288 | massaging_person's_head 289 | milking_cow 290 | moon_walking 291 | mopping_floor 292 | mosh_pit_dancing 293 | motorcycling 294 | mountain_climber_(exercise) 295 | moving_furniture 296 | mowing_lawn 297 | mushroom_foraging 298 | needle_felting 299 | news_anchoring 300 | opening_bottle_(not_wine) 301 | opening_door 302 | opening_present 303 | opening_refrigerator 304 | opening_wine_bottle 305 | packing 306 | paragliding 307 | parasailing 308 | parkour 309 | passing_American_football_(in_game) 310 | passing_american_football_(not_in_game) 311 | passing_soccer_ball 312 | peeling_apples 313 | peeling_potatoes 314 | person_collecting_garbage 315 | petting_animal_(not_cat) 316 | petting_cat 317 | photobombing 318 | photocopying 319 | picking_fruit 320 | pillow_fight 321 | pinching 322 | pirouetting 323 | planing_wood 324 | planting_trees 325 | plastering 326 | playing_accordion 327 | playing_badminton 328 | playing_bagpipes 329 | playing_basketball 330 | playing_bass_guitar 331 | playing_beer_pong 332 | playing_blackjack 333 | playing_cello 334 | playing_chess 335 | playing_clarinet 336 | playing_controller 337 | playing_cricket 338 | playing_cymbals 339 | playing_darts 340 | playing_didgeridoo 341 | playing_dominoes 342 | playing_drums 343 | playing_field_hockey 344 | playing_flute 345 | playing_gong 346 | playing_guitar 347 | playing_hand_clapping_games 348 | playing_harmonica 349 | playing_harp 350 | playing_ice_hockey 351 | playing_keyboard 352 | playing_kickball 353 | playing_laser_tag 354 | playing_lute 355 | playing_maracas 356 | playing_marbles 357 | playing_monopoly 358 | playing_netball 359 | playing_ocarina 360 | playing_organ 361 | playing_paintball 362 | playing_pan_pipes 363 | playing_piano 364 | playing_pinball 365 | playing_ping_pong 366 | playing_poker 367 | playing_polo 368 | playing_recorder 369 | playing_rubiks_cube 370 | playing_saxophone 371 | playing_scrabble 372 | playing_squash_or_racquetball 373 | playing_tennis 374 | playing_trombone 375 | playing_trumpet 376 | playing_ukulele 377 | playing_violin 378 | playing_volleyball 379 | playing_with_trains 380 | playing_xylophone 381 | poking_bellybutton 382 | pole_vault 383 | polishing_metal 384 | popping_balloons 385 | pouring_beer 386 | preparing_salad 387 | presenting_weather_forecast 388 | pull_ups 389 | pumping_fist 390 | pumping_gas 391 | punching_bag 392 | punching_person_(boxing) 393 | push_up 394 | pushing_car 395 | pushing_cart 396 | pushing_wheelbarrow 397 | pushing_wheelchair 398 | putting_in_contact_lenses 399 | putting_on_eyeliner 400 | putting_on_foundation 401 | putting_on_lipstick 402 | putting_on_mascara 403 | putting_on_sari 404 | putting_on_shoes 405 | raising_eyebrows 406 | reading_book 407 | reading_newspaper 408 | recording_music 409 | repairing_puncture 410 | riding_a_bike 411 | riding_camel 412 | riding_elephant 413 | riding_mechanical_bull 414 | riding_mule 415 | riding_or_walking_with_horse 416 | riding_scooter 417 | riding_snow_blower 418 | riding_unicycle 419 | ripping_paper 420 | roasting_marshmallows 421 | roasting_pig 422 | robot_dancing 423 | rock_climbing 424 | rock_scissors_paper 425 | roller_skating 426 | rolling_pastry 427 | rope_pushdown 428 | running_on_treadmill 429 | sailing 430 | salsa_dancing 431 | sanding_floor 432 | sausage_making 433 | sawing_wood 434 | scrambling_eggs 435 | scrapbooking 436 | scrubbing_face 437 | scuba_diving 438 | separating_eggs 439 | setting_table 440 | sewing 441 | shaking_hands 442 | shaking_head 443 | shaping_bread_dough 444 | sharpening_knives 445 | sharpening_pencil 446 | shaving_head 447 | shaving_legs 448 | shearing_sheep 449 | shining_flashlight 450 | shining_shoes 451 | shooting_basketball 452 | shooting_goal_(soccer) 453 | shopping 454 | shot_put 455 | shoveling_snow 456 | shucking_oysters 457 | shuffling_cards 458 | shuffling_feet 459 | side_kick 460 | sign_language_interpreting 461 | singing 462 | sipping_cup 463 | situp 464 | skateboarding 465 | ski_jumping 466 | skiing_crosscountry 467 | skiing_mono 468 | skiing_slalom 469 | skipping_rope 470 | skipping_stone 471 | skydiving 472 | slacklining 473 | slapping 474 | sled_dog_racing 475 | sleeping 476 | smashing 477 | smelling_feet 478 | smoking_hookah 479 | smoking_pipe 480 | smoking 481 | snatch_weight_lifting 482 | sneezing 483 | snorkeling 484 | snowboarding 485 | snowkiting 486 | snowmobiling 487 | somersaulting 488 | spelunking 489 | spinning_poi 490 | spray_painting 491 | springboard_diving 492 | square_dancing 493 | squat 494 | standing_on_hands 495 | staring 496 | steer_roping 497 | sticking_tongue_out 498 | stomping_grapes 499 | stretching_arm 500 | stretching_leg 501 | sucking_lolly 502 | surfing_crowd 503 | surfing_water 504 | sweeping_floor 505 | swimming_backstroke 506 | swimming_breast_stroke 507 | swimming_butterfly_stroke 508 | swimming_front_crawl 509 | swing_dancing 510 | swinging_baseball_bat 511 | swinging_on_something 512 | sword_fighting 513 | sword_swallowing 514 | tackling 515 | tagging_graffiti 516 | tai_chi 517 | talking_on_cell_phone 518 | tango_dancing 519 | tap_dancing 520 | tapping_guitar 521 | tapping_pen 522 | tasting_beer 523 | tasting_food 524 | tasting_wine 525 | testifying 526 | texting 527 | threading_needle 528 | throwing_axe 529 | throwing_ball_(not_baseball_or_American_football) 530 | throwing_discus 531 | throwing_knife 532 | throwing_snowballs 533 | throwing_tantrum 534 | throwing_water_balloon 535 | tickling 536 | tie_dying 537 | tightrope_walking 538 | tiptoeing 539 | tobogganing 540 | tossing_coin 541 | training_dog 542 | trapezing 543 | trimming_or_shaving_beard 544 | trimming_shrubs 545 | trimming_trees 546 | triple_jump 547 | twiddling_fingers 548 | tying_bow_tie 549 | tying_knot_(not_on_a_tie) 550 | tying_necktie 551 | tying_shoe_laces 552 | unboxing 553 | unloading_truck 554 | using_a_microscope 555 | using_a_paint_roller 556 | using_a_power_drill 557 | using_a_sledge_hammer 558 | using_a_wrench 559 | using_atm 560 | using_bagging_machine 561 | using_circular_saw 562 | using_inhaler 563 | using_puppets 564 | using_remote_controller_(not_gaming) 565 | using_segway 566 | vacuuming_floor 567 | visiting_the_zoo 568 | wading_through_mud 569 | wading_through_water 570 | waiting_in_line 571 | waking_up 572 | walking_the_dog 573 | walking_through_snow 574 | washing_dishes 575 | washing_feet 576 | washing_hair 577 | washing_hands 578 | watching_tv 579 | water_skiing 580 | water_sliding 581 | watering_plants 582 | waving_hand 583 | waxing_back 584 | waxing_chest 585 | waxing_eyebrows 586 | waxing_legs 587 | weaving_basket 588 | weaving_fabric 589 | welding 590 | whistling 591 | windsurfing 592 | winking 593 | wood_burning_(art) 594 | wrapping_present 595 | wrestling 596 | writing 597 | yarn_spinning 598 | yawning 599 | yoga 600 | zumba 601 | -------------------------------------------------------------------------------- /pic/eccv_framework.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuchenUSTC/AherNet/d3a34053b419345b79f2e395d05d60ac5e1cac98/pic/eccv_framework.JPG -------------------------------------------------------------------------------- /tf_extended/__init__.py: -------------------------------------------------------------------------------- 1 | """TF Extended: additional metrics. 2 | """ 3 | 4 | # pylint: disable=unused-import,line-too-long,g-importing-member,wildcard-import 5 | from tf_extended.metrics import * 6 | from tf_extended.tensors import * 7 | from tf_extended.bboxes import * 8 | from tf_extended.math import * 9 | 10 | -------------------------------------------------------------------------------- /tf_extended/bboxes.py: -------------------------------------------------------------------------------- 1 | """TF Extended: additional bounding boxes methods. 2 | """ 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from tf_extended import tensors as tfe_tensors 7 | from tf_extended import math as tfe_math 8 | 9 | 10 | # =========================================================================== # 11 | # Standard boxes algorithms. 12 | # =========================================================================== # 13 | def bboxes_sort_all_classes(classes, scores, bboxes, top_k=400, scope=None): 14 | """Sort bounding boxes by decreasing order and keep only the top_k. 15 | Assume the input Tensors mix-up objects with different classes. 16 | Assume a batch-type input. 17 | 18 | Args: 19 | classes: Batch x N Tensor containing integer classes. 20 | scores: Batch x N Tensor containing float scores. 21 | bboxes: Batch x N x 4 Tensor containing boxes coordinates. 22 | top_k: Top_k boxes to keep. 23 | Return: 24 | classes, scores, bboxes: Sorted tensors of shape Batch x Top_k. 25 | """ 26 | with tf.name_scope(scope, 'bboxes_sort', [classes, scores, bboxes]): 27 | scores, idxes = tf.nn.top_k(scores, k=top_k, sorted=True) 28 | 29 | # Trick to be able to use tf.gather: map for each element in the batch. 30 | def fn_gather(classes, bboxes, idxes): 31 | cl = tf.gather(classes, idxes) 32 | bb = tf.gather(bboxes, idxes) 33 | return [cl, bb] 34 | r = tf.map_fn(lambda x: fn_gather(x[0], x[1], x[2]), 35 | [classes, bboxes, idxes], 36 | dtype=[classes.dtype, bboxes.dtype], 37 | parallel_iterations=10, 38 | back_prop=False, 39 | swap_memory=False, 40 | infer_shape=True) 41 | classes = r[0] 42 | bboxes = r[1] 43 | return classes, scores, bboxes 44 | 45 | def bboxes_sort(scores, bboxes, top_k=400, scope=None): 46 | """Sort bounding boxes by decreasing order and keep only the top_k. 47 | If inputs are dictionnaries, assume every key is a different class. 48 | Assume a batch-type input. 49 | 50 | Args: 51 | scores: Batch x N Tensor/Dictionary containing float scores. 52 | bboxes: Batch x N x 4 Tensor/Dictionary containing boxes coordinates. 53 | top_k: Top_k boxes to keep. 54 | Return: 55 | scores, bboxes: Sorted Tensors/Dictionaries of shape Batch x Top_k x 1|4. 56 | """ 57 | # Dictionaries as inputs. 58 | if isinstance(scores, dict) or isinstance(bboxes, dict): 59 | with tf.name_scope(scope, 'bboxes_sort_dict'): 60 | d_scores = {} 61 | d_bboxes = {} 62 | for c in scores.keys(): 63 | s, b = bboxes_sort(scores[c], bboxes[c], top_k=top_k) 64 | d_scores[c] = s 65 | d_bboxes[c] = b 66 | return d_scores, d_bboxes 67 | 68 | # Tensors inputs. 69 | with tf.name_scope(scope, 'bboxes_sort', [scores, bboxes]): 70 | # Sort scores... 71 | scores, idxes = tf.nn.top_k(scores, k=top_k, sorted=True) 72 | 73 | # Trick to be able to use tf.gather: map for each element in the first dim. 74 | def fn_gather(bboxes, idxes): 75 | bb = tf.gather(bboxes, idxes) 76 | return [bb] 77 | r = tf.map_fn(lambda x: fn_gather(x[0], x[1]), 78 | [bboxes, idxes], 79 | dtype=[bboxes.dtype], 80 | parallel_iterations=10, 81 | back_prop=False, 82 | swap_memory=False, 83 | infer_shape=True) 84 | bboxes = r[0] 85 | return scores, bboxes 86 | 87 | def bboxes_clip(bbox_ref, bboxes, scope=None): 88 | """Clip bounding boxes to a reference box. 89 | Batch-compatible if the first dimension of `bbox_ref` and `bboxes` 90 | can be broadcasted. 91 | 92 | Args: 93 | bbox_ref: Reference bounding box. Nx4 or 4 shaped-Tensor; 94 | bboxes: Bounding boxes to clip. Nx4 or 4 shaped-Tensor or dictionary. 95 | Return: 96 | Clipped bboxes. 97 | """ 98 | # Bboxes is dictionary. 99 | if isinstance(bboxes, dict): 100 | with tf.name_scope(scope, 'bboxes_clip_dict'): 101 | d_bboxes = {} 102 | for c in bboxes.keys(): 103 | d_bboxes[c] = bboxes_clip(bbox_ref, bboxes[c]) 104 | return d_bboxes 105 | 106 | # Tensors inputs. 107 | with tf.name_scope(scope, 'bboxes_clip'): 108 | # Easier with transposed bboxes. Especially for broadcasting. 109 | bbox_ref = tf.transpose(bbox_ref) 110 | bboxes = tf.transpose(bboxes) 111 | # Intersection bboxes and reference bbox. 112 | ymin = tf.maximum(bboxes[0], bbox_ref[0]) 113 | xmin = tf.maximum(bboxes[1], bbox_ref[1]) 114 | ymax = tf.minimum(bboxes[2], bbox_ref[2]) 115 | xmax = tf.minimum(bboxes[3], bbox_ref[3]) 116 | # Double check! Empty boxes when no-intersection. 117 | ymin = tf.minimum(ymin, ymax) 118 | xmin = tf.minimum(xmin, xmax) 119 | bboxes = tf.transpose(tf.stack([ymin, xmin, ymax, xmax], axis=0)) 120 | return bboxes 121 | 122 | def bboxes_resize(bbox_ref, bboxes, name=None): 123 | """Resize bounding boxes based on a reference bounding box, 124 | assuming that the latter is [0, 0, 1, 1] after transform. Useful for 125 | updating a collection of boxes after cropping an image. 126 | """ 127 | # Bboxes is dictionary. 128 | if isinstance(bboxes, dict): 129 | with tf.name_scope(name, 'bboxes_resize_dict'): 130 | d_bboxes = {} 131 | for c in bboxes.keys(): 132 | d_bboxes[c] = bboxes_resize(bbox_ref, bboxes[c]) 133 | return d_bboxes 134 | 135 | # Tensors inputs. 136 | with tf.name_scope(name, 'bboxes_resize'): 137 | # Translate. 138 | v = tf.stack([bbox_ref[0], bbox_ref[1], bbox_ref[0], bbox_ref[1]]) 139 | bboxes = bboxes - v 140 | # Scale. 141 | s = tf.stack([bbox_ref[2] - bbox_ref[0], 142 | bbox_ref[3] - bbox_ref[1], 143 | bbox_ref[2] - bbox_ref[0], 144 | bbox_ref[3] - bbox_ref[1]]) 145 | bboxes = bboxes / s 146 | return bboxes 147 | 148 | def bbox_1d_temporal_nms(bboxes, scores, keep_top_k, nms_threshold): 149 | t1 = bboxes[:,0] 150 | t2 = bboxes[:,1] 151 | score = scores 152 | union = t2 - t1 153 | order = tf.contrib.framework.argsort(score,direction='DESCENDING') 154 | keep = [] 155 | 156 | def condition(keep, order): 157 | r = tf.greater(tf.size(order),0) 158 | return r 159 | 160 | def body(keep, order): 161 | i = order[0] 162 | keep.append(i) 163 | tt1 = tf.maximum(t1[i],tf.gather(t1,order[1:])) 164 | tt2 = tf.minimum(t2[i],tf.gather(t2,order[1:])) 165 | w = tf.maximum(0.0,tt2-tt1) 166 | inte = w 167 | ovr = inte/(union[i]+tf.gather(union,order[1:]) - inte) 168 | inds = tf.where(ovr<=nms_threshold)[:,-1] 169 | order = tf.gather(order,inds+1) 170 | return keep, order 171 | 172 | #[keep, order] = tf.while_loop(condition, body, [keep, order]) 173 | topk = tf.minimum(len(keep), keep_top_k) 174 | return keep[:topk] 175 | 176 | def bboxes_nms(scores, bboxes, nms_threshold=0.5, keep_top_k=200, scope=None): 177 | """Apply non-maximum selection to bounding boxes. In comparison to TF 178 | implementation, use classes information for matching. 179 | Should only be used on single-entries. Use batch version otherwise. 180 | 181 | Args: 182 | scores: N Tensor containing float scores. 183 | bboxes: N x 4 Tensor containing boxes coordinates. 184 | nms_threshold: Matching threshold in NMS algorithm; 185 | keep_top_k: Number of total object to keep after NMS. 186 | Return: 187 | classes, scores, bboxes Tensors, sorted by score. 188 | Padded with zero if necessary. 189 | """ 190 | with tf.name_scope(scope, 'bboxes_nms_single', [scores, bboxes]): 191 | # Apply NMS algorithm. 192 | shape = (bboxes.shape[0], 4) 193 | boxxes_n = [] 194 | shape1= (bboxes.shape[0]) 195 | ones_m = tf.ones(shape1, dtype=tf.float32) 196 | zeros_m = tf.zeros(shape1,dtype=tf.float32) 197 | boxxes_n = tf.stack([bboxes[:,0],zeros_m,bboxes[:,1],ones_m],axis=1) 198 | 199 | idxes = tf.image.non_max_suppression(boxxes_n, scores, 200 | keep_top_k, nms_threshold) 201 | #idxes = bbox_1d_temporal_nms(bboxes, scores, keep_top_k, nms_threshold) 202 | scores = tf.gather(scores, idxes) 203 | bboxes = tf.gather(bboxes, idxes) 204 | # Pad results. 205 | scores = tfe_tensors.pad_axis(scores, 0, keep_top_k, axis=0) 206 | bboxes = tfe_tensors.pad_axis(bboxes, 0, keep_top_k, axis=0) 207 | return scores, bboxes 208 | 209 | def bboxes_nms_batch(scores, bboxes, nms_threshold=0.5, keep_top_k=200, 210 | scope=None): 211 | """Apply non-maximum selection to bounding boxes. In comparison to TF 212 | implementation, use classes information for matching. 213 | Use only on batched-inputs. Use zero-padding in order to batch output 214 | results. 215 | 216 | Args: 217 | scores: Batch x N Tensor/Dictionary containing float scores. 218 | bboxes: Batch x N x 4 Tensor/Dictionary containing boxes coordinates. 219 | nms_threshold: Matching threshold in NMS algorithm; 220 | keep_top_k: Number of total object to keep after NMS. 221 | Return: 222 | scores, bboxes Tensors/Dictionaries, sorted by score. 223 | Padded with zero if necessary. 224 | """ 225 | # Dictionaries as inputs. 226 | if isinstance(scores, dict) or isinstance(bboxes, dict): 227 | with tf.name_scope(scope, 'bboxes_nms_batch_dict'): 228 | d_scores = {} 229 | d_bboxes = {} 230 | for c in scores.keys(): 231 | s, b = bboxes_nms_batch(scores[c], bboxes[c], 232 | nms_threshold=nms_threshold, 233 | keep_top_k=keep_top_k) 234 | d_scores[c] = s 235 | d_bboxes[c] = b 236 | return d_scores, d_bboxes 237 | 238 | # Tensors inputs. 239 | with tf.name_scope(scope, 'bboxes_nms_batch'): 240 | r = tf.map_fn(lambda x: bboxes_nms(x[0], x[1], 241 | nms_threshold, keep_top_k), 242 | (scores, bboxes), 243 | dtype=(scores.dtype, bboxes.dtype), 244 | parallel_iterations=10, 245 | back_prop=False, 246 | swap_memory=False, 247 | infer_shape=True) 248 | scores, bboxes = r 249 | return scores, bboxes 250 | 251 | def bboxes_matching(label, scores, bboxes, 252 | glabels, gbboxes, gdifficults, 253 | matching_threshold=0.5, scope=None): 254 | """Matching a collection of detected boxes with groundtruth values. 255 | Does not accept batched-inputs. 256 | The algorithm goes as follows: for every detected box, check 257 | if one grountruth box is matching. If none, then considered as False Positive. 258 | If the grountruth box is already matched with another one, it also counts 259 | as a False Positive. We refer the Pascal VOC documentation for the details. 260 | 261 | Args: 262 | rclasses, rscores, rbboxes: N(x4) Tensors. Detected objects, sorted by score; 263 | glabels, gbboxes: Groundtruth bounding boxes. May be zero padded, hence 264 | zero-class objects are ignored. 265 | matching_threshold: Threshold for a positive match. 266 | Return: Tuple of: 267 | n_gbboxes: Scalar Tensor with number of groundtruth boxes (may difer from 268 | size because of zero padding). 269 | tp_match: (N,)-shaped boolean Tensor containing with True Positives. 270 | fp_match: (N,)-shaped boolean Tensor containing with False Positives. 271 | """ 272 | with tf.name_scope(scope, 'bboxes_matching_single', 273 | [scores, bboxes, glabels, gbboxes]): 274 | rsize = tf.size(scores) 275 | rshape = tf.shape(scores) 276 | rlabel = tf.cast(label, glabels.dtype) 277 | # Number of groundtruth boxes. 278 | gdifficults = tf.cast(gdifficults, tf.bool) 279 | n_gbboxes = tf.count_nonzero(tf.logical_and(tf.equal(glabels, label), 280 | tf.logical_not(gdifficults))) 281 | # Grountruth matching arrays. 282 | gmatch = tf.zeros(tf.shape(glabels), dtype=tf.bool) 283 | grange = tf.range(tf.size(glabels), dtype=tf.int32) 284 | # True/False positive matching TensorArrays. 285 | sdtype = tf.bool 286 | ta_tp_bool = tf.TensorArray(sdtype, size=rsize, dynamic_size=False, infer_shape=True) 287 | ta_fp_bool = tf.TensorArray(sdtype, size=rsize, dynamic_size=False, infer_shape=True) 288 | 289 | # Loop over returned objects. 290 | def m_condition(i, ta_tp, ta_fp, gmatch): 291 | r = tf.less(i, rsize) 292 | return r 293 | 294 | def m_body(i, ta_tp, ta_fp, gmatch): 295 | # Jaccard score with groundtruth bboxes. 296 | rbbox = bboxes[i] 297 | jaccard = bboxes_jaccard(rbbox, gbboxes) 298 | jaccard = jaccard * tf.cast(tf.equal(glabels, rlabel), dtype=jaccard.dtype) 299 | 300 | # Best fit, checking it's above threshold. 301 | idxmax = tf.cast(tf.argmax(jaccard, axis=0), tf.int32) 302 | jcdmax = jaccard[idxmax] 303 | match = jcdmax > matching_threshold 304 | existing_match = gmatch[idxmax] 305 | not_difficult = tf.logical_not(gdifficults[idxmax]) 306 | 307 | # TP: match & no previous match and FP: previous match | no match. 308 | # If difficult: no record, i.e FP=False and TP=False. 309 | tp = tf.logical_and(not_difficult, 310 | tf.logical_and(match, tf.logical_not(existing_match))) 311 | ta_tp = ta_tp.write(i, tp) 312 | fp = tf.logical_and(not_difficult, 313 | tf.logical_or(existing_match, tf.logical_not(match))) 314 | ta_fp = ta_fp.write(i, fp) 315 | # Update grountruth match. 316 | mask = tf.logical_and(tf.equal(grange, idxmax), 317 | tf.logical_and(not_difficult, match)) 318 | gmatch = tf.logical_or(gmatch, mask) 319 | 320 | return [i+1, ta_tp, ta_fp, gmatch] 321 | # Main loop definition. 322 | i = 0 323 | [i, ta_tp_bool, ta_fp_bool, gmatch] = \ 324 | tf.while_loop(m_condition, m_body, 325 | [i, ta_tp_bool, ta_fp_bool, gmatch], 326 | parallel_iterations=1, 327 | back_prop=False) 328 | # TensorArrays to Tensors and reshape. 329 | tp_match = tf.reshape(ta_tp_bool.stack(), rshape) 330 | fp_match = tf.reshape(ta_fp_bool.stack(), rshape) 331 | 332 | # Some debugging information... 333 | # tp_match = tf.Print(tp_match, 334 | # [n_gbboxes, 335 | # tf.reduce_sum(tf.cast(tp_match, tf.int64)), 336 | # tf.reduce_sum(tf.cast(fp_match, tf.int64)), 337 | # tf.reduce_sum(tf.cast(gmatch, tf.int64))], 338 | # 'Matching (NG, TP, FP, GM): ') 339 | return n_gbboxes, tp_match, fp_match 340 | 341 | def bboxes_matching_batch(labels, scores, bboxes, 342 | glabels, gbboxes, gdifficults, 343 | matching_threshold=0.5, scope=None): 344 | """Matching a collection of detected boxes with groundtruth values. 345 | Batched-inputs version. 346 | 347 | Args: 348 | rclasses, rscores, rbboxes: BxN(x4) Tensors. Detected objects, sorted by score; 349 | glabels, gbboxes: Groundtruth bounding boxes. May be zero padded, hence 350 | zero-class objects are ignored. 351 | matching_threshold: Threshold for a positive match. 352 | Return: Tuple or Dictionaries with: 353 | n_gbboxes: Scalar Tensor with number of groundtruth boxes (may difer from 354 | size because of zero padding). 355 | tp: (B, N)-shaped boolean Tensor containing with True Positives. 356 | fp: (B, N)-shaped boolean Tensor containing with False Positives. 357 | """ 358 | # Dictionaries as inputs. 359 | if isinstance(scores, dict) or isinstance(bboxes, dict): 360 | with tf.name_scope(scope, 'bboxes_matching_batch_dict'): 361 | d_n_gbboxes = {} 362 | d_tp = {} 363 | d_fp = {} 364 | for c in labels: 365 | n, tp, fp, _ = bboxes_matching_batch(c, scores[c], bboxes[c], 366 | glabels, gbboxes, gdifficults, 367 | matching_threshold) 368 | d_n_gbboxes[c] = n 369 | d_tp[c] = tp 370 | d_fp[c] = fp 371 | return d_n_gbboxes, d_tp, d_fp, scores 372 | 373 | with tf.name_scope(scope, 'bboxes_matching_batch', 374 | [scores, bboxes, glabels, gbboxes]): 375 | r = tf.map_fn(lambda x: bboxes_matching(labels, x[0], x[1], 376 | x[2], x[3], x[4], 377 | matching_threshold), 378 | (scores, bboxes, glabels, gbboxes, gdifficults), 379 | dtype=(tf.int64, tf.bool, tf.bool), 380 | parallel_iterations=10, 381 | back_prop=False, 382 | swap_memory=True, 383 | infer_shape=True) 384 | return r[0], r[1], r[2], scores 385 | 386 | 387 | # =========================================================================== # 388 | # Some filteting methods. 389 | # =========================================================================== # 390 | def bboxes_filter_center(labels, bboxes, margins=[0., 0., 0., 0.], 391 | scope=None): 392 | """Filter out bounding boxes whose center are not in 393 | the rectangle [0, 0, 1, 1] + margins. The margin Tensor 394 | can be used to enforce or loosen this condition. 395 | 396 | Return: 397 | labels, bboxes: Filtered elements. 398 | """ 399 | with tf.name_scope(scope, 'bboxes_filter', [labels, bboxes]): 400 | cy = (bboxes[:, 0] + bboxes[:, 2]) / 2. 401 | cx = (bboxes[:, 1] + bboxes[:, 3]) / 2. 402 | mask = tf.greater(cy, margins[0]) 403 | mask = tf.logical_and(mask, tf.greater(cx, margins[1])) 404 | mask = tf.logical_and(mask, tf.less(cx, 1. + margins[2])) 405 | mask = tf.logical_and(mask, tf.less(cx, 1. + margins[3])) 406 | # Boolean masking... 407 | labels = tf.boolean_mask(labels, mask) 408 | bboxes = tf.boolean_mask(bboxes, mask) 409 | return labels, bboxes 410 | 411 | def bboxes_filter_overlap(labels, bboxes, 412 | threshold=0.5, assign_negative=False, 413 | scope=None): 414 | """Filter out bounding boxes based on (relative )overlap with reference 415 | box [0, 0, 1, 1]. Remove completely bounding boxes, or assign negative 416 | labels to the one outside (useful for latter processing...). 417 | 418 | Return: 419 | labels, bboxes: Filtered (or newly assigned) elements. 420 | """ 421 | with tf.name_scope(scope, 'bboxes_filter', [labels, bboxes]): 422 | scores = bboxes_intersection(tf.constant([0, 0, 1, 1], bboxes.dtype), 423 | bboxes) 424 | mask = scores > threshold 425 | if assign_negative: 426 | labels = tf.where(mask, labels, -labels) 427 | # bboxes = tf.where(mask, bboxes, bboxes) 428 | else: 429 | labels = tf.boolean_mask(labels, mask) 430 | bboxes = tf.boolean_mask(bboxes, mask) 431 | return labels, bboxes 432 | 433 | def bboxes_filter_labels(labels, bboxes, 434 | out_labels=[], num_classes=np.inf, 435 | scope=None): 436 | """Filter out labels from a collection. Typically used to get 437 | of DontCare elements. Also remove elements based on the number of classes. 438 | 439 | Return: 440 | labels, bboxes: Filtered elements. 441 | """ 442 | with tf.name_scope(scope, 'bboxes_filter_labels', [labels, bboxes]): 443 | mask = tf.greater_equal(labels, num_classes) 444 | for l in labels: 445 | mask = tf.logical_and(mask, tf.not_equal(labels, l)) 446 | labels = tf.boolean_mask(labels, mask) 447 | bboxes = tf.boolean_mask(bboxes, mask) 448 | return labels, bboxes 449 | 450 | # =========================================================================== # 451 | # Standard boxes computation. 452 | # =========================================================================== # 453 | def bboxes_jaccard(bbox_ref, bboxes, name=None): 454 | """Compute jaccard score between a reference box and a collection 455 | of bounding boxes. 456 | 457 | Args: 458 | bbox_ref: (N, 4) or (4,) Tensor with reference bounding box(es). 459 | bboxes: (N, 4) Tensor, collection of bounding boxes. 460 | Return: 461 | (N,) Tensor with Jaccard scores. 462 | """ 463 | with tf.name_scope(name, 'bboxes_jaccard'): 464 | # Should be more efficient to first transpose. 465 | bboxes = tf.transpose(bboxes) 466 | bbox_ref = tf.transpose(bbox_ref) 467 | # Intersection bbox and volume. 468 | int_ymin = tf.maximum(bboxes[0], bbox_ref[0]) 469 | int_xmin = tf.maximum(bboxes[1], bbox_ref[1]) 470 | int_ymax = tf.minimum(bboxes[2], bbox_ref[2]) 471 | int_xmax = tf.minimum(bboxes[3], bbox_ref[3]) 472 | h = tf.maximum(int_ymax - int_ymin, 0.) 473 | w = tf.maximum(int_xmax - int_xmin, 0.) 474 | # Volumes. 475 | inter_vol = h * w 476 | union_vol = -inter_vol \ 477 | + (bboxes[2] - bboxes[0]) * (bboxes[3] - bboxes[1]) \ 478 | + (bbox_ref[2] - bbox_ref[0]) * (bbox_ref[3] - bbox_ref[1]) 479 | jaccard = tfe_math.safe_divide(inter_vol, union_vol, 'jaccard') 480 | return jaccard 481 | 482 | def bboxes_intersection(bbox_ref, bboxes, name=None): 483 | """Compute relative intersection between a reference box and a 484 | collection of bounding boxes. Namely, compute the quotient between 485 | intersection area and box area. 486 | 487 | Args: 488 | bbox_ref: (N, 4) or (4,) Tensor with reference bounding box(es). 489 | bboxes: (N, 4) Tensor, collection of bounding boxes. 490 | Return: 491 | (N,) Tensor with relative intersection. 492 | """ 493 | with tf.name_scope(name, 'bboxes_intersection'): 494 | # Should be more efficient to first transpose. 495 | bboxes = tf.transpose(bboxes) 496 | bbox_ref = tf.transpose(bbox_ref) 497 | # Intersection bbox and volume. 498 | int_ymin = tf.maximum(bboxes[0], bbox_ref[0]) 499 | int_xmin = tf.maximum(bboxes[1], bbox_ref[1]) 500 | int_ymax = tf.minimum(bboxes[2], bbox_ref[2]) 501 | int_xmax = tf.minimum(bboxes[3], bbox_ref[3]) 502 | h = tf.maximum(int_ymax - int_ymin, 0.) 503 | w = tf.maximum(int_xmax - int_xmin, 0.) 504 | # Volumes. 505 | inter_vol = h * w 506 | bboxes_vol = (bboxes[2] - bboxes[0]) * (bboxes[3] - bboxes[1]) 507 | scores = tfe_math.safe_divide(inter_vol, bboxes_vol, 'intersection') 508 | return scores 509 | -------------------------------------------------------------------------------- /tf_extended/math.py: -------------------------------------------------------------------------------- 1 | """TF Extended: additional math functions. 2 | """ 3 | import tensorflow as tf 4 | 5 | from tensorflow.python.ops import array_ops 6 | from tensorflow.python.ops import math_ops 7 | from tensorflow.python.framework import dtypes 8 | from tensorflow.python.framework import ops 9 | 10 | def safe_divide(numerator, denominator, name): 11 | """Divides two values, returning 0 if the denominator is <= 0. 12 | Args: 13 | numerator: A real `Tensor`. 14 | denominator: A real `Tensor`, with dtype matching `numerator`. 15 | name: Name for the returned op. 16 | Returns: 17 | 0 if `denominator` <= 0, else `numerator` / `denominator` 18 | """ 19 | return tf.where( 20 | math_ops.greater(denominator, 0), 21 | math_ops.divide(numerator, denominator), 22 | tf.zeros_like(numerator), 23 | name=name) 24 | 25 | def cummax(x, reverse=False, name=None): 26 | """Compute the cumulative maximum of the tensor `x` along `axis`. This 27 | operation is similar to the more classic `cumsum`. Only support 1D Tensor 28 | for now. 29 | 30 | Args: 31 | x: A `Tensor`. Must be one of the following types: `float32`, `float64`, 32 | `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, 33 | `complex128`, `qint8`, `quint8`, `qint32`, `half`. 34 | axis: A `Tensor` of type `int32` (default: 0). 35 | reverse: A `bool` (default: False). 36 | name: A name for the operation (optional). 37 | Returns: 38 | A `Tensor`. Has the same type as `x`. 39 | """ 40 | with ops.name_scope(name, "Cummax", [x]) as name: 41 | x = ops.convert_to_tensor(x, name="x") 42 | # Not very optimal: should directly integrate reverse into tf.scan. 43 | if reverse: 44 | x = tf.reverse(x, axis=[0]) 45 | # 'Accumlating' maximum: ensure it is always increasing. 46 | cmax = tf.scan(lambda a, y: tf.maximum(a, y), x, 47 | initializer=None, parallel_iterations=1, 48 | back_prop=False, swap_memory=False) 49 | if reverse: 50 | cmax = tf.reverse(cmax, axis=[0]) 51 | return cmax 52 | -------------------------------------------------------------------------------- /tf_extended/metrics.py: -------------------------------------------------------------------------------- 1 | """TF Extended: additional metrics. 2 | """ 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | from tensorflow.contrib.framework.python.ops import variables as contrib_variables 7 | from tensorflow.python.framework import dtypes 8 | from tensorflow.python.framework import ops 9 | from tensorflow.python.ops import array_ops 10 | from tensorflow.python.ops import math_ops 11 | from tensorflow.python.ops import nn 12 | from tensorflow.python.ops import state_ops 13 | from tensorflow.python.ops import variable_scope 14 | from tensorflow.python.ops import variables 15 | 16 | from tf_extended import math as tfe_math 17 | 18 | 19 | # =========================================================================== # 20 | # TensorFlow utils 21 | # =========================================================================== # 22 | def _create_local(name, shape, collections=None, validate_shape=True, 23 | dtype=dtypes.float32): 24 | """Creates a new local variable. 25 | Args: 26 | name: The name of the new or existing variable. 27 | shape: Shape of the new or existing variable. 28 | collections: A list of collection names to which the Variable will be added. 29 | validate_shape: Whether to validate the shape of the variable. 30 | dtype: Data type of the variables. 31 | Returns: 32 | The created variable. 33 | """ 34 | # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES 35 | collections = list(collections or []) 36 | collections += [ops.GraphKeys.LOCAL_VARIABLES] 37 | return variables.Variable( 38 | initial_value=array_ops.zeros(shape, dtype=dtype), 39 | name=name, 40 | trainable=False, 41 | collections=collections, 42 | validate_shape=validate_shape) 43 | 44 | def _safe_div(numerator, denominator, name): 45 | """Divides two values, returning 0 if the denominator is <= 0. 46 | Args: 47 | numerator: A real `Tensor`. 48 | denominator: A real `Tensor`, with dtype matching `numerator`. 49 | name: Name for the returned op. 50 | Returns: 51 | 0 if `denominator` <= 0, else `numerator` / `denominator` 52 | """ 53 | return tf.where( 54 | math_ops.greater(denominator, 0), 55 | math_ops.divide(numerator, denominator), 56 | tf.zeros_like(numerator), 57 | name=name) 58 | 59 | def _broadcast_weights(weights, values): 60 | """Broadcast `weights` to the same shape as `values`. 61 | This returns a version of `weights` following the same broadcast rules as 62 | `mul(weights, values)`. When computing a weighted average, use this function 63 | to broadcast `weights` before summing them; e.g., 64 | `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`. 65 | Args: 66 | weights: `Tensor` whose shape is broadcastable to `values`. 67 | values: `Tensor` of any shape. 68 | Returns: 69 | `weights` broadcast to `values` shape. 70 | """ 71 | weights_shape = weights.get_shape() 72 | values_shape = values.get_shape() 73 | if(weights_shape.is_fully_defined() and 74 | values_shape.is_fully_defined() and 75 | weights_shape.is_compatible_with(values_shape)): 76 | return weights 77 | return math_ops.mul( 78 | weights, array_ops.ones_like(values), name='broadcast_weights') 79 | 80 | # =========================================================================== # 81 | # TF Extended metrics: TP and FP arrays. 82 | # =========================================================================== # 83 | def precision_recall(num_gbboxes, num_detections, tp, fp, scores, 84 | dtype=tf.float64, scope=None): 85 | """Compute precision and recall from scores, true positives and false 86 | positives booleans arrays 87 | """ 88 | # Input dictionaries: dict outputs as streaming metrics. 89 | if isinstance(scores, dict): 90 | d_precision = {} 91 | d_recall = {} 92 | for c in num_gbboxes.keys(): 93 | scope = 'precision_recall_%s' % c 94 | p, r = precision_recall(num_gbboxes[c], num_detections[c], 95 | tp[c], fp[c], scores[c], 96 | dtype, scope) 97 | d_precision[c] = p 98 | d_recall[c] = r 99 | return d_precision, d_recall 100 | 101 | # Sort by score. 102 | with tf.name_scope(scope, 'precision_recall', 103 | [num_gbboxes, num_detections, tp, fp, scores]): 104 | # Sort detections by score. 105 | scores, idxes = tf.nn.top_k(scores, k=num_detections, sorted=True) 106 | tp = tf.gather(tp, idxes) 107 | fp = tf.gather(fp, idxes) 108 | # Computer recall and precision. 109 | tp = tf.cumsum(tf.cast(tp, dtype), axis=0) 110 | fp = tf.cumsum(tf.cast(fp, dtype), axis=0) 111 | recall = _safe_div(tp, tf.cast(num_gbboxes, dtype), 'recall') 112 | precision = _safe_div(tp, tp + fp, 'precision') 113 | return tf.tuple([precision, recall]) 114 | 115 | def streaming_tp_fp_arrays(num_gbboxes, tp, fp, scores, 116 | remove_zero_scores=True, 117 | metrics_collections=None, 118 | updates_collections=None, 119 | name=None): 120 | """Streaming computation of True and False Positive arrays. This metrics 121 | also keeps track of scores and number of grountruth objects. 122 | """ 123 | # Input dictionaries: dict outputs as streaming metrics. 124 | if isinstance(scores, dict) or isinstance(fp, dict): 125 | d_values = {} 126 | d_update_ops = {} 127 | for c in num_gbboxes.keys(): 128 | scope = 'streaming_tp_fp_%s' % c 129 | v, up = streaming_tp_fp_arrays(num_gbboxes[c], tp[c], fp[c], scores[c], 130 | remove_zero_scores, 131 | metrics_collections, 132 | updates_collections, 133 | name=scope) 134 | d_values[c] = v 135 | d_update_ops[c] = up 136 | return d_values, d_update_ops 137 | 138 | # Input Tensors... 139 | with variable_scope.variable_scope(name, 'streaming_tp_fp', 140 | [num_gbboxes, tp, fp, scores]): 141 | num_gbboxes = math_ops.to_int64(num_gbboxes) 142 | scores = math_ops.to_float(scores) 143 | stype = tf.bool 144 | tp = tf.cast(tp, stype) 145 | fp = tf.cast(fp, stype) 146 | # Reshape TP and FP tensors and clean away 0 class values. 147 | scores = tf.reshape(scores, [-1]) 148 | tp = tf.reshape(tp, [-1]) 149 | fp = tf.reshape(fp, [-1]) 150 | # Remove TP and FP both false. 151 | mask = tf.logical_or(tp, fp) 152 | if remove_zero_scores: 153 | rm_threshold = 1e-4 154 | mask = tf.logical_and(mask, tf.greater(scores, rm_threshold)) 155 | scores = tf.boolean_mask(scores, mask) 156 | tp = tf.boolean_mask(tp, mask) 157 | fp = tf.boolean_mask(fp, mask) 158 | 159 | # Local variables accumlating information over batches. 160 | v_nobjects = _create_local('v_num_gbboxes', shape=[], dtype=tf.int64) 161 | v_ndetections = _create_local('v_num_detections', shape=[], dtype=tf.int32) 162 | v_scores = _create_local('v_scores', shape=[0, ]) 163 | v_tp = _create_local('v_tp', shape=[0, ], dtype=stype) 164 | v_fp = _create_local('v_fp', shape=[0, ], dtype=stype) 165 | 166 | # Update operations. 167 | nobjects_op = state_ops.assign_add(v_nobjects, 168 | tf.reduce_sum(num_gbboxes)) 169 | ndetections_op = state_ops.assign_add(v_ndetections, 170 | tf.size(scores, out_type=tf.int32)) 171 | scores_op = state_ops.assign(v_scores, tf.concat([v_scores, scores], axis=0), 172 | validate_shape=False) 173 | tp_op = state_ops.assign(v_tp, tf.concat([v_tp, tp], axis=0), 174 | validate_shape=False) 175 | fp_op = state_ops.assign(v_fp, tf.concat([v_fp, fp], axis=0), 176 | validate_shape=False) 177 | 178 | # Value and update ops. 179 | val = (v_nobjects, v_ndetections, v_tp, v_fp, v_scores) 180 | with ops.control_dependencies([nobjects_op, ndetections_op, 181 | scores_op, tp_op, fp_op]): 182 | update_op = (nobjects_op, ndetections_op, tp_op, fp_op, scores_op) 183 | 184 | if metrics_collections: 185 | ops.add_to_collections(metrics_collections, val) 186 | if updates_collections: 187 | ops.add_to_collections(updates_collections, update_op) 188 | return val, update_op 189 | 190 | # =========================================================================== # 191 | # Average precision computations. 192 | # =========================================================================== # 193 | def average_precision_voc12(precision, recall, name=None): 194 | """Compute (interpolated) average precision from precision and recall Tensors. 195 | 196 | The implementation follows Pascal 2012 and ILSVRC guidelines. 197 | See also: https://sanchom.wordpress.com/tag/average-precision/ 198 | """ 199 | with tf.name_scope(name, 'average_precision_voc12', [precision, recall]): 200 | # Convert to float64 to decrease error on Riemann sums. 201 | precision = tf.cast(precision, dtype=tf.float64) 202 | recall = tf.cast(recall, dtype=tf.float64) 203 | 204 | # Add bounds values to precision and recall. 205 | precision = tf.concat([[0.], precision, [0.]], axis=0) 206 | recall = tf.concat([[0.], recall, [1.]], axis=0) 207 | # Ensures precision is increasing in reverse order. 208 | precision = tfe_math.cummax(precision, reverse=True) 209 | 210 | # Riemann sums for estimating the integral. 211 | # mean_pre = (precision[1:] + precision[:-1]) / 2. 212 | mean_pre = precision[1:] 213 | diff_rec = recall[1:] - recall[:-1] 214 | ap = tf.reduce_sum(mean_pre * diff_rec) 215 | return ap 216 | 217 | def average_precision_voc07(precision, recall, name=None): 218 | """Compute (interpolated) average precision from precision and recall Tensors. 219 | 220 | The implementation follows Pascal 2007 guidelines. 221 | See also: https://sanchom.wordpress.com/tag/average-precision/ 222 | """ 223 | with tf.name_scope(name, 'average_precision_voc07', [precision, recall]): 224 | # Convert to float64 to decrease error on cumulated sums. 225 | precision = tf.cast(precision, dtype=tf.float64) 226 | recall = tf.cast(recall, dtype=tf.float64) 227 | # Add zero-limit value to avoid any boundary problem... 228 | precision = tf.concat([precision, [0.]], axis=0) 229 | recall = tf.concat([recall, [np.inf]], axis=0) 230 | 231 | # Split the integral into 10 bins. 232 | l_aps = [] 233 | for t in np.arange(0., 1.1, 0.1): 234 | mask = tf.greater_equal(recall, t) 235 | v = tf.reduce_max(tf.boolean_mask(precision, mask)) 236 | l_aps.append(v / 11.) 237 | ap = tf.add_n(l_aps) 238 | return ap 239 | 240 | def precision_recall_values(xvals, precision, recall, name=None): 241 | """Compute values on the precision/recall curve. 242 | 243 | Args: 244 | x: Python list of floats; 245 | precision: 1D Tensor decreasing. 246 | recall: 1D Tensor increasing. 247 | Return: 248 | list of precision values. 249 | """ 250 | with ops.name_scope(name, "precision_recall_values", 251 | [precision, recall]) as name: 252 | # Add bounds values to precision and recall. 253 | precision = tf.concat([[0.], precision, [0.]], axis=0) 254 | recall = tf.concat([[0.], recall, [1.]], axis=0) 255 | precision = tfe_math.cummax(precision, reverse=True) 256 | 257 | prec_values = [] 258 | for x in xvals: 259 | mask = tf.less_equal(recall, x) 260 | val = tf.reduce_min(tf.boolean_mask(precision, mask)) 261 | prec_values.append(val) 262 | return tf.tuple(prec_values) 263 | 264 | # =========================================================================== # 265 | # TF Extended metrics: old stuff! 266 | # =========================================================================== # 267 | def _precision_recall(n_gbboxes, n_detections, scores, tp, fp, scope=None): 268 | """Compute precision and recall from scores, true positives and false 269 | positives booleans arrays 270 | """ 271 | # Sort by score. 272 | with tf.name_scope(scope, 'prec_rec', [n_gbboxes, scores, tp, fp]): 273 | # Sort detections by score. 274 | scores, idxes = tf.nn.top_k(scores, k=n_detections, sorted=True) 275 | tp = tf.gather(tp, idxes) 276 | fp = tf.gather(fp, idxes) 277 | # Computer recall and precision. 278 | dtype = tf.float64 279 | tp = tf.cumsum(tf.cast(tp, dtype), axis=0) 280 | fp = tf.cumsum(tf.cast(fp, dtype), axis=0) 281 | recall = _safe_div(tp, tf.cast(n_gbboxes, dtype), 'recall') 282 | precision = _safe_div(tp, tp + fp, 'precision') 283 | 284 | return tf.tuple([precision, recall]) 285 | 286 | 287 | def streaming_precision_recall_arrays(n_gbboxes, rclasses, rscores, 288 | tp_tensor, fp_tensor, 289 | remove_zero_labels=True, 290 | metrics_collections=None, 291 | updates_collections=None, 292 | name=None): 293 | """Streaming computation of precision / recall arrays. This metrics 294 | keeps tracks of boolean True positives and False positives arrays. 295 | """ 296 | with variable_scope.variable_scope(name, 'stream_precision_recall', 297 | [n_gbboxes, rclasses, tp_tensor, fp_tensor]): 298 | n_gbboxes = math_ops.to_int64(n_gbboxes) 299 | rclasses = math_ops.to_int64(rclasses) 300 | rscores = math_ops.to_float(rscores) 301 | 302 | stype = tf.int32 303 | tp_tensor = tf.cast(tp_tensor, stype) 304 | fp_tensor = tf.cast(fp_tensor, stype) 305 | 306 | # Reshape TP and FP tensors and clean away 0 class values. 307 | rclasses = tf.reshape(rclasses, [-1]) 308 | rscores = tf.reshape(rscores, [-1]) 309 | tp_tensor = tf.reshape(tp_tensor, [-1]) 310 | fp_tensor = tf.reshape(fp_tensor, [-1]) 311 | if remove_zero_labels: 312 | mask = tf.greater(rclasses, 0) 313 | rclasses = tf.boolean_mask(rclasses, mask) 314 | rscores = tf.boolean_mask(rscores, mask) 315 | tp_tensor = tf.boolean_mask(tp_tensor, mask) 316 | fp_tensor = tf.boolean_mask(fp_tensor, mask) 317 | 318 | # Local variables accumlating information over batches. 319 | v_nobjects = _create_local('v_nobjects', shape=[], dtype=tf.int64) 320 | v_ndetections = _create_local('v_ndetections', shape=[], dtype=tf.int32) 321 | v_scores = _create_local('v_scores', shape=[0, ]) 322 | v_tp = _create_local('v_tp', shape=[0, ], dtype=stype) 323 | v_fp = _create_local('v_fp', shape=[0, ], dtype=stype) 324 | 325 | # Update operations. 326 | nobjects_op = state_ops.assign_add(v_nobjects, 327 | tf.reduce_sum(n_gbboxes)) 328 | ndetections_op = state_ops.assign_add(v_ndetections, 329 | tf.size(rscores, out_type=tf.int32)) 330 | scores_op = state_ops.assign(v_scores, tf.concat([v_scores, rscores], axis=0), 331 | validate_shape=False) 332 | tp_op = state_ops.assign(v_tp, tf.concat([v_tp, tp_tensor], axis=0), 333 | validate_shape=False) 334 | fp_op = state_ops.assign(v_fp, tf.concat([v_fp, fp_tensor], axis=0), 335 | validate_shape=False) 336 | 337 | # Precision and recall computations. 338 | # r = _precision_recall(nobjects_op, scores_op, tp_op, fp_op, 'value') 339 | r = _precision_recall(v_nobjects, v_ndetections, v_scores, 340 | v_tp, v_fp, 'value') 341 | 342 | with ops.control_dependencies([nobjects_op, ndetections_op, 343 | scores_op, tp_op, fp_op]): 344 | update_op = _precision_recall(nobjects_op, ndetections_op, 345 | scores_op, tp_op, fp_op, 'update_op') 346 | 347 | # update_op = tf.Print(update_op, 348 | # [tf.reduce_sum(tf.cast(mask, tf.int64)), 349 | # tf.reduce_sum(tf.cast(mask2, tf.int64)), 350 | # tf.reduce_min(rscores), 351 | # tf.reduce_sum(n_gbboxes)], 352 | # 'Metric: ') 353 | # Some debugging stuff! 354 | # update_op = tf.Print(update_op, 355 | # [tf.shape(tp_op), 356 | # tf.reduce_sum(tf.cast(tp_op, tf.int64), axis=0)], 357 | # 'TP and FP shape: ') 358 | # update_op[0] = tf.Print(update_op, 359 | # [nobjects_op], 360 | # '# Groundtruth bboxes: ') 361 | # update_op = tf.Print(update_op, 362 | # [update_op[0][0], 363 | # update_op[0][-1], 364 | # tf.reduce_min(update_op[0]), 365 | # tf.reduce_max(update_op[0]), 366 | # tf.reduce_min(update_op[1]), 367 | # tf.reduce_max(update_op[1])], 368 | # 'Precision and recall :') 369 | 370 | if metrics_collections: 371 | ops.add_to_collections(metrics_collections, r) 372 | if updates_collections: 373 | ops.add_to_collections(updates_collections, update_op) 374 | return r, update_op 375 | 376 | -------------------------------------------------------------------------------- /tf_extended/tensors.py: -------------------------------------------------------------------------------- 1 | """TF Extended: additional tensors operations. 2 | """ 3 | import tensorflow as tf 4 | 5 | from tensorflow.contrib.framework.python.ops import variables as contrib_variables 6 | from tensorflow.contrib.metrics.python.ops import set_ops 7 | from tensorflow.python.framework import dtypes 8 | from tensorflow.python.framework import ops 9 | from tensorflow.python.framework import sparse_tensor 10 | from tensorflow.python.ops import array_ops 11 | from tensorflow.python.ops import check_ops 12 | from tensorflow.python.ops import control_flow_ops 13 | from tensorflow.python.ops import math_ops 14 | from tensorflow.python.ops import nn 15 | from tensorflow.python.ops import state_ops 16 | from tensorflow.python.ops import variable_scope 17 | from tensorflow.python.ops import variables 18 | 19 | 20 | def get_shape(x, rank=None): 21 | """Returns the dimensions of a Tensor as list of integers or scale tensors. 22 | 23 | Args: 24 | x: N-d Tensor; 25 | rank: Rank of the Tensor. If None, will try to guess it. 26 | Returns: 27 | A list of `[d1, d2, ..., dN]` corresponding to the dimensions of the 28 | input tensor. Dimensions that are statically known are python integers, 29 | otherwise they are integer scalar tensors. 30 | """ 31 | if x.get_shape().is_fully_defined(): 32 | return x.get_shape().as_list() 33 | else: 34 | static_shape = x.get_shape() 35 | if rank is None: 36 | static_shape = static_shape.as_list() 37 | rank = len(static_shape) 38 | else: 39 | static_shape = x.get_shape().with_rank(rank).as_list() 40 | dynamic_shape = tf.unstack(tf.shape(x), rank) 41 | return [s if s is not None else d 42 | for s, d in zip(static_shape, dynamic_shape)] 43 | 44 | def pad_axis(x, offset, size, axis=0, name=None): 45 | """Pad a tensor on an axis, with a given offset and output size. 46 | The tensor is padded with zero (i.e. CONSTANT mode). Note that the if the 47 | `size` is smaller than existing size + `offset`, the output tensor 48 | was the latter dimension. 49 | 50 | Args: 51 | x: Tensor to pad; 52 | offset: Offset to add on the dimension chosen; 53 | size: Final size of the dimension. 54 | Return: 55 | Padded tensor whose dimension on `axis` is `size`, or greater if 56 | the input vector was larger. 57 | """ 58 | with tf.name_scope(name, 'pad_axis'): 59 | shape = get_shape(x) 60 | rank = len(shape) 61 | # Padding description. 62 | new_size = tf.maximum(size-offset-shape[axis], 0) 63 | pad1 = tf.stack([0]*axis + [offset] + [0]*(rank-axis-1)) 64 | pad2 = tf.stack([0]*axis + [new_size] + [0]*(rank-axis-1)) 65 | paddings = tf.stack([pad1, pad2], axis=1) 66 | x = tf.pad(x, paddings, mode='CONSTANT') 67 | # Reshape, to get fully defined shape if possible. 68 | # TODO: fix with tf.slice 69 | shape[axis] = size 70 | x = tf.reshape(x, tf.stack(shape)) 71 | return x 72 | 73 | 74 | -------------------------------------------------------------------------------- /tf_utils.py: -------------------------------------------------------------------------------- 1 | """Diverse TensorFlow utils, for training, evaluation and so on! 2 | """ 3 | import os 4 | from pprint import pprint 5 | import tf_extended as tfe 6 | 7 | import tensorflow as tf 8 | from tensorflow.contrib.slim.python.slim.data import parallel_reader 9 | 10 | slim = tf.contrib.slim 11 | 12 | 13 | # =========================================================================== # 14 | # General tools. 15 | # =========================================================================== # 16 | def reshape_list(l, shape=None): 17 | """Reshape list of (list): 1D to 2D or the other way around. 18 | 19 | Args: 20 | l: List or List of list. 21 | shape: 1D or 2D shape. 22 | Return 23 | Reshaped list. 24 | """ 25 | r = [] 26 | if shape is None: 27 | # Flatten everything. 28 | for a in l: 29 | if isinstance(a, (list, tuple)): 30 | r = r + list(a) 31 | else: 32 | r.append(a) 33 | else: 34 | # Reshape to list of list. 35 | i = 0 36 | for s in shape: 37 | if s == 1: 38 | r.append(l[i]) 39 | else: 40 | r.append(l[i:i+s]) 41 | i += s 42 | return r 43 | 44 | 45 | # =========================================================================== # 46 | # Training utils. 47 | # =========================================================================== # 48 | def print_configuration(flags, aher_params, data_sources, save_dir=None): 49 | """Print the training configuration. 50 | """ 51 | def print_config(stream=None): 52 | print('\n# =========================================================================== #', file=stream) 53 | print('# Training | Evaluation flags:', file=stream) 54 | print('# =========================================================================== #', file=stream) 55 | pprint(flags, stream=stream) 56 | 57 | print('\n# =========================================================================== #', file=stream) 58 | print('# AHER net parameters:', file=stream) 59 | print('# =========================================================================== #', file=stream) 60 | pprint(dict(aher_params._asdict()), stream=stream) 61 | 62 | print('\n# =========================================================================== #', file=stream) 63 | print('# Training | Evaluation dataset files:', file=stream) 64 | print('# =========================================================================== #', file=stream) 65 | data_files = parallel_reader.get_data_files(data_sources) 66 | pprint(sorted(data_files), stream=stream) 67 | print('', file=stream) 68 | 69 | print_config(None) 70 | # Save to a text file as well. 71 | if save_dir is not None: 72 | if not os.path.exists(save_dir): 73 | os.makedirs(save_dir) 74 | path = os.path.join(save_dir, 'training_config.txt') 75 | with open(path, "w") as out: 76 | print_config(out) 77 | 78 | def configure_learning_rate(flags, num_samples_per_epoch, global_step): 79 | """Configures the learning rate. 80 | 81 | Args: 82 | num_samples_per_epoch: The number of samples in each epoch of training. 83 | global_step: The global_step tensor. 84 | Returns: 85 | A `Tensor` representing the learning rate. 86 | """ 87 | decay_steps = int(num_samples_per_epoch / flags.batch_size * 88 | flags.num_epochs_per_decay) 89 | 90 | if flags.learning_rate_decay_type == 'exponential': 91 | return tf.train.exponential_decay(flags.learning_rate, 92 | global_step, 93 | decay_steps, 94 | flags.learning_rate_decay_factor, 95 | staircase=True, 96 | name='exponential_decay_learning_rate') 97 | elif flags.learning_rate_decay_type == 'fixed': 98 | return tf.constant(flags.learning_rate, name='fixed_learning_rate') 99 | elif flags.learning_rate_decay_type == 'polynomial': 100 | return tf.train.polynomial_decay(flags.learning_rate, 101 | global_step, 102 | decay_steps, 103 | flags.end_learning_rate, 104 | power=1.0, 105 | cycle=False, 106 | name='polynomial_decay_learning_rate') 107 | else: 108 | raise ValueError('learning_rate_decay_type [%s] was not recognized', 109 | flags.learning_rate_decay_type) 110 | 111 | def configure_optimizer(flags, learning_rate): 112 | """Configures the optimizer used for training. 113 | 114 | Args: 115 | learning_rate: A scalar or `Tensor` learning rate. 116 | Returns: 117 | An instance of an optimizer. 118 | """ 119 | if flags.optimizer == 'adadelta': 120 | optimizer = tf.train.AdadeltaOptimizer( 121 | learning_rate, 122 | rho=flags.adadelta_rho, 123 | epsilon=flags.opt_epsilon) 124 | elif flags.optimizer == 'adagrad': 125 | optimizer = tf.train.AdagradOptimizer( 126 | learning_rate, 127 | initial_accumulator_value=flags.adagrad_initial_accumulator_value) 128 | elif flags.optimizer == 'adam': 129 | optimizer = tf.train.AdamOptimizer( 130 | learning_rate, 131 | beta1=flags.adam_beta1, 132 | beta2=flags.adam_beta2, 133 | epsilon=flags.opt_epsilon) 134 | elif flags.optimizer == 'ftrl': 135 | optimizer = tf.train.FtrlOptimizer( 136 | learning_rate, 137 | learning_rate_power=flags.ftrl_learning_rate_power, 138 | initial_accumulator_value=flags.ftrl_initial_accumulator_value, 139 | l1_regularization_strength=flags.ftrl_l1, 140 | l2_regularization_strength=flags.ftrl_l2) 141 | elif flags.optimizer == 'momentum': 142 | optimizer = tf.train.MomentumOptimizer( 143 | learning_rate, 144 | momentum=flags.momentum, 145 | name='Momentum') 146 | elif flags.optimizer == 'rmsprop': 147 | optimizer = tf.train.RMSPropOptimizer( 148 | learning_rate, 149 | decay=flags.rmsprop_decay, 150 | momentum=flags.rmsprop_momentum, 151 | epsilon=flags.opt_epsilon) 152 | elif flags.optimizer == 'sgd': 153 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 154 | else: 155 | raise ValueError('Optimizer [%s] was not recognized', flags.optimizer) 156 | return optimizer 157 | 158 | def add_variables_summaries(learning_rate): 159 | summaries = [] 160 | for variable in slim.get_model_variables(): 161 | summaries.append(tf.summary.histogram(variable.op.name, variable)) 162 | summaries.append(tf.summary.scalar('training/Learning Rate', learning_rate)) 163 | return summaries 164 | 165 | def update_model_scope(var, ckpt_scope, new_scope): 166 | return var.op.name.replace(new_scope,'vgg_16') 167 | 168 | def get_init_fn(flags): 169 | """Returns a function run by the chief worker to warm-start the training. 170 | Note that the init_fn is only run when initializing the model during the very 171 | first global step. 172 | 173 | Returns: 174 | An init function run by the supervisor. 175 | """ 176 | if flags.checkpoint_path is None: 177 | return None 178 | # Warn the user if a checkpoint exists in the train_dir. Then ignore. 179 | if tf.train.latest_checkpoint(flags.train_dir): 180 | tf.logging.info( 181 | 'Ignoring --checkpoint_path because a checkpoint already exists in %s' 182 | % flags.train_dir) 183 | return None 184 | 185 | exclusions = [] 186 | if flags.checkpoint_exclude_scopes: 187 | exclusions = [scope.strip() 188 | for scope in flags.checkpoint_exclude_scopes.split(',')] 189 | 190 | # TODO(sguada) variables.filter_variables() 191 | variables_to_restore = [] 192 | for var in slim.get_model_variables(): 193 | excluded = False 194 | for exclusion in exclusions: 195 | if var.op.name.startswith(exclusion): 196 | excluded = True 197 | break 198 | if not excluded: 199 | variables_to_restore.append(var) 200 | # Change model scope if necessary. 201 | if flags.checkpoint_model_scope is not None: 202 | variables_to_restore = \ 203 | {var.op.name.replace(flags.model_name, 204 | flags.checkpoint_model_scope): var 205 | for var in variables_to_restore} 206 | 207 | 208 | if tf.gfile.IsDirectory(flags.checkpoint_path): 209 | checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path) 210 | else: 211 | checkpoint_path = flags.checkpoint_path 212 | tf.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' % (checkpoint_path, flags.ignore_missing_vars)) 213 | 214 | return slim.assign_from_checkpoint_fn( 215 | checkpoint_path, 216 | variables_to_restore, 217 | ignore_missing_vars=flags.ignore_missing_vars) 218 | 219 | def get_variables_to_train(flags): 220 | """Returns a list of variables to train. 221 | 222 | Returns: 223 | A list of variables to train by the optimizer. 224 | """ 225 | if flags.trainable_scopes is None: 226 | return tf.trainable_variables() 227 | else: 228 | scopes = [scope.strip() for scope in flags.trainable_scopes.split(',')] 229 | 230 | variables_to_train = [] 231 | for scope in scopes: 232 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) 233 | variables_to_train.extend(variables) 234 | return variables_to_train 235 | 236 | def sigmoid_cross_entropy_with_logits(x, y): 237 | try: 238 | return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y) 239 | except: 240 | return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y) 241 | 242 | 243 | # =========================================================================== # 244 | # depth-wise convolution 245 | # =========================================================================== # 246 | 247 | # The original 1d input data format: NWC 248 | # reshape input format as NHWC for depthwise 2d conv 249 | def depthwise_conv1d(input, k_h=1, k_w=3, channel_multiplier= 1, strides=1, 250 | padding='SAME', stddev=0.02, name='depthwise_conv1d', bias=True, weight_decay=0.0001): 251 | lshape = tfe.get_shape(input, 3) 252 | input = tf.reshape(input,[lshape[0],1,lshape[1],lshape[2]]) 253 | with tf.variable_scope(name): 254 | in_channel=input.get_shape().as_list()[-1] 255 | w = tf.get_variable('w', [k_h, k_w, in_channel, channel_multiplier], 256 | regularizer=tf.contrib.layers.l2_regularizer(weight_decay), 257 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 258 | conv = tf.nn.depthwise_conv2d(input, w, [1,strides,strides,1], padding, rate=None,name=None,data_format=None) 259 | if bias: 260 | biases = tf.get_variable('bias', [in_channel*channel_multiplier], initializer=tf.constant_initializer(0.0)) 261 | conv = tf.nn.bias_add(conv, biases) 262 | cshape = tfe.get_shape(conv,4) 263 | conv = tf.reshape(conv,[cshape[0],cshape[2],cshape[3]]) # convert to the original 1d data 264 | return conv 265 | 266 | 267 | 268 | 269 | --------------------------------------------------------------------------------