├── Examples ├── 00110.png ├── 03deb7ad95.mp4 └── readme.txt ├── Output └── readme.txt ├── README.md ├── caps_layers_cond.py ├── caps_main.py ├── caps_network_test.py ├── caps_network_train.py ├── config.py ├── inference.py ├── load_youtube_data_multi.py ├── load_youtubevalid_data.py ├── network_parts ├── lstm_capsnet_cond2_test.py └── lstm_capsnet_cond2_train.py ├── network_saves └── readme.txt └── network_saves_best └── readme.txt /Examples/00110.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KevinDuarte/CapsuleVOS/0b390da90291b3f5527e444cb7fe943d319e6952/Examples/00110.png -------------------------------------------------------------------------------- /Examples/03deb7ad95.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KevinDuarte/CapsuleVOS/0b390da90291b3f5527e444cb7fe943d319e6952/Examples/03deb7ad95.mp4 -------------------------------------------------------------------------------- /Examples/readme.txt: -------------------------------------------------------------------------------- 1 | ## Include the .mp4 files which will be used in inference here 2 | -------------------------------------------------------------------------------- /Output/readme.txt: -------------------------------------------------------------------------------- 1 | This folder will hold the inference frames. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CapsuleVOS 2 | 3 | This is the code for the ICCV 2019 paper CapsuleVOS: Semi-Supervised Video Object Segmentation Using Capsule Routing. 4 | 5 | Arxiv Link: https://arxiv.org/abs/1910.00132 6 | 7 | The network is implemented using TensorFlow 1.4.1. 8 | 9 | Python packages used: numpy, scipy, scikit-video 10 | 11 | ## Files and their use 12 | 13 | 1. caps_layers_cod.py: Contains the functions required to construct capsule layers - (primary, convolutional, and fully-connected, and conditional capsule routing). 14 | 2. caps_network_train.py: Contains the CapsuleVOS model for training. 15 | 3. caps_network_test.py: Contains the CapsuleVOS model for testing. 16 | 4. caps_main.py: Contains the main function, which is called to train the network. 17 | 5. config.py: Contains several different hyperparameters used for the network, training, or inference. 18 | 6. inference.py: Contains the inference code. 19 | 7. load_youtube_data_multi.py: Contains the training data-generator for YoutubeVOS 2018 dataset. 20 | 8. load_youtubevalid_data.py: Contains the validation data-generator for YoutubeVOS 2018 dataset. 21 | 22 | ## Data Used 23 | 24 | We have supplied the code for training and inference of the model on the YoutubeVOS-2018 dataset. The file load_youtube_data_multi.py and load_youtubevalid_data.py creates two DataLoaders - one for training and one for validation. The data_loc variable at the top of each file should be set to the base directory which contains the frames and annotations. 25 | 26 | To run this code, you need to do the following: 27 | 1. Download the YoutubeVOS dataset 28 | 2. Perform interpolation for the training frames following the papers' instructions 29 | 30 | ## Training the Model 31 | 32 | Once the data is set up you can train (and test) the network by calling python3 caps_main.py. 33 | 34 | The config.py file contains several hyper-parameters which are useful for training the network. 35 | 36 | ## Output File 37 | 38 | During training and testing, metrics are printed to stdout as well as an output*.txt file. During training/validation, the losses and accuracies are printed out to the terminal and to an output file. 39 | 40 | ## Saved Weights 41 | 42 | Pretrained weights for the network are available [here](https://drive.google.com/open?id=12zvvqd5i3EVNzPF-hEfq_hi2CEzRRSjS). To use them for inference, place them in the network_saves_best folder. 43 | 44 | ## Inference 45 | 46 | If you just want to test the trained model with the weights above, run the inference code by calling python3 inference.py. This code will read in an .mp4 file and a reference segmentation mask, and output the segmented frames of the video to the Output folder. 47 | 48 | An example video is available in the Example folder. 49 | 50 | -------------------------------------------------------------------------------- /caps_layers_cond.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from functools import reduce 4 | import config 5 | 6 | epsilon = 1e-7 7 | 8 | 9 | # Standard EM Routing 10 | def em_routing(v, a_i, beta_v, beta_a, n_iterations=3): 11 | batch_size = tf.shape(v)[0] 12 | _, _, n_caps_j, mat_len = v.get_shape().as_list() 13 | n_caps_j, mat_len = map(int, [n_caps_j, mat_len]) 14 | n_caps_i = tf.shape(v)[1] 15 | 16 | a_i = tf.expand_dims(a_i, axis=-1) 17 | 18 | # Prior probabilities for routing 19 | r = tf.ones(shape=(batch_size, n_caps_i, n_caps_j, 1), dtype=tf.float32)/float(n_caps_j) 20 | r = tf.multiply(r, a_i) 21 | 22 | den = tf.reduce_sum(r, axis=1, keep_dims=True) + epsilon 23 | 24 | # Mean: shape=(N, 1, Ch_j, mat_len) 25 | m_num = tf.reduce_sum(v*r, axis=1, keep_dims=True) 26 | m = m_num/(den + epsilon) 27 | 28 | # Stddev: shape=(N, 1, Ch_j, mat_len) 29 | s_num = tf.reduce_sum(r * tf.square(v - m), axis=1, keep_dims=True) 30 | s = s_num/(den + epsilon) 31 | 32 | # cost_h: shape=(N, 1, Ch_j, mat_len) 33 | cost = (beta_v + tf.log(tf.sqrt(s + epsilon) + epsilon)) * den 34 | # cost_h: shape=(N, 1, Ch_j, 1) 35 | cost = tf.reduce_sum(cost, axis=-1, keep_dims=True) 36 | 37 | # calculates the mean and std_deviation of the cost, used for numerical stability 38 | cost_mean = tf.reduce_mean(cost, axis=-2, keep_dims=True) 39 | cost_stdv = tf.sqrt( 40 | tf.reduce_sum( 41 | tf.square(cost - cost_mean), axis=-2, keep_dims=True 42 | ) / n_caps_j + epsilon 43 | ) 44 | 45 | # calculates the activations for the capsules in layer j 46 | a_j = tf.sigmoid(float(config.inv_temp) * (beta_a + (cost_mean - cost) / (cost_stdv + epsilon))) 47 | 48 | def condition(mean, stdsqr, act_j, r_temp, counter): 49 | return tf.less(counter, n_iterations) 50 | 51 | def route(mean, stdsqr, act_j, r_temp, counter): 52 | exp = tf.reduce_sum(tf.square(v - mean) / (2 * stdsqr + epsilon), axis=-1) 53 | coef = 0 - .5 * tf.reduce_sum(tf.log(2 * np.pi * stdsqr + epsilon), axis=-1) 54 | log_p_j = coef - exp 55 | 56 | log_ap = tf.reshape(tf.log(act_j + epsilon), (batch_size, 1, n_caps_j)) + log_p_j 57 | r_ij = tf.nn.softmax(log_ap + epsilon) # ap / (tf.reduce_sum(ap, axis=-1, keep_dims=True) + epsilon) 58 | 59 | r_ij = tf.multiply(tf.expand_dims(r_ij, axis=-1), a_i) 60 | 61 | denom = tf.reduce_sum(r_ij, axis=1, keep_dims=True) + epsilon 62 | m_numer = tf.reduce_sum(v * r_ij, axis=1, keep_dims=True) 63 | mean = m_numer / (denom + epsilon) 64 | 65 | s_numer = tf.reduce_sum(r_ij * tf.square(v - mean), axis=1, keep_dims=True) 66 | stdsqr = s_numer / (denom + epsilon) 67 | 68 | cost_h = (beta_v + tf.log(tf.sqrt(stdsqr) + epsilon)) * denom 69 | 70 | cost_h = tf.reduce_sum(cost_h, axis=-1, keep_dims=True) 71 | cost_h_mean = tf.reduce_mean(cost_h, axis=-2, keep_dims=True) 72 | cost_h_stdv = tf.sqrt( 73 | tf.reduce_sum( 74 | tf.square(cost_h - cost_h_mean), axis=-2, keep_dims=True 75 | ) / n_caps_j 76 | ) 77 | 78 | inv_temp = config.inv_temp + counter * config.inv_temp_delta 79 | act_j = tf.sigmoid(inv_temp * (beta_a + (cost_h_mean - cost_h) / (cost_h_stdv + epsilon))) 80 | 81 | return mean, stdsqr, act_j, r_ij, tf.add(counter, 1) 82 | 83 | [mean, _, act_j, r_new, _] = tf.while_loop(condition, route, [m, s, a_j, r, 1.0]) 84 | 85 | return tf.reshape(mean, (batch_size, n_caps_j, mat_len)), tf.reshape(act_j, (batch_size, n_caps_j, 1)) 86 | 87 | 88 | # Attention Routing layer 89 | def em_routing_cond(v_v1, v_v2, a_i_v, v_f, a_i_f, beta_v_v, beta_a_v, beta_v_f, beta_a_f, n_iterations=3): 90 | batch_size = tf.shape(v_f)[0] 91 | _, _, n_caps_j, mat_len = v_f.get_shape().as_list() 92 | n_caps_j, mat_len = map(int, [n_caps_j, mat_len]) 93 | n_caps_i_f = tf.shape(v_f)[1] 94 | 95 | a_i_f = tf.expand_dims(a_i_f, axis=-1) 96 | 97 | # Prior probabilities for routing 98 | r_f = tf.ones(shape=(batch_size, n_caps_i_f, n_caps_j, 1), dtype=tf.float32)/float(n_caps_j) 99 | r_f = tf.multiply(r_f, a_i_f) 100 | 101 | den_f = tf.reduce_sum(r_f, axis=1, keep_dims=True) + epsilon 102 | 103 | m_num_f = tf.reduce_sum(v_f*r_f, axis=1, keep_dims=True) # Mean: shape=(N, 1, Ch_j, mat_len) 104 | m_f = m_num_f/(den_f + epsilon) 105 | 106 | s_num_f = tf.reduce_sum(r_f * tf.square(v_f - m_f), axis=1, keep_dims=True) # Stddev: shape=(N, 1, Ch_j, mat_len) 107 | s_f = s_num_f/(den_f + epsilon) 108 | 109 | cost_f = (beta_v_f + tf.log(tf.sqrt(s_f + epsilon) + epsilon)) * den_f # cost_h: shape=(N, 1, Ch_j, mat_len) 110 | cost_f = tf.reduce_sum(cost_f, axis=-1, keep_dims=True) # cost_h: shape=(N, 1, Ch_j, 1) 111 | 112 | # calculates the mean and std_deviation of the cost 113 | cost_mean_f = tf.reduce_mean(cost_f, axis=-2, keep_dims=True) 114 | cost_stdv_f = tf.sqrt( 115 | tf.reduce_sum( 116 | tf.square(cost_f - cost_mean_f), axis=-2, keep_dims=True 117 | ) / n_caps_j + epsilon 118 | ) 119 | 120 | # calculates the activations for the capsules in layer j for the frame capsules 121 | a_j_f = tf.sigmoid(float(config.inv_temp) * (beta_a_f + (cost_mean_f - cost_f) / (cost_stdv_f + epsilon))) 122 | 123 | def condition(mean_f, stdsqr_f, act_j_f, counter): 124 | return tf.less(counter, n_iterations) 125 | 126 | def route(mean_f, stdsqr_f, act_j_f, counter): 127 | # Performs E-step for frames 128 | exp_f = tf.reduce_sum(tf.square(v_f - mean_f) / (2 * stdsqr_f + epsilon), axis=-1) 129 | coef_f = 0 - .5 * tf.reduce_sum(tf.log(2 * np.pi * stdsqr_f + epsilon), axis=-1) 130 | log_p_j_f = coef_f - exp_f 131 | 132 | log_ap_f = tf.reshape(tf.log(act_j_f + epsilon), (batch_size, 1, n_caps_j)) + log_p_j_f 133 | r_ij_f = tf.nn.softmax(log_ap_f + epsilon) 134 | 135 | # Performs M-step for frames 136 | r_ij_f = tf.multiply(tf.expand_dims(r_ij_f, axis=-1), a_i_f) 137 | 138 | denom_f = tf.reduce_sum(r_ij_f, axis=1, keep_dims=True) + epsilon 139 | m_numer_f = tf.reduce_sum(v_f * r_ij_f, axis=1, keep_dims=True) 140 | mean_f = m_numer_f / (denom_f + epsilon) 141 | 142 | s_numer_f = tf.reduce_sum(r_ij_f * tf.square(v_f - mean_f), axis=1, keep_dims=True) 143 | stdsqr_f = s_numer_f / (denom_f + epsilon) 144 | 145 | cost_h_f = (beta_v_f + tf.log(tf.sqrt(stdsqr_f + epsilon) + epsilon)) * denom_f 146 | 147 | cost_h_f = tf.reduce_sum(cost_h_f, axis=-1, keep_dims=True) 148 | cost_h_mean_f = tf.reduce_mean(cost_h_f, axis=-2, keep_dims=True) 149 | cost_h_stdv_f = tf.sqrt( 150 | tf.reduce_sum( 151 | tf.square(cost_h_f - cost_h_mean_f), axis=-2, keep_dims=True 152 | ) / n_caps_j + epsilon 153 | ) 154 | 155 | inv_temp = config.inv_temp + counter * config.inv_temp_delta 156 | act_j_f = tf.sigmoid(inv_temp * (beta_a_f + (cost_h_mean_f - cost_h_f) / (cost_h_stdv_f + epsilon))) 157 | 158 | return mean_f, stdsqr_f, act_j_f, tf.add(counter, 1) 159 | 160 | [mean_f_fin, _, act_j_f_fin, _] = tf.while_loop(condition, route, [m_f, s_f, a_j_f, 1.0]) 161 | 162 | # performs m step for the video capsules 163 | a_i_v = tf.expand_dims(a_i_v, axis=-1) 164 | 165 | dist_v = tf.reduce_sum(tf.square(v_v1 - mean_f_fin), axis=-1) 166 | r_v = tf.expand_dims(tf.nn.softmax(0 - dist_v), axis=-1) * a_i_v 167 | 168 | den_v = tf.reduce_sum(r_v, axis=1, keep_dims=True) + epsilon 169 | 170 | m_num_v = tf.reduce_sum(v_v2 * r_v, axis=1, keep_dims=True) # Mean: shape=(N, 1, Ch_j, mat_len) 171 | m_v = m_num_v / (den_v + epsilon) 172 | 173 | s_num_v = tf.reduce_sum(r_v * tf.square(v_v2 - m_v), axis=1, keep_dims=True) # Stddev: shape=(N, 1, Ch_j, mat_len) 174 | s_v = s_num_v / (den_v + epsilon) 175 | 176 | cost_v = (beta_v_v + tf.log(tf.sqrt(s_v + epsilon) + epsilon)) * den_v # cost_h: shape=(N, 1, Ch_j, mat_len) 177 | cost_v = tf.reduce_sum(cost_v, axis=-1, keep_dims=True) # cost_h: shape=(N, 1, Ch_j, 1) 178 | 179 | # calculates the mean and std_deviation of the cost 180 | cost_mean_v = tf.reduce_mean(cost_v, axis=-2, keep_dims=True) 181 | cost_stdv_v = tf.sqrt( 182 | tf.reduce_sum( 183 | tf.square(cost_v - cost_mean_v), axis=-2, keep_dims=True 184 | ) / n_caps_j + epsilon 185 | ) 186 | 187 | # calculates the activations for the capsules in layer j for the frame capsules 188 | a_j_v = tf.sigmoid(float(config.inv_temp) * (beta_a_v + (cost_mean_v - cost_v) / (cost_stdv_v + epsilon))) 189 | 190 | return (tf.reshape(m_v, (batch_size, n_caps_j, mat_len)), tf.reshape(a_j_v, (batch_size, n_caps_j, 1))), (tf.reshape(mean_f_fin, (batch_size, n_caps_j, mat_len)), tf.reshape(act_j_f_fin, (batch_size, n_caps_j, 1))) 191 | 192 | 193 | def create_prim_conv3d_caps(inputs, channels, kernel_size, strides, name, padding='VALID', activation=None, mdim=4): 194 | mdim2 = mdim*mdim 195 | batch_size = tf.shape(inputs)[0] 196 | poses = tf.layers.conv3d(inputs=inputs, filters=channels * mdim2, kernel_size=kernel_size, 197 | strides=strides, padding=padding, activation=activation, name=name+'_pose') 198 | 199 | _, d, h, w, _ = poses.get_shape().as_list() 200 | d, h, w = map(int, [d, h, w]) 201 | 202 | pose = tf.reshape(poses, (batch_size, d, h, w, channels, mdim2), name=name+'_pose_res') 203 | #pose = tf.nn.l2_normalize(pose, dim=-1) 204 | 205 | acts = tf.layers.conv3d(inputs=inputs, filters=channels, kernel_size=kernel_size, 206 | strides=strides, padding=padding, activation=tf.nn.sigmoid, name=name+'_act') 207 | activation = tf.reshape(acts, (batch_size, d, h, w, channels, 1), name=name+'_act_res') 208 | 209 | return pose, activation 210 | 211 | 212 | def create_coords_mat(pose, rel_center, mdim=4): 213 | """ 214 | 215 | :param pose: the incoming map of pose matrices, shape (N, ..., Ch_i, 16) where ... is the dimensions of the map can 216 | be 1, 2 or 3 dimensional. 217 | :param rel_center: whether or not the coordinates are relative to the center of the map 218 | :return: returns the coordinates (padded to 16) fir the incoming capsules 219 | """ 220 | batch_size = tf.shape(pose)[0] 221 | shape_list = [int(x) for x in pose.get_shape().as_list()[1:-2]] 222 | ch = int(pose.get_shape().as_list()[-2]) 223 | n_dims = len(shape_list) 224 | 225 | if n_dims == 3: 226 | d, h, w = shape_list 227 | elif n_dims == 2: 228 | d = 1 229 | h, w = shape_list 230 | else: 231 | d, h = 1, 1 232 | w = shape_list[0] 233 | 234 | subs = [0, 0, 0] 235 | if rel_center: 236 | subs = [int(d / 2), int(h / 2), int(w / 2)] 237 | 238 | c_mats = [] 239 | if n_dims >= 3: 240 | c_mats.append(tf.tile(tf.reshape(tf.range(d, dtype=tf.float32), (1, d, 1, 1, 1, 1)), [batch_size, 1, h, w, ch, 1])-subs[0]) 241 | if n_dims >= 2: 242 | c_mats.append(tf.tile(tf.reshape(tf.range(h, dtype=tf.float32), (1, 1, h, 1, 1, 1)), [batch_size, d, 1, w, ch, 1])-subs[1]) 243 | if n_dims >= 1: 244 | c_mats.append(tf.tile(tf.reshape(tf.range(w, dtype=tf.float32), (1, 1, 1, w, 1, 1)), [batch_size, d, h, 1, ch, 1])-subs[2]) 245 | add_coords = tf.concat(c_mats, axis=-1) 246 | add_coords = tf.cast(tf.reshape(add_coords, (batch_size, d*h*w, ch, n_dims)), dtype=tf.float32) 247 | 248 | mdim2 = mdim*mdim 249 | zeros = tf.zeros((batch_size, d*h*w, ch, mdim2-n_dims)) 250 | 251 | return tf.concat([zeros, add_coords], axis=-1) 252 | 253 | 254 | def create_dense_caps(inputs, n_caps_j, name, route_min=0.0, coord_add=False, rel_center=False, 255 | ch_same_w=True, mdim=4): 256 | """ 257 | 258 | :param inputs: The input capsule layer. Shape ((N, K, Ch_i, 16), (N, K, Ch_i, 1)) or 259 | ((N, ..., Ch_i, 16), (N, ..., Ch_i, 1)) where K is the number of capsules per channel and '...' is if you are 260 | inputting an map of capsules like W or HxW or DxHxW. 261 | :param n_caps_j: The number of capsules in the following layer 262 | :param name: name of the capsule layer 263 | :param route_min: A threshold activation to route 264 | :param coord_add: A boolean, whether or not to to do coordinate addition 265 | :param rel_center: A boolean, whether or not the coordinate addition is relative to the center 266 | :param routing_type: The type of routing 267 | :return: Returns a layer of capsules. Shape ((N, n_caps_j, 16), (N, n_caps_j, 1)) 268 | """ 269 | mdim2 = mdim*mdim 270 | pose, activation = inputs 271 | batch_size = tf.shape(pose)[0] 272 | shape_list = [int(x) for x in pose.get_shape().as_list()[1:]] 273 | ch = int(shape_list[-2]) 274 | n_capsch_i = 1 if len(shape_list) == 2 else reduce((lambda x, y: x * y), shape_list[:-2]) 275 | 276 | u_i = tf.reshape(pose, (batch_size, n_capsch_i, ch, mdim2)) 277 | activation = tf.reshape(activation, (batch_size, n_capsch_i, ch, 1)) 278 | coords = create_coords_mat(pose, rel_center) if coord_add else tf.zeros_like(u_i) 279 | 280 | # reshapes the input capsules 281 | u_i = tf.reshape(u_i, (batch_size, n_capsch_i, ch, mdim, mdim)) 282 | u_i = tf.expand_dims(u_i, axis=-3) 283 | u_i = tf.tile(u_i, [1, 1, 1, n_caps_j, 1, 1]) 284 | 285 | if ch_same_w: 286 | weights = tf.get_variable(name=name + '_weights', shape=(ch, n_caps_j, mdim, mdim), 287 | initializer=tf.initializers.random_normal(stddev=0.1), 288 | regularizer=tf.contrib.layers.l2_regularizer(0.1)) 289 | 290 | votes = tf.einsum('ijab,ntijbc->ntijac', weights, u_i) 291 | votes = tf.reshape(votes, (batch_size, n_capsch_i * ch, n_caps_j, mdim2), name=name+'_votes') 292 | else: 293 | weights = tf.get_variable(name=name + '_weights', shape=(n_capsch_i, ch, n_caps_j, mdim, mdim), 294 | initializer=tf.initializers.random_normal(stddev=0.1), 295 | regularizer=tf.contrib.layers.l2_regularizer(0.1)) 296 | votes = tf.einsum('tijab,ntijbc->ntijac', weights, u_i) 297 | votes = tf.reshape(votes, (batch_size, n_capsch_i * ch, n_caps_j, mdim2), name=name+'_votes') 298 | 299 | if coord_add: 300 | coords = tf.reshape(coords, (batch_size, n_capsch_i * ch, 1, mdim2)) 301 | votes = votes + tf.tile(coords, [1, 1, n_caps_j, 1]) 302 | 303 | acts = tf.reshape(activation, (batch_size, n_capsch_i * ch, 1)) 304 | activations = tf.where(tf.greater_equal(acts, tf.constant(route_min)), acts, tf.zeros_like(acts)) 305 | 306 | beta_v = tf.get_variable(name=name + '_beta_v', shape=(n_caps_j, mdim2), 307 | initializer=tf.initializers.random_normal(stddev=0.1), 308 | regularizer=tf.contrib.layers.l2_regularizer(0.1)) 309 | 310 | beta_a = tf.get_variable(name=name + '_beta_a', shape=(n_caps_j, 1), 311 | initializer=tf.initializers.random_normal(stddev=0.1), 312 | regularizer=tf.contrib.layers.l2_regularizer(0.1)) 313 | 314 | capsules = em_routing(votes, activations, beta_v, beta_a) 315 | 316 | return capsules 317 | 318 | 319 | def create_conv3d_caps(inputs, channels, kernel_size, strides, name, padding='VALID', 320 | coord_add=False, rel_center=True, route_mean=True, ch_same_w=True, mdim=4): 321 | mdim2 = mdim*mdim 322 | inputs = tf.concat(inputs, axis=-1) 323 | 324 | if padding == 'SAME': 325 | d_padding, h_padding, w_padding = int(float(kernel_size[0]) / 2), int(float(kernel_size[1]) / 2), int(float(kernel_size[2]) / 2) 326 | u_padded = tf.pad(inputs, [[0, 0], [d_padding, d_padding], [h_padding, h_padding], [w_padding, w_padding], [0, 0], [0, 0]]) 327 | else: 328 | u_padded = inputs 329 | 330 | batch_size = tf.shape(u_padded)[0] 331 | _, d, h, w, ch, _ = u_padded.get_shape().as_list() 332 | d, h, w, ch = map(int, [d, h, w, ch]) 333 | 334 | # gets indices for kernels 335 | d_offsets = [[(d_ + k) for k in range(kernel_size[0])] for d_ in range(0, d + 1 - kernel_size[0], strides[0])] 336 | h_offsets = [[(h_ + k) for k in range(kernel_size[1])] for h_ in range(0, h + 1 - kernel_size[1], strides[1])] 337 | w_offsets = [[(w_ + k) for k in range(kernel_size[2])] for w_ in range(0, w + 1 - kernel_size[2], strides[2])] 338 | 339 | # output dimensions 340 | d_out, h_out, w_out = len(d_offsets), len(h_offsets), len(w_offsets) 341 | 342 | # gathers the capsules into shape (N, D2, H2, W2, KD, KH, KW, Ch_in, 17) 343 | d_gathered = tf.gather(u_padded, d_offsets, axis=1) 344 | h_gathered = tf.gather(d_gathered, h_offsets, axis=3) 345 | w_gathered = tf.gather(h_gathered, w_offsets, axis=5) 346 | w_gathered = tf.transpose(w_gathered, [0, 1, 3, 5, 2, 4, 6, 7, 8]) 347 | 348 | if route_mean: 349 | kernels_reshaped = tf.reshape(w_gathered, [batch_size * d_out * h_out * w_out, kernel_size[0]* kernel_size[1]* kernel_size[2], ch, mdim2 + 1]) 350 | kernels_reshaped = tf.reduce_mean(kernels_reshaped, axis=1) 351 | capsules = create_dense_caps((kernels_reshaped[:, :, :-1], kernels_reshaped[:, :, -1:]), channels, name, 352 | ch_same_w=ch_same_w, mdim=mdim) 353 | else: 354 | kernels_reshaped = tf.reshape(w_gathered, [batch_size * d_out * h_out * w_out, kernel_size[0], kernel_size[1], kernel_size[2], ch, mdim2 + 1]) 355 | capsules = create_dense_caps((kernels_reshaped[:, :, :, :, :, :-1], kernels_reshaped[:, :, :, :, :, -1:]), 356 | channels, name, coord_add=coord_add, rel_center=rel_center, ch_same_w=ch_same_w, mdim=mdim) 357 | 358 | poses = tf.reshape(capsules[0][:, :, :mdim2], (batch_size, d_out, h_out, w_out, channels, mdim2), name=name+'_pose') 359 | activations = tf.reshape(capsules[1], (batch_size, d_out, h_out, w_out, channels, 1), name=name+'_act') 360 | 361 | return poses, activations 362 | 363 | 364 | def create_dense_caps_cond(inputs, n_caps_j, name, coord_add=False, rel_center=False, 365 | ch_same_w=True, mdim=4, n_cond_caps=0): 366 | """ 367 | 368 | :param inputs: The input capsule layer. Shape ((N, K, Ch_i, 16), (N, K, Ch_i, 1)) or 369 | ((N, ..., Ch_i, 16), (N, ..., Ch_i, 1)) where K is the number of capsules per channel and '...' is if you are 370 | inputting an map of capsules like W or HxW or DxHxW. 371 | :param n_caps_j: The number of capsules in the following layer 372 | :param name: name of the capsule layer 373 | :param coord_add: A boolean, whether or not to to do coordinate addition 374 | :param rel_center: A boolean, whether or not the coordinate addition is relative to the center 375 | :param routing_type: The type of routing 376 | :return: Returns a layer of capsules. Shape ((N, n_caps_j, 16), (N, n_caps_j, 1)) 377 | """ 378 | mdim2 = mdim*mdim 379 | pose, activation = inputs 380 | batch_size = tf.shape(pose)[0] 381 | shape_list = [int(x) for x in pose.get_shape().as_list()[1:]] 382 | ch = int(shape_list[-2]) 383 | n_capsch_i = 1 if len(shape_list) == 2 else reduce((lambda x, y: x * y), shape_list[:-2]) 384 | 385 | u_i = tf.reshape(pose, (batch_size, n_capsch_i, ch, mdim2)) 386 | activation = tf.reshape(activation, (batch_size, n_capsch_i, ch, 1)) 387 | coords = create_coords_mat(pose, rel_center) if coord_add else tf.zeros_like(u_i) 388 | 389 | 390 | # reshapes the input capsules 391 | u_i = tf.reshape(u_i, (batch_size, n_capsch_i, ch, mdim, mdim)) 392 | u_i = tf.expand_dims(u_i, axis=-3) 393 | u_i = tf.tile(u_i, [1, 1, 1, n_caps_j, 1, 1]) 394 | 395 | if ch_same_w: 396 | weights = tf.get_variable(name=name + '_weights', shape=(ch, n_caps_j, mdim, mdim), 397 | initializer=tf.initializers.random_normal(stddev=0.1), 398 | regularizer=tf.contrib.layers.l2_regularizer(0.1)) 399 | 400 | votes = tf.einsum('ijab,ntijbc->ntijac', weights, u_i) 401 | votes = tf.reshape(votes, (batch_size, n_capsch_i * ch, n_caps_j, mdim2), name=name+'_votes') 402 | else: 403 | weights = tf.get_variable(name=name + '_weights', shape=(n_capsch_i, ch, n_caps_j, mdim, mdim), 404 | initializer=tf.initializers.random_normal(stddev=0.1), 405 | regularizer=tf.contrib.layers.l2_regularizer(0.1)) 406 | votes = tf.einsum('tijab,ntijbc->ntijac', weights, u_i) 407 | votes = tf.reshape(votes, (batch_size, n_capsch_i * ch, n_caps_j, mdim2), name=name+'_votes') 408 | 409 | if coord_add: 410 | coords = tf.reshape(coords, (batch_size, n_capsch_i * ch, 1, mdim2)) 411 | votes = votes + tf.tile(coords, [1, 1, n_caps_j, 1]) 412 | 413 | 414 | if n_cond_caps == 0: 415 | beta_v = tf.get_variable(name=name + '_beta_v', shape=(n_caps_j, mdim2), 416 | initializer=tf.initializers.random_normal(stddev=0.1), 417 | regularizer=tf.contrib.layers.l2_regularizer(0.1)) 418 | beta_a = tf.get_variable(name=name + '_beta_a', shape=(n_caps_j, 1), 419 | initializer=tf.initializers.random_normal(stddev=0.1), 420 | regularizer=tf.contrib.layers.l2_regularizer(0.1)) 421 | 422 | acts = tf.reshape(activation, (batch_size, n_capsch_i * ch, 1)) 423 | 424 | capsules1 = em_routing(votes, acts, beta_v, beta_a) 425 | capsules2 = capsules1 426 | else: 427 | beta_v1 = tf.get_variable(name=name + '_beta_v1', shape=(n_caps_j, mdim2), 428 | initializer=tf.initializers.random_normal(stddev=0.1), 429 | regularizer=tf.contrib.layers.l2_regularizer(0.1)) 430 | beta_a1 = tf.get_variable(name=name + '_beta_a1', shape=(n_caps_j, 1), 431 | initializer=tf.initializers.random_normal(stddev=0.1), 432 | regularizer=tf.contrib.layers.l2_regularizer(0.1)) 433 | 434 | beta_v2 = tf.get_variable(name=name + '_beta_v2', shape=(n_caps_j, mdim2), 435 | initializer=tf.initializers.random_normal(stddev=0.1), 436 | regularizer=tf.contrib.layers.l2_regularizer(0.1)) 437 | beta_a2 = tf.get_variable(name=name + '_beta_a2', shape=(n_caps_j, 1), 438 | initializer=tf.initializers.random_normal(stddev=0.1), 439 | regularizer=tf.contrib.layers.l2_regularizer(0.1)) 440 | 441 | votes = tf.reshape(votes, (batch_size, n_capsch_i, ch, n_caps_j, mdim2)) 442 | 443 | votes1 = tf.reshape(votes[:, :, :ch - n_cond_caps], 444 | (batch_size, n_capsch_i * (ch - n_cond_caps), n_caps_j, mdim2)) 445 | votes2 = tf.reshape(votes[:, :, ch - n_cond_caps:], (batch_size, n_capsch_i * n_cond_caps, n_caps_j, mdim2)) 446 | 447 | acts = tf.reshape(activation, (batch_size, n_capsch_i, ch, 1)) 448 | 449 | acts1 = tf.reshape(acts[:, :, :ch - n_cond_caps], (batch_size, n_capsch_i * (ch - n_cond_caps), 1)) 450 | acts2 = tf.reshape(acts[:, :, ch - n_cond_caps:], (batch_size, n_capsch_i * n_cond_caps, 1)) 451 | 452 | weights_2 = tf.get_variable(name=name + '_weights_2', shape=(ch - n_cond_caps, n_caps_j, mdim, mdim), 453 | initializer=tf.initializers.random_normal(stddev=0.1), 454 | regularizer=tf.contrib.layers.l2_regularizer(0.1)) 455 | 456 | votes_2 = tf.einsum('ijab,ntijbc->ntijac', weights_2, u_i[:, :, :ch - n_cond_caps]) 457 | votes_2 = tf.reshape(votes_2, (batch_size, n_capsch_i * (ch - n_cond_caps), n_caps_j, mdim2), name=name + '_votes2') 458 | 459 | capsules1, capsules2 = em_routing_cond(votes1, votes_2, acts1, votes2, acts2, beta_v1, beta_a1, beta_v2, 460 | beta_a2) 461 | 462 | return capsules1, capsules2 463 | 464 | 465 | def create_conv3d_caps_cond(inputs, channels, kernel_size, strides, name, padding='VALID', 466 | coord_add=False, rel_center=True, route_mean=True, ch_same_w=True, mdim=4, n_cond_caps=0): 467 | mdim2 = mdim*mdim 468 | inputs = tf.concat(inputs, axis=-1) 469 | 470 | if padding == 'SAME': 471 | d_padding, h_padding, w_padding = int(float(kernel_size[0]) / 2), int(float(kernel_size[1]) / 2), int(float(kernel_size[2]) / 2) 472 | u_padded = tf.pad(inputs, [[0, 0], [d_padding, d_padding], [h_padding, h_padding], [w_padding, w_padding], [0, 0], [0, 0]]) 473 | else: 474 | u_padded = inputs 475 | 476 | batch_size = tf.shape(u_padded)[0] 477 | _, d, h, w, ch, _ = u_padded.get_shape().as_list() 478 | d, h, w, ch = map(int, [d, h, w, ch]) 479 | 480 | # gets indices for kernels 481 | d_offsets = [[(d_ + k) for k in range(kernel_size[0])] for d_ in range(0, d + 1 - kernel_size[0], strides[0])] 482 | h_offsets = [[(h_ + k) for k in range(kernel_size[1])] for h_ in range(0, h + 1 - kernel_size[1], strides[1])] 483 | w_offsets = [[(w_ + k) for k in range(kernel_size[2])] for w_ in range(0, w + 1 - kernel_size[2], strides[2])] 484 | 485 | # output dimensions 486 | d_out, h_out, w_out = len(d_offsets), len(h_offsets), len(w_offsets) 487 | 488 | # gathers the capsules into shape (N, D2, H2, W2, KD, KH, KW, Ch_in, 17) 489 | d_gathered = tf.gather(u_padded, d_offsets, axis=1) 490 | h_gathered = tf.gather(d_gathered, h_offsets, axis=3) 491 | w_gathered = tf.gather(h_gathered, w_offsets, axis=5) 492 | w_gathered = tf.transpose(w_gathered, [0, 1, 3, 5, 2, 4, 6, 7, 8]) 493 | 494 | if route_mean: 495 | kernels_reshaped = tf.reshape(w_gathered, [batch_size * d_out * h_out * w_out, kernel_size[0]* kernel_size[1]* kernel_size[2], ch, mdim2 + 1]) 496 | kernels_reshaped = tf.reduce_mean(kernels_reshaped, axis=1) 497 | capsules1, capsules2 = create_dense_caps_cond((kernels_reshaped[:, :, :-1], kernels_reshaped[:, :, -1:]), channels, name, 498 | ch_same_w=ch_same_w, mdim=mdim, n_cond_caps=n_cond_caps) 499 | else: 500 | kernels_reshaped = tf.reshape(w_gathered, [batch_size * d_out * h_out * w_out, kernel_size[0], kernel_size[1], kernel_size[2], ch, mdim2 + 1]) 501 | capsules1, capsules2 = create_dense_caps_cond((kernels_reshaped[:, :, :, :, :, :-1], kernels_reshaped[:, :, :, :, :, -1:]), 502 | channels, name, coord_add=coord_add, rel_center=rel_center, 503 | ch_same_w=ch_same_w, mdim=mdim, n_cond_caps=n_cond_caps) 504 | 505 | poses1 = tf.reshape(capsules1[0][:, :, :mdim2], (batch_size, d_out, h_out, w_out, channels, mdim2), name=name+'_pose1') 506 | activations1 = tf.reshape(capsules1[1], (batch_size, d_out, h_out, w_out, channels, 1), name=name+'_act1') 507 | 508 | poses2 = tf.reshape(capsules2[0][:, :, :mdim2], (batch_size, d_out, h_out, w_out, channels, mdim2), name=name + '_pose2') 509 | activations2 = tf.reshape(capsules2[1], (batch_size, d_out, h_out, w_out, channels, 1), name=name + '_act2') 510 | 511 | return (poses1, activations1), (poses2, activations2) 512 | 513 | 514 | def layer_shape(layer): 515 | return str(layer[0].get_shape()) + ' ' + str(layer[1].get_shape()) 516 | 517 | -------------------------------------------------------------------------------- /caps_main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import config 3 | from caps_network_train import CapsNet 4 | import sys 5 | from load_youtube_data_multi import YoutubeTrainDataGen as TrainDataGen 6 | from load_youtubevalid_data import YoutubeValidDataGen as ValidDataGen 7 | import numpy as np 8 | import time 9 | 10 | 11 | def get_num_params(): 12 | total_parameters = 0 13 | for variable in tf.trainable_variables(): 14 | # shape is an array of tf.Dimension 15 | shape = variable.get_shape() 16 | variable_parameters = 1 17 | for dim in shape: 18 | variable_parameters *= dim.value 19 | total_parameters += variable_parameters 20 | print('Num of parameters:', total_parameters) 21 | sys.stdout.flush() 22 | 23 | 24 | def train_one_epoch(sess, capsnet, data_gen, epoch): 25 | start_time = time.time() 26 | # continues until no more training data is generated 27 | batch, s_losses, seg_acc, reg_losses = 0.0, 0, 0, 0 28 | 29 | while data_gen.has_data(): 30 | x_batch, seg_batch, crop1_batch, crop2_batch = data_gen.get_batch(config.batch_size) 31 | 32 | if config.multi_gpu and len(x_batch) == 1: 33 | print('Batch size of one, not running') 34 | continue 35 | 36 | n_samples = len(x_batch) 37 | 38 | use_gt_seg = epoch <= config.n_epochs_for_gt_seg 39 | use_gt_crop = epoch <= config.n_epochs_for_gt_crop 40 | 41 | hr_lstm_input = np.zeros((n_samples, config.hr_lstm_size[0], config.hr_lstm_size[1], config.hr_lstm_feats)) 42 | lr_lstm_input = np.zeros((n_samples, config.lr_lstm_size[0], config.lr_lstm_size[1], config.lr_lstm_feats)) 43 | 44 | outputs = sess.run([capsnet.train_op, capsnet.segmentation_loss, capsnet.pred_caps, capsnet.seg_acc, 45 | capsnet.regression_loss], 46 | feed_dict={capsnet.x_input_video: x_batch, capsnet.y_segmentation: seg_batch, 47 | capsnet.hr_cond_input: hr_lstm_input, capsnet.lr_cond_input: lr_lstm_input, 48 | capsnet.use_gt_seg: use_gt_seg, capsnet.use_gt_crop: use_gt_crop, 49 | capsnet.gt_crops1: crop1_batch, capsnet.gt_crops2: crop2_batch}) 50 | 51 | _, s_loss, cap_vals, s_acc, reg_loss = outputs 52 | s_losses += s_loss 53 | seg_acc += s_acc 54 | reg_losses += reg_loss 55 | 56 | batch += 1 57 | 58 | if np.isnan(cap_vals[0][0]): 59 | print(cap_vals[0][:10]) 60 | print('NAN encountered.') 61 | config.write_output('NAN encountered.\n') 62 | return -1, -1, -1 63 | 64 | if batch % config.batches_until_print == 0: 65 | print('Finished %d batches. %d(s) since start. Avg Segmentation Loss is %.4f. Avg Regression Loss is %.4f. ' 66 | 'Seg Acc is %.4f.' 67 | % (batch, time.time() - start_time, s_losses / batch, reg_losses / batch, seg_acc / batch)) 68 | sys.stdout.flush() 69 | 70 | print('Finish Epoch in %d(s). Avg Segmentation Loss is %.4f. Avg Regression Loss is %.4f. Seg Acc is %.4f.' % 71 | (time.time() - start_time, s_losses / batch, reg_losses / batch, seg_acc / batch)) 72 | sys.stdout.flush() 73 | 74 | return s_losses / batch, reg_losses / batch, seg_acc / batch 75 | 76 | 77 | def validate(sess, capsnet, data_gen): 78 | batch, s_losses, seg_acc = 0.0, 0, 0 79 | start_time = time.time() 80 | 81 | while data_gen.has_data(): 82 | batch_data = data_gen.get_batch(config.batch_size) 83 | x_batch, seg_batch, crop1_batch = batch_data 84 | 85 | hr_lstm_input = np.zeros((len(x_batch), config.hr_lstm_size[0], config.hr_lstm_size[1], config.hr_lstm_feats)) 86 | lr_lstm_input = np.zeros((len(x_batch), config.lr_lstm_size[0], config.lr_lstm_size[1], config.lr_lstm_feats)) 87 | 88 | val_ouputs = sess.run([capsnet.val_seg_loss, capsnet.val_seg_acc], 89 | feed_dict={capsnet.x_input_video: x_batch, capsnet.y_segmentation: seg_batch, 90 | capsnet.hr_cond_input: hr_lstm_input, capsnet.lr_cond_input: lr_lstm_input, 91 | capsnet.use_gt_seg: True, capsnet.use_gt_crop: True, 92 | capsnet.gt_crops1: crop1_batch, capsnet.gt_crops2: crop1_batch}) 93 | 94 | s_loss, s_acc = val_ouputs 95 | 96 | s_losses += s_loss 97 | seg_acc += s_acc 98 | 99 | batch += 1 100 | 101 | if batch % config.batches_until_print == 0: 102 | print('Tested %d batches. %d(s) since start.' % (batch, time.time() - start_time)) 103 | sys.stdout.flush() 104 | 105 | print('Evaluation done in %d(s).' % (time.time() - start_time)) 106 | print('Test Segmentation Loss: %.4f. Test Segmentation Acc: %.4f' % (s_losses / batch, seg_acc / batch)) 107 | sys.stdout.flush() 108 | 109 | return s_losses / batch, seg_acc / batch 110 | 111 | 112 | def train_network(gpu_config): 113 | capsnet = CapsNet() 114 | 115 | with tf.Session(graph=capsnet.graph, config=gpu_config) as sess: 116 | tf.global_variables_initializer().run() 117 | 118 | get_num_params() 119 | if config.start_at_epoch <= 1: 120 | config.clear_output() 121 | else: 122 | capsnet.load(sess, config.save_file_best_name % (config.start_at_epoch - 1)) 123 | print('Loading from epoch %d.' % (config.start_at_epoch-1)) 124 | 125 | best_loss = 1000000 126 | best_epoch = 1 127 | print('Training on YoutubeVOS') 128 | for ep in range(config.start_at_epoch, config.n_epochs + 1): 129 | print(20 * '*', 'epoch', ep, 20 * '*') 130 | sys.stdout.flush() 131 | 132 | # Trains network for 1 epoch 133 | nan_tries = 0 134 | while nan_tries < 3: 135 | data_gen = TrainDataGen(config.wait_for_data, crop_size=config.hr_frame_size, n_frames=config.n_frames, 136 | rand_frame_skip=config.rand_frame_skip, multi_objects=config.multiple_objects) 137 | seg_loss, reg_loss, seg_acc = train_one_epoch(sess, capsnet, data_gen, ep) 138 | 139 | if seg_loss < 0 or seg_acc < 0: 140 | nan_tries += 1 141 | capsnet.load(sess, config.save_file_best_name % best_epoch) # loads in the previous epoch 142 | while data_gen.has_data(): 143 | data_gen.get_batch(config.batch_size) 144 | else: 145 | config.write_output('Epoch %d: SL: %.4f. RL: %.4f. SegAcc: %.4f.\n' % (ep, seg_loss, reg_loss, seg_acc)) 146 | break 147 | 148 | if nan_tries == 3: 149 | print('Network cannot be trained. Too many NaN issues.') 150 | exit() 151 | 152 | # Validates network 153 | data_gen = ValidDataGen(config.wait_for_data, crop_size=config.hr_frame_size, n_frames=config.n_frames) 154 | seg_loss, seg_acc = validate(sess, capsnet, data_gen) 155 | 156 | config.write_output('Validation\tSL: %.4f. SA: %.4f.\n' % (seg_loss, seg_acc)) 157 | 158 | # saves every 10 epochs 159 | if ep % config.save_every_n_epochs == 0: 160 | try: 161 | capsnet.save(sess, config.save_file_name % ep) 162 | config.write_output('Saved Network\n') 163 | except: 164 | print('Failed to save network!!!') 165 | sys.stdout.flush() 166 | 167 | # saves when validation loss becomes smaller (after 50 epochs to save space) 168 | t_loss = seg_loss 169 | 170 | if t_loss < best_loss: 171 | best_loss = t_loss 172 | try: 173 | capsnet.save(sess, config.save_file_best_name % ep) 174 | best_epoch = ep 175 | config.write_output('Saved Network - Minimum val\n') 176 | except: 177 | print('Failed to save network!!!') 178 | sys.stdout.flush() 179 | 180 | tf.reset_default_graph() 181 | 182 | 183 | def main(): 184 | gpu_config = tf.ConfigProto(allow_soft_placement=True) 185 | gpu_config.gpu_options.allow_growth = True 186 | 187 | train_network(gpu_config) 188 | 189 | 190 | sys.settrace(main()) 191 | 192 | 193 | -------------------------------------------------------------------------------- /caps_network_test.py: -------------------------------------------------------------------------------- 1 | import config 2 | import tensorflow as tf 3 | import sys 4 | from network_parts.lstm_capsnet_cond2_test import create_network 5 | 6 | 7 | class CapsNet(object): 8 | def __init__(self, graph=None): 9 | if graph is None: 10 | self.graph = tf.Graph() 11 | else: 12 | self.graph = graph 13 | 14 | with self.graph.as_default(): 15 | hr_h, hr_w = config.hr_frame_size 16 | 17 | n_frames = config.n_frames 18 | 19 | self.x_input_video = tf.placeholder(dtype=tf.float32, shape=(None, n_frames, hr_h, hr_w, 3), 20 | name='x_input_video') 21 | self.y_segmentation = tf.placeholder(dtype=tf.float32, shape=(None, n_frames, hr_h, hr_w, 1), 22 | name='y_segmentation') 23 | 24 | self.x_first_seg = tf.placeholder(dtype=tf.float32, shape=(None, hr_h, hr_w, 1), name='x_first_seg') 25 | 26 | self.use_gt_crop = tf.placeholder(dtype=tf.bool) 27 | 28 | self.gt_crops1 = tf.placeholder(dtype=tf.float32, shape=(None, 4)) # [y1, x1, alpha, 0] between 0 and 1 29 | 30 | cond_h, cond_w = config.hr_lstm_size 31 | self.hr_cond_input = tf.placeholder(dtype=tf.float32, shape=(None, cond_h, cond_w, config.hr_lstm_feats), 32 | name='hr_lstm_input') 33 | cond_h, cond_w = config.lr_lstm_size 34 | self.lr_cond_input = tf.placeholder(dtype=tf.float32, shape=(None, cond_h, cond_w, config.lr_lstm_feats), 35 | name='lr_lstm_input') 36 | 37 | self.init_network() 38 | 39 | #self.init_seg_loss() 40 | #self.init_regression_loss() 41 | #self.total_loss = self.segmentation_loss + self.regression_loss 42 | 43 | #self.init_optimizer() 44 | 45 | self.saver = tf.train.Saver() 46 | 47 | def init_network(self): 48 | print('Building Caps3d Model') 49 | 50 | with tf.variable_scope('network') as scope: 51 | #scope.reuse_variables() 52 | network_outputs = create_network(self.x_input_video, self.x_first_seg, self.hr_cond_input, self.lr_cond_input, self.use_gt_crop, self.gt_crops1) 53 | self.segment_layer, self.segment_layer_sig, self.pred_caps, self.state_t, self.state_t_lr, self.pred_crops1 = network_outputs 54 | 55 | sys.stdout.flush() 56 | 57 | def init_seg_loss(self): 58 | # Segmentation loss 59 | segment = self.segment_layer 60 | y_seg = self.segment_layer#self.y_segmentation 61 | 62 | segmentation_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_seg, logits=segment) 63 | segmentation_loss = tf.reduce_mean(tf.reduce_sum(segmentation_loss, axis=[1, 2, 3, 4])) 64 | 65 | pred_seg = tf.cast(tf.greater(segment, 0.0), tf.float32) 66 | seg_acc = tf.reduce_mean(tf.cast(tf.equal(pred_seg, y_seg), tf.float32)) 67 | 68 | frame_segment = self.segment_layer[:, 0, :, :, :] 69 | y_frame_segment = self.segment_layer[:, 0, :, :, :]#self.y_segmentation[:, 0, :, :, :] 70 | 71 | val_seg_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_frame_segment, logits=frame_segment) 72 | val_seg_loss = tf.reduce_mean(tf.reduce_sum(val_seg_loss, axis=[1, 2, 3])) 73 | 74 | val_pred_seg = tf.cast(tf.greater(frame_segment, 0.0), tf.float32) 75 | val_seg_acc = tf.reduce_mean(tf.cast(tf.equal(val_pred_seg, y_frame_segment), tf.float32)) 76 | 77 | self.segmentation_loss = segmentation_loss 78 | self.val_seg_loss = val_seg_loss 79 | 80 | self.seg_acc = seg_acc 81 | self.val_seg_acc = val_seg_acc 82 | 83 | print('Segmentation Loss Initialized') 84 | 85 | def init_regression_loss(self): 86 | regression_loss = tf.square(self.gt_crops1 - self.pred_crops1) 87 | self.regression_loss = tf.reduce_mean(tf.reduce_sum(regression_loss, axis=1)) 88 | 89 | print('Regression Loss Initialized') 90 | 91 | def init_optimizer(self): 92 | optimizer = tf.train.AdamOptimizer(learning_rate=config.learning_rate, beta1=config.beta1, name='Adam', 93 | epsilon=config.epsilon) 94 | 95 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 96 | with tf.control_dependencies(update_ops): 97 | self.train_op = optimizer.minimize(loss=self.total_loss, colocate_gradients_with_ops=True) 98 | 99 | def save(self, sess, file_name): 100 | save_path = self.saver.save(sess, file_name) 101 | print("Model saved in file: %s" % save_path) 102 | sys.stdout.flush() 103 | 104 | def load(self, sess, file_name): 105 | self.saver.restore(sess, file_name) 106 | print('Model restored.') 107 | sys.stdout.flush() 108 | 109 | -------------------------------------------------------------------------------- /caps_network_train.py: -------------------------------------------------------------------------------- 1 | import config 2 | import tensorflow as tf 3 | import sys 4 | from network_parts.lstm_capsnet_cond2_train import create_network 5 | 6 | 7 | class CapsNet(object): 8 | def __init__(self, reg_mult=1.0): 9 | self.graph = tf.Graph() 10 | 11 | with self.graph.as_default(): 12 | hr_h, hr_w = config.hr_frame_size 13 | 14 | n_frames = config.n_frames*2-1 15 | 16 | self.x_input_video = tf.placeholder(dtype=tf.float32, shape=(None, n_frames, hr_h, hr_w, 3), 17 | name='x_input_video') 18 | self.y_segmentation = tf.placeholder(dtype=tf.float32, shape=(None, n_frames, hr_h, hr_w, 1), 19 | name='y_segmentation') 20 | 21 | self.use_gt_seg = tf.placeholder(dtype=tf.bool) 22 | self.use_gt_crop = tf.placeholder(dtype=tf.bool) 23 | 24 | self.gt_crops1 = tf.placeholder(dtype=tf.float32, shape=(None, 4)) # [y1, x1, y2, x2] between 0 and 1 25 | self.gt_crops2 = tf.placeholder(dtype=tf.float32, shape=(None, 4)) # [y1, x1, y2, x2] between 0 and 1 26 | 27 | cond_h, cond_w = config.hr_lstm_size 28 | self.hr_cond_input = tf.placeholder(dtype=tf.float32, shape=(None, cond_h, cond_w, config.hr_lstm_feats), 29 | name='hr_lstm_input') 30 | cond_h, cond_w = config.lr_lstm_size 31 | self.lr_cond_input = tf.placeholder(dtype=tf.float32, shape=(None, cond_h, cond_w, config.lr_lstm_feats), 32 | name='lr_lstm_input') 33 | self.init_network() 34 | 35 | self.init_seg_loss() 36 | self.init_regression_loss() 37 | self.total_loss = self.segmentation_loss + self.regression_loss*reg_mult 38 | 39 | self.init_optimizer() 40 | 41 | self.saver = tf.train.Saver() 42 | 43 | def init_network(self): 44 | print('Building Caps3d Model') 45 | 46 | with tf.variable_scope('network') as scope: 47 | if config.multi_gpu: 48 | b = tf.cast(tf.shape(self.x_input_video)[0] / 2, tf.int32) 49 | with tf.device(config.devices[0]): 50 | segment_layer1, segment_layer_sig1, prim_caps1, state_t1, pred_crops11, pred_crops21 = create_network(self.x_input_video[:b], 51 | self.y_segmentation[:b], 52 | self.hr_cond_input[:b], 53 | self.lr_cond_input[:b], 54 | self.use_gt_seg, 55 | self.use_gt_crop, 56 | self.gt_crops1[:b], 57 | self.gt_crops2[:b]) 58 | 59 | scope.reuse_variables() 60 | with tf.device(config.devices[1]): 61 | segment_layer2, segment_layer_sig2, prim_caps2, state_t2, pred_crops12, pred_crops22 = create_network(self.x_input_video[b:], 62 | self.y_segmentation[b:], 63 | self.hr_cond_input[b:], 64 | self.lr_cond_input[b:], 65 | self.use_gt_seg, 66 | self.use_gt_crop, 67 | self.gt_crops1[b:], 68 | self.gt_crops2[b:]) 69 | 70 | self.segment_layer = tf.concat([segment_layer1, segment_layer2], axis=0) 71 | self.segment_layer_sig = tf.concat([segment_layer_sig1, segment_layer_sig2], axis=0) 72 | self.pred_caps = tf.concat([prim_caps1, prim_caps2], axis=0) 73 | self.state_t = tf.concat([state_t1, state_t2], axis=0) 74 | self.pred_crops1 = tf.concat([pred_crops11, pred_crops12], axis=0) 75 | self.pred_crops2 = tf.concat([pred_crops21, pred_crops22], axis=0) 76 | else: 77 | network_outputs = create_network(self.x_input_video, self.y_segmentation, self.hr_cond_input, self.lr_cond_input, self.use_gt_seg, self.use_gt_crop, self.gt_crops1, self.gt_crops2) 78 | self.segment_layer, self.segment_layer_sig, self.pred_caps, self.state_t, self.pred_crops1, self.pred_crops2 = network_outputs 79 | 80 | sys.stdout.flush() 81 | 82 | def init_seg_loss(self): 83 | # Segmentation loss 84 | segment = self.segment_layer 85 | y_seg = self.y_segmentation 86 | 87 | segmentation_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_seg, logits=segment) 88 | #segmentation_loss = tf.reduce_mean(tf.reduce_sum(segmentation_loss, axis=[1, 2, 3, 4])) 89 | segmentation_loss = tf.reduce_mean(tf.reduce_mean(segmentation_loss, axis=[1, 2, 3, 4])) 90 | 91 | pred_seg = tf.cast(tf.greater(segment, 0.0), tf.float32) 92 | seg_acc = tf.reduce_mean(tf.cast(tf.equal(pred_seg, y_seg), tf.float32)) 93 | 94 | frame_segment = self.segment_layer[:, 0, :, :, :] 95 | y_frame_segment = self.y_segmentation[:, 0, :, :, :] 96 | 97 | val_seg_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_frame_segment, logits=frame_segment) 98 | val_seg_loss = tf.reduce_mean(tf.reduce_sum(val_seg_loss, axis=[1, 2, 3])) 99 | 100 | val_pred_seg = tf.cast(tf.greater(frame_segment, 0.0), tf.float32) 101 | val_seg_acc = tf.reduce_mean(tf.cast(tf.equal(val_pred_seg, y_frame_segment), tf.float32)) 102 | 103 | segment_sig = self.segment_layer_sig 104 | 105 | p_times_r = segment_sig*y_seg 106 | p_plus_r = segment_sig + y_seg 107 | inv_p_times_r = (1-segment_sig) * (1-y_seg) 108 | inv_p_plus_r = 2-segment_sig-y_seg 109 | eps = 1e-8 110 | 111 | term1 = (tf.reduce_sum(p_times_r, axis=[1, 2, 3])+eps)/(tf.reduce_sum(p_plus_r, axis=[1, 2, 3])+eps) 112 | term2 = (tf.reduce_sum(inv_p_times_r, axis=[1, 2, 3])+eps)/(tf.reduce_sum(inv_p_plus_r, axis=[1, 2, 3])+eps) 113 | 114 | dice_loss = tf.reduce_mean(1 - term1 - term2) 115 | 116 | self.segmentation_loss = segmentation_loss + dice_loss 117 | # self.segmentation_loss = segmentation_loss 118 | self.val_seg_loss = val_seg_loss 119 | 120 | self.seg_acc = seg_acc 121 | self.val_seg_acc = val_seg_acc 122 | 123 | 124 | print('Segmentation Loss Initialized') 125 | 126 | def init_regression_loss(self): 127 | regression_loss = tf.square(self.gt_crops1 - self.pred_crops1) + tf.square(self.gt_crops2 - self.pred_crops2) 128 | self.regression_loss = tf.reduce_mean(tf.reduce_sum(regression_loss, axis=1)) 129 | 130 | print('Regression Loss Initialized') 131 | 132 | def init_optimizer(self): 133 | optimizer = tf.train.AdamOptimizer(learning_rate=config.learning_rate, beta1=config.beta1, name='Adam', 134 | epsilon=config.epsilon) 135 | 136 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 137 | with tf.control_dependencies(update_ops): 138 | self.train_op = optimizer.minimize(loss=self.total_loss, colocate_gradients_with_ops=True) 139 | 140 | def save(self, sess, file_name): 141 | save_path = self.saver.save(sess, file_name) 142 | print("Model saved in file: %s" % save_path) 143 | sys.stdout.flush() 144 | 145 | def load(self, sess, file_name): 146 | self.saver.restore(sess, file_name) 147 | print('Model restored.') 148 | sys.stdout.flush() 149 | 150 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | devices = ['/gpu:0', '/gpu:0'] 4 | multi_gpu = len(set(devices)) == 2 5 | 6 | batch_size = 4 7 | n_epochs = 200 8 | 9 | learning_rate, beta1, epsilon = 0.0001, 0.5, 1e-6 10 | 11 | n_frames = 8 12 | n_epochs_for_gt_seg = 0 13 | n_epochs_for_gt_crop = 500 14 | 15 | hr_frame_size = (128*4, 224*4) 16 | hr_lstm_size = (1, 1) 17 | hr_lstm_feats = 256 18 | 19 | lr_frame_size = (128, 224) 20 | lr_lstm_size = (lr_frame_size[0]//4, lr_frame_size[1]//4) 21 | lr_lstm_feats = 256 22 | 23 | model_num = 4 24 | 25 | save_every_n_epochs = 50 26 | output_file_name = './output%d.txt' % model_num 27 | save_file_name = ('./network_saves/model%d' % model_num) + '_%d.ckpt' 28 | save_file_best_name = ('./network_saves_best/model%d' % model_num) + '_%d.ckpt' 29 | start_at_epoch = 1 30 | 31 | output_inference_file = './Anns2/Annotations/' 32 | epoch_save = 394 33 | save_file_inference = ('./network_saves_best/model%d' % model_num) + '_%d.ckpt' 34 | 35 | multiple_objects = False 36 | rand_frame_skip = 1 37 | wait_for_data = 5 # in seconds 38 | batches_until_print = 1 39 | 40 | inv_temp = 0.5 41 | inv_temp_delta = 0.1 42 | pose_dimension = 4 43 | 44 | print_layers = True 45 | 46 | 47 | def clear_output(): 48 | with open(output_file_name, 'w') as f: 49 | print('Writing to ' + output_file_name) 50 | f.write('Model #: %d. Batch Size: %d.\n' % (model_num, batch_size)) 51 | 52 | 53 | def write_output(string): 54 | try: 55 | output_log = open(output_file_name, 'a') 56 | output_log.write(string) 57 | output_log.close() 58 | except: 59 | print('Unable to save to output log') 60 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from caps_network_test import CapsNet 3 | from skvideo.io import vread 4 | import numpy as np 5 | from PIL import Image 6 | from scipy.misc import imresize 7 | import os 8 | import config 9 | import time 10 | 11 | 12 | def load_video(video_name): 13 | video = vread(video_name) 14 | 15 | t, h, w, _ = video.shape 16 | 17 | resized_video = [] 18 | for frame in video: 19 | resized_video.append(imresize(frame, config.hr_frame_size)) 20 | 21 | resized_video = np.stack(resized_video, axis=0) 22 | 23 | return resized_video / 255., (h, w) 24 | 25 | 26 | def load_first_frame(frame_name): 27 | image = Image.open(frame_name) 28 | palette = image.getpalette() 29 | image_np = np.array(image) 30 | 31 | return imresize(image_np, config.hr_frame_size, interp='nearest'), palette 32 | 33 | 34 | def process_first_frame(first_frame): 35 | unique_seg_colors = np.unique(first_frame) 36 | 37 | fin_segmentations = {} 38 | for color in unique_seg_colors: 39 | if color == 0: 40 | continue 41 | 42 | gt_seg = np.where(first_frame == color, 1, 0) 43 | if np.sum(gt_seg) == 0: 44 | continue 45 | gt_seg = np.expand_dims(gt_seg, axis=-1) 46 | fin_segmentations[color] = (0, gt_seg) 47 | 48 | return fin_segmentations 49 | 50 | 51 | def get_bounds(img): 52 | h_sum = np.sum(img, axis=1) 53 | w_sum = np.sum(img, axis=0) 54 | 55 | hs = np.where(h_sum > 0) 56 | ws = np.where(w_sum > 0) 57 | 58 | try: 59 | h0 = hs[0][0] 60 | h1 = hs[0][-1] 61 | w0 = ws[0][0] 62 | w1 = ws[0][-1] 63 | except: 64 | return -1, -1, -1, -1 65 | 66 | return h0, h1, w0, w1 67 | 68 | 69 | def get_crop_to_use(h0, h1, w0, w1, h, w, prev_crop): 70 | # uses the h, w of predicted, and the center of gt 71 | if h0 == -1: 72 | use_gt_crop = False 73 | crop_to_use = prev_crop 74 | else: 75 | use_gt_crop = False 76 | crop_to_use = np.clip(np.array([((h0+h1)/2)/h, ((w0+w1)/2)/w, 1.0, 0]), 0, 1) 77 | 78 | return use_gt_crop, crop_to_use 79 | 80 | 81 | def get_seg_for_clip_gt(sess, capsnet, clip, frame_start, lstm_cond, lstm_cond_lr, prev_crop): 82 | f, h, w, _ = clip.shape 83 | 84 | # print(frame_start.min()) 85 | # print(frame_start.max()) 86 | first_frame_seg_full = frame_start # np.round(frame_start) # frame_start # 87 | 88 | new_video_in = clip 89 | len_clip = new_video_in.shape[0] 90 | if len_clip < config.n_frames: 91 | new_video_in = np.concatenate((new_video_in, np.tile(new_video_in[-1:], 92 | [config.n_frames - len_clip, 1, 1, 1])), axis=0) 93 | 94 | # gets the bounds of the segmentations 95 | h0, h1, w0, w1 = get_bounds(np.round(first_frame_seg_full[:, :, 0])) 96 | 97 | use_gt_crop, crop_to_use = get_crop_to_use(h0, h1, w0, w1, h, w, prev_crop) 98 | 99 | # runs through the network 100 | seg_pred, lstm_cond, lstm_cond_lr, pred_crops = sess.run([capsnet.segment_layer_sig, capsnet.state_t, capsnet.state_t_lr, capsnet.pred_crops1], 101 | feed_dict={capsnet.x_input_video: [new_video_in], 102 | capsnet.x_first_seg: [first_frame_seg_full], 103 | capsnet.hr_cond_input: lstm_cond, 104 | capsnet.lr_cond_input: lstm_cond_lr, 105 | capsnet.use_gt_crop: use_gt_crop, 106 | capsnet.gt_crops1: [crop_to_use]}) 107 | 108 | # resizes crop and places it back into original frame size 109 | seg_pred = seg_pred[0] 110 | 111 | overlap_frames = 3 112 | 113 | if use_gt_crop: 114 | crop_to_use = crop_to_use 115 | else: 116 | crop_to_use = np.concatenate((crop_to_use[:2], pred_crops[0][2:]), axis=-1) 117 | 118 | return seg_pred, lstm_cond, lstm_cond_lr, overlap_frames, crop_to_use 119 | 120 | 121 | def generate_inference(sess, capsnet, video, segmentations, orig_dim, vid_name, img_palette): 122 | orig_h, orig_w = orig_dim 123 | n_objects = int(max(segmentations.keys())) 124 | 125 | lstm_conds = np.zeros((n_objects + 1, config.hr_lstm_size[0], config.hr_lstm_size[1], config.hr_lstm_feats)) 126 | lstm_conds_lr = np.zeros((n_objects + 1, config.lr_lstm_size[0], config.lr_lstm_size[1], config.lr_lstm_feats)) 127 | 128 | prev_coords = np.zeros((n_objects + 1, 4)) 129 | 130 | f, h, w, _ = video.shape 131 | 132 | segmentation_maps = np.zeros((config.n_frames, h, w, n_objects + 1)) 133 | segmentation_maps[:, :, :, 0] = 0.5 134 | final_segmentation = np.zeros((h, w, 1)) 135 | cur_i = np.ones((n_objects + 1,), np.uint8) 136 | overlaps = np.ones((n_objects + 1,), np.uint8) 137 | 138 | vid_dir = 'Output/' + vid_name + '/' 139 | mkdir(vid_dir) 140 | 141 | for i in range(f): 142 | for color in range(1, n_objects + 1): 143 | if color not in segmentations.keys(): 144 | continue 145 | 146 | cur_i[color] += 1 147 | start_frame, start_seg = segmentations[color] 148 | 149 | if i < start_frame: # the current frame occurs before the object appears 150 | cur_i[color] = 7 151 | continue 152 | elif i == start_frame: # the current frame is the first frame of the object (use given segmentation) 153 | segmentation_maps[-1, :, :, color:color + 1] = start_seg 154 | cur_i[color] = 7 155 | continue 156 | 157 | if cur_i[color] != config.n_frames - overlaps[color] + 1: # the current frame's segmentation has been predicted 158 | # cur_overlap[color] -= 1 159 | # segmentation_maps[:-1, :, :, color] = segmentation_maps[1:, :, :, color] 160 | continue 161 | 162 | cur_0 = cur_i[color] - 1 163 | 164 | # cond_frame_seg = segmentation_maps[i-1, :, :, color:color+1] # This is the naive approach 165 | # cond_frame_seg = (final_segmentation[i-1] == color).astype(np.float32) # winner take all approach 166 | cond_frame_seg = ((final_segmentation == color).astype(np.float32) + (final_segmentation == 0).astype( 167 | np.float32)) * segmentation_maps[cur_0, :, :, color:color + 1] # winner take all approach 2 168 | 169 | vid_to_use = video[i - 1:i + config.n_frames - 1] 170 | orig_len = vid_to_use.shape[0] 171 | if orig_len < config.n_frames: 172 | vid_to_use = np.concatenate( 173 | [vid_to_use] + [vid_to_use[-1:] for reps in range(config.n_frames - orig_len)], axis=0) 174 | 175 | # use previous frame to generate segmentation for future N frames 176 | pred_seg, lstm_cond, lstm_cond_lr, overlap_frames, coords_used = get_seg_for_clip_gt(sess, capsnet, 177 | vid_to_use, 178 | cond_frame_seg, 179 | lstm_conds[ 180 | color:color + 1], 181 | lstm_conds_lr[ 182 | color:color + 1], 183 | prev_coords[color]) 184 | 185 | segmentation_maps[:, :, :, color:color + 1] = pred_seg 186 | lstm_conds[color:color + 1] = lstm_cond 187 | lstm_conds_lr[color:color + 1] = lstm_cond_lr 188 | overlaps[color] = overlap_frames 189 | 190 | prev_coords[color] = coords_used 191 | # print(overlap_frames) 192 | 193 | cur_i[color] = 1 194 | 195 | final_segmentation[:, :, 0] = np.argmax(segmentation_maps[cur_i, :, :, range(n_objects + 1)], axis=0) 196 | 197 | frame_name = vid_dir + ('%d.png' % i).zfill(5) 198 | 199 | fb_segs_argmax = imresize(final_segmentation[:, :, 0].astype(dtype=np.uint8), (orig_h, orig_w), interp='nearest') 200 | c = Image.fromarray(fb_segs_argmax, mode='P') 201 | c.putpalette(img_palette) 202 | c.save(frame_name, "PNG", mode='P') 203 | 204 | 205 | def mkdir(dl_path): 206 | if not os.path.exists(dl_path): 207 | print("path doesn't exist. trying to make %s" % dl_path) 208 | os.mkdir(dl_path) 209 | else: 210 | print('%s exists, cannot make directory' % dl_path) 211 | 212 | 213 | def inf(): 214 | gpu_config = tf.ConfigProto() 215 | gpu_config.gpu_options.allow_growth = True 216 | 217 | capsnet = CapsNet() 218 | with tf.Session(graph=capsnet.graph, config=gpu_config) as sess: 219 | capsnet.load(sess, config.save_file_inference % config.epoch_save) 220 | 221 | # loads in video 222 | video_name = '03deb7ad95' 223 | video, orig_dims = load_video('./Examples/' + video_name + '.mp4') 224 | first_frame, img_palette = load_first_frame('./Examples/00110.png') 225 | 226 | processed_first_frame = process_first_frame(first_frame) 227 | 228 | start_time = time.time() 229 | print('Starting Inference') 230 | 231 | generate_inference(sess, capsnet, video, processed_first_frame, orig_dims, video_name, img_palette) 232 | 233 | print('Finished Inference in %d(s)' % (time.time()-start_time)) 234 | 235 | 236 | if __name__ == "__main__": 237 | inf() 238 | -------------------------------------------------------------------------------- /load_youtube_data_multi.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from scipy.misc import imread, imresize 4 | import numpy as np 5 | import random 6 | from threading import Thread 7 | import time 8 | from PIL import Image 9 | import sys 10 | #from scipy.signal import medfilt, medfilt2d 11 | import matplotlib.pyplot as plt 12 | from scipy.ndimage.filters import median_filter 13 | 14 | data_loc = '/home/kevin/HD2TB/Datasets/YoutubeVOS2018/' 15 | max_vids = 25 16 | 17 | 18 | def get_split_names(tr_or_val): 19 | split_file = data_loc + '%s/meta.json' % tr_or_val 20 | 21 | all_files = [] 22 | with open(split_file) as f: 23 | data = json.load(f) 24 | files = sorted(list(data['videos'].keys())) 25 | for file_name in files: 26 | fdict = data['videos'][file_name]['objects'] 27 | frames = sorted(list(set([fr for x in fdict.keys() for fr in fdict[x]['frames']]))) 28 | all_files.append((file_name, data['videos'][file_name]['objects'], frames)) 29 | 30 | return all_files 31 | 32 | 33 | def load_video(file_name, allowable_frames, tr_or_val='train', shuffle=True, n_frames=8, frame_skip=1): 34 | video_dir = data_loc + ('%s_all_frames/JPEGImages/%s/' % (tr_or_val, file_name)) 35 | segment_dir = data_loc + ('%s_all_frames/Annotations/%s/' % (tr_or_val, file_name)) 36 | 37 | frame_names = sorted(os.listdir(video_dir)) 38 | seg_frame_names = sorted(os.listdir(segment_dir)) 39 | frame_names = sorted([x for x in frame_names if x[:-4] + '.png' in seg_frame_names]) 40 | 41 | start_ind = frame_names.index(allowable_frames[0] + '.jpg') 42 | 43 | if shuffle: 44 | start_frame = np.random.randint(start_ind, max(len(frame_names) - n_frames * frame_skip, start_ind + 1)) 45 | else: 46 | start_frame = start_ind 47 | 48 | while start_frame > start_ind and frame_names[start_frame][:-4] not in allowable_frames: # ensures the first frame is not interpolated 49 | start_frame -= 1 50 | 51 | # loads video 52 | frames = [] 53 | for f in range(start_frame, start_frame + n_frames*frame_skip, frame_skip): 54 | try: 55 | frames.append(imread(video_dir + frame_names[f])) 56 | #frames.append(np.array(Image.open(video_dir + frame_names[f]))) 57 | except: 58 | frames.append(frames[-1]) 59 | 60 | video = np.stack(frames, axis=0) 61 | 62 | # loads segmentations 63 | frames = [] 64 | for f in range(start_frame, start_frame + n_frames*frame_skip, frame_skip): 65 | try: 66 | frames.append(np.array(Image.open(segment_dir + seg_frame_names[f]))) 67 | except: 68 | frames.append(frames[-1]) 69 | 70 | segmentation = np.stack(frames, axis=0) 71 | 72 | return video, segmentation, frame_names[start_frame][:-4] 73 | 74 | 75 | def resize_video(video, segmentation, target_size=(120, 120)): 76 | frames, h, w, _ = video.shape 77 | 78 | t_h, t_w = target_size 79 | 80 | video_res = np.zeros((frames, t_h, t_w, 3), np.uint8) 81 | segment_res = np.zeros((frames, t_h, t_w), np.uint8) 82 | for frame in range(frames): 83 | video_res[frame] = imresize(video[frame], (t_h, t_w)) 84 | segment_res[frame] = imresize(segmentation[frame], (t_h, t_w), interp='nearest') 85 | 86 | return video_res/255., segment_res 87 | 88 | 89 | def flip_clip(clip, segment_clip): 90 | flip_y = np.random.random_sample() 91 | #flip_x = np.random.random_sample() 92 | if flip_y >= 0.5: 93 | clip = np.flip(clip, axis=2) 94 | segment_clip = np.flip(segment_clip, axis=2) 95 | # if flip_x >= 0.5: 96 | # clip = np.flip(clip, axis=1) 97 | # segment_clip = np.flip(segment_clip, axis=1) 98 | 99 | return clip, segment_clip 100 | 101 | 102 | def get_bounds(img): 103 | h_sum = np.sum(img, axis=1) 104 | w_sum = np.sum(img, axis=0) 105 | 106 | hs = np.where(h_sum > 0) 107 | ws = np.where(w_sum > 0) 108 | 109 | try: 110 | h0 = hs[0][0] 111 | h1 = hs[0][-1] 112 | w0 = ws[0][0] 113 | w1 = ws[0][-1] 114 | except: 115 | return 0, img.shape[0], 0, img.shape[1] 116 | 117 | return h0, h1, w0, w1 118 | 119 | 120 | def get_bounds2(img): 121 | h_sum = np.sum(img, axis=1) 122 | w_sum = np.sum(img, axis=0) 123 | 124 | hs = np.where(h_sum > 0) 125 | ws = np.where(w_sum > 0) 126 | 127 | try: 128 | h0 = hs[0][0] 129 | h1 = hs[0][-1] 130 | w0 = ws[0][0] 131 | w1 = ws[0][-1] 132 | except: 133 | return -1, -1, -1, -1 134 | 135 | return h0, h1, w0, w1 136 | 137 | 138 | def get_bounds_frames(frames): 139 | y0, y1, x0, x1 = [], [], [], [] 140 | 141 | y1_0, y2_0, x1_0, x2_0 = 0, 0, 0, 0 142 | 143 | for i in range(frames.shape[0]): 144 | h0, h1, w0, w1 = get_bounds2(frames[i]) 145 | if i == 0 and h0 == -1: 146 | return -1, -1, -1, -1, -1, -1 147 | elif h0 == -1: 148 | continue 149 | else: 150 | y0.append(h0) 151 | y1.append(h1) 152 | x0.append(w0) 153 | x1.append(w1) 154 | 155 | if i == 0: 156 | y1_0, y2_0, x1_0, x2_0 = h0, h1, w0, w1 157 | 158 | center_y1, center_x1 = ((y2_0 + y1_0) / 2), ((x2_0 + x1_0) / 2) 159 | 160 | min_y, max_y, min_x, max_x = min(y0), max(y1), min(x0), max(x1) 161 | 162 | return center_y1, center_x1, min_y, max_y, min_x, max_x 163 | 164 | 165 | def get_crops2(segmentations, n_frames, leeway=.15): 166 | _, h, w = segmentations.shape 167 | leeway_h, leeway_w = leeway * h, leeway * w 168 | 169 | y1_0, y2_0, x1_0, x2_0 = get_bounds(segmentations[0]) 170 | y1_1, y2_1, x1_1, x2_1 = get_bounds(segmentations[n_frames - 1]) 171 | y1_2, y2_2, x1_2, x2_2 = get_bounds(segmentations[n_frames*2 - 2]) 172 | 173 | min_y = min(y1_0, y1_1) 174 | min_x = min(x1_0, x1_1) 175 | max_y = max(y2_0, y2_1) 176 | max_x = max(x2_0, x2_1) 177 | 178 | center_y1, center_x1 = ((y2_0 + y1_0)/2), ((x2_0 + x1_0)/2) 179 | h1, w1 = 2*max(center_y1 - min_y, max_y - center_y1) + leeway_h, 2*max(center_x1 - min_x, max_x - center_x1) + leeway_w 180 | h1, w1 = max(h1, h//8), max(w1, w//8) 181 | alpha1 = max(h1/h, w1/w) 182 | h1, w1 = alpha1*h, alpha1*w 183 | 184 | y1 = center_y1 - 0.5*h1 185 | x1 = center_x1 - 0.5*w1 186 | if y1 + h1 > h: 187 | y1 -= (y1+h1-h) 188 | if x1 + w1 > w: 189 | x1 -= (x1+w1-w) 190 | 191 | min_y = min(y1_2, y1_1) 192 | min_x = min(x1_2, x1_1) 193 | max_y = max(y2_2, y2_1) 194 | max_x = max(x2_2, x2_1) 195 | 196 | center_y2, center_x2 = ((y2_1 + y1_1) / 2), ((x2_1 + x1_1) / 2) 197 | h2, w2 = 2 * max(center_y2 - min_y, max_y - center_y2) + leeway_h, 2 * max(center_x2 - min_x, max_x - center_x2) + leeway_w 198 | h2, w2 = max(h2, h // 8), max(w2, w // 8) 199 | alpha2 = max(h2 / h, w2 / w) 200 | h2, w2 = alpha2 * h, alpha2 * w 201 | 202 | y2 = center_y2 - 0.5 * h2 203 | x2 = center_x2 - 0.5 * w2 204 | if y2 + h2 > h: 205 | y2 -= (y2+h2-h) 206 | if x2 + w2 > w: 207 | x2 -= (x2+w2-w) 208 | 209 | return np.clip(np.array([y1/h, x1/w, alpha1+0.01, 0]), 0, 1), np.clip(np.array([y2/h, x2/w, alpha2+0.01, 0]), 0, 1) 210 | 211 | 212 | def get_crops_fin(segmentations, n_frames, leeway=0.15): 213 | n_in_frames, h, w = segmentations.shape 214 | leeway_h, leeway_w = leeway * h, leeway * w 215 | 216 | assert (n_in_frames - n_frames) % (n_frames-1) == 0 217 | 218 | crops = [] 219 | i = 0 220 | 221 | while i < n_in_frames: 222 | frames = segmentations[i:i+n_frames] 223 | 224 | center_y1, center_x1, min_y, max_y, min_x, max_x = get_bounds_frames(frames) 225 | 226 | if center_y1 == -1: 227 | crops.append(crops[-1]) 228 | i += 7 229 | continue 230 | 231 | h1, w1 = 2 * max(center_y1 - min_y, max_y - center_y1) + leeway_h, 2 * max(center_x1 - min_x, max_x - center_x1) + leeway_w 232 | h1, w1 = max(h1, h // 8), max(w1, w // 8) 233 | alpha1 = max(h1 / h, w1 / w) 234 | h1, w1 = alpha1 * h, alpha1 * w 235 | 236 | y1 = center_y1 - 0.5 * h1 237 | x1 = center_x1 - 0.5 * w1 238 | if y1 + h1 > h: 239 | y1 -= (y1 + h1 - h) 240 | if x1 + w1 > w: 241 | x1 -= (x1 + w1 - w) 242 | 243 | crops.append(np.clip(np.array([y1/h, x1/w, alpha1+0.01, 0]), 0, 1)) 244 | i += 7 245 | 246 | return crops 247 | 248 | 249 | class YoutubeTrainDataGen(object): 250 | def __init__(self, sec_to_wait=5, n_threads=10, crop_size=(256, 448), augment_data=True, n_frames=8, 251 | rand_frame_skip=4, multi_objects=False): 252 | self.train_files = get_split_names('train') 253 | 254 | self.sec_to_wait = sec_to_wait 255 | 256 | self.augment = augment_data 257 | self.rand_frame_skip = rand_frame_skip 258 | 259 | self.crop_size = crop_size 260 | self.multi_objects = multi_objects 261 | 262 | self.n_frames = n_frames 263 | 264 | np.random.seed(None) 265 | random.shuffle(self.train_files) 266 | 267 | self.data_queue = [] 268 | 269 | self.thread_list = [] 270 | for i in range(n_threads): 271 | load_thread = Thread(target=self.__load_and_process_data) 272 | load_thread.start() 273 | self.thread_list.append(load_thread) 274 | 275 | print('Waiting %d (s) to load data' % sec_to_wait) 276 | sys.stdout.flush() 277 | time.sleep(self.sec_to_wait) 278 | 279 | def __load_and_process_data(self): 280 | while self.train_files: 281 | while len(self.data_queue) >= max_vids: 282 | time.sleep(1) 283 | 284 | try: 285 | vid_name, fdict, allowable_frames = self.train_files.pop() 286 | except: 287 | continue # Thread issue 288 | 289 | frame_skip = np.random.randint(self.rand_frame_skip) + 1 290 | 291 | video, segmentation, frame_name = load_video(vid_name, allowable_frames, tr_or_val='train', n_frames=self.n_frames*2-1, frame_skip=frame_skip) 292 | 293 | video_res, seg_res = resize_video(video, segmentation, self.crop_size) 294 | 295 | # find objects in the first frame 296 | allowable_colors = [] 297 | for obj_id in fdict.keys(): 298 | if frame_name in fdict[obj_id]['frames']: 299 | allowable_colors.append(int(obj_id)) 300 | 301 | # no objects in the first frame 302 | if len(allowable_colors) == 0: 303 | print(vid_name, 'has no colors - SHOULD BE IMPOSSIBLE - POSSIBLE BUG FOUND!') 304 | continue 305 | 306 | # selects the objects which will be chosen from clip 307 | colors_to_select = [x+1 for x in range(len(allowable_colors))] 308 | 309 | if self.multi_objects: 310 | n_colors_foreground = random.choice(colors_to_select) 311 | else: 312 | n_colors_foreground = 1 313 | 314 | random.shuffle(allowable_colors) 315 | selected_colors = allowable_colors[:n_colors_foreground] 316 | 317 | gt_seg = np.zeros_like(seg_res, dtype=np.float32) 318 | for color in selected_colors: 319 | gt_seg += np.where(seg_res == color, 1, 0) 320 | gt_seg = np.clip(gt_seg, 0, 1) 321 | 322 | #gt_seg = np.round(medfilt(gt_seg, (1, 3, 3))) 323 | for frame in range(gt_seg.shape[0]): 324 | if frame % 5 == 0: 325 | continue 326 | gt_seg[frame] = np.round(median_filter(gt_seg[frame], 3)) 327 | 328 | if np.sum(gt_seg[0]) == 0: 329 | print('No segmentation found. ERROR.') 330 | continue 331 | 332 | if self.augment: 333 | video_res, gt_seg = flip_clip(video_res, gt_seg) 334 | 335 | leeway = np.random.random_sample() * 0.1 + 0.15 336 | gt_crop1, gr_crop2 = get_crops2(gt_seg, self.n_frames, leeway=leeway) 337 | 338 | gt_seg = np.expand_dims(gt_seg, axis=-1) 339 | 340 | self.data_queue.append((video_res, gt_seg, gt_crop1, gr_crop2)) 341 | print('Loading data thread finished') 342 | sys.stdout.flush() 343 | 344 | def get_batch(self, batch_size=5): 345 | while len(self.data_queue) < batch_size and self.train_files: 346 | print('Waiting on data. # Already Loaded = %s' % str(len(self.data_queue))) 347 | sys.stdout.flush() 348 | time.sleep(self.sec_to_wait) 349 | 350 | batch_size = min(batch_size, len(self.data_queue)) 351 | batch_x, batch_seg, batch_crop1, batch_crop2 = [], [], [], [] 352 | for i in range(batch_size): 353 | vid, seg, cr1, cr2 = self.data_queue.pop(0) 354 | batch_x.append(vid) 355 | batch_seg.append(seg) 356 | batch_crop1.append(cr1) 357 | batch_crop2.append(cr2) 358 | # print(cr1) 359 | # print(cr2) 360 | 361 | return batch_x, batch_seg, batch_crop1, batch_crop2 362 | 363 | def has_data(self): 364 | return self.data_queue != [] or self.train_files != [] 365 | 366 | 367 | # 368 | # def main(): 369 | # a = YoutubeTrainDataGen(n_threads=10, crop_size=(256*2, 448*2)) 370 | # i = 0 371 | # while a.has_data(): 372 | # v, s, cr1, cr2 = a.get_batch(1) 373 | # print(i) 374 | # i+=1 375 | # print(cr1) 376 | # cy, cx, h, w = cr1[0] 377 | # cy, cx, h, w = cy*256*2, cx*448*2, h*256*2, w*448*2 378 | # 379 | # mask = np.ones_like(s[0][0, :, :, 0])*2 380 | # mask[np.clip(int(cy-0.5*h), 0, 256*2):np.clip(int(cy+0.5*h), 0, 256*2), np.clip(int(cx-0.5*w), 0, 448*2):np.clip(int(cx+0.5*w), 0, 448*2)] = 0 381 | # 382 | # plt.imshow(s[0][0, :, :, 0] + mask) 383 | # plt.show(plt) 384 | # plt.imshow(s[0][7, :, :, 0] + mask) 385 | # plt.show(plt) 386 | # 387 | # cy, cx, h, w = cr2[0] 388 | # cy, cx, h, w = cy * 256 * 2, cx * 448 * 2, h * 256 * 2, w * 448 * 2 389 | # 390 | # mask = np.ones_like(s[0][0, :, :, 0]) * 2 391 | # mask[np.clip(int(cy - 0.5 * h), 0, 256 * 2):np.clip(int(cy + 0.5 * h), 0, 256 * 2), 392 | # np.clip(int(cx - 0.5 * w), 0, 448 * 2):np.clip(int(cx + 0.5 * w), 0, 448 * 2)] = 0 393 | # 394 | # plt.imshow(s[0][7, :, :, 0] + mask) 395 | # plt.show(plt) 396 | # plt.imshow(s[0][7, :, :, -1] + mask) 397 | # plt.show(plt) 398 | # 399 | # 400 | # main() 401 | -------------------------------------------------------------------------------- /load_youtubevalid_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from scipy.misc import imread, imresize 4 | import numpy as np 5 | from threading import Thread 6 | import time 7 | from PIL import Image 8 | import sys 9 | import matplotlib.pyplot as plt 10 | 11 | data_loc = '/home/kevin/HD2TB/Datasets/YoutubeVOS2018/' 12 | max_vids = 32 13 | 14 | 15 | def get_split_names(tr_or_val): 16 | split_file = data_loc + '%s/meta.json' % tr_or_val 17 | 18 | all_files = [] 19 | with open(split_file) as f: 20 | data = json.load(f) 21 | files = sorted(list(data['videos'].keys())) 22 | for file_name in files: 23 | all_files.append((file_name, data['videos'][file_name]['objects'])) 24 | 25 | return all_files 26 | 27 | 28 | # Loads in 8 frames 29 | def load_video(file_name, tr_or_val='train', n_frames=8): 30 | video_dir = data_loc + ('%s_all_frames/JPEGImages/%s/' % (tr_or_val, file_name)) 31 | 32 | # loads segmentations 33 | segment_dir = data_loc + ('%s/Annotations/%s/' % (tr_or_val, file_name)) 34 | 35 | seg_frame_name = sorted(os.listdir(segment_dir))[0] 36 | seg_frame = np.array(Image.open(segment_dir + seg_frame_name)) 37 | 38 | frame_names = sorted(os.listdir(video_dir)) 39 | 40 | start_frame = frame_names.index(seg_frame_name[:-4] + '.jpg') 41 | 42 | # loads video 43 | frames = [] 44 | for f in range(start_frame, start_frame+n_frames): 45 | try: 46 | frames.append(imread(video_dir + frame_names[f])) 47 | except: 48 | frames.append(frames[-1]) 49 | 50 | video = np.stack(frames, axis=0) 51 | 52 | return video, seg_frame 53 | 54 | 55 | def resize_video(video, segmentation, target_size=(120, 120)): 56 | frames, h, w, _ = video.shape 57 | 58 | t_h, t_w = target_size 59 | 60 | video_res = np.zeros((frames, t_h, t_w, 3), np.uint8) 61 | for frame in range(frames): 62 | video_res[frame] = imresize(video[frame], (t_h, t_w)) 63 | 64 | segment_res = imresize(segmentation, (t_h, t_w), interp='nearest') 65 | 66 | return video_res/255., segment_res 67 | 68 | 69 | def get_bounds(img): 70 | h_sum = np.sum(img, axis=1) 71 | w_sum = np.sum(img, axis=0) 72 | 73 | hs = np.where(h_sum > 0) 74 | ws = np.where(w_sum > 0) 75 | 76 | try: 77 | h0 = hs[0][0] 78 | h1 = hs[0][-1] 79 | w0 = ws[0][0] 80 | w1 = ws[0][-1] 81 | except: 82 | return 0, img.shape[0], 0, img.shape[1] 83 | 84 | return h0, h1, w0, w1 85 | 86 | 87 | def perform_window_crop(gt_seg): 88 | h, w = gt_seg.shape 89 | 90 | # gets the bounds of the segmentation and the center of the object 91 | h0, h1, w0, w1 = get_bounds(gt_seg) 92 | obj_h, obj_w = h1 - h0, w1 - w0 93 | center_h, center_w = h0 + int(obj_h / 2), w0 + int(obj_w / 2) 94 | 95 | # defines the window size around the object 96 | if obj_h <= 0.2*h and obj_w <= 0.2*w: 97 | crop_dims = (0.25*h, 0.25*w) 98 | elif obj_h <= 0.4*h and obj_w <= 0.4*w: 99 | crop_dims = (0.5*h, 0.5*w) 100 | elif obj_h <= 0.65*h and obj_w <= 0.65*w: 101 | crop_dims = (0.75*h, 0.75*w) 102 | else: 103 | crop_dims = (h, w) 104 | 105 | y1 = max(0, center_h - crop_dims[0] / 2) 106 | x1 = max(0, center_w - crop_dims[1] / 2) 107 | if y1 + crop_dims[0] > h: 108 | y1 -= (y1+crop_dims[0]-h) 109 | if x1 + crop_dims[1] > w: 110 | x1 -= (x1+crop_dims[1]-w) 111 | 112 | return np.clip(np.array([y1/h, x1/w, crop_dims[0]/h, crop_dims[1]/w]), 0, 1) 113 | 114 | 115 | def perform_window_crop2(gt_seg): 116 | h, w = gt_seg.shape 117 | 118 | # gets the bounds of the segmentation and the center of the object 119 | h0, h1, w0, w1 = get_bounds(gt_seg) 120 | obj_h, obj_w = h1 - h0, w1 - w0 121 | center_h, center_w = h0 + int(obj_h / 2), w0 + int(obj_w / 2) 122 | min_y, max_y, min_x, max_x = h0, h1, w0, w1 123 | 124 | leeway = 0.15 125 | leeway_h, leeway_w = leeway * h, leeway * w 126 | center_y1, center_x1 = h0 + int(obj_h / 2), w0 + int(obj_w / 2) 127 | h1, w1 = 2 * max(center_y1 - min_y, max_y - center_y1) + leeway_h, 2 * max(center_x1 - min_x, 128 | max_x - center_x1) + leeway_w 129 | h1, w1 = max(h1, h // 8), max(w1, w // 8) 130 | alpha1 = max(h1 / h, w1 / w) 131 | h1, w1 = alpha1 * h, alpha1 * w 132 | 133 | y1 = center_y1 - 0.5 * h1 134 | x1 = center_x1 - 0.5 * w1 135 | if y1 + h1 > h: 136 | y1 -= (y1 + h1 - h) 137 | if x1 + w1 > w: 138 | x1 -= (x1 + w1 - w) 139 | 140 | return np.clip(np.array([y1/h, x1/w, alpha1+0.01, 0]), 0, 1) 141 | 142 | 143 | class YoutubeValidDataGen(object): 144 | def __init__(self, sec_to_wait=5, n_threads=10, crop_size=(256, 448), n_frames=8): 145 | self.train_files = get_split_names('valid') 146 | 147 | self.sec_to_wait = sec_to_wait 148 | 149 | self.crop_size = crop_size 150 | 151 | self.n_frames = n_frames 152 | 153 | self.data_queue = [] 154 | 155 | self.thread_list = [] 156 | for i in range(n_threads): 157 | load_thread = Thread(target=self.__load_and_process_data) 158 | load_thread.start() 159 | self.thread_list.append(load_thread) 160 | 161 | print('Waiting %d (s) to load data' % sec_to_wait) 162 | sys.stdout.flush() 163 | time.sleep(self.sec_to_wait) 164 | 165 | def __load_and_process_data(self): 166 | while self.train_files: 167 | while len(self.data_queue) >= max_vids: 168 | time.sleep(1) 169 | 170 | try: 171 | vid_name, fdict = self.train_files.pop() 172 | except: 173 | continue # Thread issue 174 | 175 | video, segmentation = load_video(vid_name, tr_or_val='valid', n_frames=self.n_frames*2-1) 176 | 177 | video, segmentation = resize_video(video, segmentation, target_size=self.crop_size) 178 | 179 | # find objects in the first frame 180 | color, gt_seg = 0, 0 181 | for obj_id in sorted(fdict.keys()): 182 | color = int(obj_id) 183 | gt_seg = np.where(segmentation == color, 1, 0) 184 | if np.sum(gt_seg) > 0: 185 | break 186 | else: 187 | color = 0 188 | 189 | # no objects in the first frame 190 | if color == 0: 191 | print('%s has no foreground segmentation.' % vid_name) 192 | continue 193 | 194 | gt_crop = perform_window_crop2(gt_seg) 195 | 196 | gt_seg = np.expand_dims(gt_seg, axis=-1) 197 | 198 | seg_vid = [gt_seg] 199 | 200 | for i in range(self.n_frames*2-1 - 1): 201 | seg_vid.append(np.zeros_like(gt_seg)) 202 | seg = np.stack(seg_vid, axis=0) 203 | 204 | self.data_queue.append((video, seg, gt_crop)) 205 | 206 | print('Loading data thread finished') 207 | sys.stdout.flush() 208 | 209 | def get_batch(self, batch_size=5): 210 | while len(self.data_queue) < batch_size and self.train_files: 211 | print('Waiting on data') 212 | sys.stdout.flush() 213 | time.sleep(self.sec_to_wait) 214 | 215 | batch_size = min(batch_size, len(self.data_queue)) 216 | batch_x, batch_seg, batch_crop = [], [], [] 217 | for i in range(batch_size): 218 | vid, seg, crp = self.data_queue.pop(0) 219 | batch_x.append(vid) 220 | batch_seg.append(seg) 221 | batch_crop.append(crp) 222 | 223 | return batch_x, batch_seg, batch_crop 224 | 225 | def has_data(self): 226 | return self.data_queue != [] or self.train_files != [] 227 | 228 | 229 | # def main(): 230 | # a = YoutubeValidDataGen(n_threads=1, crop_size=(256*2, 448*2)) 231 | # 232 | # while a.has_data(): 233 | # v, s, cr1 = a.get_batch(1) 234 | # print(cr1) 235 | # cy, cx, alpha, _ = cr1[0] 236 | # cy, cx, h, w = int(cy*256*2), int(cx*448*2), int(alpha*256*2), int(alpha*448*2) 237 | # 238 | # mask = np.ones_like(s[0][0, :, :, 0])*2 239 | # mask[cy:cy+h, cx:cx+w] = 0 240 | # 241 | # plt.imshow(s[0][0, :, :, 0] + mask) 242 | # plt.show(plt) 243 | # 244 | # 245 | # 246 | # 247 | # main() 248 | 249 | 250 | 251 | -------------------------------------------------------------------------------- /network_parts/lstm_capsnet_cond2_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from caps_layers_cond import create_prim_conv3d_caps, create_conv3d_caps, layer_shape, create_conv3d_caps_cond 3 | import config 4 | 5 | 6 | # Basic network: 7 | # HR frame branch has LSTM right before prediction and bounding box of shape y1, x1, alpha (keeps 128x224 aspect ratio) 8 | # LR frame branch has ConvLSTM at 1/4 LR input resolution 9 | # Skip connections only from conditioned capsules and video capsules (no frame capsules given to decoder) 10 | 11 | # Batch size = 3 on TitanXp 12 | 13 | def conv_lstm_layer(x_t, c_tm1, h_tm1, n_feats, kernel_size=(3, 3), strides=(1, 1), padding='SAME', name='lstm_convs'): 14 | inp_cat = tf.concat([x_t, h_tm1], axis=-1) 15 | conv_outputs = tf.layers.conv2d(inp_cat, n_feats*4, kernel_size, strides, padding=padding, name=name) 16 | 17 | input_gate, new_input, forget_gate, output_gate = tf.split(conv_outputs, 4, axis=-1) 18 | 19 | forget_bias = 0.0 # changed this from 1.0 20 | 21 | c_t = tf.nn.sigmoid(forget_gate + forget_bias) * c_tm1 22 | c_t += tf.nn.sigmoid(input_gate) * tf.nn.tanh(new_input) # tf.nn.tanh(new_input) 23 | h_t = tf.nn.tanh(c_t) * tf.sigmoid(output_gate) 24 | 25 | return c_t, h_t 26 | 27 | 28 | def transposed_conv3d(inputs, n_units, kernel_size, strides, padding='VALID', activation=tf.nn.relu, name='deconv', use_bias=True): 29 | conv = tf.layers.conv3d_transpose(inputs, n_units, kernel_size=kernel_size, strides=strides, padding=padding, 30 | use_bias=False, activation=activation, name=name) 31 | if use_bias: 32 | bias = tf.get_variable(name=name + '_bias', shape=(1, 1, 1, 1, n_units)) 33 | return activation(conv + bias) 34 | else: 35 | return activation(conv) 36 | 37 | 38 | def create_skip_connection(in_caps_layer, n_units, kernel_size, strides=(1, 1, 1), padding='VALID', name='skip', activation=tf.nn.relu): 39 | in_caps_layer = in_caps_layer[0] 40 | batch_size = tf.shape(in_caps_layer)[0] 41 | _, d, h, w, ch, dim = in_caps_layer.get_shape() 42 | d, h, w, ch, dim = map(int, [d, h, w, ch, dim]) 43 | 44 | in_caps_res = tf.reshape(in_caps_layer, [batch_size, d, h, w, ch * dim]) 45 | 46 | return tf.layers.conv3d_transpose(in_caps_res, n_units, kernel_size=kernel_size, strides=strides, padding=padding, 47 | use_bias=False, activation=activation, name=name) 48 | 49 | 50 | def video_encoder(x_input): 51 | x = tf.layers.conv3d(x_input, 32, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 1, 1], 52 | activation=tf.nn.relu, name='conv1_2d') 53 | x = tf.layers.conv3d(x, 64, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1], 54 | activation=tf.nn.relu, name='conv1_1d') 55 | 56 | x = tf.layers.conv3d(x, 64, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 2, 2], 57 | activation=tf.nn.relu, name='conv2_2d') 58 | x = tf.layers.conv3d(x, 128, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1], 59 | activation=tf.nn.relu, name='conv2_1d') 60 | 61 | x = tf.layers.conv3d(x, 256, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 1, 1], 62 | activation=tf.nn.relu, name='conv3a_2d') 63 | x = tf.layers.conv3d(x, 256, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1], 64 | activation=tf.nn.relu, name='conv3a_1d') 65 | x = tf.layers.conv3d(x, 256, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 2, 2], 66 | activation=tf.nn.relu, name='conv3b_2d') 67 | x = tf.layers.conv3d(x, 256, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1], 68 | activation=tf.nn.relu, name='conv3b_1d') 69 | 70 | x = tf.layers.conv3d(x, 512, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 1, 1], 71 | activation=tf.nn.relu, name='conv4a_2d') 72 | x = tf.layers.conv3d(x, 512, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1], 73 | activation=tf.nn.relu, name='conv4a_1d') 74 | x = tf.layers.conv3d(x, 512, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 1, 1], 75 | activation=tf.nn.relu, name='conv4b_2d') 76 | x = tf.layers.conv3d(x, 512, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1], 77 | activation=tf.nn.relu, name='conv4b_1d') 78 | 79 | return x 80 | 81 | 82 | def lr_frame_encoder(frame_plus_seg, c_tm1_lr, h_tm1_lr): 83 | fr_conv1 = tf.layers.conv2d(frame_plus_seg, 32, kernel_size=[3, 3], padding='SAME', strides=[1, 1], 84 | activation=tf.nn.relu, name='lr_fr_conv1') 85 | fr_conv2 = tf.layers.conv2d(fr_conv1, 64, kernel_size=[3, 3], padding='SAME', strides=[2, 2], 86 | activation=tf.nn.relu, name='lr_fr_conv2') 87 | 88 | fr_conv3 = tf.layers.conv2d(fr_conv2, 64, kernel_size=[3, 3], padding='SAME', strides=[1, 1], 89 | activation=tf.nn.relu, name='lr_fr_conv3') 90 | fr_conv4 = tf.layers.conv2d(fr_conv3, 128, kernel_size=[3, 3], padding='SAME', strides=[2, 2], 91 | activation=tf.nn.relu, name='lr_fr_conv4') 92 | 93 | fr_conv5 = tf.layers.conv2d(fr_conv4, config.lr_lstm_feats // 2, kernel_size=[3, 3], padding='SAME', strides=[1, 1], 94 | activation=tf.nn.relu, name='lr_fr_conv5') 95 | c_t, h_t = conv_lstm_layer(fr_conv5, c_tm1_lr, h_tm1_lr, config.lr_lstm_feats // 2, name='lr_fr_lstm') 96 | 97 | return c_t, h_t 98 | 99 | 100 | def hr_frame_encoder(frame_plus_seg, c_tm1, h_tm1): 101 | fr_conv1 = tf.layers.conv2d(frame_plus_seg, 32, kernel_size=[3, 3], padding='SAME', strides=[1, 1], 102 | activation=tf.nn.relu, name='hr_fr_conv1') 103 | fr_conv2 = tf.layers.conv2d(fr_conv1, 64, kernel_size=[3, 3], padding='SAME', strides=[2, 2], 104 | activation=tf.nn.relu, name='hr_fr_conv2') # 256, 448 105 | 106 | fr_conv3 = tf.layers.conv2d(fr_conv2, 128, kernel_size=[3, 3], padding='SAME', strides=[2, 2], 107 | activation=tf.nn.relu, name='hr_fr_conv3') # 128, 224 108 | fr_conv4 = tf.layers.conv2d(fr_conv3, 128, kernel_size=[3, 3], padding='SAME', strides=[2, 2], 109 | activation=tf.nn.relu, name='hr_fr_conv4') # 64, 112 110 | 111 | fr_conv5 = tf.layers.conv2d(fr_conv4, 256, kernel_size=[3, 3], padding='SAME', strides=[2, 2], 112 | activation=tf.nn.relu, name='hr_fr_conv5') # 32, 56 113 | fr_conv6 = tf.layers.conv2d(fr_conv5, 256, kernel_size=[3, 3], padding='SAME', strides=[2, 2], 114 | activation=tf.nn.relu, name='hr_fr_conv6') # 16, 28 115 | 116 | fr_conv7 = tf.layers.conv2d(fr_conv6, 512, kernel_size=[3, 3], padding='SAME', strides=[2, 2], 117 | activation=tf.nn.relu, name='hr_fr_conv7') # 8, 14 118 | 119 | fr_conv8 = tf.layers.conv2d(fr_conv7, config.hr_lstm_feats // 2, kernel_size=[8, 14], padding='VALID', 120 | strides=[1, 1], activation=tf.nn.relu, name='hr_fr_conv8') # 1, 1 121 | 122 | c_t, h_t = conv_lstm_layer(fr_conv8, c_tm1, h_tm1, config.hr_lstm_feats // 2, kernel_size=(1, 1)) 123 | 124 | crop_pred = tf.layers.conv2d(h_t, 3, kernel_size=[1, 1], padding='VALID', strides=[1, 1], 125 | activation=tf.nn.sigmoid, name='crop_reg') 126 | 127 | y1, x1, alpha = tf.split(crop_pred[:, 0, 0, :], 3, axis=-1) 128 | 129 | alpha = 0.875 * alpha + 0.125 130 | zero = tf.zeros_like(alpha) 131 | 132 | return c_t, h_t, tf.concat((y1, x1, alpha, zero), axis=-1) 133 | 134 | 135 | def convert_crop(crop, use_gt): 136 | y1, x1, alpha, _ = tf.split(crop, 4, axis=-1) 137 | y1 = tf.cond(use_gt, lambda: y1, lambda: y1) 138 | x1 = tf.cond(use_gt, lambda: x1, lambda: x1) 139 | alpha = tf.cond(use_gt, lambda: alpha, lambda: alpha) 140 | y1, x1, alpha = y1 - 0.05, x1 - 0.05, alpha + 0.1 141 | #y1, x1, alpha = y1 - 0.1, x1 - 0.1, alpha + 0.2 142 | 143 | y2 = y1+alpha 144 | x2 = x1+alpha 145 | 146 | crop_new = tf.concat((y1, x1, y2, x2), axis=-1) 147 | crop_new = tf.clip_by_value(crop_new, 0, 1) 148 | 149 | return crop_new 150 | 151 | 152 | def uncrop(seg_out, crop_used): 153 | hr_h, hr_w = config.hr_frame_size 154 | 155 | y1, x1, y2, x2 = tf.split(crop_used, 4, axis=-1) 156 | 157 | up_padding = tf.cast(tf.floor(y1 * hr_h), tf.int32) 158 | left_padding = tf.cast(tf.floor(x1 * hr_w), tf.int32) 159 | 160 | img_res_h = tf.cast(tf.ceil((y2 - y1) * hr_h), tf.int32) 161 | img_res_w = tf.cast(tf.ceil((x2 - x1) * hr_w), tf.int32) 162 | 163 | down_padding = hr_h - img_res_h - up_padding 164 | right_padding = hr_w - img_res_w - left_padding 165 | 166 | fin_seg_init = tf.TensorArray(tf.float32, size=tf.shape(seg_out)[0]) 167 | 168 | def cond(fin_seg, counter): 169 | return tf.less(counter, tf.shape(seg_out)[0]) 170 | 171 | def res_and_pad(fin_seg, counter): 172 | segmentation = seg_out[counter] 173 | 174 | res_img = tf.image.resize_images(segmentation, (img_res_h[counter, 0], img_res_w[counter, 0])) 175 | 176 | padded_img = tf.pad(res_img, [[up_padding[counter, 0], down_padding[counter, 0]], 177 | [left_padding[counter, 0], right_padding[counter, 0]], 178 | [0, 0]], constant_values=-1000) 179 | 180 | fin_seg = fin_seg.write(counter, padded_img) 181 | 182 | return fin_seg, counter+1 183 | 184 | fin_seg, _ = tf.while_loop(cond, res_and_pad, [fin_seg_init, 0]) 185 | 186 | return fin_seg.stack() 187 | 188 | 189 | def create_decoder_network(pred_caps, sec_caps, prim_caps, print_layers=True): 190 | deconv1 = create_skip_connection(pred_caps, 128, kernel_size=[3, 3, 3], strides=[2, 2, 2], padding='SAME', 191 | name='deconv1') 192 | 193 | skip_connection1 = create_skip_connection(sec_caps, 128, kernel_size=[3, 3, 3], strides=[1, 1, 1], padding='SAME', 194 | name='skip_1') 195 | deconv1 = tf.concat([deconv1, skip_connection1], axis=-1) 196 | 197 | deconv2 = transposed_conv3d(deconv1, 128, kernel_size=[3, 3, 3], strides=[2, 2, 2], padding='SAME', 198 | activation=tf.nn.relu, name='deconv2') 199 | 200 | skip_connection2 = create_skip_connection(prim_caps, 128, kernel_size=[1, 3, 3], 201 | strides=[1, 1, 1], padding='SAME', name='skip_2') 202 | deconv2 = tf.concat([deconv2, skip_connection2], axis=-1) 203 | 204 | deconv3 = transposed_conv3d(deconv2, 256, kernel_size=[1, 3, 3], strides=[1, 2, 2], padding='SAME', 205 | activation=tf.nn.relu, name='deconv3') 206 | deconv4 = transposed_conv3d(deconv3, 256, kernel_size=[1, 3, 3], strides=[1, 2, 2], padding='SAME', 207 | activation=tf.nn.relu, name='deconv4') 208 | deconv5 = transposed_conv3d(deconv4, 128, kernel_size=[1, 3, 3], strides=[1, 2, 2], padding='SAME', 209 | activation=tf.nn.relu, name='deconv5') 210 | 211 | segment_layer = tf.layers.conv3d(deconv5, 1, kernel_size=[1, 1, 1], strides=[1, 1, 1], padding='SAME', 212 | activation=None, name='segment_layer') 213 | #segment_layer_sig = tf.nn.sigmoid(segment_layer) 214 | 215 | if print_layers: 216 | print('Deconv Layer 1:', deconv1.get_shape()) 217 | print('Deconv Layer 2:', deconv2.get_shape()) 218 | print('Deconv Layer 3:', deconv3.get_shape()) 219 | print('Deconv Layer 4:', deconv4.get_shape()) 220 | print('Deconv Layer 5:', deconv5.get_shape()) 221 | print('Segment Layer:', segment_layer.get_shape()) 222 | 223 | return segment_layer 224 | 225 | 226 | def create_network_one_pass(hr_frames, hr_first_frame_seg, c_tm1, h_tm1, c_tm1_lr, h_tm1_lr, use_gt, gt_crop, 227 | coord_addition, print_layers=True): 228 | # encodes hr frame, and gets the predicted crop 229 | hr_frame_plus_seg = tf.concat([hr_frames[:, 0], hr_first_frame_seg], axis=-1) 230 | c_t, h_t, pred_crop = hr_frame_encoder(hr_frame_plus_seg, c_tm1, h_tm1) 231 | 232 | center_y, center_x, alpha_gt, _ = tf.split(gt_crop, 4, axis=-1) 233 | _, _, alpha, _ = tf.split(pred_crop, 4, -1) 234 | 235 | y0 = tf.clip_by_value(center_y - 0.5*alpha, 0, 1) 236 | x0 = tf.clip_by_value(center_x - 0.5 * alpha, 0, 1) 237 | y1 = tf.clip_by_value(center_y - 0.5 * alpha_gt, 0, 1) 238 | x1 = tf.clip_by_value(center_x - 0.5 * alpha_gt, 0, 1) 239 | 240 | crop_to_use = tf.cond(use_gt, lambda:tf.concat([y1, x1, alpha_gt, alpha_gt], axis=-1), lambda: tf.concat([y0, x0, alpha, alpha], axis=-1)) 241 | # crop_to_use = tf.cond(use_gt, lambda: gt_crop, lambda: pred_crop) 242 | 243 | # crop_to_use = tf.cond(use_gt, lambda: gt_crop, lambda: pred_crop) 244 | crop_to_use = convert_crop(crop_to_use, use_gt) 245 | 246 | frame_h, frame_w = config.lr_frame_size 247 | 248 | # crops the low resolution frame+seg 249 | range_crops = tf.range(tf.shape(crop_to_use)[0]) 250 | lr_frame_plus_seg = tf.image.crop_and_resize(hr_frame_plus_seg, crop_to_use, range_crops, (frame_h, frame_w)) 251 | 252 | c_t_lr, h_t_lr = lr_frame_encoder(lr_frame_plus_seg, c_tm1_lr, h_tm1_lr) 253 | 254 | # crops the video 255 | tiled_crop_to_use = tf.reshape(tf.tile(tf.expand_dims(crop_to_use, 1), [1, config.n_frames, 1]), (-1, 4)) 256 | video_res = tf.reshape(hr_frames, (-1, config.hr_frame_size[0], config.hr_frame_size[1], 3)) 257 | cropped_video = tf.image.crop_and_resize(video_res, tiled_crop_to_use, tf.range(tf.shape(tiled_crop_to_use)[0]), 258 | (frame_h, frame_w)) 259 | cropped_video = tf.reshape(cropped_video, (-1, config.n_frames, frame_h, frame_w, 3)) 260 | 261 | # creates video capsules 262 | lr_video_encoding = video_encoder(cropped_video) 263 | 264 | vid_caps = create_prim_conv3d_caps(lr_video_encoding, 12, kernel_size=[1, 3, 3], strides=[1, 2, 2], padding='SAME', 265 | name='vid_caps') 266 | 267 | vid_caps2 = vid_caps 268 | # vid_caps2 = (vid_caps[0] + coord_addition, vid_caps[1]) 269 | 270 | # creates frame capsules, tiles them, and performs coordinate addition 271 | frame_caps = create_prim_conv3d_caps(tf.expand_dims(h_t_lr, axis=1), 8, kernel_size=[1, 3, 3], 272 | strides=[1, 2, 2], padding='SAME', name='frame_caps') 273 | frame_caps = (tf.tile(frame_caps[0], [1, config.n_frames, 1, 1, 1, 1]), 274 | tf.tile(frame_caps[1], [1, config.n_frames, 1, 1, 1, 1])) 275 | 276 | frame_caps = (frame_caps[0] + coord_addition, frame_caps[1]) 277 | 278 | # merges video and frame capsules 279 | prim_caps = (tf.concat([vid_caps2[0], frame_caps[0]], axis=-2), tf.concat([vid_caps2[1], frame_caps[1]], axis=-2)) 280 | 281 | # performs capsule routing 282 | sec_caps, _ = create_conv3d_caps_cond(prim_caps, 16, kernel_size=[3, 3, 3], strides=[2, 2, 2], padding='SAME', 283 | name='sec_caps', route_mean=True, n_cond_caps=8) 284 | pred_caps = create_conv3d_caps(sec_caps, 16, kernel_size=[3, 3, 3], strides=[2, 2, 2], padding='SAME', 285 | name='third_caps', route_mean=True) 286 | fin_caps = tf.reduce_mean(pred_caps[1], [2, 3, 4, 5]) 287 | 288 | if print_layers: 289 | print('Primary Caps:', layer_shape(prim_caps)) 290 | print('Secondary Caps:', layer_shape(sec_caps)) 291 | print('Prediction Caps:', layer_shape(pred_caps)) 292 | 293 | # obtains the segmentations 294 | seg = create_decoder_network(pred_caps, sec_caps, vid_caps, print_layers=print_layers) 295 | 296 | hr_seg_out = uncrop(tf.reshape(seg, (-1, frame_h, frame_w, 1)), tiled_crop_to_use) 297 | hr_seg = tf.reshape(hr_seg_out, (-1, config.n_frames, config.hr_frame_size[0], config.hr_frame_size[1], 1)) 298 | hr_seg_sig = tf.nn.sigmoid(hr_seg) 299 | 300 | return hr_seg, hr_seg_sig, c_t, h_t, c_t_lr, h_t_lr, pred_crop, fin_caps 301 | 302 | 303 | def create_network(x_input, y_segmentation, f_tm1, f_tm1_lr, use_gt_crop, gt_crop1): 304 | coords_to_add = tf.reshape(tf.range(config.n_frames, dtype=tf.float32) / (config.n_frames - 1), 305 | (1, config.n_frames, 1, 1, 1, 1)) 306 | zeros_to_add = tf.zeros((1, config.n_frames, 1, 1, 1, 15), dtype=tf.float32) 307 | coords_to_add = tf.concat((zeros_to_add, coords_to_add), axis=-1) 308 | 309 | c_tm1, h_tm1 = f_tm1[:, :, :, :config.hr_lstm_feats // 2], f_tm1[:, :, :, config.hr_lstm_feats // 2:] 310 | c_tm1_lr, h_tm1_lr = f_tm1_lr[:, :, :, :config.lr_lstm_feats // 2], f_tm1_lr[:, :, :, config.lr_lstm_feats // 2:] 311 | 312 | network_outs1 = create_network_one_pass(x_input[:, :config.n_frames], y_segmentation, c_tm1, h_tm1, c_tm1_lr, 313 | h_tm1_lr, use_gt_crop, gt_crop1, coords_to_add, print_layers=True) 314 | 315 | seg, seg_sig, c_t, h_t, c_t_lr, h_t_lr, pred_crop1, fin_caps = network_outs1 316 | 317 | return seg, seg_sig, fin_caps, tf.concat([c_t, h_t], axis=-1), tf.concat([c_t_lr, h_t_lr], axis=-1), pred_crop1 318 | -------------------------------------------------------------------------------- /network_parts/lstm_capsnet_cond2_train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from caps_layers_cond import create_prim_conv3d_caps, create_conv3d_caps, layer_shape, create_conv3d_caps_cond 3 | import config 4 | 5 | 6 | # Basic network: 7 | # HR frame branch has LSTM right before prediction and bounding box of shape y1, x1, alpha (keeps 128x224 aspect ratio) 8 | # LR frame branch has ConvLSTM at 1/4 LR input resolution 9 | # Skip connections only from conditioned capsules and video capsules (no frame capsules given to decoder) 10 | 11 | # Batch size = 4 on TitanXp 12 | 13 | def conv_lstm_layer(x_t, c_tm1, h_tm1, n_feats, kernel_size=(3, 3), strides=(1, 1), padding='SAME', name='lstm_convs'): 14 | inp_cat = tf.concat([x_t, h_tm1], axis=-1) 15 | conv_outputs = tf.layers.conv2d(inp_cat, n_feats*4, kernel_size, strides, padding=padding, name=name) 16 | 17 | input_gate, new_input, forget_gate, output_gate = tf.split(conv_outputs, 4, axis=-1) 18 | 19 | forget_bias = 0.0 # changed this from 1.0 20 | 21 | c_t = tf.nn.sigmoid(forget_gate + forget_bias) * c_tm1 22 | c_t += tf.nn.sigmoid(input_gate) * tf.nn.tanh(new_input) # tf.nn.tanh(new_input) 23 | h_t = tf.nn.tanh(c_t) * tf.sigmoid(output_gate) 24 | 25 | return c_t, h_t 26 | 27 | 28 | def transposed_conv3d(inputs, n_units, kernel_size, strides, padding='VALID', activation=tf.nn.relu, name='deconv', use_bias=True): 29 | conv = tf.layers.conv3d_transpose(inputs, n_units, kernel_size=kernel_size, strides=strides, padding=padding, 30 | use_bias=False, activation=activation, name=name) 31 | if use_bias: 32 | bias = tf.get_variable(name=name + '_bias', shape=(1, 1, 1, 1, n_units)) 33 | return activation(conv + bias) 34 | else: 35 | return activation(conv) 36 | 37 | 38 | def create_skip_connection(in_caps_layer, n_units, kernel_size, strides=(1, 1, 1), padding='VALID', name='skip', activation=tf.nn.relu): 39 | in_caps_layer = in_caps_layer[0] 40 | batch_size = tf.shape(in_caps_layer)[0] 41 | _, d, h, w, ch, dim = in_caps_layer.get_shape() 42 | d, h, w, ch, dim = map(int, [d, h, w, ch, dim]) 43 | 44 | in_caps_res = tf.reshape(in_caps_layer, [batch_size, d, h, w, ch * dim]) 45 | 46 | return tf.layers.conv3d_transpose(in_caps_res, n_units, kernel_size=kernel_size, strides=strides, padding=padding, 47 | use_bias=False, activation=activation, name=name) 48 | 49 | 50 | def video_encoder(x_input): 51 | x = tf.layers.conv3d(x_input, 32, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 1, 1], 52 | activation=tf.nn.relu, name='conv1_2d') 53 | x = tf.layers.conv3d(x, 64, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1], 54 | activation=tf.nn.relu, name='conv1_1d') 55 | 56 | x = tf.layers.conv3d(x, 64, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 2, 2], 57 | activation=tf.nn.relu, name='conv2_2d') 58 | x = tf.layers.conv3d(x, 128, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1], 59 | activation=tf.nn.relu, name='conv2_1d') 60 | 61 | x = tf.layers.conv3d(x, 256, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 1, 1], 62 | activation=tf.nn.relu, name='conv3a_2d') 63 | x = tf.layers.conv3d(x, 256, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1], 64 | activation=tf.nn.relu, name='conv3a_1d') 65 | x = tf.layers.conv3d(x, 256, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 2, 2], 66 | activation=tf.nn.relu, name='conv3b_2d') 67 | x = tf.layers.conv3d(x, 256, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1], 68 | activation=tf.nn.relu, name='conv3b_1d') 69 | 70 | x = tf.layers.conv3d(x, 512, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 1, 1], 71 | activation=tf.nn.relu, name='conv4a_2d') 72 | x = tf.layers.conv3d(x, 512, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1], 73 | activation=tf.nn.relu, name='conv4a_1d') 74 | x = tf.layers.conv3d(x, 512, kernel_size=[1, 3, 3], padding='SAME', strides=[1, 1, 1], 75 | activation=tf.nn.relu, name='conv4b_2d') 76 | x = tf.layers.conv3d(x, 512, kernel_size=[3, 1, 1], padding='SAME', strides=[1, 1, 1], 77 | activation=tf.nn.relu, name='conv4b_1d') 78 | 79 | return x 80 | 81 | 82 | def lr_frame_encoder(frame_plus_seg, c_tm1_lr, h_tm1_lr): 83 | fr_conv1 = tf.layers.conv2d(frame_plus_seg, 32, kernel_size=[3, 3], padding='SAME', strides=[1, 1], 84 | activation=tf.nn.relu, name='lr_fr_conv1') 85 | fr_conv2 = tf.layers.conv2d(fr_conv1, 64, kernel_size=[3, 3], padding='SAME', strides=[2, 2], 86 | activation=tf.nn.relu, name='lr_fr_conv2') 87 | 88 | fr_conv3 = tf.layers.conv2d(fr_conv2, 64, kernel_size=[3, 3], padding='SAME', strides=[1, 1], 89 | activation=tf.nn.relu, name='lr_fr_conv3') 90 | fr_conv4 = tf.layers.conv2d(fr_conv3, 128, kernel_size=[3, 3], padding='SAME', strides=[2, 2], 91 | activation=tf.nn.relu, name='lr_fr_conv4') 92 | 93 | fr_conv5 = tf.layers.conv2d(fr_conv4, config.lr_lstm_feats // 2, kernel_size=[3, 3], padding='SAME', strides=[1, 1], 94 | activation=tf.nn.relu, name='lr_fr_conv5') 95 | c_t, h_t = conv_lstm_layer(fr_conv5, c_tm1_lr, h_tm1_lr, config.lr_lstm_feats // 2, name='lr_fr_lstm') 96 | 97 | return c_t, h_t 98 | 99 | 100 | def hr_frame_encoder(frame_plus_seg, c_tm1, h_tm1): 101 | fr_conv1 = tf.layers.conv2d(frame_plus_seg, 32, kernel_size=[3, 3], padding='SAME', strides=[1, 1], 102 | activation=tf.nn.relu, name='hr_fr_conv1') 103 | fr_conv2 = tf.layers.conv2d(fr_conv1, 64, kernel_size=[3, 3], padding='SAME', strides=[2, 2], 104 | activation=tf.nn.relu, name='hr_fr_conv2') # 256, 448 105 | 106 | fr_conv3 = tf.layers.conv2d(fr_conv2, 128, kernel_size=[3, 3], padding='SAME', strides=[2, 2], 107 | activation=tf.nn.relu, name='hr_fr_conv3') # 128, 224 108 | fr_conv4 = tf.layers.conv2d(fr_conv3, 128, kernel_size=[3, 3], padding='SAME', strides=[2, 2], 109 | activation=tf.nn.relu, name='hr_fr_conv4') # 64, 112 110 | 111 | fr_conv5 = tf.layers.conv2d(fr_conv4, 256, kernel_size=[3, 3], padding='SAME', strides=[2, 2], 112 | activation=tf.nn.relu, name='hr_fr_conv5') # 32, 56 113 | fr_conv6 = tf.layers.conv2d(fr_conv5, 256, kernel_size=[3, 3], padding='SAME', strides=[2, 2], 114 | activation=tf.nn.relu, name='hr_fr_conv6') # 16, 28 115 | 116 | fr_conv7 = tf.layers.conv2d(fr_conv6, 512, kernel_size=[3, 3], padding='SAME', strides=[2, 2], 117 | activation=tf.nn.relu, name='hr_fr_conv7') # 8, 14 118 | 119 | fr_conv8 = tf.layers.conv2d(fr_conv7, config.hr_lstm_feats // 2, kernel_size=[8, 14], padding='VALID', 120 | strides=[1, 1], activation=tf.nn.relu, name='hr_fr_conv8') # 1, 1 121 | 122 | c_t, h_t = conv_lstm_layer(fr_conv8, c_tm1, h_tm1, config.hr_lstm_feats // 2, kernel_size=(1, 1)) 123 | 124 | crop_pred = tf.layers.conv2d(h_t, 3, kernel_size=[1, 1], padding='VALID', strides=[1, 1], 125 | activation=tf.nn.sigmoid, name='crop_reg') 126 | 127 | y1, x1, alpha = tf.split(crop_pred[:, 0, 0, :], 3, axis=-1) 128 | 129 | alpha = 0.875 * alpha + 0.125 130 | zero = tf.zeros_like(alpha) 131 | 132 | return c_t, h_t, tf.concat((y1, x1, alpha, zero), axis=-1) 133 | 134 | 135 | def convert_crop(crop): 136 | y1, x1, alpha, _ = tf.split(crop, 4, axis=-1) 137 | #y1, x1, alpha = y1 - 0.05, x1 - 0.05, alpha + 0.1 138 | 139 | y2 = y1+alpha 140 | x2 = x1+alpha 141 | 142 | crop_new = tf.concat((y1, x1, y2, x2), axis=-1) 143 | crop_new = tf.clip_by_value(crop_new, 0, 1) 144 | 145 | return crop_new 146 | 147 | 148 | def uncrop(seg_out, crop_used): 149 | hr_h, hr_w = config.hr_frame_size 150 | 151 | y1, x1, y2, x2 = tf.split(crop_used, 4, axis=-1) 152 | 153 | up_padding = tf.cast(tf.floor(y1 * hr_h), tf.int32) 154 | left_padding = tf.cast(tf.floor(x1 * hr_w), tf.int32) 155 | 156 | img_res_h = tf.cast(tf.ceil((y2 - y1) * hr_h), tf.int32) 157 | img_res_w = tf.cast(tf.ceil((x2 - x1) * hr_w), tf.int32) 158 | 159 | down_padding = hr_h - img_res_h - up_padding 160 | right_padding = hr_w - img_res_w - left_padding 161 | 162 | fin_seg_init = tf.TensorArray(tf.float32, size=tf.shape(seg_out)[0]) 163 | 164 | def cond(fin_seg, counter): 165 | return tf.less(counter, tf.shape(seg_out)[0]) 166 | 167 | def res_and_pad(fin_seg, counter): 168 | segmentation = seg_out[counter] 169 | 170 | res_img = tf.image.resize_images(segmentation, (img_res_h[counter, 0], img_res_w[counter, 0])) 171 | 172 | padded_img = tf.pad(res_img, [[up_padding[counter, 0], down_padding[counter, 0]], 173 | [left_padding[counter, 0], right_padding[counter, 0]], 174 | [0, 0]], constant_values=-1000) 175 | 176 | fin_seg = fin_seg.write(counter, padded_img) 177 | 178 | return fin_seg, counter+1 179 | 180 | fin_seg, _ = tf.while_loop(cond, res_and_pad, [fin_seg_init, 0]) 181 | 182 | return fin_seg.stack() 183 | 184 | 185 | def create_decoder_network(pred_caps, sec_caps, prim_caps, print_layers=True): 186 | deconv1 = create_skip_connection(pred_caps, 128, kernel_size=[3, 3, 3], strides=[2, 2, 2], padding='SAME', 187 | name='deconv1') 188 | 189 | skip_connection1 = create_skip_connection(sec_caps, 128, kernel_size=[3, 3, 3], strides=[1, 1, 1], padding='SAME', 190 | name='skip_1') 191 | deconv1 = tf.concat([deconv1, skip_connection1], axis=-1) 192 | 193 | deconv2 = transposed_conv3d(deconv1, 128, kernel_size=[3, 3, 3], strides=[2, 2, 2], padding='SAME', 194 | activation=tf.nn.relu, name='deconv2') 195 | 196 | skip_connection2 = create_skip_connection(prim_caps, 128, kernel_size=[1, 3, 3], 197 | strides=[1, 1, 1], padding='SAME', name='skip_2') 198 | deconv2 = tf.concat([deconv2, skip_connection2], axis=-1) 199 | 200 | deconv3 = transposed_conv3d(deconv2, 256, kernel_size=[1, 3, 3], strides=[1, 2, 2], padding='SAME', 201 | activation=tf.nn.relu, name='deconv3') 202 | deconv4 = transposed_conv3d(deconv3, 256, kernel_size=[1, 3, 3], strides=[1, 2, 2], padding='SAME', 203 | activation=tf.nn.relu, name='deconv4') 204 | deconv5 = transposed_conv3d(deconv4, 128, kernel_size=[1, 3, 3], strides=[1, 2, 2], padding='SAME', 205 | activation=tf.nn.relu, name='deconv5') 206 | 207 | segment_layer = tf.layers.conv3d(deconv5, 1, kernel_size=[1, 1, 1], strides=[1, 1, 1], padding='SAME', 208 | activation=None, name='segment_layer') 209 | #segment_layer_sig = tf.nn.sigmoid(segment_layer) 210 | 211 | if print_layers: 212 | print('Deconv Layer 1:', deconv1.get_shape()) 213 | print('Deconv Layer 2:', deconv2.get_shape()) 214 | print('Deconv Layer 3:', deconv3.get_shape()) 215 | print('Deconv Layer 4:', deconv4.get_shape()) 216 | print('Deconv Layer 5:', deconv5.get_shape()) 217 | print('Segment Layer:', segment_layer.get_shape()) 218 | 219 | return segment_layer 220 | 221 | 222 | def create_network_one_pass(hr_frames, hr_first_frame_seg, c_tm1, h_tm1, c_tm1_lr, h_tm1_lr, use_gt, gt_crop, 223 | coord_addition, print_layers=True): 224 | # encodes hr frame, and gets the predicted crop 225 | hr_frame_plus_seg = tf.concat([hr_frames[:, 0], hr_first_frame_seg], axis=-1) 226 | c_t, h_t, pred_crop = hr_frame_encoder(hr_frame_plus_seg, c_tm1, h_tm1) 227 | 228 | crop_to_use = tf.cond(use_gt, lambda: gt_crop, lambda: pred_crop) 229 | crop_to_use = convert_crop(crop_to_use) 230 | 231 | frame_h, frame_w = config.lr_frame_size 232 | 233 | # crops the low resolution frame+seg 234 | range_crops = tf.range(tf.shape(crop_to_use)[0]) 235 | lr_frame_plus_seg = tf.image.crop_and_resize(hr_frame_plus_seg, crop_to_use, range_crops, (frame_h, frame_w)) 236 | 237 | c_t_lr, h_t_lr = lr_frame_encoder(lr_frame_plus_seg, c_tm1_lr, h_tm1_lr) 238 | 239 | # crops the video 240 | tiled_crop_to_use = tf.reshape(tf.tile(tf.expand_dims(crop_to_use, 1), [1, config.n_frames, 1]), (-1, 4)) 241 | video_res = tf.reshape(hr_frames, (-1, config.hr_frame_size[0], config.hr_frame_size[1], 3)) 242 | cropped_video = tf.image.crop_and_resize(video_res, tiled_crop_to_use, tf.range(tf.shape(tiled_crop_to_use)[0]), 243 | (frame_h, frame_w)) 244 | cropped_video = tf.reshape(cropped_video, (-1, config.n_frames, frame_h, frame_w, 3)) 245 | 246 | # creates video capsules 247 | lr_video_encoding = video_encoder(cropped_video) 248 | 249 | vid_caps = create_prim_conv3d_caps(lr_video_encoding, 12, kernel_size=[1, 3, 3], strides=[1, 2, 2], padding='SAME', 250 | name='vid_caps') 251 | 252 | vid_caps2 = vid_caps 253 | # vid_caps2 = (vid_caps[0] + coord_addition, vid_caps[1]) 254 | 255 | # creates frame capsules, tiles them, and performs coordinate addition 256 | frame_caps = create_prim_conv3d_caps(tf.expand_dims(h_t_lr, axis=1), 8, kernel_size=[1, 3, 3], 257 | strides=[1, 2, 2], padding='SAME', name='frame_caps') 258 | frame_caps = (tf.tile(frame_caps[0], [1, config.n_frames, 1, 1, 1, 1]), 259 | tf.tile(frame_caps[1], [1, config.n_frames, 1, 1, 1, 1])) 260 | 261 | frame_caps = (frame_caps[0] + coord_addition, frame_caps[1]) 262 | 263 | # merges video and frame capsules 264 | prim_caps = (tf.concat([vid_caps2[0], frame_caps[0]], axis=-2), tf.concat([vid_caps2[1], frame_caps[1]], axis=-2)) 265 | 266 | # performs capsule routing 267 | sec_caps, _ = create_conv3d_caps_cond(prim_caps, 16, kernel_size=[3, 3, 3], strides=[2, 2, 2], padding='SAME', 268 | name='sec_caps', route_mean=True, n_cond_caps=8) 269 | pred_caps = create_conv3d_caps(sec_caps, 16, kernel_size=[3, 3, 3], strides=[2, 2, 2], padding='SAME', 270 | name='third_caps', route_mean=True) 271 | fin_caps = tf.reduce_mean(pred_caps[1], [2, 3, 4, 5]) 272 | 273 | if print_layers: 274 | print('Primary Caps:', layer_shape(prim_caps)) 275 | print('Secondary Caps:', layer_shape(sec_caps)) 276 | print('Prediction Caps:', layer_shape(pred_caps)) 277 | 278 | # obtains the segmentations 279 | seg = create_decoder_network(pred_caps, sec_caps, vid_caps, print_layers=print_layers) 280 | 281 | hr_seg_out = uncrop(tf.reshape(seg, (-1, frame_h, frame_w, 1)), tiled_crop_to_use) 282 | hr_seg = tf.reshape(hr_seg_out, (-1, config.n_frames, config.hr_frame_size[0], config.hr_frame_size[1], 1)) 283 | hr_seg_sig = tf.nn.sigmoid(hr_seg) 284 | 285 | return hr_seg, hr_seg_sig, c_t, h_t, c_t_lr, h_t_lr, pred_crop, fin_caps 286 | 287 | 288 | def create_network(x_input, y_segmentation, f_tm1, f_tm1_lr, use_gt_seg, use_gt_crop, gt_crop1, gt_crop2): 289 | coords_to_add = tf.reshape(tf.range(config.n_frames, dtype=tf.float32) / (config.n_frames - 1), 290 | (1, config.n_frames, 1, 1, 1, 1)) 291 | zeros_to_add = tf.zeros((1, config.n_frames, 1, 1, 1, 15), dtype=tf.float32) 292 | coords_to_add = tf.concat((zeros_to_add, coords_to_add), axis=-1) 293 | 294 | c_tm1, h_tm1 = f_tm1[:, :, :, :config.hr_lstm_feats // 2], f_tm1[:, :, :, config.hr_lstm_feats // 2:] 295 | c_tm1_lr, h_tm1_lr = f_tm1_lr[:, :, :, :config.lr_lstm_feats // 2], f_tm1_lr[:, :, :, config.lr_lstm_feats // 2:] 296 | 297 | network_outs1 = create_network_one_pass(x_input[:, :config.n_frames], y_segmentation[:, 0], c_tm1, h_tm1, c_tm1_lr, 298 | h_tm1_lr, use_gt_crop, gt_crop1, coords_to_add, print_layers=True) 299 | 300 | hr_seg1, hr_seg_sig1, c_t, h_t, c_t_lr, h_t_lr, pred_crop1, fin_caps = network_outs1 301 | 302 | # reuse variables 303 | tf.get_variable_scope().reuse_variables() 304 | 305 | first_seg2 = tf.cond(use_gt_seg, lambda: y_segmentation[:, config.n_frames - 1], lambda: hr_seg_sig1[:, -1]) 306 | 307 | network_outs2 = create_network_one_pass(x_input[:, config.n_frames-1:], first_seg2, c_t, h_t, c_t_lr, h_t_lr, 308 | use_gt_crop, gt_crop2, coords_to_add, print_layers=False) 309 | 310 | hr_seg2, hr_seg_sig2, c_t2, h_t2, c_t2_lr, h_t2_lr, pred_crop2, _ = network_outs2 311 | 312 | fin_seg = tf.concat([hr_seg1, hr_seg2[:, 1:]], axis=1) 313 | fin_seg_sig = tf.concat([hr_seg_sig1, hr_seg_sig2[:, 1:]], axis=1) 314 | 315 | return fin_seg, fin_seg_sig, fin_caps, tf.concat([c_t2, h_t2], axis=-1), pred_crop1, pred_crop2 316 | -------------------------------------------------------------------------------- /network_saves/readme.txt: -------------------------------------------------------------------------------- 1 | This will hold network weights, which are periodically saved. 2 | -------------------------------------------------------------------------------- /network_saves_best/readme.txt: -------------------------------------------------------------------------------- 1 | This will hold network weights which are saved when the validation loss in minimized. 2 | --------------------------------------------------------------------------------