├── README.md
├── activitynet_category_name.txt
├── aher_adv.py
├── aher_anet.py
├── aher_common.py
├── aher_k600_test.py
├── aher_k600_train.py
├── custom_layers.py
├── data_loader.py
├── evaluation
├── eval_detection.py
├── eval_proposal.py
├── get_detection_performance.py
├── get_proposal_performance.py
└── utils.py
├── k600_category.txt
├── k600_val_annotation.txt
├── pic
└── eccv_framework.JPG
├── tf_extended
├── __init__.py
├── bboxes.py
├── math.py
├── metrics.py
└── tensors.py
└── tf_utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # AherNet: Learning to Localize Actions from Moments
2 |
3 | This repository includes the codes and configuration files of the transfer setting of "From ActivityNet v1.3 to Kinetics-600, i.e., ANet -> K600" in the paper.
4 |
5 | All the training/validation features and related meta files of the dataset are not uploaded in the repository since they are too large, which will be released on google drive or Baidu Yun in the future.
6 |
7 | The file `k600_val_annotation.txt` contains the mannul temporal annotations of 6,459 videos in the validation set of Kinetics-600.
8 |
9 | # Update
10 | * 2020.8.25: Repository for AherNet and annotations of sampled validation video in Kinetics-600
11 |
12 | # Contents:
13 |
14 | * [Paper Introduction](#paper-introduction)
15 | * [Environment](#environment)
16 | * [Training of AherNet](#training-of-ahernet)
17 | * [Testing of AherNet on Kinetics-600](#testing-of-ahernet-on-kinetics-600)
18 | * [Citation](#citation)
19 |
20 | # Paper Introduction
21 |
22 |

23 |
24 | With the knowledge of action moments (i.e., trimmed video clips that each contains an action instance), humans could routinely localize an action temporally in an untrimmed video. Nevertheless, most practical methods still require all training videos to be labeled with temporal annotations (action category and temporal boundary) and develop the models in a fully-supervised manner, despite expensive labeling efforts and inapplicable to new categories. In this paper, we introduce a new design of transfer learning type to learn action localization for a large set of action categories, but only on action moments from the categories of interest and temporal annotations of untrimmed videos from a small set of action classes. Specifically, we present Action Herald Networks (AherNet) that integrate such design into an one-stage action localization framework. Technically, a weight transfer function is uniquely devised to build the transformation between classification of action moments or foreground video segments and action localization in synthetic contextual moments or untrimmed videos. The context of each moment is learnt through the adversarial mechanism to differentiate the generated features from those of background in untrimmed videos. Extensive experiments are conducted on the learning both across the splits of ActivityNet v1.3 and from THUMOS14 to ActivityNet v1.3. Our AherNet demonstrates the superiority even comparing to most fully-supervised action localization methods. More remarkably, we train AherNet to localize actions from 600 categories on the leverage of action moments in Kinetics-600 and temporal annotations from 200 classes in ActivityNet v1.3.
25 |
26 | # Environment
27 |
28 | TensorFlow version: 1.12.3 of GPU
29 |
30 | Operation system in docker version: Ubuntu16.04
31 |
32 | Python version: 3.6.7
33 |
34 | GPU version: NVIDIA Tesla P40 (23GB)
35 |
36 | Cuda and Cudnn version: CUDA 9.0 and cudnn 7.3.1
37 |
38 |
39 | # Training of AherNet
40 |
41 | The training script is `aher_k600_train.py`.
42 | Once you have all the training features and information files, you could run
43 |
44 | ```
45 | CUDA_VISIBLE_DEVICES=0 python3 aher_k600_train.py
46 | ```
47 |
48 | # Testing of AherNet on Kinetics-600
49 |
50 | The testing script is `aher_k600_test.py`.
51 | Once you have all the testing features and information files, you could run
52 |
53 | ```
54 | CUDA_VISIBLE_DEVICES=0 python3 aher_k600_test.py
55 | ```
56 | You can evaluate several snapshots (models) during testing stage to find which model is the best one.
57 |
58 | After the evaluation, you can get the results of ".json" file which will be evaluated by `./evaluation/get_proposal_performance.py` or `./evaluation/get_detection_performance.py` for temporal action proposal and temporal action localization evaluation.
59 |
60 |
61 | # Citation
62 |
63 | If you use these models in your research, please cite:
64 |
65 | @inproceedings{Long:ECCV20,
66 | title={Learning To Localize Actions from Moments},
67 | author={Fuchen Long, Ting Yao, Zhaofan Qiu, Xinmei Tian, Jiebo Luo and Tao Mei},
68 | booktitle={ECCV},
69 | year={2020}
70 | }
71 |
--------------------------------------------------------------------------------
/activitynet_category_name.txt:
--------------------------------------------------------------------------------
1 | Drinking coffee
2 | Zumba
3 | Doing kickboxing
4 | Tango
5 | Playing polo
6 | Putting on makeup
7 | Snatch
8 | Long jump
9 | Cricket
10 | Ironing clothes
11 | Clean and jerk
12 | Using parallel bars
13 | Bathing dog
14 | Discus throw
15 | Playing field hockey
16 | Grooming horse
17 | Preparing salad
18 | Doing karate
19 | Playing harmonica
20 | Pole vault
21 | Playing saxophone
22 | Chopping wood
23 | Washing face
24 | Using the pommel horse
25 | Javelin throw
26 | Spinning
27 | Ping-pong
28 | Volleyball
29 | Brushing hair
30 | Making a sandwich
31 | Playing water polo
32 | Cleaning shoes
33 | Drinking beer
34 | Playing bagpipes
35 | Cheerleading
36 | Paintball
37 | Cleaning windows
38 | Brushing teeth
39 | Playing flauta
40 | Tennis serve with ball bouncing
41 | Starting a campfire
42 | Bungee jumping
43 | Triple jump
44 | Polishing forniture
45 | Layup drill in basketball
46 | Vacuuming floor
47 | Dodgeball
48 | Doing nails
49 | Shot put
50 | Fixing bicycle
51 | Washing hands
52 | Horseback riding
53 | Skateboarding
54 | Wrapping presents
55 | Using the balance beam
56 | Shoveling snow
57 | Preparing pasta
58 | Getting a tattoo
59 | Rock climbing
60 | Smoking hookah
61 | Shaving
62 | Using uneven bars
63 | Springboard diving
64 | Playing squash
65 | High jump
66 | Playing piano
67 | Shaving legs
68 | Smoking a cigarette
69 | Getting a haircut
70 | Playing lacrosse
71 | Cumbia
72 | Washing dishes
73 | Getting a piercing
74 | Painting
75 | Mowing the lawn
76 | Walking the dog
77 | Hammer throw
78 | Polishing shoes
79 | Ballet
80 | Doing step aerobics
81 | Hand washing clothes
82 | Plataform diving
83 | Playing violin
84 | Breakdancing
85 | Windsurfing
86 | Sailing
87 | Hopscotch
88 | Doing motocross
89 | Mixing drinks
90 | Belly dance
91 | Removing curlers
92 | Archery
93 | Playing guitarra
94 | Playing racquetball
95 | Kayaking
96 | Playing kickball
97 | Tumbling
98 | Tai chi
99 | Playing accordion
100 | Playing badminton
101 | Arm wrestling
102 | Assembling bicycle
103 | BMX
104 | Baking cookies
105 | Baton twirling
106 | Beach soccer
107 | Beer pong
108 | Blow-drying hair
109 | Blowing leaves
110 | Playing ten pins
111 | Braiding hair
112 | Building sandcastles
113 | Bullfighting
114 | Calf roping
115 | Camel ride
116 | Canoeing
117 | Capoeira
118 | Carving jack-o-lanterns
119 | Changing car wheel
120 | Cleaning sink
121 | Clipping cat claws
122 | Croquet
123 | Curling
124 | Cutting the grass
125 | Decorating the Christmas tree
126 | Disc dog
127 | Doing a powerbomb
128 | Doing crunches
129 | Drum corps
130 | Elliptical trainer
131 | Doing fencing
132 | Fixing the roof
133 | Fun sliding down
134 | Futsal
135 | Gargling mouthwash
136 | Grooming dog
137 | Hand car wash
138 | Hanging wallpaper
139 | Having an ice cream
140 | Hitting a pinata
141 | Hula hoop
142 | Hurling
143 | Ice fishing
144 | Installing carpet
145 | Kite flying
146 | Kneeling
147 | Knitting
148 | Laying tile
149 | Longboarding
150 | Making a cake
151 | Making a lemonade
152 | Making an omelette
153 | Mooping floor
154 | Painting fence
155 | Painting furniture
156 | Peeling potatoes
157 | Plastering
158 | Playing beach volleyball
159 | Playing blackjack
160 | Playing congas
161 | Playing drums
162 | Playing ice hockey
163 | Playing pool
164 | Playing rubik cube
165 | Powerbocking
166 | Putting in contact lenses
167 | Putting on shoes
168 | Rafting
169 | Raking leaves
170 | Removing ice from car
171 | Riding bumper cars
172 | River tubing
173 | Rock-paper-scissors
174 | Rollerblading
175 | Roof shingle removal
176 | Rope skipping
177 | Running a marathon
178 | Scuba diving
179 | Sharpening knives
180 | Shuffleboard
181 | Skiing
182 | Slacklining
183 | Snow tubing
184 | Snowboarding
185 | Spread mulch
186 | Sumo
187 | Surfing
188 | Swimming
189 | Swinging at the playground
190 | Table soccer
191 | Throwing darts
192 | Trimming branches or hedges
193 | Tug of war
194 | Using the monkey bar
195 | Using the rowing machine
196 | Wakeboarding
197 | Waterskiing
198 | Waxing skis
199 | Welding
200 | Applying sunscreen
201 |
--------------------------------------------------------------------------------
/aher_adv.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import tensorflow as tf
3 | import numpy as np
4 | import aher_anet
5 | from aher_anet import aher_multibox_adv_layer
6 | import tf_utils
7 | from datetime import datetime
8 | import time
9 | import data_loader
10 | import tf_extended as tfe
11 | import os
12 | import sys
13 | import pandas as pd
14 | from multiprocessing import Process,Queue,JoinableQueue
15 | import multiprocessing
16 | import math
17 | import random
18 | from tf_utils import *
19 |
20 | FLAGS = tf.app.flags.FLAGS
21 | tf.app.flags.DEFINE_float('dis_weights', 0.1, 'The weight for the discriminator.')
22 | tf.app.flags.DEFINE_float('gen_weights', 0.1, 'The weight for the generator.')
23 |
24 | class Config(object):
25 | def __init__(self):
26 | self.learning_rates=[0.001]*100+[0.0001]*100
27 | #self.training_epochs = len(self.learning_rates)
28 | self.training_epochs = 1
29 | self.total_batch_num = 15000
30 | self.n_inputs = 2048
31 | self.batch_size = 16
32 | self.input_steps=512
33 | self.input_moment_steps=256
34 | self.gt_hold_num = 25
35 | self.gt_hold_num_th = 25
36 | self.batch_size_val=1
37 |
38 | # generate context feature
39 | def Context_Train(m_feature,ratio_id,pos_id):
40 | """ Model and loss function of context information network
41 | input: m_feature: batch_size x 256 x 2048
42 | input: position: batch_size
43 | output: concate_feature: batch_size x 512 x 2048
44 | """
45 | config = Config()
46 |
47 | # The start context generator
48 | net1_i=tf.contrib.layers.conv1d(inputs=m_feature[:,:,],num_outputs=1024,kernel_size=3, \
49 | stride=1,padding='same',scope='g_conv_s1')
50 |
51 | net1=tf.contrib.layers.conv1d(inputs=net1_i,num_outputs=2048,kernel_size=3, \
52 | stride=2,padding='same',scope='g_conv_s2')
53 |
54 | # The end context generator
55 | net2_i=tf.contrib.layers.conv1d(inputs=m_feature[:,:,],num_outputs=1024,kernel_size=3,
56 | stride=1,padding='same',scope='g_conv_e1')
57 |
58 | net2=tf.contrib.layers.conv1d(inputs=net2_i,num_outputs=2048,kernel_size=3, \
59 | stride=2,padding='same',scope='g_conv_e2')
60 |
61 |
62 | # random crop and select temporal gt
63 | net_res = []
64 | temporal_gt = []
65 | for i in range(config.batch_size):
66 | ratio = tf.cast(ratio_id[i],tf.float32) * tf.constant(0.05)
67 | posi = tf.cast(pos_id[i],tf.float32) * tf.constant(0.05)
68 |
69 | resize_fea_len = tf.cast(tf.constant(512.0)*ratio, tf.int32)
70 | temp_feature = tf.expand_dims(m_feature,2)
71 | resize_fea = tf.image.resize_images(temp_feature,[resize_fea_len,1])
72 | reduce_fea = tf.squeeze(resize_fea,2)
73 |
74 | net1_len = tf.cast((tf.constant(512)-resize_fea_len)/tf.constant(2),tf.int32)
75 | net2_len = tf.constant(512)-resize_fea_len-net1_len
76 |
77 | temp_net1 = tf.expand_dims(net1,2)
78 | resize_net1 = tf.image.resize_images(temp_net1,[net1_len,1])
79 | inc_net1 = tf.squeeze(resize_net1,2)
80 |
81 | temp_net2 = tf.expand_dims(net2,2)
82 | resize_net2 = tf.image.resize_images(temp_net2,[net2_len,1])
83 | inc_net2 = tf.squeeze(resize_net2,2)
84 |
85 | if i % 2 == 0:
86 | start = tf.cast((tf.constant(1.0)-ratio) * tf.cast(net1_len,tf.float32) * posi, tf.int32)
87 | net_a = inc_net1[i,:start,]
88 | net_b = inc_net1[i,start:,]
89 | net_3 = tf.keras.layers.concatenate(inputs=[net_a,reduce_fea[i,:,],net_b,inc_net2[i,:,]],axis=0)
90 | net_res.append(tf.reshape(net_3,[1,config.input_steps,config.n_inputs]))
91 | temporal_gt.append(tf.reshape(tf.cast(start,tf.float32),[1]))
92 | temporal_gt.append(tf.reshape(tf.cast(start + resize_fea_len,tf.float32),[1]))
93 | else:
94 | start = tf.cast((tf.constant(1.0)-ratio) * tf.cast(net2_len,tf.float32) * posi, tf.int32)
95 | net_a = inc_net2[i,:start,]
96 | net_b = inc_net2[i,start:,]
97 | net_3 = tf.keras.layers.concatenate(inputs=[inc_net1[i,:,],net_a,reduce_fea[i,:,],net_b],axis=0)
98 | net_res.append(tf.reshape(net_3,[1,config.input_steps,config.n_inputs]))
99 | temporal_gt.append(tf.reshape(tf.cast(start + net1_len,tf.float32),[1]))
100 | temporal_gt.append(tf.reshape(tf.cast(start + resize_fea_len,tf.float32),[1]))
101 |
102 | net_c = tf.concat(net_res,axis=0)
103 | temp_gt = tf.concat(temporal_gt,axis=0)
104 | temp_gt = tf.reshape(temp_gt,[config.batch_size,1,2])
105 |
106 | return net_c,temp_gt
107 |
108 | # Discriminator in each anchor layer
109 | def Context_Back_Discriminator(input_points,
110 | feat_layers=aher_anet.AHERNet.default_params.feat_layers,
111 | anchor_sizes=aher_anet.AHERNet.default_params.anchor_sizes,
112 | anchor_ratios=aher_anet.AHERNet.default_params.anchor_ratios,
113 | normalizations=aher_anet.AHERNet.default_params.normalizations,
114 | reuse = None):
115 |
116 | num_classes = 2
117 | D_logits = []
118 | D = []
119 | for i, layer in enumerate(feat_layers):
120 | with tf.variable_scope(layer + '_adv',reuse=reuse):
121 | adv_logits = aher_multibox_adv_layer(input_points[layer],
122 | num_classes,
123 | anchor_sizes[i],
124 | anchor_ratios[i],
125 | normalizations[i])
126 | D_logits.append(adv_logits)
127 | D.append(tf.math.sigmoid(adv_logits))
128 | return D,D_logits
129 |
130 | # Background Discriminator
131 | def Adversary_Back_Train(D, D_logits, D_, D_logits_,
132 | gscore_untrim, gscore_gene,scope=None):
133 |
134 | with tf.name_scope(scope,'aher_adv_losses'):
135 | lshape = tfe.get_shape(D_logits[0], 8)
136 | batch_size = lshape[0]
137 | fgscore_untrim = []
138 | fgscore_gene = []
139 | f_D_logits = []
140 | f_D_logits_ = []
141 | f_D = []
142 | f_D_ = []
143 | for i in range(len(D_logits)):
144 | fgscore_untrim.append(tf.reshape(gscore_untrim[i],[-1]))
145 | fgscore_gene.append(tf.reshape(gscore_gene[i],[-1]))
146 | f_D_logits.append(tf.reshape(D_logits[i],[-1]))
147 | f_D_logits_.append(tf.reshape(D_logits_[i],[-1]))
148 | f_D.append(tf.reshape(D[i],[-1]))
149 | f_D_.append(tf.reshape(D_[i],[-1]))
150 | gscore_untrim = tf.concat(fgscore_untrim, axis=0)
151 | gscore_gene = tf.concat(fgscore_gene, axis=0)
152 | D_logits = tf.concat(f_D_logits, axis=0)
153 | D_logits_ = tf.concat(f_D_logits_, axis=0)
154 | D = tf.concat(f_D, axis=0)
155 | D_ = tf.concat(f_D_, axis=0)
156 | dtype = D_logits.dtype
157 |
158 | # select the background position and logits
159 | pos_mask_untrim = gscore_untrim > 0.70
160 | nmask_untrim = tf.logical_and(tf.logical_not(pos_mask_untrim),gscore_untrim < 0.3)
161 |
162 | pos_mask_gene = gscore_gene > 0.70
163 | nmask_gene = tf.logical_and(tf.logical_not(pos_mask_gene),gscore_gene < 0.3)
164 |
165 | nmask = tf.logical_and(nmask_untrim,nmask_gene)
166 | fnmask = tf.cast(nmask, dtype)
167 | fnmask_num = tf.reduce_sum(fnmask)
168 |
169 | # compute the sigmoid cross entropy loss
170 | d_loss_real=sigmoid_cross_entropy_with_logits(D_logits, tf.ones_like(D))
171 | d_loss_real=tf.div(tf.reduce_sum(d_loss_real*fnmask), fnmask_num/FLAGS.dis_weights, name='d_loss_real')
172 |
173 | d_loss_fake=sigmoid_cross_entropy_with_logits(D_logits_, tf.zeros_like(D_))
174 | d_loss_fake=tf.div(tf.reduce_sum(d_loss_fake*fnmask), fnmask_num/FLAGS.dis_weights, name='d_loss_fake')
175 |
176 | g_loss=sigmoid_cross_entropy_with_logits(D_logits_, tf.ones_like(D_))
177 | g_loss=tf.div(tf.reduce_sum(g_loss*fnmask), fnmask_num/FLAGS.gen_weights, name='g_loss')
178 |
179 | return d_loss_real,d_loss_fake,g_loss
180 |
--------------------------------------------------------------------------------
/aher_common.py:
--------------------------------------------------------------------------------
1 | """Shared function between different AherNet implementations.
2 | """
3 | import numpy as np
4 | import tensorflow as tf
5 | import tf_extended as tfe
6 |
7 | batch_size_fix=16
8 |
9 | # =========================================================================== #
10 | # TensorFlow implementation of boxes Aher encoding / decoding.
11 | # =========================================================================== #
12 | def tf_aher_bboxes_encode_layer(labels,
13 | bboxes,
14 | anchors_layer,
15 | num_classes,
16 | temporal_shape,
17 | matching_threshold=0.70,
18 | prior_scaling=[1.0,1.0],
19 | dtype=tf.float32):
20 | """Encode groundtruth labels and bounding boxes using 1D anchors from
21 | one layer.
22 |
23 | Arguments:
24 | labels: 1D Tensor(int32) containing groundtruth labels;
25 | bboxes: batch_size x N x 2 Tensor(float) with bboxes relative coordinates;
26 | anchors_layer: Numpy array with layer anchors;
27 | matching_threshold: Threshold for positive match with groundtruth bboxes;
28 | prior_scaling: Scaling of encoded coordinates.
29 |
30 | Return:
31 | (target_labels, target_localizations, target_scores): Target Tensors.
32 | """
33 | # Anchors coordinates and volume.
34 | yref, tref= anchors_layer
35 | ymin = yref - tref / 2.
36 | ymax = yref + tref / 2.
37 |
38 | ymin = tf.maximum(0.0, ymin)
39 | ymax = tf.minimum(ymax, temporal_shape - 1.0)
40 |
41 | vol_anchors = ymax - ymin
42 |
43 | if batch_size_fix == 1:
44 | bboxes = tf.reshape(bboxes,[1,bboxes.shape[0],bboxes.shape[1]])
45 | labels = tf.reshape(labels,[1])
46 |
47 | # Initialize tensors...
48 | shape = (bboxes.shape[0], yref.shape[0], tref.size)
49 | s_shape = (yref.shape[0], tref.size)
50 | feat_scores = tf.zeros(shape, dtype=dtype)
51 | feat_max_iou = tf.zeros(shape, dtype=dtype)
52 |
53 | feat_ymin = tf.zeros(shape, dtype=dtype)
54 | feat_ymax = tf.ones(shape, dtype=dtype)
55 | mask_minus = feat_ymin - feat_ymax
56 | s_max_one = tf.ones(s_shape, dtype=tf.int32)
57 | s_max_one_f = tf.ones(s_shape, dtype=dtype)
58 |
59 | label_shape = (bboxes.shape[0], yref.shape[0], 1)
60 | s_label_shape = (yref.shape[0], 1)
61 | s_label_max_one = tf.ones(s_label_shape, dtype=tf.int32)
62 | feat_labels = tf.zeros(label_shape, dtype=tf.int32)
63 |
64 |
65 | def jaccard_with_anchors(bbox):
66 | """Compute jaccard score between a box and the anchors.
67 | """
68 | int_ymin = tf.maximum(ymin, bbox[0])
69 | int_ymax = tf.minimum(ymax, bbox[1])
70 | t = tf.maximum(int_ymax - int_ymin, 0.)
71 |
72 | # Volumes.
73 | inter_vol = t
74 | union_vol = vol_anchors - inter_vol \
75 | + (bbox[1] - bbox[0])
76 | jaccard = tf.div(inter_vol, union_vol)
77 | return jaccard
78 |
79 | def intersection_with_anchors(bbox):
80 | """Compute intersection between score a box and the anchors.
81 | """
82 | int_ymin = tf.maximum(ymin, bbox[0])
83 | int_ymax = tf.minimum(ymax, bbox[1])
84 | t = tf.maximum(int_ymax - int_ymin, 0.)
85 | inter_vol = t
86 | scores = tf.div(inter_vol, vol_anchors)
87 | return scores
88 |
89 | def condition(i,ik, feat_labels, feat_scores,
90 | feat_ymin, feat_ymax):
91 | """Condition: check label index.
92 | """
93 | # remove unusable gt
94 | bbox = bboxes[i]
95 | bbox = tf.reshape(tf.boolean_mask(bbox,tf.not_equal([-1.0,-1.0],bbox)),[-1,2])
96 | r = tf.less(ik, tf.shape(bbox)[0])
97 | return r
98 |
99 | def body(batch_id, i, feat_labels, feat_scores,
100 | feat_ymin, feat_ymax):
101 | """Body: update feature labels, scores and bboxes.
102 | - assign values when jaccard > 0.5;
103 | - only update if beat the score of other bboxes.
104 | """
105 | # Jaccard score.
106 | label = labels[batch_id]
107 | bbox = bboxes[batch_id]
108 |
109 | bbox = tf.reshape(tf.boolean_mask(bbox,tf.not_equal([-1.0,-1.0],bbox)),[-1,2])[i]
110 |
111 | jaccard = jaccard_with_anchors(bbox)
112 | # Mask: check threshold + scores + no annotations + num_classes.
113 | mask = tf.greater(jaccard, feat_scores)
114 | imask = tf.cast(mask, tf.int32)
115 | fmask = tf.cast(mask, dtype)
116 | # Update values using mask.
117 |
118 | #feat_labels = imask * label + (1 - imask) * feat_labels
119 | feat_labels = s_label_max_one * label
120 |
121 | feat_scores = tf.where(mask, jaccard, feat_scores)
122 |
123 | feat_ymin = fmask * bbox[0] + (1 - fmask) * feat_ymin
124 | feat_ymax = fmask * bbox[1] + (1 - fmask) * feat_ymax
125 |
126 | # Check no annotation label: ignore these anchors...
127 | # interscts = intersection_with_anchors(bbox)
128 | # mask = tf.logical_and(interscts > ignore_threshold,
129 | # label == no_annotation_label)
130 | # # Replace scores by -1.
131 | # feat_scores = tf.where(mask, -tf.cast(mask, dtype), feat_scores)
132 |
133 | return [batch_id, i+1, feat_labels, feat_scores,feat_ymin, feat_ymax]
134 |
135 | def batch_condition(i, feat_labels, feat_scores,
136 | feat_ymin, feat_ymax):
137 | r = tf.less(i, tf.shape(bboxes)[0])
138 | return r
139 |
140 | def batch_body(i, feat_labels, feat_scores,
141 | feat_ymin,feat_ymax):
142 | ik = 0
143 | s_feat_labels = tf.zeros(s_label_shape, dtype=tf.int32)
144 | s_feat_scores = tf.zeros(s_shape, dtype=dtype)
145 | s_feat_ymin = tf.zeros(s_shape, dtype=dtype)
146 | s_feat_ymax = tf.ones(s_shape, dtype=dtype)
147 | [i, ik, s_feat_labels, s_feat_scores, s_feat_ymin,s_feat_ymax] = tf.while_loop(condition, body,
148 | [i, ik, s_feat_labels, s_feat_scores, s_feat_ymin, s_feat_ymax])
149 |
150 | s_feat_labels = tf.reshape(s_feat_labels,[1,s_label_shape[0],s_label_shape[1]])
151 | s_feat_scores = tf.reshape(s_feat_scores,[1,s_shape[0],s_shape[1]])
152 | s_feat_ymin = tf.reshape(s_feat_ymin,[1,s_shape[0],s_shape[1]])
153 |
154 |
155 |
156 | # labels
157 | if batch_size_fix != 1:
158 | s_feat_labels_1 = tf.zeros([i,s_label_shape[0],s_label_shape[1]], dtype=tf.int32)
159 | s_feat_labels_2 = tf.zeros([batch_size_fix-i-1,s_label_shape[0],s_label_shape[1]], dtype=tf.int32)
160 | s_feat_labels = tf.concat([s_feat_labels_1,s_feat_labels,s_feat_labels_2],0)
161 | feat_labels = feat_labels + s_feat_labels
162 | else:
163 | feat_labels = s_feat_labels
164 |
165 | # scores
166 | if batch_size_fix != 1:
167 | s_feat_scores_1 = tf.zeros([i,s_shape[0],s_shape[1]], dtype=dtype)
168 | s_feat_scores_2 = tf.zeros([batch_size_fix-i-1,s_shape[0],s_shape[1]], dtype=dtype)
169 | s_feat_scores = tf.concat([s_feat_scores_1,s_feat_scores,s_feat_scores_2],0)
170 | feat_scores = feat_scores + s_feat_scores
171 | else:
172 | feat_scores = s_feat_scores
173 |
174 | # ymin
175 | if batch_size_fix != 1:
176 | s_feat_ymin_1 = tf.zeros([i,s_shape[0],s_shape[1]], dtype=dtype)
177 | s_feat_ymin_2 = tf.zeros([batch_size_fix-i-1,s_shape[0],s_shape[1]], dtype=dtype)
178 | s_feat_ymin = tf.concat([s_feat_ymin_1,s_feat_ymin,s_feat_ymin_2],0)
179 | feat_ymin = feat_ymin + s_feat_ymin
180 | else:
181 | feat_ymin = s_feat_ymin
182 |
183 | # ymax
184 | if batch_size_fix != 1:
185 | s_feat_ymax = s_feat_ymax - s_max_one_f
186 | s_feat_ymax = tf.reshape(s_feat_ymax,[1,s_shape[0],s_shape[1]])
187 | s_feat_ymax_1 = tf.zeros([i,s_shape[0],s_shape[1]], dtype=dtype)
188 | s_feat_ymax_2 = tf.zeros([batch_size_fix-i-1,s_shape[0],s_shape[1]], dtype=dtype)
189 | s_feat_ymax = tf.concat([s_feat_ymax_1,s_feat_ymax,s_feat_ymax_2],0)
190 | feat_ymax = feat_ymax + s_feat_ymax
191 | else:
192 | s_feat_ymax = tf.reshape(s_feat_ymax,[1,s_shape[0],s_shape[1]])
193 | feat_ymax = s_feat_ymax
194 |
195 | #feat_labels = tf.concat([feat_labels,s_feat_labels],0)
196 | #feat_scores = tf.concat([feat_scores,s_feat_scores],0)
197 | #feat_ymin = tf.concat([feat_ymin,s_feat_ymin],0)
198 | #feat_ymax = tf.concat([feat_ymax,s_feat_ymax],0)
199 | #feat_labels[i] = s_feat_labels
200 | #feat_scores[i] = s_feat_scores
201 | #feat_ymin[i] = s_feat_ymin
202 | #feat_ymax[i] = s_feat_ymax
203 | return [i+1,feat_labels,feat_scores,feat_ymin,feat_ymax]
204 |
205 | # Main loop definition.
206 | i = 0
207 | #[i, feat_labels, feat_scores,feat_ymin, feat_ymax] = tf.while_loop(condition, body,
208 | # [i, feat_labels, feat_scores, feat_ymin,feat_ymax])
209 | [i, feat_labels, feat_scores,feat_ymin, feat_ymax] = tf.while_loop(batch_condition, batch_body,
210 | [i, feat_labels, feat_scores, feat_ymin,feat_ymax])
211 |
212 | #feat_labels = labels
213 | mask_zero = tf.equal(feat_scores,0.0)
214 | fmask_zero = tf.cast(mask_zero,dtype)
215 | feat_max_iou = fmask_zero * mask_minus + (1-fmask_zero)*feat_scores
216 |
217 | # Transform to center / size.
218 | feat_cy = (feat_ymax + feat_ymin) / 2.
219 | feat_t = feat_ymax - feat_ymin
220 | # Encode features.
221 | feat_cy = (feat_cy - yref) / tref / prior_scaling[0]
222 | feat_t = tf.log(feat_t / tref) / prior_scaling[1]
223 | # Use ordering: x / y / w / h instead of ours.
224 | feat_localizations = tf.stack([feat_cy, feat_t], axis=-1)
225 |
226 | return feat_labels, feat_localizations, feat_scores, feat_max_iou
227 |
228 | def tf_aher_bboxes_encode(labels,
229 | bboxes,
230 | anchors,
231 | num_classes,
232 | temporal_shape,
233 | matching_threshold=0.70,
234 | prior_scaling=[1.0,1.0],
235 | dtype=tf.float32,
236 | scope='aher_bboxes_encode'):
237 | """Encode groundtruth labels and bounding boxes using 1D net anchors.
238 | Encoding boxes for all feature layers.
239 |
240 | Arguments:
241 | labels: 1D Tensor(int32) containing groundtruth labels;
242 | bboxes: Nx2 Tensor(float) with bboxes relative coordinates;
243 | anchors: List of Numpy array with layer anchors;
244 | matching_threshold: Threshold for positive match with groundtruth bboxes;
245 | prior_scaling: Scaling of encoded coordinates.
246 |
247 | Return:
248 | (target_labels, target_localizations, target_scores):
249 | Each element is a list of target Tensors.
250 | """
251 | with tf.name_scope(scope):
252 | target_labels = []
253 | target_localizations = []
254 | target_scores = []
255 | target_iou = []
256 | for i, anchors_layer in enumerate(anchors):
257 | with tf.name_scope('bboxes_encode_block_%i' % i):
258 | t_labels, t_loc, t_scores, t_iou = \
259 | tf_aher_bboxes_encode_layer(labels, bboxes, anchors_layer,
260 | num_classes, temporal_shape,
261 | matching_threshold,
262 | prior_scaling, dtype)
263 | target_labels.append(t_labels)
264 | target_localizations.append(t_loc)
265 | target_scores.append(t_scores)
266 | target_iou.append(t_iou)
267 | return target_labels, target_localizations, target_scores, target_iou
268 |
269 | def tf_aher_bboxes_decode_layer(feat_localizations,
270 | duration,
271 | anchors_layer,
272 | prior_scaling=[1.0, 1.0]):
273 | """Compute the relative bounding boxes from the layer features and
274 | reference anchor bounding boxes.
275 |
276 | Arguments:
277 | feat_localizations: Tensor containing localization features.
278 | anchors: List of numpy array containing anchor boxes.
279 |
280 | Return:
281 | Tensor Nx2: ymin, ymax
282 | """
283 | yref, tref = anchors_layer
284 |
285 | # Compute center, height and width
286 | cy = feat_localizations[:, :, :, 0] * tref * prior_scaling[0] + yref # tf.expand_dims(yref,-1)
287 | t = tref * tf.exp(feat_localizations[:, :, :, 1] * prior_scaling[1])
288 | # Boxes coordinates.
289 | ymin = cy - t / 2.
290 | ymax = cy + t / 2.
291 | ymin = tf.maximum(ymin, 0.0) / 512.0 * duration
292 | ymax = tf.minimum(ymax, 512.0) / 512.0 * duration
293 | bboxes = tf.stack([ymin, ymax], axis=-1)
294 | return bboxes
295 |
296 | def tf_aher_bboxes_decode_logits_layer(feat_localizations,
297 | duration,
298 | anchors_layer,
299 | logitsprediction,
300 | prior_scaling=[1.0, 1.0]):
301 | """Compute the relative bounding boxes from the layer features and
302 | reference anchor bounding boxes.
303 |
304 | Arguments:
305 | feat_localizations: Tensor containing localization features.
306 | anchors: List of numpy array containing anchor boxes.
307 |
308 | Return:
309 | Tensor Nx2: ymin, ymax
310 | """
311 | yref, tref = anchors_layer
312 |
313 | # Compute center, height and width
314 | cy = feat_localizations[:, :, :, 0] * tref * prior_scaling[0] + yref # tf.expand_dims(yref,-1)
315 | t = tref * tf.exp(feat_localizations[:, :, :, 1] * prior_scaling[1])
316 | # Boxes coordinates.
317 | ymin = cy - t / 2.
318 | ymax = cy + t / 2.
319 | ymin = tf.maximum(ymin, 0.0) / 512.0 * duration
320 | ymax = tf.minimum(ymax, 512.0) / 512.0 * duration
321 |
322 | maxid = tf.argmax(logitsprediction,axis=-1)
323 | maxidtile = tf.tile(maxid,(1,1,3))
324 | maxidtilef = tf.cast(maxidtile,tf.float32)
325 |
326 | maxscore = tf.reduce_max(logitsprediction,axis=-1)
327 | maxscoretile = tf.tile(maxscore,(1,1,3))
328 |
329 | bboxes = tf.stack([ymin, ymax, maxidtilef, maxscoretile], axis=-1)
330 |
331 | return bboxes
332 |
333 | def tf_aher_bboxes_decode_detect_layer(feat_localizations,
334 | duration,
335 | anchors_layer,
336 | propredictions,
337 | prior_scaling=[1.0, 1.0]):
338 | """Compute the relative bounding boxes from the layer features and
339 | reference anchor bounding boxes.
340 |
341 | Arguments:
342 | feat_localizations: Tensor containing localization features.
343 | anchors: List of numpy array containing anchor boxes.
344 |
345 |
346 | Return:
347 | Tensor Nx2: ymin, ymax
348 | """
349 | yref, tref = anchors_layer
350 |
351 | # Compute center, height and width
352 | cy = feat_localizations[:, :, :, 0] * tref * prior_scaling[0] + yref # tf.expand_dims(yref,-1)
353 | t = tref * tf.exp(feat_localizations[:, :, :, 1] * prior_scaling[1])
354 | # Boxes coordinates.
355 | ymin = cy - t / 2.
356 | ymax = cy + t / 2.
357 | ymin = tf.maximum(ymin, 0.0) / 512.0 * duration
358 | ymax = tf.minimum(ymax, 512.0) / 512.0 * duration
359 |
360 | #maxid = tf.argmax(propredictions,axis=-1)
361 | #maxidtile = tf.tile(maxid,(1,1,3))
362 | #maxidf = tf.cast(maxid,tf.float32)
363 |
364 | #maxscore = tf.reduce_max(propredictions,axis=-1)
365 | #maxscoretile = tf.tile(maxscore,(1,1,3))
366 |
367 | bboxes = tf.stack([ymin, ymax], axis=-1)
368 | bboxes = tf.concat([bboxes, propredictions],axis=-1)
369 |
370 | return bboxes
371 |
372 | def tf_aher_bboxes_decode(feat_localizations,
373 | duration,
374 | anchors,
375 | prior_scaling=[1.0,1.0],
376 | scope='aher_bboxes_decode'):
377 | """Compute the relative bounding boxes from the net features and
378 | reference anchors bounding boxes.
379 |
380 | Arguments:
381 | feat_localizations: List of Tensors containing localization features.
382 | anchors: List of numpy array containing anchor boxes.
383 |
384 | Return:
385 | List of Tensors Nx2: ymin, ymax
386 | """
387 | with tf.name_scope(scope):
388 | bboxes = []
389 | for i, anchors_layer in enumerate(anchors):
390 | bboxes.append(
391 | tf_aher_bboxes_decode_layer(feat_localizations[i],
392 | duration,
393 | anchors_layer,
394 | prior_scaling))
395 | return bboxes
396 |
397 | def tf_aher_bboxes_decode_logits(feat_localizations,
398 | duration,
399 | anchors,
400 | logitsprediction,
401 | prior_scaling=[1.0,1.0],
402 | scope='aher_bboxes_decode'):
403 | """Compute the relative bounding boxes from the net features and
404 | reference anchors bounding boxes.
405 |
406 | Arguments:
407 | feat_localizations: List of Tensors containing localization features.
408 | anchors: List of numpy array containing anchor boxes.
409 |
410 | Return:
411 | List of Tensors Nx2: ymin, ymax
412 | """
413 | with tf.name_scope(scope):
414 | bboxes = []
415 | for i, anchors_layer in enumerate(anchors):
416 | bboxes.append(
417 | tf_aher_bboxes_decode_logits_layer(feat_localizations[i],
418 | duration,
419 | anchors_layer,
420 | logitsprediction[i],
421 | prior_scaling))
422 | return bboxes
423 |
424 | def tf_aher_bboxes_decode_detect(feat_localizations,
425 | duration,
426 | anchors,
427 | propprediction,
428 | prior_scaling=[1.0,1.0],
429 | scope='aher_bboxes_decode'):
430 | """Compute the relative bounding boxes from the net features and
431 | reference anchors bounding boxes.
432 |
433 | Arguments:
434 | feat_localizations: List of Tensors containing localization features.
435 | anchors: List of numpy array containing anchor boxes.
436 |
437 | Return:
438 | List of Tensors Nx2: ymin, ymax
439 | """
440 | with tf.name_scope(scope):
441 | bboxes = []
442 | for i, anchors_layer in enumerate(anchors):
443 | bboxes.append(
444 | tf_aher_bboxes_decode_detect_layer(feat_localizations[i],
445 | duration,
446 | anchors_layer,
447 | propprediction[i],
448 | prior_scaling))
449 | return bboxes
450 |
451 | # =========================================================================== #
452 | # temporal boxes selection.
453 | # =========================================================================== #
454 | def tf_aher_bboxes_select_layer(predictions_layer, localizations_layer,
455 | select_threshold=None,
456 | num_classes=21,
457 | ignore_class=0,
458 | scope=None,
459 | IoU_flag=False):
460 | """Extract classes, scores and bounding boxes from features in one layer.
461 | Batch-compatible: inputs are supposed to have batch-type shapes.
462 |
463 | Args:
464 | predictions_layer: A prediction layer;
465 | localizations_layer: A localization layer;
466 | select_threshold: Classification threshold for selecting a box. All boxes
467 | under the threshold are set to 'zero'. If None, no threshold applied.
468 | Return:
469 | d_scores, d_bboxes: Dictionary of scores and bboxes Tensors of
470 | size Batches X N x 1 | 2. Each key corresponding to a class.
471 | """
472 | select_threshold = 0.0 if select_threshold is None else select_threshold
473 | with tf.name_scope(scope, 'aher_bboxes_select_layer',
474 | [predictions_layer, localizations_layer]):
475 | # Reshape features: Batches x N x N_labels | 4
476 | p_shape = tfe.get_shape(predictions_layer)
477 | predictions_layer = tf.reshape(predictions_layer,
478 | tf.stack([p_shape[0], -1, p_shape[-1]]))
479 | if IoU_flag:
480 | zeros_m = tf.zeros([predictions_layer.shape[1],1])
481 | predictions_layer = tf.reshape(tf.stack([zeros_m,
482 | tf.reshape(predictions_layer,[predictions_layer.shape[1],1])],axis=1),
483 | [predictions_layer.shape[0],predictions_layer.shape[1],2])
484 | l_shape = tfe.get_shape(localizations_layer)
485 | localizations_layer = tf.reshape(localizations_layer,
486 | tf.stack([l_shape[0], -1, l_shape[-1]]))
487 |
488 | d_scores = {}
489 | d_bboxes = {}
490 | for c in range(0, num_classes):
491 | if c != ignore_class:
492 | # Remove boxes under the threshold.
493 | scores = predictions_layer[:, :, c]
494 | fmask = tf.cast(tf.greater_equal(scores, select_threshold), scores.dtype)
495 | scores = scores * fmask
496 | bboxes = localizations_layer * tf.expand_dims(fmask, axis=-1)
497 | # Append to dictionary.
498 | d_scores[c] = scores
499 | d_bboxes[c] = bboxes
500 |
501 | return d_scores, d_bboxes
502 |
503 | def tf_aher_bboxes_select(predictions_net, localizations_net,
504 | select_threshold=None,
505 | num_classes=21,
506 | ignore_class=0,
507 | scope=None,
508 | IoU_flag = False):
509 | """Extract classes, scores and bounding boxes from network output layers.
510 | Batch-compatible: inputs are supposed to have batch-type shapes.
511 |
512 | Args:
513 | predictions_net: List of prediction layers;
514 | localizations_net: List of localization layers;
515 | select_threshold: Classification threshold for selecting a box. All boxes
516 | under the threshold are set to 'zero'. If None, no threshold applied.
517 | Return:
518 | d_scores, d_bboxes: Dictionary of scores and bboxes Tensors of
519 | size Batches X N x 1 | 4. Each key corresponding to a class.
520 | """
521 | with tf.name_scope(scope, 'aher_bboxes_select',
522 | [predictions_net, localizations_net]):
523 | l_scores = []
524 | l_bboxes = []
525 | for i in range(len(predictions_net)):
526 | scores, bboxes = tf_aher_bboxes_select_layer(predictions_net[i],
527 | localizations_net[i],
528 | select_threshold,
529 | num_classes,
530 | ignore_class,
531 | IoU_flag = IoU_flag)
532 | l_scores.append(scores)
533 | l_bboxes.append(bboxes)
534 | # Concat results.
535 | d_scores = {}
536 | d_bboxes = {}
537 | for c in l_scores[0].keys():
538 | ls = [s[c] for s in l_scores]
539 | lb = [b[c] for b in l_bboxes]
540 | d_scores[c] = tf.concat(ls, axis=1)
541 | d_bboxes[c] = tf.concat(lb, axis=1)
542 | return d_scores, d_bboxes
543 |
544 | def tf_aher_bboxes_select_layer_all_classes(predictions_layer, localizations_layer,
545 | select_threshold=None):
546 | """Extract classes, scores and bounding boxes from features in one layer.
547 | Batch-compatible: inputs are supposed to have batch-type shapes.
548 |
549 | Args:
550 | predictions_layer: A prediction layer;
551 | localizations_layer: A localization layer;
552 | select_threshold: Classification threshold for selecting a box. If None,
553 | select boxes whose classification score is higher than 'no class'.
554 | Return:
555 | classes, scores, bboxes: Input Tensors.
556 | """
557 | # Reshape features: Batches x N x N_labels | 4
558 | p_shape = tfe.get_shape(predictions_layer)
559 | predictions_layer = tf.reshape(predictions_layer,
560 | tf.stack([p_shape[0], -1, p_shape[-1]]))
561 | l_shape = tfe.get_shape(localizations_layer)
562 | localizations_layer = tf.reshape(localizations_layer,
563 | tf.stack([l_shape[0], -1, l_shape[-1]]))
564 | # Boxes selection: use threshold or score > no-label criteria.
565 | if select_threshold is None or select_threshold == 0:
566 | # Class prediction and scores: assign 0. to 0-class
567 | classes = tf.argmax(predictions_layer, axis=2)
568 | scores = tf.reduce_max(predictions_layer, axis=2)
569 | scores = scores * tf.cast(classes > 0, scores.dtype)
570 | else:
571 | sub_predictions = predictions_layer[:, :, 1:]
572 | classes = tf.argmax(sub_predictions, axis=2) + 1
573 | scores = tf.reduce_max(sub_predictions, axis=2)
574 | # Only keep predictions higher than threshold.
575 | mask = tf.greater(scores, select_threshold)
576 | classes = classes * tf.cast(mask, classes.dtype)
577 | scores = scores * tf.cast(mask, scores.dtype)
578 | # Assume localization layer already decoded.
579 | bboxes = localizations_layer
580 | return classes, scores, bboxes
581 |
582 | def tf_aher_bboxes_select_all_classes(predictions_net, localizations_net,
583 | select_threshold=None,
584 | scope=None):
585 | """Extract classes, scores and bounding boxes from network output layers.
586 | Batch-compatible: inputs are supposed to have batch-type shapes.
587 |
588 | Args:
589 | predictions_net: List of prediction layers;
590 | localizations_net: List of localization layers;
591 | select_threshold: Classification threshold for selecting a box. If None,
592 | select boxes whose classification score is higher than 'no class'.
593 | Return:
594 | classes, scores, bboxes: Tensors.
595 | """
596 | with tf.name_scope(scope, 'aher_bboxes_select',
597 | [predictions_net, localizations_net]):
598 | l_classes = []
599 | l_scores = []
600 | l_bboxes = []
601 | for i in range(len(predictions_net)):
602 | classes, scores, bboxes = \
603 | tf_aher_bboxes_select_layer_all_classes(predictions_net[i],
604 | localizations_net[i],
605 | select_threshold)
606 | l_classes.append(classes)
607 | l_scores.append(scores)
608 | l_bboxes.append(bboxes)
609 |
610 | classes = tf.concat(l_classes, axis=1)
611 | scores = tf.concat(l_scores, axis=1)
612 | bboxes = tf.concat(l_bboxes, axis=1)
613 | return classes, scores, bboxes
614 |
615 |
--------------------------------------------------------------------------------
/aher_k600_test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import tensorflow as tf
3 | import numpy as np
4 | import aher_anet
5 | import tf_utils
6 | from datetime import datetime
7 | import time
8 | from data_loader import *
9 | import tf_extended as tfe
10 | import os
11 | import sys
12 | import pandas as pd
13 | from multiprocessing import Process,Queue
14 | import multiprocessing
15 | import math
16 | import random
17 | from evaluation import get_proposal_performance
18 |
19 | FLAGS = tf.app.flags.FLAGS
20 | os.environ["TF_CPP_MIN_LOG_LEVEL"]='2'
21 |
22 | tf.app.flags.DEFINE_float('loss_alpha', 1., 'Alpha parameter in the loss function.')
23 | tf.app.flags.DEFINE_float('match_threshold', 0.65, 'Matching threshold in the loss function.') # 0.65
24 | tf.app.flags.DEFINE_float('neg_match_threshold', 0.35, 'Negative threshold in the loss function.') # 0.3
25 | tf.app.flags.DEFINE_float('negative_ratio', 1., 'Negative ratio in the loss function.')
26 | tf.app.flags.DEFINE_float('label_smoothing', 0.0, 'The amount of label smoothing.')
27 | tf.app.flags.DEFINE_integer('output_idx', 1, 'The output index.')
28 | tf.app.flags.DEFINE_integer('training_epochs',7,'The training epochs number')
29 | tf.app.flags.DEFINE_float('cls_weights', 2.0, 'The weight for the classification.') # 0.1
30 | tf.app.flags.DEFINE_float('iou_weights', 25.0, 'The weight for the iou prediction.') # 5.0
31 | tf.app.flags.DEFINE_float('initial_learning_rate', 0.0001,'Initial learning rate.')
32 | # Flags for box generation
33 | tf.app.flags.DEFINE_float('select_threshold', 0.0, 'Selection threshold.')
34 | tf.app.flags.DEFINE_float('nms_threshold', 0.90, 'Non-Maximum Selection threshold.')
35 | tf.app.flags.DEFINE_integer('select_top_k', 765, 'Select top-k detected bounding boxes.')
36 | tf.app.flags.DEFINE_integer('keep_top_k', 100, 'Keep top-k detected objects.')
37 | tf.app.flags.DEFINE_bool('cls_flag', False, 'Utilize classification score.')
38 |
39 | def AHER_init():
40 | # AHER parameter and Model
41 | aher_anet_param = aher_anet.AHERNet.default_params
42 | aher_anet_model = aher_anet.AHERNet(aher_anet_param)
43 | aher_anet_temporal_shape = aher_anet_param.temporal_shape
44 | aher_anet_anchor = aher_anet_model.anchors(aher_anet_temporal_shape)
45 | return aher_anet_model,aher_anet_anchor
46 |
47 | def AHER_Predictor_Cls(aher_anet_model,aher_anet_anchor,
48 | feature,temporal_gt,vname,label,duration,reuse,class_num=600,cls_suffix='_anchor'):
49 | """ Model and loss function of sigle shot action localization
50 | feature: batch_size x 512 x 2048
51 | temporal_gt: batch_size x 25 x 2
52 | vname: batch_size x 1
53 | label: batch_size x 1
54 | duration: batch_size x 1
55 | """
56 |
57 | # Encode groundtruth labels and bboxes.
58 | gclasses, glocalisations, gscores, giou = \
59 | aher_anet_model.bboxes_encode(label, temporal_gt, aher_anet_anchor)
60 |
61 | # predict location and iou
62 | predictions, localisation, logits, proplogits, proppredictions, iouprediction, clsweights, clsbias, end_points = \
63 | aher_anet_model.net_pool_cls(feature, num_classes= class_num, untrim_num=class_num, is_training=True,reuse=reuse,cls_suffix=cls_suffix)
64 |
65 | return predictions, localisation, logits, proplogits, iouprediction, \
66 | clsweights, clsbias, gscores, giou, gclasses, glocalisations, end_points
67 |
68 | def AHER_Detection_Inference(aher_anet_model,aher_anet_anchor,
69 | feature,vname,label,duration, clsweights, clsbias, reuse, n_class,cls_suffix='_anet'):
70 | """ Inference bbox of sigle shot action localization
71 | feature: batch_size x 512 x 4069
72 | vname: batch_size x 1
73 | label: batch_size x 1
74 | duration: batch_size x 1
75 | """
76 |
77 | predictions, localisation, logits, proplogits, proppredictions, iouprediction, end_points \
78 | = aher_anet_model.net_prop_iou(feature,clsweights,clsbias,is_training=True,reuse=reuse,num_classes=n_class,cls_suffix=cls_suffix)
79 |
80 | # decode bounding box and get scores
81 | localisation = aher_anet_model.bboxes_decode_logits(localisation, duration ,aher_anet_anchor, predictions)
82 | if FLAGS.cls_flag:
83 | rscores, rbboxes = aher_anet_model.detected_bboxes_classwise(
84 | proppredictions, localisation,
85 | select_threshold=FLAGS.select_threshold,
86 | nms_threshold=FLAGS.nms_threshold,
87 | clipping_bbox=None,
88 | top_k=FLAGS.select_top_k,
89 | keep_top_k=FLAGS.keep_top_k,
90 | iou_flag=False)
91 | else:
92 | rscores, rbboxes = aher_anet_model.detected_bboxes_classwise(
93 | iouprediction, localisation,
94 | select_threshold=FLAGS.select_threshold,
95 | nms_threshold=FLAGS.nms_threshold,
96 | clipping_bbox=None,
97 | top_k=FLAGS.select_top_k,
98 | keep_top_k=FLAGS.keep_top_k,
99 | iou_flag=True)
100 | # compute pooling score
101 | lshape = tfe.get_shape(predictions[0], 8)
102 | num_classes = lshape[-1]
103 | batch_size = lshape[0]
104 | fprediction = []
105 | for i in range(len(predictions)):
106 | fprediction.append(tf.reshape(predictions[i], [-1, num_classes]))
107 | predictions = tf.concat(fprediction, axis=0)
108 | avergeprediction = tf.reduce_mean(predictions,axis=0)
109 | labelid = tf.argmax(avergeprediction, 0)
110 | argmaxid = tf.argmax(predictions, 1)
111 |
112 | prebbox={"rscores":rscores,"rbboxes":rbboxes,"label":labelid,"avescore":avergeprediction, \
113 | "rawscore":predictions,"argmaxid":argmaxid}
114 | return prebbox
115 |
116 | class Config(object):
117 | def __init__(self):
118 | self.learning_rates=[0.001]*100+[0.0001]*100
119 | #self.training_epochs = 15
120 | self.training_epochs = 12
121 | self.n_inputs = 2048
122 | self.batch_size = 8
123 | self.input_steps=512
124 | self.input_resize_steps=512
125 | self.gt_hold_num = 25
126 | self.batch_size_val=1
127 |
128 | if __name__ == "__main__":
129 | """ define the input and the network"""
130 | config = Config()
131 | LR= tf.placeholder(tf.float32)
132 |
133 | #--------------------------------------------Restore Folder and Feature Folder----------------------------------------------#
134 | eva_step_id = [8000,9000,10000] # The model id array to be evaluated
135 | output_folder_dir = 'p3d_k600_models/AHER_k600_gen_adv' # The restore folder to output localization results
136 | csv_dir = '/data/Kinetics-600/csv_p3d_clip_val_untrim_512' # validation clip feature for Kinetics-600
137 | csv_oris_dir = ''
138 | #------------------------------------------------------------------------------------------------------------------------------#
139 |
140 |
141 |
142 |
143 | #-------------------------------------------------------Test Placeholder-----------------------------------------------------#
144 | # train placehold
145 | feature_seg = tf.placeholder(tf.float32, shape=(config.batch_size,config.input_steps,config.n_inputs))
146 | temporal_gt_seg = tf.placeholder(tf.float32, shape=(config.batch_size, config.gt_hold_num, 2))
147 | vname_seg = tf.placeholder(tf.string, shape=(config.batch_size))
148 | label_seg = tf.placeholder(tf.int32, shape=(config.batch_size))
149 | duration_seg = tf.placeholder(tf.float32,shape=(config.batch_size))
150 |
151 | # val placehold
152 | feature_val = tf.placeholder(tf.float32, shape=(config.batch_size_val,config.input_steps,config.n_inputs))
153 | temporal_gt_val = tf.placeholder(tf.float32, shape=(config.batch_size_val, config.gt_hold_num, 2))
154 | vname_val = tf.placeholder(tf.string, shape=(config.batch_size_val))
155 | label_val = tf.placeholder(tf.int32, shape=(config.batch_size_val))
156 | duration_val = tf.placeholder(tf.float32,shape=(config.batch_size_val))
157 | #------------------------------------------------------------------------------------------------------------------------------#
158 |
159 |
160 |
161 |
162 | #----------------------------------------------- AherNet Structure ------------------------------------------------------------#
163 | aher_anet_model,aher_anet_anchor = AHER_init()
164 | # Initialize the backbone
165 | predictions_gene_seg, localisation_gene_seg, logits_gene_seg, proplogits_gene_seg, iouprediction_gene_seg, \
166 | clsweights_ki, clsbias_ki, gscore_gene_seg, \
167 | giou_gene_seg, gclasses_gene_seg, glocalisations_gene_seg, \
168 | end_points_gene_seg \
169 | = AHER_Predictor_Cls(aher_anet_model,aher_anet_anchor,feature_seg, \
170 | temporal_gt_seg,vname_seg,label_seg,duration_seg,reuse=False,class_num=600,cls_suffix='_anchor')
171 |
172 | # localization inference
173 | bbox = AHER_Detection_Inference(aher_anet_model,aher_anet_anchor,feature_val, \
174 | vname_val,label_val,duration_val, clsweights_ki, clsbias_ki, reuse=True, n_class=600,cls_suffix='_anchor')
175 | #------------------------------------------------------------------------------------------------------------------------------#
176 |
177 |
178 |
179 |
180 | AHER_trainable_variables=tf.trainable_variables()
181 |
182 | """ Init tf"""
183 | model_saver=tf.train.Saver(var_list=AHER_trainable_variables,max_to_keep=80)
184 | full_vars = []
185 | tf_config = tf.ConfigProto()
186 | tf_config.gpu_options.allow_growth = True
187 | tf_config.log_device_placement =True
188 | sess=tf.InteractiveSession(config=tf_config)
189 |
190 |
191 |
192 | #---------------------------------------------------- Test Data Input List -------------------------------------------------------#
193 | data_generate_val = KineticsDatasetLoadTrimList('val_moment',csv_dir,csv_oris_dir,25,512,df_file='./cvs_records_k600/k600_train_val_info.csv')
194 | print('Load train feature, wait...')
195 | while True:
196 | if len(data_generate_val.train_feature_list) == 6459:
197 | data_generate_val.stop_process()
198 | time.sleep(30)
199 | break
200 | print('Wait done.')
201 | print('The val feature len is %d'%(len(data_generate_val.train_feature_list)))
202 | dataset_val = tf.data.Dataset.from_generator(data_generate_val.gen,
203 | (tf.float32, tf.float32, tf.string, tf.int32, tf.float32),
204 | (tf.TensorShape([512, 2048]),tf.TensorShape([25, 2]),tf.TensorShape([]),tf.TensorShape([]),tf.TensorShape([])))
205 | dataset_val = dataset_val.batch(config.batch_size_val)
206 | batch_num_val = int(len(data_generate_val.train_feature_list) / config.batch_size_val)
207 | iterator_val = dataset_val.make_one_shot_iterator()
208 | feature_g_val, \
209 | video_gt_g_val, \
210 | video_name_g_val, \
211 | video_label_g_val, \
212 | video_duration_g_val = iterator_val.get_next()
213 | #---------------------------------------------------------------------------------------------------------------------------------#
214 |
215 |
216 |
217 | #---------------------------------------------------- Category Information --------------------------------------------------------#
218 | # load category information
219 | cateidx_name = {}
220 | category_info = pd.read_csv('k600_category.txt',sep='\n',header=None)
221 | for i in range(len(category_info)):
222 | name = category_info.loc[i][0]
223 | cateidx_name[i] = name
224 | #---------------------------------------------------------------------------------------------------------------------------------------#
225 |
226 |
227 |
228 |
229 | #---------------------------------------------------- Detection Result Extraction --------------------------------------------------------#
230 | fw_log = open('%s/log_res_k600_gene_dec.txt'%(output_folder_dir),'w',1)
231 | with tf.Session() as sess:
232 | tf.global_variables_initializer().run()
233 | tf.local_variables_initializer().run()
234 | for epoch in eva_step_id:
235 | print('Restore model %s/aher_adv_model_checkpoint-step_%d'%(output_folder_dir,epoch))
236 | model_saver.restore(sess,"%s/aher_adv_model_checkpoint-step_%d"%(output_folder_dir,epoch))
237 | """ Validation"""
238 | hit_video_num = 0
239 | total_video_num = 0
240 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + \
241 | "----Epoch-%d Val set validata start." %(epoch))
242 | fw_proposal = open('%s/results_aher_k600_gendec_step-%d.json'%(output_folder_dir,epoch),'w')
243 | fw_proposal.write('{\"version\": \"VERSION 1.3\", \"results\": {')
244 | for idx in range(batch_num_val):
245 | feature_batch_val, \
246 | video_gt_batch_val, \
247 | video_name_batch_val, \
248 | video_label_batch_val, \
249 | video_duration_batch_val = sess.run([feature_g_val,video_gt_g_val,video_name_g_val,video_label_g_val,video_duration_g_val])
250 | out_bbox=sess.run(bbox,feed_dict={feature_val:feature_batch_val,
251 | vname_val:video_name_batch_val,
252 | label_val:video_label_batch_val,
253 | duration_val:video_duration_batch_val})
254 | predict_score = out_bbox['rscores']
255 | predict_bbox = out_bbox['rbboxes']
256 | average_cls_score = out_bbox['avescore']
257 | cls_id = out_bbox['label']
258 | raw_score = out_bbox['rawscore']
259 | argmaxid = out_bbox['argmaxid']
260 | real_label = video_label_batch_val[0]
261 | if real_label == cls_id: hit_video_num += 1
262 | total_video_num += 1
263 | cateids = 1
264 | write_first = 0
265 | # write the json file
266 | if idx!=0: fw_proposal.write(', ')
267 | fw_proposal.write('\"%s\":['%(video_name_batch_val[0].decode("utf-8")))
268 | for kk in range(len(predict_score[1][0])):
269 | # compute the ranking score for each localization result
270 | score = predict_score[cateids][0][kk] * average_cls_score[cls_id]
271 | start_time = predict_bbox[cateids][0][kk][0]
272 | end_time = predict_bbox[cateids][0][kk][1]
273 | if score > 0 and end_time - start_time > 0:
274 | if write_first == 0:
275 | fw_proposal.write('{\"score\": %f, \"segment\": [%f, %f], \"label\": \"%s\"}'% \
276 | (score,start_time,end_time,cateidx_name[cls_id])) #
277 | else:
278 | fw_proposal.write(', {\"score\": %f, \"segment\": [%f, %f], \"label\": \"%s\"}'% \
279 | (score,start_time,end_time,cateidx_name[cls_id]))
280 | write_first += 1
281 | fw_proposal.write(']')
282 | fw_proposal.write('}, \"external_data\": {}}')
283 | fw_proposal.close()
284 | # evaluation for temporal action proposal
285 | an_ar,recall = get_proposal_performance.evaluate_return_area('data_split/k600_val_action.json', \
286 | '%s/results_aher_k600_gendec_step-%d.json'%(output_folder_dir,epoch))
287 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + \
288 | "----Epoch-%d Val set validata finished." %(epoch))
289 | accuracy_video = hit_video_num / total_video_num
290 |
291 | print('*********************************************************')
292 | print("Epoch-%d Val AUC: %.04f AverageRecall: %04f Accuracy: %04f" %(epoch,an_ar,recall,accuracy_video))
293 | print('*********************************************************')
294 | fw_log.write('*********************************************************\n')
295 | fw_log.write("Epoch-%d Val AUC: %.04f AverageRecall: %04f Accuracy: %04f\n" %(epoch,an_ar,recall,accuracy_video))
296 | fw_log.write('*********************************************************\n')
297 |
298 | fw_log.close()
299 |
--------------------------------------------------------------------------------
/custom_layers.py:
--------------------------------------------------------------------------------
1 | """Implement some custom layers, not provided by TensorFlow.
2 |
3 | Trying to follow as much as possible the style/standards used in
4 | tf.contrib.layers
5 | """
6 | import tensorflow as tf
7 |
8 | from tensorflow.contrib.framework.python.ops import add_arg_scope
9 | from tensorflow.contrib.layers.python.layers import initializers
10 | from tensorflow.contrib.framework.python.ops import variables
11 | from tensorflow.contrib.layers.python.layers import utils
12 | from tensorflow.python.ops import nn
13 | from tensorflow.python.ops import init_ops
14 | from tensorflow.python.ops import variable_scope
15 |
16 |
17 | def abs_smooth(x):
18 | """Smoothed absolute function. Useful to compute an L1 smooth error.
19 |
20 | Define as:
21 | x^2 / 2 if abs(x) < 1
22 | abs(x) - 0.5 if abs(x) > 1
23 | We use here a differentiable definition using min(x) and abs(x). Clearly
24 | not optimal, but good enough for our purpose!
25 | """
26 | absx = tf.abs(x)
27 | minx = tf.minimum(absx, 1)
28 | r = 0.5 * ((absx - 1) * minx + absx)
29 | return r
30 |
31 | def single_hinge_loss(x,label,hinge=1.0):
32 | hinge_value = tf.constant(hinge)
33 | mini_hinge = tf.constant(0.0)
34 | r = tf.maximum(hinge_value - label * x, mini_hinge)
35 | return r
36 |
37 | @add_arg_scope
38 | def l2_normalization(
39 | inputs,
40 | scaling=False,
41 | scale_initializer=init_ops.ones_initializer(),
42 | reuse=None,
43 | variables_collections=None,
44 | outputs_collections=None,
45 | data_format='NHWC',
46 | trainable=True,
47 | scope=None):
48 | """Implement L2 normalization on every feature (i.e. spatial normalization).
49 |
50 | Should be extended in some near future to other dimensions, providing a more
51 | flexible normalization framework.
52 |
53 | Args:
54 | inputs: a 4-D tensor with dimensions [batch_size, height, width, channels].
55 | scaling: whether or not to add a post scaling operation along the dimensions
56 | which have been normalized.
57 | scale_initializer: An initializer for the weights.
58 | reuse: whether or not the layer and its variables should be reused. To be
59 | able to reuse the layer scope must be given.
60 | variables_collections: optional list of collections for all the variables or
61 | a dictionary containing a different list of collection per variable.
62 | outputs_collections: collection to add the outputs.
63 | data_format: NHWC or NCHW data format.
64 | trainable: If `True` also add variables to the graph collection
65 | `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
66 | scope: Optional scope for `variable_scope`.
67 | Returns:
68 | A `Tensor` representing the output of the operation.
69 | """
70 |
71 | with variable_scope.variable_scope(
72 | scope, 'L2Normalization', [inputs], reuse=reuse) as sc:
73 | inputs_shape = inputs.get_shape()
74 | inputs_rank = inputs_shape.ndims
75 | dtype = inputs.dtype.base_dtype
76 | if data_format == 'NHWC':
77 | # norm_dim = tf.range(1, inputs_rank-1)
78 | norm_dim = tf.range(inputs_rank-1, inputs_rank)
79 | params_shape = inputs_shape[-1:]
80 | elif data_format == 'NCHW':
81 | # norm_dim = tf.range(2, inputs_rank)
82 | norm_dim = tf.range(1, 2)
83 | params_shape = (inputs_shape[1])
84 |
85 | # Normalize along spatial dimensions.
86 | outputs = nn.l2_normalize(inputs, norm_dim, epsilon=1e-12)
87 | # Additional scaling.
88 | if scaling:
89 | scale_collections = utils.get_variable_collections(
90 | variables_collections, 'scale')
91 | scale = variables.model_variable('gamma',
92 | shape=params_shape,
93 | dtype=dtype,
94 | initializer=scale_initializer,
95 | collections=scale_collections,
96 | trainable=trainable)
97 | if data_format == 'NHWC':
98 | outputs = tf.multiply(outputs, scale)
99 | elif data_format == 'NCHW':
100 | scale = tf.expand_dims(scale, axis=-1)
101 | scale = tf.expand_dims(scale, axis=-1)
102 | outputs = tf.multiply(outputs, scale)
103 | # outputs = tf.transpose(outputs, perm=(0, 2, 3, 1))
104 |
105 | return utils.collect_named_outputs(outputs_collections,
106 | sc.original_name_scope, outputs)
107 |
108 | @add_arg_scope
109 | def pad2d(inputs,
110 | pad=(0, 0),
111 | mode='CONSTANT',
112 | data_format='NHWC',
113 | trainable=True,
114 | scope=None):
115 | """2D Padding layer, adding a symmetric padding to H and W dimensions.
116 |
117 | Aims to mimic padding in Caffe and MXNet, helping the port of models to
118 | TensorFlow. Tries to follow the naming convention of `tf.contrib.layers`.
119 |
120 | Args:
121 | inputs: 4D input Tensor;
122 | pad: 2-Tuple with padding values for H and W dimensions;
123 | mode: Padding mode. C.f. `tf.pad`
124 | data_format: NHWC or NCHW data format.
125 | """
126 | with tf.name_scope(scope, 'pad2d', [inputs]):
127 | # Padding shape.
128 | if data_format == 'NHWC':
129 | paddings = [[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]]
130 | elif data_format == 'NCHW':
131 | paddings = [[0, 0], [0, 0], [pad[0], pad[0]], [pad[1], pad[1]]]
132 | net = tf.pad(inputs, paddings, mode=mode)
133 | return net
134 |
135 | @add_arg_scope
136 | def channel_to_last(inputs,
137 | data_format='NHWC',
138 | scope=None):
139 | """Move the channel axis to the last dimension. Allows to
140 | provide a single output format whatever the input data format.
141 |
142 | Args:
143 | inputs: Input Tensor;
144 | data_format: NHWC or NCHW.
145 | Return:
146 | Input in NHWC format.
147 | """
148 | with tf.name_scope(scope, 'channel_to_last', [inputs]):
149 | if data_format == 'NHWC':
150 | net = inputs
151 | elif data_format == 'NCHW':
152 | net = tf.transpose(inputs, perm=(0, 2, 3, 1))
153 | return net
154 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import pandas as pd
4 | import json
5 | import scipy
6 | import time
7 | import tensorflow as tf
8 | from multiprocessing import Process,Queue,JoinableQueue
9 | import multiprocessing
10 |
11 | x_tdim = 512
12 | fea_dim_th = 2048
13 |
14 | def pool_fea_interpolate_th(fea_temp,feature_count,duration,fnumber,resize_temporal_dim=100,pool_type="mean"):
15 |
16 | num_bin = 1
17 | num_sample_bin=3
18 |
19 | num_prop = resize_temporal_dim
20 | video_frame = fnumber
21 | video_second = duration
22 |
23 | data = fea_temp[0:feature_count]
24 |
25 |
26 | feature_frame=feature_count*8
27 | corrected_second=float(feature_frame)/video_frame*video_second
28 | fps=float(video_frame)/video_second
29 | st=8/fps
30 |
31 | if feature_count==1:
32 | video_feature=np.stack([data]*num_prop)
33 | video_feature=np.reshape(video_feature,[num_prop,fea_dim_th])
34 | return video_feature
35 |
36 | x=[st/2+ii*st for ii in range(feature_count)]
37 | f=scipy.interpolate.interp1d(x,data,axis=0)
38 |
39 | video_feature=[]
40 | zero_sample=np.zeros(num_bin*fea_dim_th)
41 | tmp_anchor_xmin=[1.0/num_prop*i for i in range(num_prop)]
42 | tmp_anchor_xmax=[1.0/num_prop*i for i in range(1,num_prop+1)]
43 |
44 | num_sample=num_bin*num_sample_bin
45 | for idx in range(num_prop):
46 | xmin=max(x[0]+0.0001,tmp_anchor_xmin[idx]*corrected_second)
47 | xmax=min(x[-1]-0.0001,tmp_anchor_xmax[idx]*corrected_second)
48 | if xmaxx[-1]:
52 | video_feature.append(zero_sample)
53 | continue
54 |
55 | plen=(xmax-xmin)/(num_sample-1)
56 | x_new=[xmin+plen*ii for ii in range(num_sample)]
57 | y_new=f(x_new)
58 | y_new_pool=[]
59 | for b in range(num_bin):
60 | tmp_y_new=y_new[num_sample_bin*b:num_sample_bin*(b+1)]
61 | if pool_type=="mean":
62 | tmp_y_new=np.mean(y_new,axis=0)
63 | elif pool_type=="max":
64 | tmp_y_new=np.max(y_new,axis=0)
65 | y_new_pool.append(tmp_y_new)
66 | y_new_pool=np.stack(y_new_pool)
67 | y_new_pool=np.reshape(y_new_pool,[-1])
68 | video_feature.append(y_new_pool)
69 | video_feature=np.stack(video_feature)
70 | return video_feature
71 |
72 | def load_json(file):
73 | with open(file) as json_file:
74 | data = json.load(json_file)
75 | return data
76 |
77 | def getDatasetDict(df_file='./cvs_records/anet_train_val_info.csv'):
78 | """Load dataset file
79 | """
80 | df=pd.read_csv(df_file)
81 | json_data= load_json("./cvs_records_anet/anet_annotation_json.json")
82 | database=json_data
83 | train_dict={}
84 | val_dict={}
85 | test_dict={}
86 | miss_count = 0
87 | for i in range(len(df)):
88 | video_name=df.video.values[i]
89 | if video_name not in database.keys():
90 | miss_count += 1
91 | continue
92 | video_info=database[video_name]
93 | video_new_info={}
94 | video_new_info['duration_frame']=video_info['duration_frame']
95 | video_new_info['duration_second']=video_info['duration_second']
96 | video_new_info["feature_frame"]=video_info['feature_frame']
97 | video_subset=df.subset.values[i]
98 | video_new_info['annotations']=video_info['annotations']
99 | if video_subset=="training":
100 | train_dict[video_name]=video_new_info
101 | elif video_subset=="validation":
102 | val_dict[video_name]=video_new_info
103 | elif video_subset=="testing":
104 | test_dict[video_name]=video_new_info
105 | print('Miss video annotation: %d'%(miss_count))
106 | return train_dict,val_dict,test_dict
107 |
108 | def getDatasetDictK600(df_file='./cvs_records_k600/k600_train_val_info.csv'):
109 | """Load dataset file
110 | """
111 | df=pd.read_csv(df_file)
112 | json_data= load_json("./cvs_records_k600/k600_annotation_json.json")
113 | database=json_data
114 | train_dict={}
115 | val_dict={}
116 | test_dict={}
117 | miss_count = 0
118 | for i in range(len(df)):
119 | video_name=df.video.values[i]
120 | if video_name not in database.keys():
121 | miss_count += 1
122 | continue
123 | video_info=database[video_name]
124 | video_new_info={}
125 | video_new_info['duration_frame']=video_info['duration_frame']
126 | video_new_info['duration_second']=video_info['duration_second']
127 | video_new_info["feature_frame"]=video_info['feature_frame']
128 | video_subset=df.subset.values[i]
129 | video_new_info['annotations']=video_info['annotations']
130 | if video_subset=="training":
131 | train_dict[video_name]=video_new_info
132 | elif video_subset=="validation":
133 | val_dict[video_name]=video_new_info
134 | elif video_subset=="testing":
135 | test_dict[video_name]=video_new_info
136 | print('Miss video annotation: %d'%(miss_count))
137 | return train_dict,val_dict,test_dict
138 |
139 | def getSplitsetDict(df_file='./cvs_records/anet_train_val_info.csv'):
140 | """Load dataset file
141 | ActivityNet:
142 | untrim set class number: 87
143 | moment set class number: 113
144 | """
145 | cate_info = pd.read_csv('activitynet_train_val_gt.txt',sep = ' ',header = None)
146 | cate_dict = {}
147 | for v_id in range(len(cate_info)):
148 | v = cate_info.loc[v_id][0]
149 | cate = cate_info.loc[v_id][2]
150 | cate_dict[v] = cate
151 |
152 | moment_info = pd.read_csv('data_split/label_map_moment.txt',sep = '\n',header = None)
153 | moment_cate = []
154 | for idm in range(len(moment_info)):
155 | moment_cate.append(moment_info.loc[idm][0])
156 |
157 | untrim_info = pd.read_csv('data_split/label_map_untrim.txt',sep = '\n',header = None)
158 | untrim_cate = []
159 | for idm in range(len(untrim_info)):
160 | untrim_cate.append(untrim_info.loc[idm][0])
161 |
162 | df=pd.read_csv("./cvs_records/anet_train_val_info.csv")
163 | json_data= load_json("./cvs_records/anet_annotation_json.json")
164 | database=json_data
165 | train_untrim_dict={}
166 | train_moment_dict={}
167 | val_untrim_dict={}
168 | val_moment_dict={}
169 | miss_count = 0
170 | for i in range(len(df)):
171 | video_name=df.video.values[i]
172 | if video_name not in database.keys():
173 | miss_count += 1
174 | continue
175 | video_info=database[video_name]
176 | video_cate=cate_dict[video_name]
177 | video_new_info={}
178 | video_new_info['duration_frame']=video_info['duration_frame']
179 | video_new_info['duration_second']=video_info['duration_second']
180 | video_new_info["feature_frame"]=video_info['feature_frame']
181 | video_subset=df.subset.values[i]
182 | video_new_info['annotations']=video_info['annotations']
183 |
184 | if video_subset=="training" and video_cate in untrim_cate:
185 | train_untrim_dict[video_name]=video_new_info
186 | if video_subset=="training" and video_cate in moment_cate:
187 | train_moment_dict[video_name]=video_new_info
188 | if video_subset=="validation" and video_cate in untrim_cate:
189 | val_untrim_dict[video_name]=video_new_info
190 | if video_subset=="validation" and video_cate in moment_cate:
191 | val_moment_dict[video_name]=video_new_info
192 |
193 | print('Miss video annotation: %d'%(miss_count))
194 | return train_untrim_dict,train_moment_dict,val_untrim_dict,val_moment_dict
195 |
196 | class AnetDatasetLoadForeQueue():
197 | def __init__(self,dataSet,csv_dir,csv_oris_dir,gt_hold_num,x_tdim,df_file='./cvs_records_anet/anet_train_val_reduce_info.csv',untrim_start=87,resize_dim=512):
198 | self.train_dict,self.val_dict,self.test_dict=getDatasetDict(df_file=df_file)
199 | if dataSet == 'train':
200 | self.train_list = list(self.train_dict.keys())
201 | else:
202 | self.train_list = list(self.val_dict.keys())
203 | self.train_list_copy = self.train_list
204 | cate_info = pd.read_csv('activitynet_train_val_gt.txt',sep = ' ',header = None)
205 | self.cate_dict = {}
206 | for v_id in range(len(cate_info)):
207 | v = cate_info.loc[v_id][0]
208 | cate = cate_info.loc[v_id][2]
209 | self.cate_dict[v] = cate
210 | self.shuffle()
211 | self.resize_dim = resize_dim
212 | self.dataSet = dataSet
213 | self.csv_dir = csv_dir
214 | self.csv_oris_dir = csv_oris_dir
215 | self.gt_hold_num = gt_hold_num
216 | self.x_tdim = x_tdim
217 | self.num_samples = len(self.train_list)
218 | self.batch_size = 1
219 | self.queue = Queue(maxsize=1536) #Joinable
220 | self.train_feature_list = multiprocessing.Manager().list()
221 | self.process_num = 32
222 | self.process_array = []
223 | self.start_train_idx = 0
224 | for pid in range(self.process_num):
225 | pid_num = int(self.num_samples / self.process_num)
226 | start_idx = pid*pid_num
227 | end_idx = (pid+1)*pid_num
228 | if pid == (self.process_num - 1):
229 | if end_idx < self.num_samples: end_idx = self.num_samples
230 | t = Process(target=self.load_queue,args=(start_idx,end_idx))
231 | self.process_array.append(t)
232 | t.start()
233 | def shuffle(self):
234 | randperm = np.random.permutation(len(self.train_list))
235 | self.train_list = [self.train_list_copy[int(randperm[idx])] for idx in range(len(randperm))]
236 | def load_single_feature(self,start):
237 | result = []
238 | name = self.train_list[start]
239 | s_bbox = []
240 | r_bbox = []
241 | video_info=self.train_dict[name]
242 | video_frame=video_info['duration_frame']
243 | video_second=video_info['duration_second']
244 | feature_frame=video_info['feature_frame']
245 | feature_num = int(feature_frame / 8)
246 | corrected_second=float(feature_frame)/video_frame*video_second
247 | video_labels=video_info['annotations']
248 | vlabels = self.cate_dict[name]
249 | for j in range(len(video_labels)):
250 | tmp_info=video_labels[j]
251 | tmp_start=tmp_info['segment'][0]
252 | tmp_end=tmp_info['segment'][1]
253 | tmp_start=max(min(1,tmp_start/corrected_second),0)
254 | tmp_end=max(min(1,tmp_end/corrected_second),0)
255 | tmp_fea_start = tmp_start*self.x_tdim
256 | tmp_fea_end = tmp_end*self.x_tdim
257 | s_bbox.append([tmp_fea_start,tmp_fea_end])
258 | tmp_len = tmp_end - tmp_start
259 | if tmp_len*feature_num >= 32: r_bbox.append([tmp_start,tmp_end])
260 |
261 | tmp_df=pd.read_csv(self.csv_oris_dir+"/"+name+".csv",skiprows=1,header=None,sep=',')
262 | tmp_values = tmp_df.values[:,:]
263 | for m in range(len(r_bbox)):
264 | tmp_start_fea,tmp_end_fea = int(r_bbox[m][0]*feature_num),int(r_bbox[m][1]*feature_num)
265 | proposal_fea = tmp_values[tmp_start_fea:tmp_end_fea,:]
266 | proposal_duration,proposal_frame = (r_bbox[m][1]-r_bbox[m][0])*video_second,(r_bbox[m][1]-r_bbox[m][0])*video_frame
267 | proposal_inter_fea = pool_fea_interpolate_th(proposal_fea,len(proposal_fea),proposal_duration,proposal_frame,self.resize_dim)
268 | result.append([proposal_inter_fea,[[128.0,384.0]],name,vlabels,proposal_duration*2])
269 | return result
270 | def load_queue(self,start,end):
271 | bindex = 0
272 | lendata = end - start
273 | while True:
274 | t = self.load_single_feature(bindex+start)
275 | if self.dataSet == 'train_untrim' or \
276 | self.dataSet == 'train_momwhole' or \
277 | self.dataSet == 'val_untrim' or \
278 | self.dataSet == 'val_moment':
279 | self.queue.put(t)
280 | else:
281 | for z in range(len(t)):
282 | self.queue.put(t[z])
283 | bindex += 1
284 | if bindex == lendata:
285 | bindex = 0
286 | def gen(self):
287 | while True:
288 | t = self.queue.get()
289 | yield t[0],t[1],t[2],t[3],t[4]
290 | def stop_process(self):
291 | for t in self.process_array:
292 | t.terminate()
293 | t.join()
294 | print('Stop the process sucess.')
295 |
296 | class KineticsDatasetLoadTrimMultiProcess():
297 | def __init__(self,dataSet,csv_dir,csv_oris_dir,gt_hold_num,x_tdim,df_file='./cvs_records_k600/k600_train_val_info.csv',resize_dim=256,ratio_sid=1,ratio_eid=9):
298 | train_video_dict, val_video_dict, test_video_dict=getDatasetDictK600(df_file=df_file)
299 | if "train" in dataSet:
300 | self.train_dict=train_video_dict
301 | else:
302 | self.train_dict=val_video_dict
303 | self.dataSet = dataSet
304 | self.resize_dim = resize_dim
305 | self.ratio_sid = ratio_sid
306 | self.ratio_eid = ratio_eid
307 | self.train_list = list(self.train_dict.keys())
308 | self.train_list_copy = self.train_list
309 | print('The %s video num is: %d'%(self.dataSet,len(self.train_list)))
310 | print('Load k600 train val gt info...')
311 | cate_info = pd.read_csv('k600_train_val_gt.txt',sep = ' ',header = None)
312 | self.cate_dict = {}
313 | for v_id in range(len(cate_info)):
314 | v = cate_info.loc[v_id][0]
315 | cate = cate_info.loc[v_id][2]
316 | self.cate_dict[v] = cate
317 | print('Load Done.')
318 | self.shuffle()
319 | self.csv_dir = csv_dir
320 | self.csv_oris_dir = csv_oris_dir
321 | self.gt_hold_num = gt_hold_num
322 | self.x_tdim = x_tdim
323 | self.num_samples = len(self.train_list)
324 | self.batch_size = 1
325 | #self.train_feature_list = multiprocessing.Manager().list()
326 | self.process_num = 32
327 | self.process_array = []
328 | self.queue = Queue(maxsize=1536) #Joinable
329 | self.start_train_idx = 0
330 | for pid in range(self.process_num):
331 | pid_num = int(self.num_samples / self.process_num)
332 | start_idx = pid*pid_num
333 | end_idx = (pid+1)*pid_num
334 | if pid == (self.process_num - 1):
335 | if end_idx < self.num_samples: end_idx = self.num_samples
336 | t = Process(target=self.load_queue,args=(start_idx,end_idx))
337 | self.process_array.append(t)
338 | t.start()
339 | def shuffle(self):
340 | randperm = np.random.permutation(len(self.train_list))
341 | self.train_list = [self.train_list_copy[int(randperm[idx])] for idx in range(len(randperm))]
342 | def load_single_feature(self,start):
343 | result = []
344 | name = self.train_list[start]
345 | s_bbox = []
346 | r_bbox = []
347 | label_bbox=[]
348 | slabel_box = []
349 | rlabel_bbox = []
350 | video_info=self.train_dict[name]
351 | video_frame=video_info['duration_frame']
352 | video_second=video_info['duration_second']
353 | feature_frame=video_info['feature_frame']
354 | feature_num = int(feature_frame / 8)
355 | corrected_second=float(feature_frame)/video_frame*video_second
356 | video_labels=video_info['annotations']
357 | vlabels = self.cate_dict[name]
358 | for j in range(len(video_labels)):
359 | tmp_info=video_labels[j]
360 | tmp_start=tmp_info['segment'][0]
361 | tmp_end=tmp_info['segment'][1]
362 | tmp_label=tmp_info['label']
363 | tmp_start=max(min(1,tmp_start/corrected_second),0)
364 | tmp_end=max(min(1,tmp_end/corrected_second),0)
365 | tmp_fea_start = tmp_start*self.x_tdim
366 | tmp_fea_end = tmp_end*self.x_tdim
367 | s_bbox.append([tmp_fea_start,tmp_fea_end])
368 | tmp_len = tmp_end - tmp_start
369 | if tmp_len*feature_num >= 8:
370 | r_bbox.append([tmp_start,tmp_end])
371 | if self.dataSet == 'val_moment':
372 | tmp_df=pd.read_csv(self.csv_dir+"/"+name+".csv",skiprows=1,header=None,sep=',')
373 | anchor_feature = tmp_df.values[:,:]
374 | for kk in range(len(s_bbox),self.gt_hold_num):
375 | s_bbox.append([-1.0,-1.0])
376 | result.append(anchor_feature)
377 | result.append(s_bbox)
378 | result.append(name)
379 | result.append(vlabels)
380 | result.append(video_second)
381 | return result
382 | else:
383 | tmp_df=pd.read_csv(self.csv_oris_dir+"/"+name+".csv",skiprows=1,header=None,sep=',')
384 | tmp_values = tmp_df.values[:,:]
385 | for m in range(len(r_bbox)):
386 | tmp_start_fea,tmp_end_fea = int(r_bbox[m][0]*feature_num),int(r_bbox[m][1]*feature_num)
387 | proposal_fea = tmp_values[tmp_start_fea:tmp_end_fea,:]
388 | proposal_duration,proposal_frame = (r_bbox[m][1]-r_bbox[m][0])*video_second,(r_bbox[m][1]-r_bbox[m][0])*video_frame
389 | proposal_inter_fea = pool_fea_interpolate_th(proposal_fea,len(proposal_fea),proposal_duration,proposal_frame,self.resize_dim)
390 | result.append([proposal_inter_fea,[[128.0,384.0]],name,vlabels,proposal_duration*2])
391 | return result
392 | def load_queue(self,start,end):
393 | bindex = 0
394 | lendata = end - start
395 | while True:
396 | t = self.load_single_feature(bindex+start)
397 | if self.dataSet == 'val_moment':
398 | self.queue.put(t)
399 | else:
400 | for z in range(len(t)):
401 | self.queue.put(t[z])
402 | bindex += 1
403 | if bindex == lendata: bindex = 0
404 | def gen(self):
405 | while True:
406 | if self.queue.empty():
407 | print('The Kinetics %s Queue is empty, loading feature...'%(self.dataSet))
408 | time.sleep(10)
409 | else:
410 | t = self.queue.get_nowait()
411 | yield t[0],t[1],t[2],t[3],t[4]
412 | def stop_process(self):
413 | for t in self.process_array:
414 | t.terminate()
415 | t.join()
416 | print('Stop the process sucess.')
417 | def gen_pos_random(self):
418 | while True:
419 | t = self.queue.get()
420 | yield t[0],t[1],t[2],t[3],t[4],random.randint(self.ratio_sid,self.ratio_eid),random.randint(1,19)
421 |
422 | class AnetDatasetLoadMultiProcessQueue():
423 | def __init__(self,dataSet,csv_dir,gt_hold_num,x_tdim,df_file='./cvs_records_anet/anet_train_val_reduce_info.csv'):
424 | self.train_dict,self.val_dict,self.test_dict=getDatasetDict(df_file=df_file)
425 | if dataSet == 'train':
426 | self.train_list = list(self.train_dict.keys())
427 | else:
428 | self.train_list = list(self.val_dict.keys())
429 | self.train_list_copy = self.train_list
430 | cate_info = pd.read_csv('activitynet_train_val_gt.txt',sep = ' ',header = None)
431 | self.cate_dict = {}
432 | for v_id in range(len(cate_info)):
433 | v = cate_info.loc[v_id][0]
434 | cate = cate_info.loc[v_id][2]
435 | self.cate_dict[v] = cate
436 | self.shuffle()
437 | self.dataSet = dataSet
438 | self.csv_dir = csv_dir
439 | self.gt_hold_num = gt_hold_num
440 | self.x_tdim = x_tdim
441 | self.num_samples = len(self.train_list)
442 | self.batch_size = 1
443 | self.queue = Queue(maxsize=1536) #Joinable
444 | self.train_feature_list = multiprocessing.Manager().list()
445 | self.process_num = 32
446 | self.process_array = []
447 | self.start_train_idx = 0
448 | for pid in range(self.process_num):
449 | pid_num = int(self.num_samples / self.process_num)
450 | start_idx = pid*pid_num
451 | end_idx = (pid+1)*pid_num
452 | if pid == (self.process_num - 1):
453 | if end_idx < self.num_samples: end_idx = self.num_samples
454 | t = Process(target=self.load_queue,args=(start_idx,end_idx))
455 | self.process_array.append(t)
456 | t.start()
457 | def shuffle(self):
458 | randperm = np.random.permutation(len(self.train_list))
459 | self.train_list = [self.train_list_copy[int(randperm[idx])] for idx in range(len(randperm))]
460 | def load_single_feature(self,start):
461 | result = []
462 | name = self.train_list[start]
463 | s_bbox = []
464 | if self.dataSet == 'train':
465 | video_info=self.train_dict[name]
466 | else:
467 | video_info=self.val_dict[name]
468 | video_frame=video_info['duration_frame']
469 | video_second=video_info['duration_second']
470 | feature_frame=video_info['feature_frame']
471 | corrected_second=float(feature_frame)/video_frame*video_second
472 | video_labels=video_info['annotations']
473 | vlabels = self.cate_dict[name]
474 | for j in range(len(video_labels)):
475 | tmp_info=video_labels[j]
476 | tmp_start=tmp_info['segment'][0]
477 | tmp_end=tmp_info['segment'][1]
478 | tmp_start=max(min(1,tmp_start/corrected_second),0)
479 | tmp_end=max(min(1,tmp_end/corrected_second),0)
480 | tmp_fea_start = tmp_start*self.x_tdim
481 | tmp_fea_end = tmp_end*self.x_tdim
482 | s_bbox.append([tmp_fea_start,tmp_fea_end])
483 | for kk in range(len(s_bbox),self.gt_hold_num):
484 | s_bbox.append([-1.0,-1.0])
485 | tmp_df=pd.read_csv(self.csv_dir+"/"+name+".csv",skiprows=1,header=None,sep=',')
486 | anchor_feature = tmp_df.values[:,:]
487 | result.append(anchor_feature)
488 | result.append(s_bbox)
489 | result.append(name)
490 | result.append(vlabels)
491 | result.append(video_second)
492 | return result
493 | def load_queue(self,start,end):
494 | bindex = 0
495 | lendata = end - start
496 | while True:
497 | t = self.load_single_feature(bindex+start)
498 | self.queue.put(t)
499 | bindex += 1
500 | if bindex == lendata:
501 | bindex = 0
502 | def gen(self):
503 | while True:
504 | t = self.queue.get()
505 | yield t[0],t[1],t[2],t[3],t[4]
506 | def stop_process(self):
507 | for t in self.process_array:
508 | t.terminate()
509 | t.join()
510 | print('Stop the process sucess.')
511 |
512 | #-------------------------- Test Data Loader -----------------------------------#
513 | class KineticsDatasetLoadTrimList():
514 | def __init__(self,dataSet,csv_dir,csv_oris_dir,gt_hold_num,x_tdim,df_file='./cvs_records_k600/k600_train_val_info.csv',resize_dim=256):
515 | train_video_dict, val_video_dict, test_video_dict=getDatasetDictK600(df_file=df_file)
516 | if "train" in dataSet:
517 | self.train_dict=train_video_dict
518 | else:
519 | self.train_dict=val_video_dict
520 | self.dataSet = dataSet
521 | self.resize_dim = resize_dim
522 | self.train_list = list(self.train_dict.keys())
523 | self.train_list_copy = self.train_list
524 | print('The %s video num is: %d'%(self.dataSet,len(self.train_list)))
525 | print('Load k600 train val gt info...')
526 | cate_info = pd.read_csv('k600_train_val_gt.txt',sep = ' ',header = None)
527 | self.cate_dict = {}
528 | for v_id in range(len(cate_info)):
529 | v = cate_info.loc[v_id][0]
530 | cate = cate_info.loc[v_id][2]
531 | self.cate_dict[v] = cate
532 | print('Load Done.')
533 | self.shuffle()
534 | self.csv_dir = csv_dir
535 | self.csv_oris_dir = csv_oris_dir
536 | self.gt_hold_num = gt_hold_num
537 | self.x_tdim = x_tdim
538 | self.num_samples = len(self.train_list)
539 | self.batch_size = 1
540 | self.train_feature_list = multiprocessing.Manager().list()
541 | self.process_num = 32
542 | self.process_array = []
543 | self.start_train_idx = 0
544 | for pid in range(self.process_num):
545 | pid_num = int(self.num_samples / self.process_num)
546 | start_idx = pid*pid_num
547 | end_idx = (pid+1)*pid_num
548 | if pid == (self.process_num - 1):
549 | if end_idx < self.num_samples: end_idx = self.num_samples
550 | t = Process(target=self.load_queue,args=(start_idx,end_idx))
551 | self.process_array.append(t)
552 | t.start()
553 | def shuffle(self):
554 | randperm = np.random.permutation(len(self.train_list))
555 | self.train_list = [self.train_list_copy[int(randperm[idx])] for idx in range(len(randperm))]
556 | def load_single_feature(self,start):
557 | result = []
558 | name = self.train_list[start]
559 | s_bbox = []
560 | r_bbox = []
561 | label_bbox=[]
562 | slabel_box = []
563 | rlabel_bbox = []
564 | video_info=self.train_dict[name]
565 | video_frame=video_info['duration_frame']
566 | video_second=video_info['duration_second']
567 | feature_frame=video_info['feature_frame']
568 | feature_num = int(feature_frame / 8)
569 | corrected_second=float(feature_frame)/video_frame*video_second
570 | video_labels=video_info['annotations']
571 | vlabels = self.cate_dict[name]
572 | for j in range(len(video_labels)):
573 | tmp_info=video_labels[j]
574 | tmp_start=tmp_info['segment'][0]
575 | tmp_end=tmp_info['segment'][1]
576 | tmp_label=tmp_info['label']
577 | tmp_start=max(min(1,tmp_start/corrected_second),0)
578 | tmp_end=max(min(1,tmp_end/corrected_second),0)
579 | tmp_fea_start = tmp_start*self.x_tdim
580 | tmp_fea_end = tmp_end*self.x_tdim
581 | s_bbox.append([tmp_fea_start,tmp_fea_end])
582 | tmp_len = tmp_end - tmp_start
583 | if tmp_len*feature_num >= 8:
584 | r_bbox.append([tmp_start,tmp_end])
585 | if self.dataSet == 'val_moment':
586 | tmp_df=pd.read_csv(self.csv_dir+"/"+name+".csv",skiprows=1,header=None,sep=',')
587 | anchor_feature = tmp_df.values[:,:]
588 | for kk in range(len(s_bbox),self.gt_hold_num):
589 | s_bbox.append([-1.0,-1.0])
590 | result.append(anchor_feature)
591 | result.append(s_bbox)
592 | result.append(name)
593 | result.append(vlabels)
594 | result.append(video_second)
595 | return result
596 | else:
597 | tmp_df=pd.read_csv(self.csv_oris_dir+"/"+name+".csv",skiprows=1,header=None,sep=',')
598 | tmp_values = tmp_df.values[:,:]
599 | for m in range(len(r_bbox)):
600 | tmp_start_fea,tmp_end_fea = int(r_bbox[m][0]*feature_num),int(r_bbox[m][1]*feature_num)
601 | proposal_fea = tmp_values[tmp_start_fea:tmp_end_fea,:]
602 | proposal_duration,proposal_frame = (r_bbox[m][1]-r_bbox[m][0])*video_second,(r_bbox[m][1]-r_bbox[m][0])*video_frame
603 | proposal_inter_fea = pool_fea_interpolate_th(proposal_fea,len(proposal_fea),proposal_duration,proposal_frame,self.resize_dim)
604 | result.append([proposal_inter_fea,[[128.0,384.0]],name,vlabels,proposal_duration*2])
605 | return result
606 | def load_queue(self,start,end):
607 | bindex = 0
608 | lendata = end - start
609 | for i in range(lendata):
610 | t = self.load_single_feature(bindex+start)
611 | if self.dataSet == 'val_moment':
612 | self.train_feature_list.append(t)
613 | else:
614 | for z in range(len(t)):
615 | self.train_feature_list.append(t[z])
616 | bindex += 1
617 | if bindex == lendata: bindex = 0
618 | def gen(self):
619 | self.num_samples = len(self.train_feature_list)
620 | print('The sample number of generator %s is %d'%(self.dataSet,self.num_samples))
621 | while True:
622 | t = self.train_feature_list[self.start_train_idx]
623 | self.start_train_idx += 1
624 | if self.start_train_idx == self.num_samples:
625 | self.start_train_idx = 0
626 | yield t[0],t[1],t[2],t[3],t[4]
627 | def stop_process(self):
628 | for t in self.process_array:
629 | t.terminate()
630 | t.join()
631 | print('Stop the process sucess.')
632 |
--------------------------------------------------------------------------------
/evaluation/eval_detection.py:
--------------------------------------------------------------------------------
1 | import json
2 | #import urllib2
3 |
4 | import numpy as np
5 | import pandas as pd
6 |
7 | from utils import get_blocked_videos
8 | from utils import interpolated_prec_rec
9 | from utils import segment_iou
10 |
11 | class ANETdetection(object):
12 |
13 | GROUND_TRUTH_FIELDS = ['database', 'version'] #'taxonomy',
14 | PREDICTION_FIELDS = ['results', 'version', 'external_data']
15 |
16 | def __init__(self, ground_truth_filename=None, prediction_filename=None,
17 | ground_truth_fields=GROUND_TRUTH_FIELDS,
18 | prediction_fields=PREDICTION_FIELDS,
19 | tiou_thresholds=np.linspace(0.5, 0.95, 10),
20 | subset='validation', verbose=False,
21 | check_status=True):
22 | if not ground_truth_filename:
23 | raise IOError('Please input a valid ground truth file.')
24 | if not prediction_filename:
25 | raise IOError('Please input a valid prediction file.')
26 | self.subset = subset
27 | self.tiou_thresholds = tiou_thresholds
28 | self.verbose = verbose
29 | self.gt_fields = ground_truth_fields
30 | self.pred_fields = prediction_fields
31 | self.ap = None
32 | self.check_status = check_status
33 |
34 | self.blocked_videos = {}
35 | # Retrieve blocked videos from server.
36 | #if self.check_status:
37 | # self.blocked_videos = get_blocked_videos()
38 | #else:
39 | # self.blocked_videos = list()
40 | # Import ground truth and predictions.
41 | self.ground_truth, self.activity_index = self._import_ground_truth(
42 | ground_truth_filename)
43 | self.prediction = self._import_prediction(prediction_filename)
44 |
45 | if self.verbose:
46 | print('[INIT] Loaded annotations from {} subset.'.format(subset))
47 | nr_gt = len(self.ground_truth)
48 | print('\tNumber of ground truth instances: {}'.format(nr_gt))
49 | nr_pred = len(self.prediction)
50 | print('\tNumber of predictions: {}'.format(nr_pred))
51 | print('\tFixed threshold for tiou score: {}'.format(self.tiou_thresholds))
52 |
53 | def _import_ground_truth(self, ground_truth_filename):
54 | """Reads ground truth file, checks if it is well formatted, and returns
55 | the ground truth instances and the activity classes.
56 |
57 | Parameters
58 | ----------
59 | ground_truth_filename : str
60 | Full path to the ground truth json file.
61 |
62 | Outputs
63 | -------
64 | ground_truth : df
65 | Data frame containing the ground truth instances.
66 | activity_index : dict
67 | Dictionary containing class index.
68 | """
69 | with open(ground_truth_filename, 'r') as fobj:
70 | data = json.load(fobj)
71 | # Checking format
72 | if not all([field in data.keys() for field in self.gt_fields]):
73 | raise IOError('Please input a valid ground truth file.')
74 |
75 | # Read ground truth data.
76 | activity_index, cidx = {}, 0
77 | video_lst, t_start_lst, t_end_lst, label_lst = [], [], [], []
78 | for videoid, v in data['database'].items():
79 | if self.subset != v['subset']:
80 | continue
81 | if videoid in self.blocked_videos:
82 | continue
83 | for ann in v['annotations']:
84 | if ann['label'] not in activity_index:
85 | activity_index[ann['label']] = cidx
86 | cidx += 1
87 | video_lst.append(videoid)
88 | t_start_lst.append(ann['segment'][0])
89 | t_end_lst.append(ann['segment'][1])
90 | label_lst.append(activity_index[ann['label']])
91 |
92 | ground_truth = pd.DataFrame({'video-id': video_lst,
93 | 't-start': t_start_lst,
94 | 't-end': t_end_lst,
95 | 'label': label_lst})
96 | return ground_truth, activity_index
97 |
98 | def _import_prediction(self, prediction_filename):
99 | """Reads prediction file, checks if it is well formatted, and returns
100 | the prediction instances.
101 |
102 | Parameters
103 | ----------
104 | prediction_filename : str
105 | Full path to the prediction json file.
106 |
107 | Outputs
108 | -------
109 | prediction : df
110 | Data frame containing the prediction instances.
111 | """
112 | with open(prediction_filename, 'r') as fobj:
113 | data = json.load(fobj)
114 | # Checking format...
115 | if not all([field in data.keys() for field in self.pred_fields]):
116 | raise IOError('Please input a valid prediction file.')
117 |
118 | # Read predicitons.
119 | video_lst, t_start_lst, t_end_lst = [], [], []
120 | label_lst, score_lst = [], []
121 | for videoid, v in data['results'].items():
122 | if videoid not in self.ground_truth['video-id'].values: continue
123 | if videoid in self.blocked_videos:
124 | continue
125 | for result in v:
126 | if result['label'] not in self.activity_index.keys(): continue
127 | label = self.activity_index[result['label']]
128 | video_lst.append(videoid)
129 | t_start_lst.append(result['segment'][0])
130 | t_end_lst.append(result['segment'][1])
131 | label_lst.append(label)
132 | score_lst.append(result['score'])
133 | prediction = pd.DataFrame({'video-id': video_lst,
134 | 't-start': t_start_lst,
135 | 't-end': t_end_lst,
136 | 'label': label_lst,
137 | 'score': score_lst})
138 | return prediction
139 |
140 | def wrapper_compute_average_precision(self):
141 | """Computes average precision for each class in the subset.
142 | """
143 | ap = np.zeros((len(self.tiou_thresholds), len(self.activity_index.items())))
144 | for activity, cidx in self.activity_index.items():
145 | gt_idx = self.ground_truth['label'] == cidx
146 | pred_idx = self.prediction['label'] == cidx
147 | ap[:,cidx] = compute_average_precision_detection(
148 | self.ground_truth.loc[gt_idx].reset_index(drop=True),
149 | self.prediction.loc[pred_idx].reset_index(drop=True),
150 | tiou_thresholds=self.tiou_thresholds)
151 | return ap
152 |
153 | def evaluate(self,output_log):
154 | """Evaluates a prediction file. For the detection task we measure the
155 | interpolated mean average precision to measure the performance of a
156 | method.
157 | """
158 | self.ap = self.wrapper_compute_average_precision()
159 | self.mAP = self.ap.mean(axis=1)
160 | if self.verbose:
161 | print('[RESULTS] Performance on ActivityNet detection task.')
162 | print('AP in each IOU: {}'.format(self.mAP))
163 | print('\tAverage-mAP: {}'.format(self.mAP.mean()))
164 | if len(output_log) != 0:
165 | with open(output_log,'w') as fw:
166 | mAPc = self.ap.mean(axis=0)
167 | for activity, cidx in self.activity_index.items():
168 | #cname = self.activity_index[i]
169 | apc = mAPc[cidx]
170 | fw.write('%s %0.4f\n'%(activity,apc))
171 |
172 | def compute_average_precision_detection(ground_truth, prediction, tiou_thresholds=np.linspace(0.5, 0.95, 10)):
173 | """Compute average precision (detection task) between ground truth and
174 | predictions data frames. If multiple predictions occurs for the same
175 | predicted segment, only the one with highest score is matches as
176 | true positive. This code is greatly inspired by Pascal VOC devkit.
177 |
178 | Parameters
179 | ----------
180 | ground_truth : df
181 | Data frame containing the ground truth instances.
182 | Required fields: ['video-id', 't-start', 't-end']
183 | prediction : df
184 | Data frame containing the prediction instances.
185 | Required fields: ['video-id, 't-start', 't-end', 'score']
186 | tiou_thresholds : 1darray, optional
187 | Temporal intersection over union threshold.
188 |
189 | Outputs
190 | -------
191 | ap : float
192 | Average precision score.
193 | """
194 | npos = float(len(ground_truth))
195 | lock_gt = np.ones((len(tiou_thresholds),len(ground_truth))) * -1
196 | # Sort predictions by decreasing score order.
197 | sort_idx = prediction['score'].values.argsort()[::-1]
198 | prediction = prediction.loc[sort_idx].reset_index(drop=True)
199 |
200 | # Initialize true positive and false positive vectors.
201 | tp = np.zeros((len(tiou_thresholds), len(prediction)))
202 | fp = np.zeros((len(tiou_thresholds), len(prediction)))
203 |
204 | # Adaptation to query faster
205 | ground_truth_gbvn = ground_truth.groupby('video-id')
206 |
207 | # Assigning true positive to truly grount truth instances.
208 | for idx, this_pred in prediction.iterrows():
209 |
210 | try:
211 | # Check if there is at least one ground truth in the video associated.
212 | ground_truth_videoid = ground_truth_gbvn.get_group(this_pred['video-id'])
213 | except Exception as e:
214 | fp[:, idx] = 1
215 | continue
216 |
217 | this_gt = ground_truth_videoid.reset_index()
218 | tiou_arr = segment_iou(this_pred[['t-start', 't-end']].values,
219 | this_gt[['t-start', 't-end']].values)
220 | # We would like to retrieve the predictions with highest tiou score.
221 | tiou_sorted_idx = tiou_arr.argsort()[::-1]
222 | for tidx, tiou_thr in enumerate(tiou_thresholds):
223 | for jdx in tiou_sorted_idx:
224 | if tiou_arr[jdx] < tiou_thr:
225 | fp[tidx, idx] = 1
226 | break
227 | if lock_gt[tidx, this_gt.loc[jdx]['index']] >= 0:
228 | continue
229 | # Assign as true positive after the filters above.
230 | tp[tidx, idx] = 1
231 | lock_gt[tidx, this_gt.loc[jdx]['index']] = idx
232 | break
233 |
234 | if fp[tidx, idx] == 0 and tp[tidx, idx] == 0:
235 | fp[tidx, idx] = 1
236 |
237 | ap = np.zeros(len(tiou_thresholds))
238 |
239 | for tidx in range(len(tiou_thresholds)):
240 | # Computing prec-rec
241 | this_tp = np.cumsum(tp[tidx,:]).astype(np.float)
242 | this_fp = np.cumsum(fp[tidx,:]).astype(np.float)
243 | rec = this_tp / npos
244 | prec = this_tp / (this_tp + this_fp)
245 | ap[tidx] = interpolated_prec_rec(prec, rec)
246 |
247 | return ap
248 |
--------------------------------------------------------------------------------
/evaluation/eval_proposal.py:
--------------------------------------------------------------------------------
1 | import json
2 | #import urllib2
3 |
4 | import numpy as np
5 | import pandas as pd
6 |
7 | from utils import get_blocked_videos
8 | from utils import interpolated_prec_rec
9 | from utils import segment_iou
10 | from utils import wrapper_segment_iou
11 |
12 | class ANETproposal(object):
13 |
14 | GROUND_TRUTH_FIELDS = ['database', 'version']
15 | PROPOSAL_FIELDS = ['results', 'version', 'external_data']
16 |
17 | def __init__(self, ground_truth_filename=None, proposal_filename=None,
18 | ground_truth_fields=GROUND_TRUTH_FIELDS,
19 | proposal_fields=PROPOSAL_FIELDS,
20 | tiou_thresholds=np.linspace(0.5, 0.95, 10),
21 | max_avg_nr_proposals=None,
22 | subset='validation', verbose=False,
23 | check_status=True):
24 | if not ground_truth_filename:
25 | raise IOError('Please input a valid ground truth file.')
26 | if not proposal_filename:
27 | raise IOError('Please input a valid proposal file.')
28 | self.subset = subset
29 | self.tiou_thresholds = tiou_thresholds
30 | self.max_avg_nr_proposals = max_avg_nr_proposals
31 | self.verbose = verbose
32 | self.gt_fields = ground_truth_fields
33 | self.pred_fields = proposal_fields
34 | self.recall = None
35 | self.avg_recall = None
36 | self.proposals_per_video = None
37 | self.check_status = check_status
38 |
39 | # Include the Block video in server (modified in 6/3/2017)
40 | # Retrieve blocked videos from server.
41 | #if self.check_status:
42 | # self.blocked_videos = get_blocked_videos()
43 | #else:
44 | # self.blocked_videos = list()
45 | self.blocked_videos = list()
46 | #print('Block Video Num: %d'%(len(self.blocked_videos)))
47 |
48 | #with open('block_video_list.txt','w') as fw:
49 | # for i in range(len(self.blocked_videos)):
50 | # fw.write('%s\n'%self.blocked_videos[i])
51 |
52 | self.lesstime_videos = list()
53 | print('The less video number: %d'%(len(self.lesstime_videos)))
54 | # Import ground truth and proposals.
55 | self.ground_truth, self.activity_index = self._import_ground_truth(
56 | ground_truth_filename)
57 | self.proposal = self._import_proposal(proposal_filename)
58 |
59 | #if self.verbose:
60 | # print '[INIT] Loaded annotations from {} subset.'.format(subset)
61 | # nr_gt = len(self.ground_truth)
62 | # print '\tNumber of ground truth instances: {}'.format(nr_gt)
63 | # nr_pred = len(self.proposal)
64 | # print '\tNumber of proposals: {}'.format(nr_pred)
65 | # print '\tFixed threshold for tiou score: {}'.format(self.tiou_thresholds)
66 |
67 | def _import_ground_truth(self, ground_truth_filename):
68 | """Reads ground truth file, checks if it is well formatted, and returns
69 | the ground truth instances and the activity classes.
70 |
71 | Parameters
72 | ----------
73 | ground_truth_filename : str
74 | Full path to the ground truth json file.
75 |
76 | Outputs
77 | -------
78 | ground_truth : df
79 | Data frame containing the ground truth instances.
80 | activity_index : dict
81 | Dictionary containing class index.
82 | """
83 | with open(ground_truth_filename, 'r') as fobj:
84 | data = json.load(fobj)
85 | # Checking format
86 | if not all([field in data.keys() for field in self.gt_fields]):
87 | raise IOError('Please input a valid ground truth file.')
88 |
89 | # Read ground truth data.
90 | activity_index, cidx, block_video_n = {}, 0, 0
91 | video_lst, t_start_lst, t_end_lst, label_lst = [], [], [], []
92 | for videoid, v in data['database'].items():
93 | if self.subset != v['subset']:
94 | continue
95 | if videoid in self.blocked_videos:
96 | block_video_n += 1
97 | continue
98 | if videoid in self.lesstime_videos:
99 | continue
100 | for ann in v['annotations']:
101 | if ann['label'] not in activity_index:
102 | activity_index[ann['label']] = cidx
103 | cidx += 1
104 | video_lst.append(videoid)
105 | t_start_lst.append(ann['segment'][0])
106 | t_end_lst.append(ann['segment'][1])
107 | label_lst.append(activity_index[ann['label']])
108 | print('time less video: %d'%(len(video_lst)))
109 | print('block video number: %d'%(block_video_n))
110 | ground_truth = pd.DataFrame({'video-id': video_lst,
111 | 't-start': t_start_lst,
112 | 't-end': t_end_lst,
113 | 'label': label_lst})
114 | return ground_truth, activity_index
115 |
116 | def _import_proposal(self, proposal_filename):
117 | """Reads proposal file, checks if it is well formatted, and returns
118 | the proposal instances.
119 |
120 | Parameters
121 | ----------
122 | proposal_filename : str
123 | Full path to the proposal json file.
124 |
125 | Outputs
126 | -------
127 | proposal : df
128 | Data frame containing the proposal instances.
129 | """
130 | with open(proposal_filename, 'r') as fobj:
131 | data = json.load(fobj)
132 | # Checking format...
133 | if not all([field in data.keys() for field in self.pred_fields]):
134 | raise IOError('Please input a valid proposal file.')
135 |
136 | # Read predictions.
137 | video_lst, t_start_lst, t_end_lst = [], [], []
138 | score_lst = []
139 | for videoid, v in data['results'].items():
140 | if videoid in self.blocked_videos:
141 | continue
142 | for result in v:
143 | video_lst.append(videoid)
144 | t_start_lst.append(result['segment'][0])
145 | t_end_lst.append(result['segment'][1])
146 | score_lst.append(result['score'])
147 | proposal = pd.DataFrame({'video-id': video_lst,
148 | 't-start': t_start_lst,
149 | 't-end': t_end_lst,
150 | 'score': score_lst})
151 | return proposal
152 |
153 | def evaluate(self):
154 | """Evaluates a proposal file. To measure the performance of a
155 | method for the proposal task, we computes the area under the
156 | average recall vs average number of proposals per video curve.
157 | """
158 | recall, avg_recall, proposals_per_video = average_recall_vs_avg_nr_proposals(
159 | self.ground_truth, self.proposal,
160 | max_avg_nr_proposals=self.max_avg_nr_proposals,
161 | tiou_thresholds=self.tiou_thresholds)
162 |
163 | area_under_curve = np.trapz(avg_recall, proposals_per_video)
164 |
165 | if self.verbose:
166 | print('[RESULTS] Performance on ActivityNet proposal task.')
167 | print('\tArea Under the AR vs AN curve: {}%'.format(100.*float(area_under_curve)/proposals_per_video[-1]))
168 | print('\tProposal 100, IOU from 0.5:0.05:0.95, Average Recall: {}%'.format(100.*float(avg_recall[-1])))
169 |
170 | self.recall = recall
171 | self.avg_recall = avg_recall
172 | self.proposals_per_video = proposals_per_video
173 |
174 | def average_recall_vs_avg_nr_proposals(ground_truth, proposals,
175 | max_avg_nr_proposals=None,
176 | tiou_thresholds=np.linspace(0.5, 0.95, 10)):
177 | """ Computes the average recall given an average number
178 | of proposals per video.
179 |
180 | Parameters
181 | ----------
182 | ground_truth : df
183 | Data frame containing the ground truth instances.
184 | Required fields: ['video-id', 't-start', 't-end']
185 | proposal : df
186 | Data frame containing the proposal instances.
187 | Required fields: ['video-id, 't-start', 't-end', 'score']
188 | tiou_thresholds : 1darray, optional
189 | array with tiou thresholds.
190 |
191 | Outputs
192 | -------
193 | recall : 2darray
194 | recall[i,j] is recall at ith tiou threshold at the jth average number of average number of proposals per video.
195 | average_recall : 1darray
196 | recall averaged over a list of tiou threshold. This is equivalent to recall.mean(axis=0).
197 | proposals_per_video : 1darray
198 | average number of proposals per video.
199 | """
200 |
201 | # Get list of videos.
202 | video_lst = ground_truth['video-id'].unique()
203 |
204 | if not max_avg_nr_proposals:
205 | max_avg_nr_proposals = float(proposals.shape[0])/video_lst.shape[0]
206 |
207 | ratio = max_avg_nr_proposals*float(video_lst.shape[0])/proposals.shape[0]
208 | print('max avg p: %d video_lst shape0: %d proposals_shape0: %d'%(max_avg_nr_proposals,video_lst.shape[0],proposals.shape[0]))
209 | print('\ttotal proposal: %d ratio: %0.4f'%(proposals.shape[0],ratio))
210 |
211 | # Adaptation to query faster
212 | ground_truth_gbvn = ground_truth.groupby('video-id')
213 | proposals_gbvn = proposals.groupby('video-id')
214 |
215 | # For each video, computes tiou scores among the retrieved proposals.
216 | score_lst = []
217 | total_nr_proposals = 0
218 | for videoid in video_lst:
219 |
220 | # Get ground-truth instances associated to this video.
221 | ground_truth_videoid = ground_truth_gbvn.get_group(videoid)
222 | this_video_ground_truth = ground_truth_videoid.loc[:,['t-start', 't-end']].values
223 |
224 | # Get proposals for this video.
225 | try:
226 | proposals_videoid = proposals_gbvn.get_group(videoid)
227 | this_video_proposals = proposals_videoid.loc[:, ['t-start', 't-end']].values
228 | except:
229 | n = this_video_ground_truth.shape[0]
230 | score_lst.append(np.zeros((n, 1)))
231 | continue
232 |
233 | # Sort proposals by score.
234 | sort_idx = proposals_videoid['score'].argsort()[::-1]
235 | this_video_proposals = this_video_proposals[sort_idx, :]
236 |
237 | if this_video_proposals.shape[0] == 0:
238 | n = this_video_ground_truth.shape[0]
239 | score_lst.append(np.zeros((n, 1)))
240 | continue
241 |
242 | if this_video_proposals.ndim != 2:
243 | this_video_proposals = np.expand_dims(this_video_proposals, axis=0)
244 | if this_video_ground_truth.ndim != 2:
245 | this_video_ground_truth = np.expand_dims(this_video_ground_truth, axis=0)
246 |
247 | nr_proposals = np.minimum(int(this_video_proposals.shape[0] * ratio), this_video_proposals.shape[0])
248 | total_nr_proposals += nr_proposals
249 | this_video_proposals = this_video_proposals[:nr_proposals, :]
250 |
251 | # Compute tiou scores.
252 | tiou = wrapper_segment_iou(this_video_proposals, this_video_ground_truth)
253 | score_lst.append(tiou)
254 |
255 | # Given that the length of the videos is really varied, we
256 | # compute the number of proposals in terms of a ratio of the total
257 | # proposals retrieved, i.e. average recall at a percentage of proposals
258 | # retrieved per video.
259 |
260 | print('\ttotal proposal: %d'%total_nr_proposals)
261 | print('\tvideo list number: %d'%video_lst.shape[0])
262 | # Computes average recall.
263 | pcn_lst = np.arange(1, 101) / 100.0 *(max_avg_nr_proposals*float(video_lst.shape[0])/total_nr_proposals)
264 | matches = np.empty((video_lst.shape[0], pcn_lst.shape[0]))
265 | positives = np.empty(video_lst.shape[0])
266 | recall = np.empty((tiou_thresholds.shape[0], pcn_lst.shape[0]))
267 | # Iterates over each tiou threshold.
268 | for ridx, tiou in enumerate(tiou_thresholds):
269 |
270 | # Inspect positives retrieved per video at different
271 | # number of proposals (percentage of the total retrieved).
272 | for i, score in enumerate(score_lst):
273 | # Total positives per video.
274 | positives[i] = score.shape[0]
275 | # Find proposals that satisfies minimum tiou threshold.
276 | true_positives_tiou = score >= tiou
277 | # Get number of proposals as a percentage of total retrieved.
278 | pcn_proposals = np.minimum((score.shape[1] * pcn_lst).astype(np.int), score.shape[1])
279 |
280 | for j, nr_proposals in enumerate(pcn_proposals):
281 | # Compute the number of matches for each percentage of the proposals
282 | matches[i, j] = np.count_nonzero((true_positives_tiou[:, :nr_proposals]).sum(axis=1))
283 | # Computes recall given the set of matches per video.
284 | recall[ridx, :] = matches.sum(axis=0) / positives.sum()
285 |
286 | # Recall is averaged.
287 | avg_recall = recall.mean(axis=0)
288 |
289 | # Get the average number of proposals per video.
290 | proposals_per_video = pcn_lst * (float(total_nr_proposals) / video_lst.shape[0])
291 |
292 | return recall, avg_recall, proposals_per_video
293 |
294 |
--------------------------------------------------------------------------------
/evaluation/get_detection_performance.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 |
4 | from eval_detection import ANETdetection
5 |
6 | def main(ground_truth_filename, prediction_filename,
7 | subset='validation', tiou_thresholds=np.linspace(0.5, 0.95, 10),
8 | verbose=True, check_status=True):
9 |
10 | anet_detection = ANETdetection(ground_truth_filename, prediction_filename,
11 | subset=subset, tiou_thresholds=tiou_thresholds,
12 | verbose=verbose, check_status=True)
13 | anet_detection.evaluate()
14 |
15 | def detect(prediction_filename, ground_truth_filename = 'activity_net.v1-3.min.json',output_log=''):
16 | anet_detection = ANETdetection(ground_truth_filename, prediction_filename,
17 | subset='validation', tiou_thresholds=np.linspace(0.5, 0.95, 10),
18 | verbose=True, check_status=True)
19 | anet_detection.evaluate(output_log)
20 | return anet_detection.mAP
21 |
22 |
23 |
24 | def parse_input():
25 | description = ('This script allows you to evaluate the ActivityNet '
26 | 'detection task which is intended to evaluate the ability '
27 | 'of algorithms to temporally localize activities in '
28 | 'untrimmed video sequences.')
29 | p = argparse.ArgumentParser(description=description)
30 | p.add_argument('ground_truth_filename',
31 | help='Full path to json file containing the ground truth.')
32 | p.add_argument('prediction_filename',
33 | help='Full path to json file containing the predictions.')
34 | p.add_argument('--subset', default='validation',
35 | help=('String indicating subset to evaluate: '
36 | '(training, validation)'))
37 | p.add_argument('--tiou_thresholds', type=float, default=np.linspace(0.5, 0.95, 10),
38 | help='Temporal intersection over union threshold.')
39 | p.add_argument('--verbose', type=bool, default=True)
40 | p.add_argument('--check_status', type=bool, default=True)
41 | return p.parse_args()
42 |
43 | if __name__ == '__main__':
44 | args = parse_input()
45 | main(**vars(args))
46 |
--------------------------------------------------------------------------------
/evaluation/get_proposal_performance.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import sys
4 | sys.path.append('./evaluation')
5 | import matplotlib.pyplot as plt
6 | import json
7 |
8 | from eval_proposal import ANETproposal
9 |
10 | def main(ground_truth_filename, proposal_filename, max_avg_nr_proposals=100,
11 | tiou_thresholds=np.linspace(0.5, 0.95, 10),
12 | subset='validation', verbose=True, check_status=True):
13 |
14 | anet_proposal = ANETproposal(ground_truth_filename, proposal_filename,
15 | tiou_thresholds=tiou_thresholds,
16 | max_avg_nr_proposals=max_avg_nr_proposals,
17 | subset=subset, verbose=True, check_status=True)
18 | anet_proposal.evaluate()
19 |
20 | def run_evaluation(ground_truth_filename, proposal_filename,
21 | max_avg_nr_proposals=100,
22 | tiou_thresholds=np.linspace(0.5, 0.95, 10),
23 | subset='validation'):
24 |
25 | anet_proposal = ANETproposal(ground_truth_filename, proposal_filename,
26 | tiou_thresholds=tiou_thresholds,
27 | max_avg_nr_proposals=max_avg_nr_proposals,
28 | subset=subset, verbose=True, check_status=True)
29 | anet_proposal.evaluate()
30 |
31 | recall = anet_proposal.recall
32 | average_recall = anet_proposal.avg_recall
33 | average_nr_proposals = anet_proposal.proposals_per_video
34 |
35 | return (average_nr_proposals, average_recall, recall)
36 |
37 | def plot_metric(average_nr_proposals, average_recall, recall,
38 | tiou_thresholds=np.linspace(0.5, 0.95, 10)):
39 | fn_size = 14
40 | plt.figure(num=None, figsize=(6, 5))
41 | ax = plt.subplot(1,1,1)
42 |
43 | colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9']
44 |
45 | area_under_curve = np.zeros_like(tiou_thresholds)
46 | for i in range(recall.shape[0]):
47 | area_under_curve[i] = np.trapz(recall[i], average_nr_proposals)
48 |
49 | for idx, tiou in enumerate(tiou_thresholds[::2]):
50 | ax.plot(average_nr_proposals, recall[2*idx,:], color=colors[idx+1],
51 | label="tiou=[" + str(tiou) + "], area=" + str(int(area_under_curve[2*idx]*100)/100.),
52 | linewidth=4, linestyle='--', marker=None)
53 |
54 | # Plots Average Recall vs Average number of proposals.
55 | ax.plot(average_nr_proposals, average_recall, color=colors[0],
56 | label="tiou = 0.5:0.05:0.95," + " area=" + str(int(np.trapz(average_recall, average_nr_proposals)*100)/100.),
57 | linewidth=4, linestyle='-', marker=None)#
58 |
59 | handles, labels = ax.get_legend_handles_labels()
60 | ax.legend([handles[-1]] + handles[:-1], [labels[-1]] + labels[:-1], loc='best')
61 |
62 | plt.ylabel('Average Recall', fontsize=fn_size)
63 | plt.xlabel('Average Number of Proposals per Video', fontsize=fn_size)
64 | plt.grid(b=True, which="both")
65 | plt.ylim([0, 1.0])
66 | plt.setp(plt.axes().get_xticklabels(), fontsize=fn_size)
67 | plt.setp(plt.axes().get_yticklabels(), fontsize=fn_size)
68 |
69 | plt.show()
70 |
71 | def plot_figure(ground_truth_filename, proposal_filename, max_avg_nr_proposals=100,
72 | tiou_thresholds=np.linspace(0.5, 0.95, 10),
73 | subset='validation', verbose=True, check_status=True):
74 | uniform_average_nr_proposals_valid, uniform_average_recall_valid, uniform_recall_valid = run_evaluation(
75 | ground_truth_filename,
76 | proposal_filename,
77 | max_avg_nr_proposals=100,
78 | tiou_thresholds=np.linspace(0.5, 0.95, 10),
79 | subset='validation')
80 | plot_metric(uniform_average_nr_proposals_valid, uniform_average_recall_valid, uniform_recall_valid)
81 |
82 | def evaluate_return_area(ground_truth_filename,proposal_filename,max_avg_nr_proposals=100,
83 | tiou_thresholds=np.linspace(0.5,0.95,10),
84 | subset='validation',verbose = True, check_status = True):
85 | anet_proposal = ANETproposal(ground_truth_filename, proposal_filename,
86 | tiou_thresholds=tiou_thresholds,
87 | max_avg_nr_proposals=max_avg_nr_proposals,
88 | subset=subset, verbose=True, check_status=True)
89 | anet_proposal.evaluate()
90 |
91 | recall = anet_proposal.recall
92 | average_recall = anet_proposal.avg_recall
93 | average_nr_proposals = anet_proposal.proposals_per_video
94 |
95 | area_under_curve = np.trapz(average_recall, average_nr_proposals)
96 | AR_AN = 100.*float(area_under_curve)/average_nr_proposals[-1]
97 | Recall_all = 100.*float(average_recall[-1])
98 |
99 | return (AR_AN,Recall_all)
100 |
101 | def parse_input():
102 | description = ('This script allows you to evaluate the ActivityNet '
103 | 'proposal task which is intended to evaluate the ability '
104 | 'of algorithms to generate activity proposals that temporally '
105 | 'localize activities in untrimmed video sequences.')
106 | p = argparse.ArgumentParser(description=description)
107 | p.add_argument('ground_truth_filename',
108 | help='Full path to json file containing the ground truth.')
109 | p.add_argument('proposal_filename',
110 | help='Full path to json file containing the proposals.')
111 | p.add_argument('--subset', default='validation',
112 | help=('String indicating subset to evaluate: '
113 | '(training, validation)'))
114 | p.add_argument('--verbose', type=bool, default=True)
115 | p.add_argument('--check_status', type=bool, default=True)
116 | return p.parse_args()
117 |
118 | def write_ar_an(gt, json_path, out_file):
119 | uniform_average_nr_proposals_valid, uniform_average_recall_valid, uniform_recall_valid = run_evaluation(
120 | gt,json_path,
121 | max_avg_nr_proposals=100,
122 | tiou_thresholds=np.linspace(0.5, 0.95, 10),
123 | subset='validation')
124 | with open(out_file,'w') as fw:
125 | for k in range(len(uniform_average_nr_proposals_valid)):
126 | fw.write('%f %f\n'%(uniform_average_nr_proposals_valid[k],uniform_average_recall_valid[k]))
127 |
--------------------------------------------------------------------------------
/evaluation/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import urllib
3 |
4 | import numpy as np
5 |
6 | API = 'http://ec2-52-11-11-89.us-west-2.compute.amazonaws.com/challenge16/api.py'
7 |
8 | def get_blocked_videos(api=API):
9 | api_url = '{}?action=get_blocked'.format(api)
10 | req = urllib.Request(api_url)
11 | response = urllib.urlopen(req)
12 | return json.loads(response.read())
13 |
14 | def interpolated_prec_rec(prec, rec):
15 | """Interpolated AP - VOCdevkit from VOC 2011.
16 | """
17 | mprec = np.hstack([[0], prec, [0]])
18 | mrec = np.hstack([[0], rec, [1]])
19 | for i in range(len(mprec) - 1)[::-1]:
20 | mprec[i] = max(mprec[i], mprec[i + 1])
21 | idx = np.where(mrec[1::] != mrec[0:-1])[0] + 1
22 | ap = np.sum((mrec[idx] - mrec[idx - 1]) * mprec[idx])
23 | return ap
24 |
25 | def segment_iou(target_segment, candidate_segments):
26 | """Compute the temporal intersection over union between a
27 | target segment and all the test segments.
28 |
29 | Parameters
30 | ----------
31 | target_segment : 1d array
32 | Temporal target segment containing [starting, ending] times.
33 | candidate_segments : 2d array
34 | Temporal candidate segments containing N x [starting, ending] times.
35 |
36 | Outputs
37 | -------
38 | tiou : 1d array
39 | Temporal intersection over union score of the N's candidate segments.
40 | """
41 | tt1 = np.maximum(target_segment[0], candidate_segments[:, 0])
42 | tt2 = np.minimum(target_segment[1], candidate_segments[:, 1])
43 | # Intersection including Non-negative overlap score.
44 | segments_intersection = (tt2 - tt1).clip(0)
45 | # Segment union.
46 | segments_union = (candidate_segments[:, 1] - candidate_segments[:, 0]) \
47 | + (target_segment[1] - target_segment[0]) - segments_intersection
48 | # Compute overlap as the ratio of the intersection
49 | # over union of two segments.
50 | tIoU = segments_intersection.astype(float) / segments_union
51 | return tIoU
52 |
53 | def wrapper_segment_iou(target_segments, candidate_segments):
54 | """Compute intersection over union btw segments
55 | Parameters
56 | ----------
57 | target_segments : ndarray
58 | 2-dim array in format [m x 2:=[init, end]]
59 | candidate_segments : ndarray
60 | 2-dim array in format [n x 2:=[init, end]]
61 | Outputs
62 | -------
63 | tiou : ndarray
64 | 2-dim array [n x m] with IOU ratio.
65 | Note: It assumes that candidate-segments are more scarce that target-segments
66 | """
67 | if candidate_segments.ndim != 2 or target_segments.ndim != 2:
68 | raise ValueError('Dimension of arguments is incorrect')
69 |
70 | n, m = candidate_segments.shape[0], target_segments.shape[0]
71 | tiou = np.empty((n, m))
72 | for i in range(m):
73 | tiou[:, i] = segment_iou(target_segments[i,:], candidate_segments)
74 |
75 | return tiou
76 |
--------------------------------------------------------------------------------
/k600_category.txt:
--------------------------------------------------------------------------------
1 | abseiling
2 | acting_in_play
3 | adjusting_glasses
4 | air_drumming
5 | alligator_wrestling
6 | answering_questions
7 | applauding
8 | applying_cream
9 | archaeological_excavation
10 | archery
11 | arguing
12 | arm_wrestling
13 | arranging_flowers
14 | assembling_bicycle
15 | assembling_computer
16 | attending_conference
17 | auctioning
18 | backflip_(human)
19 | baking_cookies
20 | bandaging
21 | barbequing
22 | bartending
23 | base_jumping
24 | bathing_dog
25 | battle_rope_training
26 | beatboxing
27 | bee_keeping
28 | belly_dancing
29 | bench_pressing
30 | bending_back
31 | bending_metal
32 | biking_through_snow
33 | blasting_sand
34 | blowdrying_hair
35 | blowing_bubble_gum
36 | blowing_glass
37 | blowing_leaves
38 | blowing_nose
39 | blowing_out_candles
40 | bobsledding
41 | bodysurfing
42 | bookbinding
43 | bottling
44 | bouncing_on_bouncy_castle
45 | bouncing_on_trampoline
46 | bowling
47 | braiding_hair
48 | breading_or_breadcrumbing
49 | breakdancing
50 | breaking_boards
51 | breathing_fire
52 | brush_painting
53 | brushing_hair
54 | brushing_teeth
55 | building_cabinet
56 | building_lego
57 | building_sandcastle
58 | building_shed
59 | bull_fighting
60 | bulldozing
61 | bungee_jumping
62 | burping
63 | busking
64 | calculating
65 | calligraphy
66 | canoeing_or_kayaking
67 | capoeira
68 | capsizing
69 | card_stacking
70 | card_throwing
71 | carrying_baby
72 | cartwheeling
73 | carving_ice
74 | carving_pumpkin
75 | casting_fishing_line
76 | catching_fish
77 | catching_or_throwing_baseball
78 | catching_or_throwing_frisbee
79 | catching_or_throwing_softball
80 | celebrating
81 | changing_gear_in_car
82 | changing_oil
83 | changing_wheel_(not_on_bike)
84 | checking_tires
85 | cheerleading
86 | chewing_gum
87 | chiseling_stone
88 | chiseling_wood
89 | chopping_meat
90 | chopping_vegetables
91 | chopping_wood
92 | clam_digging
93 | clapping
94 | clay_pottery_making
95 | clean_and_jerk
96 | cleaning_gutters
97 | cleaning_pool
98 | cleaning_shoes
99 | cleaning_toilet
100 | cleaning_windows
101 | climbing_a_rope
102 | climbing_ladder
103 | climbing_tree
104 | coloring_in
105 | combing_hair
106 | contact_juggling
107 | contorting
108 | cooking_egg
109 | cooking_on_campfire
110 | cooking_sausages_(not_on_barbeque)
111 | cooking_scallops
112 | cosplaying
113 | counting_money
114 | country_line_dancing
115 | cracking_back
116 | cracking_knuckles
117 | cracking_neck
118 | crawling_baby
119 | crossing_eyes
120 | crossing_river
121 | crying
122 | cumbia
123 | curling_(sport)
124 | curling_hair
125 | cutting_apple
126 | cutting_nails
127 | cutting_orange
128 | cutting_pineapple
129 | cutting_watermelon
130 | dancing_ballet
131 | dancing_charleston
132 | dancing_gangnam_style
133 | dancing_macarena
134 | deadlifting
135 | decorating_the_christmas_tree
136 | delivering_mail
137 | dining
138 | directing_traffic
139 | disc_golfing
140 | diving_cliff
141 | docking_boat
142 | dodgeball
143 | doing_aerobics
144 | doing_jigsaw_puzzle
145 | doing_laundry
146 | doing_nails
147 | drawing
148 | dribbling_basketball
149 | drinking_shots
150 | driving_car
151 | driving_tractor
152 | drooling
153 | drop_kicking
154 | drumming_fingers
155 | dumpster_diving
156 | dunking_basketball
157 | dyeing_eyebrows
158 | dyeing_hair
159 | eating_burger
160 | eating_cake
161 | eating_carrots
162 | eating_chips
163 | eating_doughnuts
164 | eating_hotdog
165 | eating_ice_cream
166 | eating_spaghetti
167 | eating_watermelon
168 | egg_hunting
169 | embroidering
170 | exercising_with_an_exercise_ball
171 | extinguishing_fire
172 | faceplanting
173 | falling_off_bike
174 | falling_off_chair
175 | feeding_birds
176 | feeding_fish
177 | feeding_goats
178 | fencing_(sport)
179 | fidgeting
180 | finger_snapping
181 | fixing_bicycle
182 | fixing_hair
183 | flint_knapping
184 | flipping_pancake
185 | fly_tying
186 | flying_kite
187 | folding_clothes
188 | folding_napkins
189 | folding_paper
190 | front_raises
191 | frying_vegetables
192 | geocaching
193 | getting_a_haircut
194 | getting_a_piercing
195 | getting_a_tattoo
196 | giving_or_receiving_award
197 | gold_panning
198 | golf_chipping
199 | golf_driving
200 | golf_putting
201 | gospel_singing_in_church
202 | grinding_meat
203 | grooming_dog
204 | grooming_horse
205 | gymnastics_tumbling
206 | hammer_throw
207 | hand_washing_clothes
208 | head_stand
209 | headbanging
210 | headbutting
211 | high_jump
212 | high_kick
213 | historical_reenactment
214 | hitting_baseball
215 | hockey_stop
216 | holding_snake
217 | home_roasting_coffee
218 | hopscotch
219 | hoverboarding
220 | huddling
221 | hugging_(not_baby)
222 | hugging_baby
223 | hula_hooping
224 | hurdling
225 | hurling_(sport)
226 | ice_climbing
227 | ice_fishing
228 | ice_skating
229 | ice_swimming
230 | inflating_balloons
231 | installing_carpet
232 | ironing_hair
233 | ironing
234 | javelin_throw
235 | jaywalking
236 | jetskiing
237 | jogging
238 | juggling_balls
239 | juggling_fire
240 | juggling_soccer_ball
241 | jumping_bicycle
242 | jumping_into_pool
243 | jumping_jacks
244 | jumpstyle_dancing
245 | karaoke
246 | kicking_field_goal
247 | kicking_soccer_ball
248 | kissing
249 | kitesurfing
250 | knitting
251 | krumping
252 | land_sailing
253 | laughing
254 | lawn_mower_racing
255 | laying_bricks
256 | laying_concrete
257 | laying_stone
258 | laying_tiles
259 | leatherworking
260 | licking
261 | lifting_hat
262 | lighting_fire
263 | lock_picking
264 | long_jump
265 | longboarding
266 | looking_at_phone
267 | luge
268 | lunge
269 | making_a_cake
270 | making_a_sandwich
271 | making_balloon_shapes
272 | making_bubbles
273 | making_cheese
274 | making_horseshoes
275 | making_jewelry
276 | making_paper_aeroplanes
277 | making_pizza
278 | making_snowman
279 | making_sushi
280 | making_tea
281 | making_the_bed
282 | marching
283 | marriage_proposal
284 | massaging_back
285 | massaging_feet
286 | massaging_legs
287 | massaging_neck
288 | massaging_person's_head
289 | milking_cow
290 | moon_walking
291 | mopping_floor
292 | mosh_pit_dancing
293 | motorcycling
294 | mountain_climber_(exercise)
295 | moving_furniture
296 | mowing_lawn
297 | mushroom_foraging
298 | needle_felting
299 | news_anchoring
300 | opening_bottle_(not_wine)
301 | opening_door
302 | opening_present
303 | opening_refrigerator
304 | opening_wine_bottle
305 | packing
306 | paragliding
307 | parasailing
308 | parkour
309 | passing_American_football_(in_game)
310 | passing_american_football_(not_in_game)
311 | passing_soccer_ball
312 | peeling_apples
313 | peeling_potatoes
314 | person_collecting_garbage
315 | petting_animal_(not_cat)
316 | petting_cat
317 | photobombing
318 | photocopying
319 | picking_fruit
320 | pillow_fight
321 | pinching
322 | pirouetting
323 | planing_wood
324 | planting_trees
325 | plastering
326 | playing_accordion
327 | playing_badminton
328 | playing_bagpipes
329 | playing_basketball
330 | playing_bass_guitar
331 | playing_beer_pong
332 | playing_blackjack
333 | playing_cello
334 | playing_chess
335 | playing_clarinet
336 | playing_controller
337 | playing_cricket
338 | playing_cymbals
339 | playing_darts
340 | playing_didgeridoo
341 | playing_dominoes
342 | playing_drums
343 | playing_field_hockey
344 | playing_flute
345 | playing_gong
346 | playing_guitar
347 | playing_hand_clapping_games
348 | playing_harmonica
349 | playing_harp
350 | playing_ice_hockey
351 | playing_keyboard
352 | playing_kickball
353 | playing_laser_tag
354 | playing_lute
355 | playing_maracas
356 | playing_marbles
357 | playing_monopoly
358 | playing_netball
359 | playing_ocarina
360 | playing_organ
361 | playing_paintball
362 | playing_pan_pipes
363 | playing_piano
364 | playing_pinball
365 | playing_ping_pong
366 | playing_poker
367 | playing_polo
368 | playing_recorder
369 | playing_rubiks_cube
370 | playing_saxophone
371 | playing_scrabble
372 | playing_squash_or_racquetball
373 | playing_tennis
374 | playing_trombone
375 | playing_trumpet
376 | playing_ukulele
377 | playing_violin
378 | playing_volleyball
379 | playing_with_trains
380 | playing_xylophone
381 | poking_bellybutton
382 | pole_vault
383 | polishing_metal
384 | popping_balloons
385 | pouring_beer
386 | preparing_salad
387 | presenting_weather_forecast
388 | pull_ups
389 | pumping_fist
390 | pumping_gas
391 | punching_bag
392 | punching_person_(boxing)
393 | push_up
394 | pushing_car
395 | pushing_cart
396 | pushing_wheelbarrow
397 | pushing_wheelchair
398 | putting_in_contact_lenses
399 | putting_on_eyeliner
400 | putting_on_foundation
401 | putting_on_lipstick
402 | putting_on_mascara
403 | putting_on_sari
404 | putting_on_shoes
405 | raising_eyebrows
406 | reading_book
407 | reading_newspaper
408 | recording_music
409 | repairing_puncture
410 | riding_a_bike
411 | riding_camel
412 | riding_elephant
413 | riding_mechanical_bull
414 | riding_mule
415 | riding_or_walking_with_horse
416 | riding_scooter
417 | riding_snow_blower
418 | riding_unicycle
419 | ripping_paper
420 | roasting_marshmallows
421 | roasting_pig
422 | robot_dancing
423 | rock_climbing
424 | rock_scissors_paper
425 | roller_skating
426 | rolling_pastry
427 | rope_pushdown
428 | running_on_treadmill
429 | sailing
430 | salsa_dancing
431 | sanding_floor
432 | sausage_making
433 | sawing_wood
434 | scrambling_eggs
435 | scrapbooking
436 | scrubbing_face
437 | scuba_diving
438 | separating_eggs
439 | setting_table
440 | sewing
441 | shaking_hands
442 | shaking_head
443 | shaping_bread_dough
444 | sharpening_knives
445 | sharpening_pencil
446 | shaving_head
447 | shaving_legs
448 | shearing_sheep
449 | shining_flashlight
450 | shining_shoes
451 | shooting_basketball
452 | shooting_goal_(soccer)
453 | shopping
454 | shot_put
455 | shoveling_snow
456 | shucking_oysters
457 | shuffling_cards
458 | shuffling_feet
459 | side_kick
460 | sign_language_interpreting
461 | singing
462 | sipping_cup
463 | situp
464 | skateboarding
465 | ski_jumping
466 | skiing_crosscountry
467 | skiing_mono
468 | skiing_slalom
469 | skipping_rope
470 | skipping_stone
471 | skydiving
472 | slacklining
473 | slapping
474 | sled_dog_racing
475 | sleeping
476 | smashing
477 | smelling_feet
478 | smoking_hookah
479 | smoking_pipe
480 | smoking
481 | snatch_weight_lifting
482 | sneezing
483 | snorkeling
484 | snowboarding
485 | snowkiting
486 | snowmobiling
487 | somersaulting
488 | spelunking
489 | spinning_poi
490 | spray_painting
491 | springboard_diving
492 | square_dancing
493 | squat
494 | standing_on_hands
495 | staring
496 | steer_roping
497 | sticking_tongue_out
498 | stomping_grapes
499 | stretching_arm
500 | stretching_leg
501 | sucking_lolly
502 | surfing_crowd
503 | surfing_water
504 | sweeping_floor
505 | swimming_backstroke
506 | swimming_breast_stroke
507 | swimming_butterfly_stroke
508 | swimming_front_crawl
509 | swing_dancing
510 | swinging_baseball_bat
511 | swinging_on_something
512 | sword_fighting
513 | sword_swallowing
514 | tackling
515 | tagging_graffiti
516 | tai_chi
517 | talking_on_cell_phone
518 | tango_dancing
519 | tap_dancing
520 | tapping_guitar
521 | tapping_pen
522 | tasting_beer
523 | tasting_food
524 | tasting_wine
525 | testifying
526 | texting
527 | threading_needle
528 | throwing_axe
529 | throwing_ball_(not_baseball_or_American_football)
530 | throwing_discus
531 | throwing_knife
532 | throwing_snowballs
533 | throwing_tantrum
534 | throwing_water_balloon
535 | tickling
536 | tie_dying
537 | tightrope_walking
538 | tiptoeing
539 | tobogganing
540 | tossing_coin
541 | training_dog
542 | trapezing
543 | trimming_or_shaving_beard
544 | trimming_shrubs
545 | trimming_trees
546 | triple_jump
547 | twiddling_fingers
548 | tying_bow_tie
549 | tying_knot_(not_on_a_tie)
550 | tying_necktie
551 | tying_shoe_laces
552 | unboxing
553 | unloading_truck
554 | using_a_microscope
555 | using_a_paint_roller
556 | using_a_power_drill
557 | using_a_sledge_hammer
558 | using_a_wrench
559 | using_atm
560 | using_bagging_machine
561 | using_circular_saw
562 | using_inhaler
563 | using_puppets
564 | using_remote_controller_(not_gaming)
565 | using_segway
566 | vacuuming_floor
567 | visiting_the_zoo
568 | wading_through_mud
569 | wading_through_water
570 | waiting_in_line
571 | waking_up
572 | walking_the_dog
573 | walking_through_snow
574 | washing_dishes
575 | washing_feet
576 | washing_hair
577 | washing_hands
578 | watching_tv
579 | water_skiing
580 | water_sliding
581 | watering_plants
582 | waving_hand
583 | waxing_back
584 | waxing_chest
585 | waxing_eyebrows
586 | waxing_legs
587 | weaving_basket
588 | weaving_fabric
589 | welding
590 | whistling
591 | windsurfing
592 | winking
593 | wood_burning_(art)
594 | wrapping_present
595 | wrestling
596 | writing
597 | yarn_spinning
598 | yawning
599 | yoga
600 | zumba
601 |
--------------------------------------------------------------------------------
/pic/eccv_framework.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FuchenUSTC/AherNet/d3a34053b419345b79f2e395d05d60ac5e1cac98/pic/eccv_framework.JPG
--------------------------------------------------------------------------------
/tf_extended/__init__.py:
--------------------------------------------------------------------------------
1 | """TF Extended: additional metrics.
2 | """
3 |
4 | # pylint: disable=unused-import,line-too-long,g-importing-member,wildcard-import
5 | from tf_extended.metrics import *
6 | from tf_extended.tensors import *
7 | from tf_extended.bboxes import *
8 | from tf_extended.math import *
9 |
10 |
--------------------------------------------------------------------------------
/tf_extended/bboxes.py:
--------------------------------------------------------------------------------
1 | """TF Extended: additional bounding boxes methods.
2 | """
3 | import numpy as np
4 | import tensorflow as tf
5 |
6 | from tf_extended import tensors as tfe_tensors
7 | from tf_extended import math as tfe_math
8 |
9 |
10 | # =========================================================================== #
11 | # Standard boxes algorithms.
12 | # =========================================================================== #
13 | def bboxes_sort_all_classes(classes, scores, bboxes, top_k=400, scope=None):
14 | """Sort bounding boxes by decreasing order and keep only the top_k.
15 | Assume the input Tensors mix-up objects with different classes.
16 | Assume a batch-type input.
17 |
18 | Args:
19 | classes: Batch x N Tensor containing integer classes.
20 | scores: Batch x N Tensor containing float scores.
21 | bboxes: Batch x N x 4 Tensor containing boxes coordinates.
22 | top_k: Top_k boxes to keep.
23 | Return:
24 | classes, scores, bboxes: Sorted tensors of shape Batch x Top_k.
25 | """
26 | with tf.name_scope(scope, 'bboxes_sort', [classes, scores, bboxes]):
27 | scores, idxes = tf.nn.top_k(scores, k=top_k, sorted=True)
28 |
29 | # Trick to be able to use tf.gather: map for each element in the batch.
30 | def fn_gather(classes, bboxes, idxes):
31 | cl = tf.gather(classes, idxes)
32 | bb = tf.gather(bboxes, idxes)
33 | return [cl, bb]
34 | r = tf.map_fn(lambda x: fn_gather(x[0], x[1], x[2]),
35 | [classes, bboxes, idxes],
36 | dtype=[classes.dtype, bboxes.dtype],
37 | parallel_iterations=10,
38 | back_prop=False,
39 | swap_memory=False,
40 | infer_shape=True)
41 | classes = r[0]
42 | bboxes = r[1]
43 | return classes, scores, bboxes
44 |
45 | def bboxes_sort(scores, bboxes, top_k=400, scope=None):
46 | """Sort bounding boxes by decreasing order and keep only the top_k.
47 | If inputs are dictionnaries, assume every key is a different class.
48 | Assume a batch-type input.
49 |
50 | Args:
51 | scores: Batch x N Tensor/Dictionary containing float scores.
52 | bboxes: Batch x N x 4 Tensor/Dictionary containing boxes coordinates.
53 | top_k: Top_k boxes to keep.
54 | Return:
55 | scores, bboxes: Sorted Tensors/Dictionaries of shape Batch x Top_k x 1|4.
56 | """
57 | # Dictionaries as inputs.
58 | if isinstance(scores, dict) or isinstance(bboxes, dict):
59 | with tf.name_scope(scope, 'bboxes_sort_dict'):
60 | d_scores = {}
61 | d_bboxes = {}
62 | for c in scores.keys():
63 | s, b = bboxes_sort(scores[c], bboxes[c], top_k=top_k)
64 | d_scores[c] = s
65 | d_bboxes[c] = b
66 | return d_scores, d_bboxes
67 |
68 | # Tensors inputs.
69 | with tf.name_scope(scope, 'bboxes_sort', [scores, bboxes]):
70 | # Sort scores...
71 | scores, idxes = tf.nn.top_k(scores, k=top_k, sorted=True)
72 |
73 | # Trick to be able to use tf.gather: map for each element in the first dim.
74 | def fn_gather(bboxes, idxes):
75 | bb = tf.gather(bboxes, idxes)
76 | return [bb]
77 | r = tf.map_fn(lambda x: fn_gather(x[0], x[1]),
78 | [bboxes, idxes],
79 | dtype=[bboxes.dtype],
80 | parallel_iterations=10,
81 | back_prop=False,
82 | swap_memory=False,
83 | infer_shape=True)
84 | bboxes = r[0]
85 | return scores, bboxes
86 |
87 | def bboxes_clip(bbox_ref, bboxes, scope=None):
88 | """Clip bounding boxes to a reference box.
89 | Batch-compatible if the first dimension of `bbox_ref` and `bboxes`
90 | can be broadcasted.
91 |
92 | Args:
93 | bbox_ref: Reference bounding box. Nx4 or 4 shaped-Tensor;
94 | bboxes: Bounding boxes to clip. Nx4 or 4 shaped-Tensor or dictionary.
95 | Return:
96 | Clipped bboxes.
97 | """
98 | # Bboxes is dictionary.
99 | if isinstance(bboxes, dict):
100 | with tf.name_scope(scope, 'bboxes_clip_dict'):
101 | d_bboxes = {}
102 | for c in bboxes.keys():
103 | d_bboxes[c] = bboxes_clip(bbox_ref, bboxes[c])
104 | return d_bboxes
105 |
106 | # Tensors inputs.
107 | with tf.name_scope(scope, 'bboxes_clip'):
108 | # Easier with transposed bboxes. Especially for broadcasting.
109 | bbox_ref = tf.transpose(bbox_ref)
110 | bboxes = tf.transpose(bboxes)
111 | # Intersection bboxes and reference bbox.
112 | ymin = tf.maximum(bboxes[0], bbox_ref[0])
113 | xmin = tf.maximum(bboxes[1], bbox_ref[1])
114 | ymax = tf.minimum(bboxes[2], bbox_ref[2])
115 | xmax = tf.minimum(bboxes[3], bbox_ref[3])
116 | # Double check! Empty boxes when no-intersection.
117 | ymin = tf.minimum(ymin, ymax)
118 | xmin = tf.minimum(xmin, xmax)
119 | bboxes = tf.transpose(tf.stack([ymin, xmin, ymax, xmax], axis=0))
120 | return bboxes
121 |
122 | def bboxes_resize(bbox_ref, bboxes, name=None):
123 | """Resize bounding boxes based on a reference bounding box,
124 | assuming that the latter is [0, 0, 1, 1] after transform. Useful for
125 | updating a collection of boxes after cropping an image.
126 | """
127 | # Bboxes is dictionary.
128 | if isinstance(bboxes, dict):
129 | with tf.name_scope(name, 'bboxes_resize_dict'):
130 | d_bboxes = {}
131 | for c in bboxes.keys():
132 | d_bboxes[c] = bboxes_resize(bbox_ref, bboxes[c])
133 | return d_bboxes
134 |
135 | # Tensors inputs.
136 | with tf.name_scope(name, 'bboxes_resize'):
137 | # Translate.
138 | v = tf.stack([bbox_ref[0], bbox_ref[1], bbox_ref[0], bbox_ref[1]])
139 | bboxes = bboxes - v
140 | # Scale.
141 | s = tf.stack([bbox_ref[2] - bbox_ref[0],
142 | bbox_ref[3] - bbox_ref[1],
143 | bbox_ref[2] - bbox_ref[0],
144 | bbox_ref[3] - bbox_ref[1]])
145 | bboxes = bboxes / s
146 | return bboxes
147 |
148 | def bbox_1d_temporal_nms(bboxes, scores, keep_top_k, nms_threshold):
149 | t1 = bboxes[:,0]
150 | t2 = bboxes[:,1]
151 | score = scores
152 | union = t2 - t1
153 | order = tf.contrib.framework.argsort(score,direction='DESCENDING')
154 | keep = []
155 |
156 | def condition(keep, order):
157 | r = tf.greater(tf.size(order),0)
158 | return r
159 |
160 | def body(keep, order):
161 | i = order[0]
162 | keep.append(i)
163 | tt1 = tf.maximum(t1[i],tf.gather(t1,order[1:]))
164 | tt2 = tf.minimum(t2[i],tf.gather(t2,order[1:]))
165 | w = tf.maximum(0.0,tt2-tt1)
166 | inte = w
167 | ovr = inte/(union[i]+tf.gather(union,order[1:]) - inte)
168 | inds = tf.where(ovr<=nms_threshold)[:,-1]
169 | order = tf.gather(order,inds+1)
170 | return keep, order
171 |
172 | #[keep, order] = tf.while_loop(condition, body, [keep, order])
173 | topk = tf.minimum(len(keep), keep_top_k)
174 | return keep[:topk]
175 |
176 | def bboxes_nms(scores, bboxes, nms_threshold=0.5, keep_top_k=200, scope=None):
177 | """Apply non-maximum selection to bounding boxes. In comparison to TF
178 | implementation, use classes information for matching.
179 | Should only be used on single-entries. Use batch version otherwise.
180 |
181 | Args:
182 | scores: N Tensor containing float scores.
183 | bboxes: N x 4 Tensor containing boxes coordinates.
184 | nms_threshold: Matching threshold in NMS algorithm;
185 | keep_top_k: Number of total object to keep after NMS.
186 | Return:
187 | classes, scores, bboxes Tensors, sorted by score.
188 | Padded with zero if necessary.
189 | """
190 | with tf.name_scope(scope, 'bboxes_nms_single', [scores, bboxes]):
191 | # Apply NMS algorithm.
192 | shape = (bboxes.shape[0], 4)
193 | boxxes_n = []
194 | shape1= (bboxes.shape[0])
195 | ones_m = tf.ones(shape1, dtype=tf.float32)
196 | zeros_m = tf.zeros(shape1,dtype=tf.float32)
197 | boxxes_n = tf.stack([bboxes[:,0],zeros_m,bboxes[:,1],ones_m],axis=1)
198 |
199 | idxes = tf.image.non_max_suppression(boxxes_n, scores,
200 | keep_top_k, nms_threshold)
201 | #idxes = bbox_1d_temporal_nms(bboxes, scores, keep_top_k, nms_threshold)
202 | scores = tf.gather(scores, idxes)
203 | bboxes = tf.gather(bboxes, idxes)
204 | # Pad results.
205 | scores = tfe_tensors.pad_axis(scores, 0, keep_top_k, axis=0)
206 | bboxes = tfe_tensors.pad_axis(bboxes, 0, keep_top_k, axis=0)
207 | return scores, bboxes
208 |
209 | def bboxes_nms_batch(scores, bboxes, nms_threshold=0.5, keep_top_k=200,
210 | scope=None):
211 | """Apply non-maximum selection to bounding boxes. In comparison to TF
212 | implementation, use classes information for matching.
213 | Use only on batched-inputs. Use zero-padding in order to batch output
214 | results.
215 |
216 | Args:
217 | scores: Batch x N Tensor/Dictionary containing float scores.
218 | bboxes: Batch x N x 4 Tensor/Dictionary containing boxes coordinates.
219 | nms_threshold: Matching threshold in NMS algorithm;
220 | keep_top_k: Number of total object to keep after NMS.
221 | Return:
222 | scores, bboxes Tensors/Dictionaries, sorted by score.
223 | Padded with zero if necessary.
224 | """
225 | # Dictionaries as inputs.
226 | if isinstance(scores, dict) or isinstance(bboxes, dict):
227 | with tf.name_scope(scope, 'bboxes_nms_batch_dict'):
228 | d_scores = {}
229 | d_bboxes = {}
230 | for c in scores.keys():
231 | s, b = bboxes_nms_batch(scores[c], bboxes[c],
232 | nms_threshold=nms_threshold,
233 | keep_top_k=keep_top_k)
234 | d_scores[c] = s
235 | d_bboxes[c] = b
236 | return d_scores, d_bboxes
237 |
238 | # Tensors inputs.
239 | with tf.name_scope(scope, 'bboxes_nms_batch'):
240 | r = tf.map_fn(lambda x: bboxes_nms(x[0], x[1],
241 | nms_threshold, keep_top_k),
242 | (scores, bboxes),
243 | dtype=(scores.dtype, bboxes.dtype),
244 | parallel_iterations=10,
245 | back_prop=False,
246 | swap_memory=False,
247 | infer_shape=True)
248 | scores, bboxes = r
249 | return scores, bboxes
250 |
251 | def bboxes_matching(label, scores, bboxes,
252 | glabels, gbboxes, gdifficults,
253 | matching_threshold=0.5, scope=None):
254 | """Matching a collection of detected boxes with groundtruth values.
255 | Does not accept batched-inputs.
256 | The algorithm goes as follows: for every detected box, check
257 | if one grountruth box is matching. If none, then considered as False Positive.
258 | If the grountruth box is already matched with another one, it also counts
259 | as a False Positive. We refer the Pascal VOC documentation for the details.
260 |
261 | Args:
262 | rclasses, rscores, rbboxes: N(x4) Tensors. Detected objects, sorted by score;
263 | glabels, gbboxes: Groundtruth bounding boxes. May be zero padded, hence
264 | zero-class objects are ignored.
265 | matching_threshold: Threshold for a positive match.
266 | Return: Tuple of:
267 | n_gbboxes: Scalar Tensor with number of groundtruth boxes (may difer from
268 | size because of zero padding).
269 | tp_match: (N,)-shaped boolean Tensor containing with True Positives.
270 | fp_match: (N,)-shaped boolean Tensor containing with False Positives.
271 | """
272 | with tf.name_scope(scope, 'bboxes_matching_single',
273 | [scores, bboxes, glabels, gbboxes]):
274 | rsize = tf.size(scores)
275 | rshape = tf.shape(scores)
276 | rlabel = tf.cast(label, glabels.dtype)
277 | # Number of groundtruth boxes.
278 | gdifficults = tf.cast(gdifficults, tf.bool)
279 | n_gbboxes = tf.count_nonzero(tf.logical_and(tf.equal(glabels, label),
280 | tf.logical_not(gdifficults)))
281 | # Grountruth matching arrays.
282 | gmatch = tf.zeros(tf.shape(glabels), dtype=tf.bool)
283 | grange = tf.range(tf.size(glabels), dtype=tf.int32)
284 | # True/False positive matching TensorArrays.
285 | sdtype = tf.bool
286 | ta_tp_bool = tf.TensorArray(sdtype, size=rsize, dynamic_size=False, infer_shape=True)
287 | ta_fp_bool = tf.TensorArray(sdtype, size=rsize, dynamic_size=False, infer_shape=True)
288 |
289 | # Loop over returned objects.
290 | def m_condition(i, ta_tp, ta_fp, gmatch):
291 | r = tf.less(i, rsize)
292 | return r
293 |
294 | def m_body(i, ta_tp, ta_fp, gmatch):
295 | # Jaccard score with groundtruth bboxes.
296 | rbbox = bboxes[i]
297 | jaccard = bboxes_jaccard(rbbox, gbboxes)
298 | jaccard = jaccard * tf.cast(tf.equal(glabels, rlabel), dtype=jaccard.dtype)
299 |
300 | # Best fit, checking it's above threshold.
301 | idxmax = tf.cast(tf.argmax(jaccard, axis=0), tf.int32)
302 | jcdmax = jaccard[idxmax]
303 | match = jcdmax > matching_threshold
304 | existing_match = gmatch[idxmax]
305 | not_difficult = tf.logical_not(gdifficults[idxmax])
306 |
307 | # TP: match & no previous match and FP: previous match | no match.
308 | # If difficult: no record, i.e FP=False and TP=False.
309 | tp = tf.logical_and(not_difficult,
310 | tf.logical_and(match, tf.logical_not(existing_match)))
311 | ta_tp = ta_tp.write(i, tp)
312 | fp = tf.logical_and(not_difficult,
313 | tf.logical_or(existing_match, tf.logical_not(match)))
314 | ta_fp = ta_fp.write(i, fp)
315 | # Update grountruth match.
316 | mask = tf.logical_and(tf.equal(grange, idxmax),
317 | tf.logical_and(not_difficult, match))
318 | gmatch = tf.logical_or(gmatch, mask)
319 |
320 | return [i+1, ta_tp, ta_fp, gmatch]
321 | # Main loop definition.
322 | i = 0
323 | [i, ta_tp_bool, ta_fp_bool, gmatch] = \
324 | tf.while_loop(m_condition, m_body,
325 | [i, ta_tp_bool, ta_fp_bool, gmatch],
326 | parallel_iterations=1,
327 | back_prop=False)
328 | # TensorArrays to Tensors and reshape.
329 | tp_match = tf.reshape(ta_tp_bool.stack(), rshape)
330 | fp_match = tf.reshape(ta_fp_bool.stack(), rshape)
331 |
332 | # Some debugging information...
333 | # tp_match = tf.Print(tp_match,
334 | # [n_gbboxes,
335 | # tf.reduce_sum(tf.cast(tp_match, tf.int64)),
336 | # tf.reduce_sum(tf.cast(fp_match, tf.int64)),
337 | # tf.reduce_sum(tf.cast(gmatch, tf.int64))],
338 | # 'Matching (NG, TP, FP, GM): ')
339 | return n_gbboxes, tp_match, fp_match
340 |
341 | def bboxes_matching_batch(labels, scores, bboxes,
342 | glabels, gbboxes, gdifficults,
343 | matching_threshold=0.5, scope=None):
344 | """Matching a collection of detected boxes with groundtruth values.
345 | Batched-inputs version.
346 |
347 | Args:
348 | rclasses, rscores, rbboxes: BxN(x4) Tensors. Detected objects, sorted by score;
349 | glabels, gbboxes: Groundtruth bounding boxes. May be zero padded, hence
350 | zero-class objects are ignored.
351 | matching_threshold: Threshold for a positive match.
352 | Return: Tuple or Dictionaries with:
353 | n_gbboxes: Scalar Tensor with number of groundtruth boxes (may difer from
354 | size because of zero padding).
355 | tp: (B, N)-shaped boolean Tensor containing with True Positives.
356 | fp: (B, N)-shaped boolean Tensor containing with False Positives.
357 | """
358 | # Dictionaries as inputs.
359 | if isinstance(scores, dict) or isinstance(bboxes, dict):
360 | with tf.name_scope(scope, 'bboxes_matching_batch_dict'):
361 | d_n_gbboxes = {}
362 | d_tp = {}
363 | d_fp = {}
364 | for c in labels:
365 | n, tp, fp, _ = bboxes_matching_batch(c, scores[c], bboxes[c],
366 | glabels, gbboxes, gdifficults,
367 | matching_threshold)
368 | d_n_gbboxes[c] = n
369 | d_tp[c] = tp
370 | d_fp[c] = fp
371 | return d_n_gbboxes, d_tp, d_fp, scores
372 |
373 | with tf.name_scope(scope, 'bboxes_matching_batch',
374 | [scores, bboxes, glabels, gbboxes]):
375 | r = tf.map_fn(lambda x: bboxes_matching(labels, x[0], x[1],
376 | x[2], x[3], x[4],
377 | matching_threshold),
378 | (scores, bboxes, glabels, gbboxes, gdifficults),
379 | dtype=(tf.int64, tf.bool, tf.bool),
380 | parallel_iterations=10,
381 | back_prop=False,
382 | swap_memory=True,
383 | infer_shape=True)
384 | return r[0], r[1], r[2], scores
385 |
386 |
387 | # =========================================================================== #
388 | # Some filteting methods.
389 | # =========================================================================== #
390 | def bboxes_filter_center(labels, bboxes, margins=[0., 0., 0., 0.],
391 | scope=None):
392 | """Filter out bounding boxes whose center are not in
393 | the rectangle [0, 0, 1, 1] + margins. The margin Tensor
394 | can be used to enforce or loosen this condition.
395 |
396 | Return:
397 | labels, bboxes: Filtered elements.
398 | """
399 | with tf.name_scope(scope, 'bboxes_filter', [labels, bboxes]):
400 | cy = (bboxes[:, 0] + bboxes[:, 2]) / 2.
401 | cx = (bboxes[:, 1] + bboxes[:, 3]) / 2.
402 | mask = tf.greater(cy, margins[0])
403 | mask = tf.logical_and(mask, tf.greater(cx, margins[1]))
404 | mask = tf.logical_and(mask, tf.less(cx, 1. + margins[2]))
405 | mask = tf.logical_and(mask, tf.less(cx, 1. + margins[3]))
406 | # Boolean masking...
407 | labels = tf.boolean_mask(labels, mask)
408 | bboxes = tf.boolean_mask(bboxes, mask)
409 | return labels, bboxes
410 |
411 | def bboxes_filter_overlap(labels, bboxes,
412 | threshold=0.5, assign_negative=False,
413 | scope=None):
414 | """Filter out bounding boxes based on (relative )overlap with reference
415 | box [0, 0, 1, 1]. Remove completely bounding boxes, or assign negative
416 | labels to the one outside (useful for latter processing...).
417 |
418 | Return:
419 | labels, bboxes: Filtered (or newly assigned) elements.
420 | """
421 | with tf.name_scope(scope, 'bboxes_filter', [labels, bboxes]):
422 | scores = bboxes_intersection(tf.constant([0, 0, 1, 1], bboxes.dtype),
423 | bboxes)
424 | mask = scores > threshold
425 | if assign_negative:
426 | labels = tf.where(mask, labels, -labels)
427 | # bboxes = tf.where(mask, bboxes, bboxes)
428 | else:
429 | labels = tf.boolean_mask(labels, mask)
430 | bboxes = tf.boolean_mask(bboxes, mask)
431 | return labels, bboxes
432 |
433 | def bboxes_filter_labels(labels, bboxes,
434 | out_labels=[], num_classes=np.inf,
435 | scope=None):
436 | """Filter out labels from a collection. Typically used to get
437 | of DontCare elements. Also remove elements based on the number of classes.
438 |
439 | Return:
440 | labels, bboxes: Filtered elements.
441 | """
442 | with tf.name_scope(scope, 'bboxes_filter_labels', [labels, bboxes]):
443 | mask = tf.greater_equal(labels, num_classes)
444 | for l in labels:
445 | mask = tf.logical_and(mask, tf.not_equal(labels, l))
446 | labels = tf.boolean_mask(labels, mask)
447 | bboxes = tf.boolean_mask(bboxes, mask)
448 | return labels, bboxes
449 |
450 | # =========================================================================== #
451 | # Standard boxes computation.
452 | # =========================================================================== #
453 | def bboxes_jaccard(bbox_ref, bboxes, name=None):
454 | """Compute jaccard score between a reference box and a collection
455 | of bounding boxes.
456 |
457 | Args:
458 | bbox_ref: (N, 4) or (4,) Tensor with reference bounding box(es).
459 | bboxes: (N, 4) Tensor, collection of bounding boxes.
460 | Return:
461 | (N,) Tensor with Jaccard scores.
462 | """
463 | with tf.name_scope(name, 'bboxes_jaccard'):
464 | # Should be more efficient to first transpose.
465 | bboxes = tf.transpose(bboxes)
466 | bbox_ref = tf.transpose(bbox_ref)
467 | # Intersection bbox and volume.
468 | int_ymin = tf.maximum(bboxes[0], bbox_ref[0])
469 | int_xmin = tf.maximum(bboxes[1], bbox_ref[1])
470 | int_ymax = tf.minimum(bboxes[2], bbox_ref[2])
471 | int_xmax = tf.minimum(bboxes[3], bbox_ref[3])
472 | h = tf.maximum(int_ymax - int_ymin, 0.)
473 | w = tf.maximum(int_xmax - int_xmin, 0.)
474 | # Volumes.
475 | inter_vol = h * w
476 | union_vol = -inter_vol \
477 | + (bboxes[2] - bboxes[0]) * (bboxes[3] - bboxes[1]) \
478 | + (bbox_ref[2] - bbox_ref[0]) * (bbox_ref[3] - bbox_ref[1])
479 | jaccard = tfe_math.safe_divide(inter_vol, union_vol, 'jaccard')
480 | return jaccard
481 |
482 | def bboxes_intersection(bbox_ref, bboxes, name=None):
483 | """Compute relative intersection between a reference box and a
484 | collection of bounding boxes. Namely, compute the quotient between
485 | intersection area and box area.
486 |
487 | Args:
488 | bbox_ref: (N, 4) or (4,) Tensor with reference bounding box(es).
489 | bboxes: (N, 4) Tensor, collection of bounding boxes.
490 | Return:
491 | (N,) Tensor with relative intersection.
492 | """
493 | with tf.name_scope(name, 'bboxes_intersection'):
494 | # Should be more efficient to first transpose.
495 | bboxes = tf.transpose(bboxes)
496 | bbox_ref = tf.transpose(bbox_ref)
497 | # Intersection bbox and volume.
498 | int_ymin = tf.maximum(bboxes[0], bbox_ref[0])
499 | int_xmin = tf.maximum(bboxes[1], bbox_ref[1])
500 | int_ymax = tf.minimum(bboxes[2], bbox_ref[2])
501 | int_xmax = tf.minimum(bboxes[3], bbox_ref[3])
502 | h = tf.maximum(int_ymax - int_ymin, 0.)
503 | w = tf.maximum(int_xmax - int_xmin, 0.)
504 | # Volumes.
505 | inter_vol = h * w
506 | bboxes_vol = (bboxes[2] - bboxes[0]) * (bboxes[3] - bboxes[1])
507 | scores = tfe_math.safe_divide(inter_vol, bboxes_vol, 'intersection')
508 | return scores
509 |
--------------------------------------------------------------------------------
/tf_extended/math.py:
--------------------------------------------------------------------------------
1 | """TF Extended: additional math functions.
2 | """
3 | import tensorflow as tf
4 |
5 | from tensorflow.python.ops import array_ops
6 | from tensorflow.python.ops import math_ops
7 | from tensorflow.python.framework import dtypes
8 | from tensorflow.python.framework import ops
9 |
10 | def safe_divide(numerator, denominator, name):
11 | """Divides two values, returning 0 if the denominator is <= 0.
12 | Args:
13 | numerator: A real `Tensor`.
14 | denominator: A real `Tensor`, with dtype matching `numerator`.
15 | name: Name for the returned op.
16 | Returns:
17 | 0 if `denominator` <= 0, else `numerator` / `denominator`
18 | """
19 | return tf.where(
20 | math_ops.greater(denominator, 0),
21 | math_ops.divide(numerator, denominator),
22 | tf.zeros_like(numerator),
23 | name=name)
24 |
25 | def cummax(x, reverse=False, name=None):
26 | """Compute the cumulative maximum of the tensor `x` along `axis`. This
27 | operation is similar to the more classic `cumsum`. Only support 1D Tensor
28 | for now.
29 |
30 | Args:
31 | x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
32 | `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
33 | `complex128`, `qint8`, `quint8`, `qint32`, `half`.
34 | axis: A `Tensor` of type `int32` (default: 0).
35 | reverse: A `bool` (default: False).
36 | name: A name for the operation (optional).
37 | Returns:
38 | A `Tensor`. Has the same type as `x`.
39 | """
40 | with ops.name_scope(name, "Cummax", [x]) as name:
41 | x = ops.convert_to_tensor(x, name="x")
42 | # Not very optimal: should directly integrate reverse into tf.scan.
43 | if reverse:
44 | x = tf.reverse(x, axis=[0])
45 | # 'Accumlating' maximum: ensure it is always increasing.
46 | cmax = tf.scan(lambda a, y: tf.maximum(a, y), x,
47 | initializer=None, parallel_iterations=1,
48 | back_prop=False, swap_memory=False)
49 | if reverse:
50 | cmax = tf.reverse(cmax, axis=[0])
51 | return cmax
52 |
--------------------------------------------------------------------------------
/tf_extended/metrics.py:
--------------------------------------------------------------------------------
1 | """TF Extended: additional metrics.
2 | """
3 | import tensorflow as tf
4 | import numpy as np
5 |
6 | from tensorflow.contrib.framework.python.ops import variables as contrib_variables
7 | from tensorflow.python.framework import dtypes
8 | from tensorflow.python.framework import ops
9 | from tensorflow.python.ops import array_ops
10 | from tensorflow.python.ops import math_ops
11 | from tensorflow.python.ops import nn
12 | from tensorflow.python.ops import state_ops
13 | from tensorflow.python.ops import variable_scope
14 | from tensorflow.python.ops import variables
15 |
16 | from tf_extended import math as tfe_math
17 |
18 |
19 | # =========================================================================== #
20 | # TensorFlow utils
21 | # =========================================================================== #
22 | def _create_local(name, shape, collections=None, validate_shape=True,
23 | dtype=dtypes.float32):
24 | """Creates a new local variable.
25 | Args:
26 | name: The name of the new or existing variable.
27 | shape: Shape of the new or existing variable.
28 | collections: A list of collection names to which the Variable will be added.
29 | validate_shape: Whether to validate the shape of the variable.
30 | dtype: Data type of the variables.
31 | Returns:
32 | The created variable.
33 | """
34 | # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES
35 | collections = list(collections or [])
36 | collections += [ops.GraphKeys.LOCAL_VARIABLES]
37 | return variables.Variable(
38 | initial_value=array_ops.zeros(shape, dtype=dtype),
39 | name=name,
40 | trainable=False,
41 | collections=collections,
42 | validate_shape=validate_shape)
43 |
44 | def _safe_div(numerator, denominator, name):
45 | """Divides two values, returning 0 if the denominator is <= 0.
46 | Args:
47 | numerator: A real `Tensor`.
48 | denominator: A real `Tensor`, with dtype matching `numerator`.
49 | name: Name for the returned op.
50 | Returns:
51 | 0 if `denominator` <= 0, else `numerator` / `denominator`
52 | """
53 | return tf.where(
54 | math_ops.greater(denominator, 0),
55 | math_ops.divide(numerator, denominator),
56 | tf.zeros_like(numerator),
57 | name=name)
58 |
59 | def _broadcast_weights(weights, values):
60 | """Broadcast `weights` to the same shape as `values`.
61 | This returns a version of `weights` following the same broadcast rules as
62 | `mul(weights, values)`. When computing a weighted average, use this function
63 | to broadcast `weights` before summing them; e.g.,
64 | `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`.
65 | Args:
66 | weights: `Tensor` whose shape is broadcastable to `values`.
67 | values: `Tensor` of any shape.
68 | Returns:
69 | `weights` broadcast to `values` shape.
70 | """
71 | weights_shape = weights.get_shape()
72 | values_shape = values.get_shape()
73 | if(weights_shape.is_fully_defined() and
74 | values_shape.is_fully_defined() and
75 | weights_shape.is_compatible_with(values_shape)):
76 | return weights
77 | return math_ops.mul(
78 | weights, array_ops.ones_like(values), name='broadcast_weights')
79 |
80 | # =========================================================================== #
81 | # TF Extended metrics: TP and FP arrays.
82 | # =========================================================================== #
83 | def precision_recall(num_gbboxes, num_detections, tp, fp, scores,
84 | dtype=tf.float64, scope=None):
85 | """Compute precision and recall from scores, true positives and false
86 | positives booleans arrays
87 | """
88 | # Input dictionaries: dict outputs as streaming metrics.
89 | if isinstance(scores, dict):
90 | d_precision = {}
91 | d_recall = {}
92 | for c in num_gbboxes.keys():
93 | scope = 'precision_recall_%s' % c
94 | p, r = precision_recall(num_gbboxes[c], num_detections[c],
95 | tp[c], fp[c], scores[c],
96 | dtype, scope)
97 | d_precision[c] = p
98 | d_recall[c] = r
99 | return d_precision, d_recall
100 |
101 | # Sort by score.
102 | with tf.name_scope(scope, 'precision_recall',
103 | [num_gbboxes, num_detections, tp, fp, scores]):
104 | # Sort detections by score.
105 | scores, idxes = tf.nn.top_k(scores, k=num_detections, sorted=True)
106 | tp = tf.gather(tp, idxes)
107 | fp = tf.gather(fp, idxes)
108 | # Computer recall and precision.
109 | tp = tf.cumsum(tf.cast(tp, dtype), axis=0)
110 | fp = tf.cumsum(tf.cast(fp, dtype), axis=0)
111 | recall = _safe_div(tp, tf.cast(num_gbboxes, dtype), 'recall')
112 | precision = _safe_div(tp, tp + fp, 'precision')
113 | return tf.tuple([precision, recall])
114 |
115 | def streaming_tp_fp_arrays(num_gbboxes, tp, fp, scores,
116 | remove_zero_scores=True,
117 | metrics_collections=None,
118 | updates_collections=None,
119 | name=None):
120 | """Streaming computation of True and False Positive arrays. This metrics
121 | also keeps track of scores and number of grountruth objects.
122 | """
123 | # Input dictionaries: dict outputs as streaming metrics.
124 | if isinstance(scores, dict) or isinstance(fp, dict):
125 | d_values = {}
126 | d_update_ops = {}
127 | for c in num_gbboxes.keys():
128 | scope = 'streaming_tp_fp_%s' % c
129 | v, up = streaming_tp_fp_arrays(num_gbboxes[c], tp[c], fp[c], scores[c],
130 | remove_zero_scores,
131 | metrics_collections,
132 | updates_collections,
133 | name=scope)
134 | d_values[c] = v
135 | d_update_ops[c] = up
136 | return d_values, d_update_ops
137 |
138 | # Input Tensors...
139 | with variable_scope.variable_scope(name, 'streaming_tp_fp',
140 | [num_gbboxes, tp, fp, scores]):
141 | num_gbboxes = math_ops.to_int64(num_gbboxes)
142 | scores = math_ops.to_float(scores)
143 | stype = tf.bool
144 | tp = tf.cast(tp, stype)
145 | fp = tf.cast(fp, stype)
146 | # Reshape TP and FP tensors and clean away 0 class values.
147 | scores = tf.reshape(scores, [-1])
148 | tp = tf.reshape(tp, [-1])
149 | fp = tf.reshape(fp, [-1])
150 | # Remove TP and FP both false.
151 | mask = tf.logical_or(tp, fp)
152 | if remove_zero_scores:
153 | rm_threshold = 1e-4
154 | mask = tf.logical_and(mask, tf.greater(scores, rm_threshold))
155 | scores = tf.boolean_mask(scores, mask)
156 | tp = tf.boolean_mask(tp, mask)
157 | fp = tf.boolean_mask(fp, mask)
158 |
159 | # Local variables accumlating information over batches.
160 | v_nobjects = _create_local('v_num_gbboxes', shape=[], dtype=tf.int64)
161 | v_ndetections = _create_local('v_num_detections', shape=[], dtype=tf.int32)
162 | v_scores = _create_local('v_scores', shape=[0, ])
163 | v_tp = _create_local('v_tp', shape=[0, ], dtype=stype)
164 | v_fp = _create_local('v_fp', shape=[0, ], dtype=stype)
165 |
166 | # Update operations.
167 | nobjects_op = state_ops.assign_add(v_nobjects,
168 | tf.reduce_sum(num_gbboxes))
169 | ndetections_op = state_ops.assign_add(v_ndetections,
170 | tf.size(scores, out_type=tf.int32))
171 | scores_op = state_ops.assign(v_scores, tf.concat([v_scores, scores], axis=0),
172 | validate_shape=False)
173 | tp_op = state_ops.assign(v_tp, tf.concat([v_tp, tp], axis=0),
174 | validate_shape=False)
175 | fp_op = state_ops.assign(v_fp, tf.concat([v_fp, fp], axis=0),
176 | validate_shape=False)
177 |
178 | # Value and update ops.
179 | val = (v_nobjects, v_ndetections, v_tp, v_fp, v_scores)
180 | with ops.control_dependencies([nobjects_op, ndetections_op,
181 | scores_op, tp_op, fp_op]):
182 | update_op = (nobjects_op, ndetections_op, tp_op, fp_op, scores_op)
183 |
184 | if metrics_collections:
185 | ops.add_to_collections(metrics_collections, val)
186 | if updates_collections:
187 | ops.add_to_collections(updates_collections, update_op)
188 | return val, update_op
189 |
190 | # =========================================================================== #
191 | # Average precision computations.
192 | # =========================================================================== #
193 | def average_precision_voc12(precision, recall, name=None):
194 | """Compute (interpolated) average precision from precision and recall Tensors.
195 |
196 | The implementation follows Pascal 2012 and ILSVRC guidelines.
197 | See also: https://sanchom.wordpress.com/tag/average-precision/
198 | """
199 | with tf.name_scope(name, 'average_precision_voc12', [precision, recall]):
200 | # Convert to float64 to decrease error on Riemann sums.
201 | precision = tf.cast(precision, dtype=tf.float64)
202 | recall = tf.cast(recall, dtype=tf.float64)
203 |
204 | # Add bounds values to precision and recall.
205 | precision = tf.concat([[0.], precision, [0.]], axis=0)
206 | recall = tf.concat([[0.], recall, [1.]], axis=0)
207 | # Ensures precision is increasing in reverse order.
208 | precision = tfe_math.cummax(precision, reverse=True)
209 |
210 | # Riemann sums for estimating the integral.
211 | # mean_pre = (precision[1:] + precision[:-1]) / 2.
212 | mean_pre = precision[1:]
213 | diff_rec = recall[1:] - recall[:-1]
214 | ap = tf.reduce_sum(mean_pre * diff_rec)
215 | return ap
216 |
217 | def average_precision_voc07(precision, recall, name=None):
218 | """Compute (interpolated) average precision from precision and recall Tensors.
219 |
220 | The implementation follows Pascal 2007 guidelines.
221 | See also: https://sanchom.wordpress.com/tag/average-precision/
222 | """
223 | with tf.name_scope(name, 'average_precision_voc07', [precision, recall]):
224 | # Convert to float64 to decrease error on cumulated sums.
225 | precision = tf.cast(precision, dtype=tf.float64)
226 | recall = tf.cast(recall, dtype=tf.float64)
227 | # Add zero-limit value to avoid any boundary problem...
228 | precision = tf.concat([precision, [0.]], axis=0)
229 | recall = tf.concat([recall, [np.inf]], axis=0)
230 |
231 | # Split the integral into 10 bins.
232 | l_aps = []
233 | for t in np.arange(0., 1.1, 0.1):
234 | mask = tf.greater_equal(recall, t)
235 | v = tf.reduce_max(tf.boolean_mask(precision, mask))
236 | l_aps.append(v / 11.)
237 | ap = tf.add_n(l_aps)
238 | return ap
239 |
240 | def precision_recall_values(xvals, precision, recall, name=None):
241 | """Compute values on the precision/recall curve.
242 |
243 | Args:
244 | x: Python list of floats;
245 | precision: 1D Tensor decreasing.
246 | recall: 1D Tensor increasing.
247 | Return:
248 | list of precision values.
249 | """
250 | with ops.name_scope(name, "precision_recall_values",
251 | [precision, recall]) as name:
252 | # Add bounds values to precision and recall.
253 | precision = tf.concat([[0.], precision, [0.]], axis=0)
254 | recall = tf.concat([[0.], recall, [1.]], axis=0)
255 | precision = tfe_math.cummax(precision, reverse=True)
256 |
257 | prec_values = []
258 | for x in xvals:
259 | mask = tf.less_equal(recall, x)
260 | val = tf.reduce_min(tf.boolean_mask(precision, mask))
261 | prec_values.append(val)
262 | return tf.tuple(prec_values)
263 |
264 | # =========================================================================== #
265 | # TF Extended metrics: old stuff!
266 | # =========================================================================== #
267 | def _precision_recall(n_gbboxes, n_detections, scores, tp, fp, scope=None):
268 | """Compute precision and recall from scores, true positives and false
269 | positives booleans arrays
270 | """
271 | # Sort by score.
272 | with tf.name_scope(scope, 'prec_rec', [n_gbboxes, scores, tp, fp]):
273 | # Sort detections by score.
274 | scores, idxes = tf.nn.top_k(scores, k=n_detections, sorted=True)
275 | tp = tf.gather(tp, idxes)
276 | fp = tf.gather(fp, idxes)
277 | # Computer recall and precision.
278 | dtype = tf.float64
279 | tp = tf.cumsum(tf.cast(tp, dtype), axis=0)
280 | fp = tf.cumsum(tf.cast(fp, dtype), axis=0)
281 | recall = _safe_div(tp, tf.cast(n_gbboxes, dtype), 'recall')
282 | precision = _safe_div(tp, tp + fp, 'precision')
283 |
284 | return tf.tuple([precision, recall])
285 |
286 |
287 | def streaming_precision_recall_arrays(n_gbboxes, rclasses, rscores,
288 | tp_tensor, fp_tensor,
289 | remove_zero_labels=True,
290 | metrics_collections=None,
291 | updates_collections=None,
292 | name=None):
293 | """Streaming computation of precision / recall arrays. This metrics
294 | keeps tracks of boolean True positives and False positives arrays.
295 | """
296 | with variable_scope.variable_scope(name, 'stream_precision_recall',
297 | [n_gbboxes, rclasses, tp_tensor, fp_tensor]):
298 | n_gbboxes = math_ops.to_int64(n_gbboxes)
299 | rclasses = math_ops.to_int64(rclasses)
300 | rscores = math_ops.to_float(rscores)
301 |
302 | stype = tf.int32
303 | tp_tensor = tf.cast(tp_tensor, stype)
304 | fp_tensor = tf.cast(fp_tensor, stype)
305 |
306 | # Reshape TP and FP tensors and clean away 0 class values.
307 | rclasses = tf.reshape(rclasses, [-1])
308 | rscores = tf.reshape(rscores, [-1])
309 | tp_tensor = tf.reshape(tp_tensor, [-1])
310 | fp_tensor = tf.reshape(fp_tensor, [-1])
311 | if remove_zero_labels:
312 | mask = tf.greater(rclasses, 0)
313 | rclasses = tf.boolean_mask(rclasses, mask)
314 | rscores = tf.boolean_mask(rscores, mask)
315 | tp_tensor = tf.boolean_mask(tp_tensor, mask)
316 | fp_tensor = tf.boolean_mask(fp_tensor, mask)
317 |
318 | # Local variables accumlating information over batches.
319 | v_nobjects = _create_local('v_nobjects', shape=[], dtype=tf.int64)
320 | v_ndetections = _create_local('v_ndetections', shape=[], dtype=tf.int32)
321 | v_scores = _create_local('v_scores', shape=[0, ])
322 | v_tp = _create_local('v_tp', shape=[0, ], dtype=stype)
323 | v_fp = _create_local('v_fp', shape=[0, ], dtype=stype)
324 |
325 | # Update operations.
326 | nobjects_op = state_ops.assign_add(v_nobjects,
327 | tf.reduce_sum(n_gbboxes))
328 | ndetections_op = state_ops.assign_add(v_ndetections,
329 | tf.size(rscores, out_type=tf.int32))
330 | scores_op = state_ops.assign(v_scores, tf.concat([v_scores, rscores], axis=0),
331 | validate_shape=False)
332 | tp_op = state_ops.assign(v_tp, tf.concat([v_tp, tp_tensor], axis=0),
333 | validate_shape=False)
334 | fp_op = state_ops.assign(v_fp, tf.concat([v_fp, fp_tensor], axis=0),
335 | validate_shape=False)
336 |
337 | # Precision and recall computations.
338 | # r = _precision_recall(nobjects_op, scores_op, tp_op, fp_op, 'value')
339 | r = _precision_recall(v_nobjects, v_ndetections, v_scores,
340 | v_tp, v_fp, 'value')
341 |
342 | with ops.control_dependencies([nobjects_op, ndetections_op,
343 | scores_op, tp_op, fp_op]):
344 | update_op = _precision_recall(nobjects_op, ndetections_op,
345 | scores_op, tp_op, fp_op, 'update_op')
346 |
347 | # update_op = tf.Print(update_op,
348 | # [tf.reduce_sum(tf.cast(mask, tf.int64)),
349 | # tf.reduce_sum(tf.cast(mask2, tf.int64)),
350 | # tf.reduce_min(rscores),
351 | # tf.reduce_sum(n_gbboxes)],
352 | # 'Metric: ')
353 | # Some debugging stuff!
354 | # update_op = tf.Print(update_op,
355 | # [tf.shape(tp_op),
356 | # tf.reduce_sum(tf.cast(tp_op, tf.int64), axis=0)],
357 | # 'TP and FP shape: ')
358 | # update_op[0] = tf.Print(update_op,
359 | # [nobjects_op],
360 | # '# Groundtruth bboxes: ')
361 | # update_op = tf.Print(update_op,
362 | # [update_op[0][0],
363 | # update_op[0][-1],
364 | # tf.reduce_min(update_op[0]),
365 | # tf.reduce_max(update_op[0]),
366 | # tf.reduce_min(update_op[1]),
367 | # tf.reduce_max(update_op[1])],
368 | # 'Precision and recall :')
369 |
370 | if metrics_collections:
371 | ops.add_to_collections(metrics_collections, r)
372 | if updates_collections:
373 | ops.add_to_collections(updates_collections, update_op)
374 | return r, update_op
375 |
376 |
--------------------------------------------------------------------------------
/tf_extended/tensors.py:
--------------------------------------------------------------------------------
1 | """TF Extended: additional tensors operations.
2 | """
3 | import tensorflow as tf
4 |
5 | from tensorflow.contrib.framework.python.ops import variables as contrib_variables
6 | from tensorflow.contrib.metrics.python.ops import set_ops
7 | from tensorflow.python.framework import dtypes
8 | from tensorflow.python.framework import ops
9 | from tensorflow.python.framework import sparse_tensor
10 | from tensorflow.python.ops import array_ops
11 | from tensorflow.python.ops import check_ops
12 | from tensorflow.python.ops import control_flow_ops
13 | from tensorflow.python.ops import math_ops
14 | from tensorflow.python.ops import nn
15 | from tensorflow.python.ops import state_ops
16 | from tensorflow.python.ops import variable_scope
17 | from tensorflow.python.ops import variables
18 |
19 |
20 | def get_shape(x, rank=None):
21 | """Returns the dimensions of a Tensor as list of integers or scale tensors.
22 |
23 | Args:
24 | x: N-d Tensor;
25 | rank: Rank of the Tensor. If None, will try to guess it.
26 | Returns:
27 | A list of `[d1, d2, ..., dN]` corresponding to the dimensions of the
28 | input tensor. Dimensions that are statically known are python integers,
29 | otherwise they are integer scalar tensors.
30 | """
31 | if x.get_shape().is_fully_defined():
32 | return x.get_shape().as_list()
33 | else:
34 | static_shape = x.get_shape()
35 | if rank is None:
36 | static_shape = static_shape.as_list()
37 | rank = len(static_shape)
38 | else:
39 | static_shape = x.get_shape().with_rank(rank).as_list()
40 | dynamic_shape = tf.unstack(tf.shape(x), rank)
41 | return [s if s is not None else d
42 | for s, d in zip(static_shape, dynamic_shape)]
43 |
44 | def pad_axis(x, offset, size, axis=0, name=None):
45 | """Pad a tensor on an axis, with a given offset and output size.
46 | The tensor is padded with zero (i.e. CONSTANT mode). Note that the if the
47 | `size` is smaller than existing size + `offset`, the output tensor
48 | was the latter dimension.
49 |
50 | Args:
51 | x: Tensor to pad;
52 | offset: Offset to add on the dimension chosen;
53 | size: Final size of the dimension.
54 | Return:
55 | Padded tensor whose dimension on `axis` is `size`, or greater if
56 | the input vector was larger.
57 | """
58 | with tf.name_scope(name, 'pad_axis'):
59 | shape = get_shape(x)
60 | rank = len(shape)
61 | # Padding description.
62 | new_size = tf.maximum(size-offset-shape[axis], 0)
63 | pad1 = tf.stack([0]*axis + [offset] + [0]*(rank-axis-1))
64 | pad2 = tf.stack([0]*axis + [new_size] + [0]*(rank-axis-1))
65 | paddings = tf.stack([pad1, pad2], axis=1)
66 | x = tf.pad(x, paddings, mode='CONSTANT')
67 | # Reshape, to get fully defined shape if possible.
68 | # TODO: fix with tf.slice
69 | shape[axis] = size
70 | x = tf.reshape(x, tf.stack(shape))
71 | return x
72 |
73 |
74 |
--------------------------------------------------------------------------------
/tf_utils.py:
--------------------------------------------------------------------------------
1 | """Diverse TensorFlow utils, for training, evaluation and so on!
2 | """
3 | import os
4 | from pprint import pprint
5 | import tf_extended as tfe
6 |
7 | import tensorflow as tf
8 | from tensorflow.contrib.slim.python.slim.data import parallel_reader
9 |
10 | slim = tf.contrib.slim
11 |
12 |
13 | # =========================================================================== #
14 | # General tools.
15 | # =========================================================================== #
16 | def reshape_list(l, shape=None):
17 | """Reshape list of (list): 1D to 2D or the other way around.
18 |
19 | Args:
20 | l: List or List of list.
21 | shape: 1D or 2D shape.
22 | Return
23 | Reshaped list.
24 | """
25 | r = []
26 | if shape is None:
27 | # Flatten everything.
28 | for a in l:
29 | if isinstance(a, (list, tuple)):
30 | r = r + list(a)
31 | else:
32 | r.append(a)
33 | else:
34 | # Reshape to list of list.
35 | i = 0
36 | for s in shape:
37 | if s == 1:
38 | r.append(l[i])
39 | else:
40 | r.append(l[i:i+s])
41 | i += s
42 | return r
43 |
44 |
45 | # =========================================================================== #
46 | # Training utils.
47 | # =========================================================================== #
48 | def print_configuration(flags, aher_params, data_sources, save_dir=None):
49 | """Print the training configuration.
50 | """
51 | def print_config(stream=None):
52 | print('\n# =========================================================================== #', file=stream)
53 | print('# Training | Evaluation flags:', file=stream)
54 | print('# =========================================================================== #', file=stream)
55 | pprint(flags, stream=stream)
56 |
57 | print('\n# =========================================================================== #', file=stream)
58 | print('# AHER net parameters:', file=stream)
59 | print('# =========================================================================== #', file=stream)
60 | pprint(dict(aher_params._asdict()), stream=stream)
61 |
62 | print('\n# =========================================================================== #', file=stream)
63 | print('# Training | Evaluation dataset files:', file=stream)
64 | print('# =========================================================================== #', file=stream)
65 | data_files = parallel_reader.get_data_files(data_sources)
66 | pprint(sorted(data_files), stream=stream)
67 | print('', file=stream)
68 |
69 | print_config(None)
70 | # Save to a text file as well.
71 | if save_dir is not None:
72 | if not os.path.exists(save_dir):
73 | os.makedirs(save_dir)
74 | path = os.path.join(save_dir, 'training_config.txt')
75 | with open(path, "w") as out:
76 | print_config(out)
77 |
78 | def configure_learning_rate(flags, num_samples_per_epoch, global_step):
79 | """Configures the learning rate.
80 |
81 | Args:
82 | num_samples_per_epoch: The number of samples in each epoch of training.
83 | global_step: The global_step tensor.
84 | Returns:
85 | A `Tensor` representing the learning rate.
86 | """
87 | decay_steps = int(num_samples_per_epoch / flags.batch_size *
88 | flags.num_epochs_per_decay)
89 |
90 | if flags.learning_rate_decay_type == 'exponential':
91 | return tf.train.exponential_decay(flags.learning_rate,
92 | global_step,
93 | decay_steps,
94 | flags.learning_rate_decay_factor,
95 | staircase=True,
96 | name='exponential_decay_learning_rate')
97 | elif flags.learning_rate_decay_type == 'fixed':
98 | return tf.constant(flags.learning_rate, name='fixed_learning_rate')
99 | elif flags.learning_rate_decay_type == 'polynomial':
100 | return tf.train.polynomial_decay(flags.learning_rate,
101 | global_step,
102 | decay_steps,
103 | flags.end_learning_rate,
104 | power=1.0,
105 | cycle=False,
106 | name='polynomial_decay_learning_rate')
107 | else:
108 | raise ValueError('learning_rate_decay_type [%s] was not recognized',
109 | flags.learning_rate_decay_type)
110 |
111 | def configure_optimizer(flags, learning_rate):
112 | """Configures the optimizer used for training.
113 |
114 | Args:
115 | learning_rate: A scalar or `Tensor` learning rate.
116 | Returns:
117 | An instance of an optimizer.
118 | """
119 | if flags.optimizer == 'adadelta':
120 | optimizer = tf.train.AdadeltaOptimizer(
121 | learning_rate,
122 | rho=flags.adadelta_rho,
123 | epsilon=flags.opt_epsilon)
124 | elif flags.optimizer == 'adagrad':
125 | optimizer = tf.train.AdagradOptimizer(
126 | learning_rate,
127 | initial_accumulator_value=flags.adagrad_initial_accumulator_value)
128 | elif flags.optimizer == 'adam':
129 | optimizer = tf.train.AdamOptimizer(
130 | learning_rate,
131 | beta1=flags.adam_beta1,
132 | beta2=flags.adam_beta2,
133 | epsilon=flags.opt_epsilon)
134 | elif flags.optimizer == 'ftrl':
135 | optimizer = tf.train.FtrlOptimizer(
136 | learning_rate,
137 | learning_rate_power=flags.ftrl_learning_rate_power,
138 | initial_accumulator_value=flags.ftrl_initial_accumulator_value,
139 | l1_regularization_strength=flags.ftrl_l1,
140 | l2_regularization_strength=flags.ftrl_l2)
141 | elif flags.optimizer == 'momentum':
142 | optimizer = tf.train.MomentumOptimizer(
143 | learning_rate,
144 | momentum=flags.momentum,
145 | name='Momentum')
146 | elif flags.optimizer == 'rmsprop':
147 | optimizer = tf.train.RMSPropOptimizer(
148 | learning_rate,
149 | decay=flags.rmsprop_decay,
150 | momentum=flags.rmsprop_momentum,
151 | epsilon=flags.opt_epsilon)
152 | elif flags.optimizer == 'sgd':
153 | optimizer = tf.train.GradientDescentOptimizer(learning_rate)
154 | else:
155 | raise ValueError('Optimizer [%s] was not recognized', flags.optimizer)
156 | return optimizer
157 |
158 | def add_variables_summaries(learning_rate):
159 | summaries = []
160 | for variable in slim.get_model_variables():
161 | summaries.append(tf.summary.histogram(variable.op.name, variable))
162 | summaries.append(tf.summary.scalar('training/Learning Rate', learning_rate))
163 | return summaries
164 |
165 | def update_model_scope(var, ckpt_scope, new_scope):
166 | return var.op.name.replace(new_scope,'vgg_16')
167 |
168 | def get_init_fn(flags):
169 | """Returns a function run by the chief worker to warm-start the training.
170 | Note that the init_fn is only run when initializing the model during the very
171 | first global step.
172 |
173 | Returns:
174 | An init function run by the supervisor.
175 | """
176 | if flags.checkpoint_path is None:
177 | return None
178 | # Warn the user if a checkpoint exists in the train_dir. Then ignore.
179 | if tf.train.latest_checkpoint(flags.train_dir):
180 | tf.logging.info(
181 | 'Ignoring --checkpoint_path because a checkpoint already exists in %s'
182 | % flags.train_dir)
183 | return None
184 |
185 | exclusions = []
186 | if flags.checkpoint_exclude_scopes:
187 | exclusions = [scope.strip()
188 | for scope in flags.checkpoint_exclude_scopes.split(',')]
189 |
190 | # TODO(sguada) variables.filter_variables()
191 | variables_to_restore = []
192 | for var in slim.get_model_variables():
193 | excluded = False
194 | for exclusion in exclusions:
195 | if var.op.name.startswith(exclusion):
196 | excluded = True
197 | break
198 | if not excluded:
199 | variables_to_restore.append(var)
200 | # Change model scope if necessary.
201 | if flags.checkpoint_model_scope is not None:
202 | variables_to_restore = \
203 | {var.op.name.replace(flags.model_name,
204 | flags.checkpoint_model_scope): var
205 | for var in variables_to_restore}
206 |
207 |
208 | if tf.gfile.IsDirectory(flags.checkpoint_path):
209 | checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path)
210 | else:
211 | checkpoint_path = flags.checkpoint_path
212 | tf.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' % (checkpoint_path, flags.ignore_missing_vars))
213 |
214 | return slim.assign_from_checkpoint_fn(
215 | checkpoint_path,
216 | variables_to_restore,
217 | ignore_missing_vars=flags.ignore_missing_vars)
218 |
219 | def get_variables_to_train(flags):
220 | """Returns a list of variables to train.
221 |
222 | Returns:
223 | A list of variables to train by the optimizer.
224 | """
225 | if flags.trainable_scopes is None:
226 | return tf.trainable_variables()
227 | else:
228 | scopes = [scope.strip() for scope in flags.trainable_scopes.split(',')]
229 |
230 | variables_to_train = []
231 | for scope in scopes:
232 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
233 | variables_to_train.extend(variables)
234 | return variables_to_train
235 |
236 | def sigmoid_cross_entropy_with_logits(x, y):
237 | try:
238 | return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)
239 | except:
240 | return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y)
241 |
242 |
243 | # =========================================================================== #
244 | # depth-wise convolution
245 | # =========================================================================== #
246 |
247 | # The original 1d input data format: NWC
248 | # reshape input format as NHWC for depthwise 2d conv
249 | def depthwise_conv1d(input, k_h=1, k_w=3, channel_multiplier= 1, strides=1,
250 | padding='SAME', stddev=0.02, name='depthwise_conv1d', bias=True, weight_decay=0.0001):
251 | lshape = tfe.get_shape(input, 3)
252 | input = tf.reshape(input,[lshape[0],1,lshape[1],lshape[2]])
253 | with tf.variable_scope(name):
254 | in_channel=input.get_shape().as_list()[-1]
255 | w = tf.get_variable('w', [k_h, k_w, in_channel, channel_multiplier],
256 | regularizer=tf.contrib.layers.l2_regularizer(weight_decay),
257 | initializer=tf.truncated_normal_initializer(stddev=stddev))
258 | conv = tf.nn.depthwise_conv2d(input, w, [1,strides,strides,1], padding, rate=None,name=None,data_format=None)
259 | if bias:
260 | biases = tf.get_variable('bias', [in_channel*channel_multiplier], initializer=tf.constant_initializer(0.0))
261 | conv = tf.nn.bias_add(conv, biases)
262 | cshape = tfe.get_shape(conv,4)
263 | conv = tf.reshape(conv,[cshape[0],cshape[2],cshape[3]]) # convert to the original 1d data
264 | return conv
265 |
266 |
267 |
268 |
269 |
--------------------------------------------------------------------------------